3 Resource limiting policies.
5 @seealso: See also L{twisted.protocols.htb} for rate limiting.
12 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
13 from twisted.internet import reactor, error
14 from twisted.python import log
15 from zope.interface import providedBy, directlyProvides
18 class ProtocolWrapper(Protocol):
19 """Wraps protocol instances and acts as their transport as well."""
23 def __init__(self, factory, wrappedProtocol):
24 self.wrappedProtocol = wrappedProtocol
25 self.factory = factory
27 def makeConnection(self, transport):
28 directlyProvides(self, *providedBy(self) + providedBy(transport))
29 Protocol.makeConnection(self, transport)
33 def write(self, data):
34 self.transport.write(data)
36 def writeSequence(self, data):
37 self.transport.writeSequence(data)
39 def loseConnection(self):
40 self.disconnecting = 1
41 self.transport.loseConnection()
44 return self.transport.getPeer()
47 return self.transport.getHost()
49 def registerProducer(self, producer, streaming):
50 self.transport.registerProducer(producer, streaming)
52 def unregisterProducer(self):
53 self.transport.unregisterProducer()
55 def stopConsuming(self):
56 self.transport.stopConsuming()
58 def __getattr__(self, name):
59 return getattr(self.transport, name)
63 def connectionMade(self):
64 self.factory.registerProtocol(self)
65 self.wrappedProtocol.makeConnection(self)
67 def dataReceived(self, data):
68 self.wrappedProtocol.dataReceived(data)
70 def connectionLost(self, reason):
71 self.factory.unregisterProtocol(self)
72 self.wrappedProtocol.connectionLost(reason)
75 class WrappingFactory(ClientFactory):
76 """Wraps a factory and its protocols, and keeps track of them."""
78 protocol = ProtocolWrapper
80 def __init__(self, wrappedFactory):
81 self.wrappedFactory = wrappedFactory
85 self.wrappedFactory.doStart()
86 ClientFactory.doStart(self)
89 self.wrappedFactory.doStop()
90 ClientFactory.doStop(self)
92 def startedConnecting(self, connector):
93 self.wrappedFactory.startedConnecting(connector)
95 def clientConnectionFailed(self, connector, reason):
96 self.wrappedFactory.clientConnectionFailed(connector, reason)
98 def clientConnectionLost(self, connector, reason):
99 self.wrappedFactory.clientConnectionLost(connector, reason)
101 def buildProtocol(self, addr):
102 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
104 def registerProtocol(self, p):
105 """Called by protocol to register itself."""
106 self.protocols[p] = 1
108 def unregisterProtocol(self, p):
109 """Called by protocols when they go away."""
110 del self.protocols[p]
113 class ThrottlingProtocol(ProtocolWrapper):
114 """Protocol for ThrottlingFactory."""
116 # wrap API for tracking bandwidth
118 def __init__(self, factory, wrappedProtocol):
119 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
120 self._tempDataBuffer = []
121 self._tempDataLength = 0
122 self.throttled = False
124 def write(self, data):
125 # Check if we can write
126 if not self.throttled:
127 paused = self.factory.registerWritten(len(data))
129 ProtocolWrapper.write(self, data)
131 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
132 # Interrupt the flow so that others can can have a chance
133 # We can only do this if it's not already paused otherwise we
134 # risk unpausing something that the Server paused
135 self.producer.pauseProducing()
136 reactor.callLater(0, self.producer.resumeProducing)
138 if self.throttled or paused:
139 # Can't write, buffer the data
140 self._tempDataBuffer.append(data)
141 self._tempDataLength += len(data)
142 self._throttleWrites()
144 def writeSequence(self, seq):
146 if not self.throttled:
147 # Write each sequence separately
148 while i < len(seq) and not self.factory.registerWritten(len(seq[i])):
149 ProtocolWrapper.write(self, seq[i])
152 # If there's some left, we must have been paused
154 self._tempDataBuffer.extend(seq[i:])
155 self._tempDataLength += reduce(operator.add, map(len, seq[i:]))
156 self._throttleWrites()
158 def dataReceived(self, data):
159 self.factory.registerRead(len(data))
160 ProtocolWrapper.dataReceived(self, data)
162 def registerProducer(self, producer, streaming):
163 assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
164 self.producer = producer
165 ProtocolWrapper.registerProducer(self, producer, streaming)
167 def unregisterProducer(self):
169 ProtocolWrapper.unregisterProducer(self)
172 def throttleReads(self):
173 self.transport.pauseProducing()
175 def unthrottleReads(self):
176 self.transport.resumeProducing()
178 def _throttleWrites(self):
179 # If we haven't yet, queue for unthrottling
180 if not self.throttled:
181 self.throttled = True
182 self.factory.throttledWrites(self)
184 if hasattr(self, "producer") and self.producer:
185 self.producer.pauseProducing()
187 def unthrottleWrites(self):
189 if self._tempDataBuffer:
190 assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
191 self._tempDataLength -= len(self._tempDataBuffer[0])
192 ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
193 assert self._tempDataLength >= 0
195 # If we wrote it all, start producing more
196 if not self._tempDataBuffer:
197 assert self._tempDataLength == 0
198 self.throttled = False
199 if hasattr(self, "producer") and self.producer:
200 # This might unpause something the Server has also paused, but
201 # it will get paused again on first write anyway
202 reactor.callLater(0, self.producer.resumeProducing)
204 return self._tempDataLength
207 class ThrottlingFactory(WrappingFactory):
209 Throttles bandwidth and number of connections.
211 Write bandwidth will only be throttled if there is a producer
215 protocol = ThrottlingProtocol
218 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
219 readLimit=None, writeLimit=None):
220 WrappingFactory.__init__(self, wrappedFactory)
221 self.connectionCount = 0
222 self.maxConnectionCount = maxConnectionCount
223 self.readLimit = readLimit # max bytes we should read per second
224 self.writeLimit = writeLimit # max bytes we should write per second
225 self.readThisSecond = 0
226 self.writeAvailable = writeLimit
227 self._writeQueue = []
228 self.unthrottleReadsID = None
229 self.checkReadBandwidthID = None
230 self.unthrottleWritesID = None
231 self.checkWriteBandwidthID = None
234 def callLater(self, period, func):
236 Wrapper around L{reactor.callLater} for test purpose.
238 return reactor.callLater(period, func)
241 def registerWritten(self, length):
243 Called by protocol to tell us more bytes were written.
244 Returns True if the bytes could not be written and the protocol should pause itself.
246 # Check if there are bytes available to write
247 if self.writeLimit is None:
249 elif self.writeAvailable > 0:
250 self.writeAvailable -= length
256 def throttledWrites(self, p):
258 Called by the protocol to queue it for later writing.
260 assert p not in self._writeQueue
261 self._writeQueue.append(p)
264 def registerRead(self, length):
266 Called by protocol to tell us more bytes were read.
268 self.readThisSecond += length
271 def checkReadBandwidth(self):
273 Checks if we've passed bandwidth limits.
275 if self.readThisSecond > self.readLimit:
277 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
278 self.unthrottleReadsID = self.callLater(throttleTime,
279 self.unthrottleReads)
280 self.readThisSecond = 0
281 self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
284 def checkWriteBandwidth(self):
286 Add some new available bandwidth, and check for protocols to unthrottle.
288 # Increase the available write bytes, but not higher than the limit
289 self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
291 # Write from the queue until it's empty or we're throttled again
292 while self.writeAvailable > 0 and self._writeQueue:
293 # Get the first queued protocol
294 p = self._writeQueue.pop(0)
295 _tempWriteAvailable = self.writeAvailable
298 # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
299 while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
300 # Unthrottle a single write (from the protocol's buffer)
301 bytesLeft = p.unthrottleWrites()
303 # If the protocol is not done, requeue it
305 self._writeQueue.append(p)
307 self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
310 def throttleReads(self):
312 Throttle reads on all protocols.
314 log.msg("Throttling reads on %s" % self)
315 for p in self.protocols.keys():
319 def unthrottleReads(self):
321 Stop throttling reads on all protocols.
323 self.unthrottleReadsID = None
324 log.msg("Stopped throttling reads on %s" % self)
325 for p in self.protocols.keys():
329 def buildProtocol(self, addr):
330 if self.connectionCount == 0:
331 if self.readLimit is not None:
332 self.checkReadBandwidth()
333 if self.writeLimit is not None:
334 self.checkWriteBandwidth()
336 if self.connectionCount < self.maxConnectionCount:
337 self.connectionCount += 1
338 return WrappingFactory.buildProtocol(self, addr)
340 log.msg("Max connection count reached!")
344 def unregisterProtocol(self, p):
345 WrappingFactory.unregisterProtocol(self, p)
346 self.connectionCount -= 1
347 if self.connectionCount == 0:
348 if self.unthrottleReadsID is not None:
349 self.unthrottleReadsID.cancel()
350 if self.checkReadBandwidthID is not None:
351 self.checkReadBandwidthID.cancel()
352 if self.unthrottleWritesID is not None:
353 self.unthrottleWritesID.cancel()
354 if self.checkWriteBandwidthID is not None:
355 self.checkWriteBandwidthID.cancel()
359 class SpewingProtocol(ProtocolWrapper):
360 def dataReceived(self, data):
361 log.msg("Received: %r" % data)
362 ProtocolWrapper.dataReceived(self,data)
364 def write(self, data):
365 log.msg("Sending: %r" % data)
366 ProtocolWrapper.write(self,data)
370 class SpewingFactory(WrappingFactory):
371 protocol = SpewingProtocol
375 class LimitConnectionsByPeer(WrappingFactory):
376 """Stability: Unstable"""
378 maxConnectionsPerPeer = 5
380 def startFactory(self):
381 self.peerConnections = {}
383 def buildProtocol(self, addr):
385 connectionCount = self.peerConnections.get(peerHost, 0)
386 if connectionCount >= self.maxConnectionsPerPeer:
388 self.peerConnections[peerHost] = connectionCount + 1
389 return WrappingFactory.buildProtocol(self, addr)
391 def unregisterProtocol(self, p):
392 peerHost = p.getPeer()[1]
393 self.peerConnections[peerHost] -= 1
394 if self.peerConnections[peerHost] == 0:
395 del self.peerConnections[peerHost]
398 class LimitTotalConnectionsFactory(ServerFactory):
399 """Factory that limits the number of simultaneous connections.
401 API Stability: Unstable
403 @type connectionCount: C{int}
404 @ivar connectionCount: number of current connections.
405 @type connectionLimit: C{int} or C{None}
406 @cvar connectionLimit: maximum number of connections.
407 @type overflowProtocol: L{Protocol} or C{None}
408 @cvar overflowProtocol: Protocol to use for new connections when
409 connectionLimit is exceeded. If C{None} (the default value), excess
410 connections will be closed immediately.
413 connectionLimit = None
414 overflowProtocol = None
416 def buildProtocol(self, addr):
417 if (self.connectionLimit is None or
418 self.connectionCount < self.connectionLimit):
419 # Build the normal protocol
420 wrappedProtocol = self.protocol()
421 elif self.overflowProtocol is None:
422 # Just drop the connection
425 # Too many connections, so build the overflow protocol
426 wrappedProtocol = self.overflowProtocol()
428 wrappedProtocol.factory = self
429 protocol = ProtocolWrapper(self, wrappedProtocol)
430 self.connectionCount += 1
433 def registerProtocol(self, p):
436 def unregisterProtocol(self, p):
437 self.connectionCount -= 1
441 class TimeoutProtocol(ProtocolWrapper):
443 Protocol that automatically disconnects when the connection is idle.
448 def __init__(self, factory, wrappedProtocol, timeoutPeriod):
452 @param factory: An L{IFactory}.
453 @param wrappedProtocol: A L{Protocol} to wrapp.
454 @param timeoutPeriod: Number of seconds to wait for activity before
457 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
458 self.timeoutCall = None
459 self.setTimeout(timeoutPeriod)
462 def setTimeout(self, timeoutPeriod=None):
466 This will cancel any existing timeouts.
468 @param timeoutPeriod: If not C{None}, change the timeout period.
469 Otherwise, use the existing value.
472 if timeoutPeriod is not None:
473 self.timeoutPeriod = timeoutPeriod
474 self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
477 def cancelTimeout(self):
481 If the timeout was already cancelled, this does nothing.
485 self.timeoutCall.cancel()
486 except error.AlreadyCalled:
488 self.timeoutCall = None
491 def resetTimeout(self):
493 Reset the timeout, usually because some activity just happened.
496 self.timeoutCall.reset(self.timeoutPeriod)
499 def write(self, data):
501 ProtocolWrapper.write(self, data)
504 def writeSequence(self, seq):
506 ProtocolWrapper.writeSequence(self, seq)
509 def dataReceived(self, data):
511 ProtocolWrapper.dataReceived(self, data)
514 def connectionLost(self, reason):
516 ProtocolWrapper.connectionLost(self, reason)
519 def timeoutFunc(self):
521 This method is called when the timeout is triggered.
523 By default it calls L{loseConnection}. Override this if you want
524 something else to happen.
526 self.loseConnection()
530 class TimeoutFactory(WrappingFactory):
532 Factory for TimeoutWrapper.
536 protocol = TimeoutProtocol
539 def __init__(self, wrappedFactory, timeoutPeriod=30*60):
540 self.timeoutPeriod = timeoutPeriod
541 WrappingFactory.__init__(self, wrappedFactory)
544 def buildProtocol(self, addr):
545 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
546 timeoutPeriod=self.timeoutPeriod)
549 def callLater(self, period, func):
551 Wrapper around L{reactor.callLater} for test purpose.
553 return reactor.callLater(period, func)
557 class TrafficLoggingProtocol(ProtocolWrapper):
559 def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
562 @param factory: factory which created this protocol.
563 @type factory: C{protocol.Factory}.
564 @param wrappedProtocol: the underlying protocol.
565 @type wrappedProtocol: C{protocol.Protocol}.
566 @param logfile: file opened for writing used to write log messages.
567 @type logfile: C{file}
568 @param lengthLimit: maximum size of the datareceived logged.
569 @type lengthLimit: C{int}
570 @param number: identifier of the connection.
571 @type number: C{int}.
573 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
574 self.logfile = logfile
575 self.lengthLimit = lengthLimit
576 self._number = number
579 def _log(self, line):
580 self.logfile.write(line + '\n')
584 def _mungeData(self, data):
585 if self.lengthLimit and len(data) > self.lengthLimit:
586 data = data[:self.lengthLimit - 12] + '<... elided>'
591 def connectionMade(self):
593 return ProtocolWrapper.connectionMade(self)
596 def dataReceived(self, data):
597 self._log('C %d: %r' % (self._number, self._mungeData(data)))
598 return ProtocolWrapper.dataReceived(self, data)
601 def connectionLost(self, reason):
602 self._log('C %d: %r' % (self._number, reason))
603 return ProtocolWrapper.connectionLost(self, reason)
607 def write(self, data):
608 self._log('S %d: %r' % (self._number, self._mungeData(data)))
609 return ProtocolWrapper.write(self, data)
612 def writeSequence(self, iovec):
613 self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
614 return ProtocolWrapper.writeSequence(self, iovec)
617 def loseConnection(self):
618 self._log('S %d: *' % (self._number,))
619 return ProtocolWrapper.loseConnection(self)
623 class TrafficLoggingFactory(WrappingFactory):
624 protocol = TrafficLoggingProtocol
628 def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
629 self.logfilePrefix = logfilePrefix
630 self.lengthLimit = lengthLimit
631 WrappingFactory.__init__(self, wrappedFactory)
634 def open(self, name):
635 return file(name, 'w')
638 def buildProtocol(self, addr):
640 logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
641 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
642 logfile, self.lengthLimit, self._counter)
645 def resetCounter(self):
647 Reset the value of the counter used to identify connections.
654 """Mixin for protocols which wish to timeout connections
656 @cvar timeOut: The number of seconds after which to timeout the connection.
662 def callLater(self, period, func):
663 return reactor.callLater(period, func)
666 def resetTimeout(self):
667 """Reset the timeout count down"""
668 if self.__timeoutCall is not None and self.timeOut is not None:
669 self.__timeoutCall.reset(self.timeOut)
671 def setTimeout(self, period):
672 """Change the timeout period
674 @type period: C{int} or C{NoneType}
675 @param period: The period, in seconds, to change the timeout to, or
676 C{None} to disable the timeout.
679 self.timeOut = period
681 if self.__timeoutCall is not None:
683 self.__timeoutCall.cancel()
684 self.__timeoutCall = None
686 self.__timeoutCall.reset(period)
687 elif period is not None:
688 self.__timeoutCall = self.callLater(period, self.__timedOut)
692 def __timedOut(self):
693 self.__timeoutCall = None
694 self.timeoutConnection()
696 def timeoutConnection(self):
697 """Called when the connection times out.
698 Override to define behavior other than dropping the connection.
700 self.transport.loseConnection()