1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 Resource limiting policies.
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet import reactor, error
18 from twisted.python import log
19 from zope.interface import providedBy, directlyProvides
22 class ProtocolWrapper(Protocol):
23 """Wraps protocol instances and acts as their transport as well."""
27 def __init__(self, factory, wrappedProtocol):
28 self.wrappedProtocol = wrappedProtocol
29 self.factory = factory
31 def makeConnection(self, transport):
32 directlyProvides(self, *providedBy(self) + providedBy(transport))
33 Protocol.makeConnection(self, transport)
37 def write(self, data):
38 self.transport.write(data)
40 def writeSequence(self, data):
41 self.transport.writeSequence(data)
43 def loseConnection(self):
44 self.disconnecting = 1
45 self.transport.loseConnection()
48 return self.transport.getPeer()
51 return self.transport.getHost()
53 def registerProducer(self, producer, streaming):
54 self.transport.registerProducer(producer, streaming)
56 def unregisterProducer(self):
57 self.transport.unregisterProducer()
59 def stopConsuming(self):
60 self.transport.stopConsuming()
62 def __getattr__(self, name):
63 return getattr(self.transport, name)
67 def connectionMade(self):
68 self.factory.registerProtocol(self)
69 self.wrappedProtocol.makeConnection(self)
71 def dataReceived(self, data):
72 self.wrappedProtocol.dataReceived(data)
74 def connectionLost(self, reason):
75 self.factory.unregisterProtocol(self)
76 self.wrappedProtocol.connectionLost(reason)
79 class WrappingFactory(ClientFactory):
80 """Wraps a factory and its protocols, and keeps track of them."""
82 protocol = ProtocolWrapper
84 def __init__(self, wrappedFactory):
85 self.wrappedFactory = wrappedFactory
89 self.wrappedFactory.doStart()
90 ClientFactory.doStart(self)
93 self.wrappedFactory.doStop()
94 ClientFactory.doStop(self)
96 def startedConnecting(self, connector):
97 self.wrappedFactory.startedConnecting(connector)
99 def clientConnectionFailed(self, connector, reason):
100 self.wrappedFactory.clientConnectionFailed(connector, reason)
102 def clientConnectionLost(self, connector, reason):
103 self.wrappedFactory.clientConnectionLost(connector, reason)
105 def buildProtocol(self, addr):
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
108 def registerProtocol(self, p):
109 """Called by protocol to register itself."""
110 self.protocols[p] = 1
112 def unregisterProtocol(self, p):
113 """Called by protocols when they go away."""
114 del self.protocols[p]
117 class ThrottlingProtocol(ProtocolWrapper):
118 """Protocol for ThrottlingFactory."""
120 # wrap API for tracking bandwidth
122 def __init__(self, factory, wrappedProtocol):
123 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124 self._tempDataBuffer = []
125 self._tempDataLength = 0
126 self.throttled = False
128 def write(self, data):
129 # Check if we can write
130 if not self.throttled:
131 paused = self.factory.registerWritten(len(data))
133 ProtocolWrapper.write(self, data)
135 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
136 # Interrupt the flow so that others can can have a chance
137 # We can only do this if it's not already paused otherwise we
138 # risk unpausing something that the Server paused
139 self.producer.pauseProducing()
140 reactor.callLater(0, self.producer.resumeProducing)
142 if self.throttled or paused:
143 # Can't write, buffer the data
144 self._tempDataBuffer.append(data)
145 self._tempDataLength += len(data)
146 self._throttleWrites()
148 def writeSequence(self, seq):
149 if not self.throttled:
150 # Write each sequence separately
151 while seq and not self.factory.registerWritten(len(seq[0])):
152 ProtocolWrapper.write(self, seq.pop(0))
154 # If there's some left, we must have been paused
156 self._tempDataBuffer.extend(seq)
157 self._tempDataLength += reduce(operator.add, map(len, seq))
158 self._throttleWrites()
160 def dataReceived(self, data):
161 self.factory.registerRead(len(data))
162 ProtocolWrapper.dataReceived(self, data)
164 def registerProducer(self, producer, streaming):
165 assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
166 self.producer = producer
167 ProtocolWrapper.registerProducer(self, producer, streaming)
169 def unregisterProducer(self):
171 ProtocolWrapper.unregisterProducer(self)
174 def throttleReads(self):
175 self.transport.pauseProducing()
177 def unthrottleReads(self):
178 self.transport.resumeProducing()
180 def _throttleWrites(self):
181 # If we haven't yet, queue for unthrottling
182 if not self.throttled:
183 self.throttled = True
184 self.factory.throttledWrites(self)
186 if hasattr(self, "producer") and self.producer:
187 self.producer.pauseProducing()
189 def unthrottleWrites(self):
191 if self._tempDataBuffer:
192 assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
193 self._tempDataLength -= len(self._tempDataBuffer[0])
194 ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
195 assert self._tempDataLength >= 0
197 # If we wrote it all, start producing more
198 if not self._tempDataBuffer:
199 assert self._tempDataLength == 0
200 self.throttled = False
201 if hasattr(self, "producer") and self.producer:
202 # This might unpause something the Server has also paused, but
203 # it will get paused again on first write anyway
204 reactor.callLater(0, self.producer.resumeProducing)
206 return self._tempDataLength
209 class ThrottlingFactory(WrappingFactory):
211 Throttles bandwidth and number of connections.
213 Write bandwidth will only be throttled if there is a producer
217 protocol = ThrottlingProtocol
220 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
221 readLimit=None, writeLimit=None):
222 WrappingFactory.__init__(self, wrappedFactory)
223 self.connectionCount = 0
224 self.maxConnectionCount = maxConnectionCount
225 self.readLimit = readLimit # max bytes we should read per second
226 self.writeLimit = writeLimit # max bytes we should write per second
227 self.readThisSecond = 0
228 self.writeAvailable = writeLimit
229 self._writeQueue = []
230 self.unthrottleReadsID = None
231 self.checkReadBandwidthID = None
232 self.unthrottleWritesID = None
233 self.checkWriteBandwidthID = None
236 def callLater(self, period, func):
238 Wrapper around L{reactor.callLater} for test purpose.
240 return reactor.callLater(period, func)
243 def registerWritten(self, length):
245 Called by protocol to tell us more bytes were written.
246 Returns True if the bytes could not be written and the protocol should pause itself.
248 # Check if there are bytes available to write
249 if self.writeLimit is None:
251 elif self.writeAvailable > 0:
252 self.writeAvailable -= length
258 def throttledWrites(self, p):
260 Called by the protocol to queue it for later writing.
262 assert p not in self._writeQueue
263 self._writeQueue.append(p)
266 def registerRead(self, length):
268 Called by protocol to tell us more bytes were read.
270 self.readThisSecond += length
273 def checkReadBandwidth(self):
275 Checks if we've passed bandwidth limits.
277 if self.readThisSecond > self.readLimit:
279 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
280 self.unthrottleReadsID = self.callLater(throttleTime,
281 self.unthrottleReads)
282 self.readThisSecond = 0
283 self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
286 def checkWriteBandwidth(self):
288 Add some new available bandwidth, and check for protocols to unthrottle.
290 # Increase the available write bytes, but not higher than the limit
291 self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
293 # Write from the queue until it's empty or we're throttled again
294 while self.writeAvailable > 0 and self._writeQueue:
295 # Get the first queued protocol
296 p = self._writeQueue.pop(0)
297 _tempWriteAvailable = self.writeAvailable
300 # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
301 while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
302 # Unthrottle a single write (from the protocol's buffer)
303 bytesLeft = p.unthrottleWrites()
305 # If the protocol is not done, requeue it
307 self._writeQueue.append(p)
309 self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
312 def throttleReads(self):
314 Throttle reads on all protocols.
316 log.msg("Throttling reads on %s" % self)
317 for p in self.protocols.keys():
321 def unthrottleReads(self):
323 Stop throttling reads on all protocols.
325 self.unthrottleReadsID = None
326 log.msg("Stopped throttling reads on %s" % self)
327 for p in self.protocols.keys():
331 def buildProtocol(self, addr):
332 if self.connectionCount == 0:
333 if self.readLimit is not None:
334 self.checkReadBandwidth()
335 if self.writeLimit is not None:
336 self.checkWriteBandwidth()
338 if self.connectionCount < self.maxConnectionCount:
339 self.connectionCount += 1
340 return WrappingFactory.buildProtocol(self, addr)
342 log.msg("Max connection count reached!")
346 def unregisterProtocol(self, p):
347 WrappingFactory.unregisterProtocol(self, p)
348 self.connectionCount -= 1
349 if self.connectionCount == 0:
350 if self.unthrottleReadsID is not None:
351 self.unthrottleReadsID.cancel()
352 if self.checkReadBandwidthID is not None:
353 self.checkReadBandwidthID.cancel()
354 if self.unthrottleWritesID is not None:
355 self.unthrottleWritesID.cancel()
356 if self.checkWriteBandwidthID is not None:
357 self.checkWriteBandwidthID.cancel()
361 class SpewingProtocol(ProtocolWrapper):
362 def dataReceived(self, data):
363 log.msg("Received: %r" % data)
364 ProtocolWrapper.dataReceived(self,data)
366 def write(self, data):
367 log.msg("Sending: %r" % data)
368 ProtocolWrapper.write(self,data)
372 class SpewingFactory(WrappingFactory):
373 protocol = SpewingProtocol
377 class LimitConnectionsByPeer(WrappingFactory):
378 """Stability: Unstable"""
380 maxConnectionsPerPeer = 5
382 def startFactory(self):
383 self.peerConnections = {}
385 def buildProtocol(self, addr):
387 connectionCount = self.peerConnections.get(peerHost, 0)
388 if connectionCount >= self.maxConnectionsPerPeer:
390 self.peerConnections[peerHost] = connectionCount + 1
391 return WrappingFactory.buildProtocol(self, addr)
393 def unregisterProtocol(self, p):
394 peerHost = p.getPeer()[1]
395 self.peerConnections[peerHost] -= 1
396 if self.peerConnections[peerHost] == 0:
397 del self.peerConnections[peerHost]
400 class LimitTotalConnectionsFactory(ServerFactory):
401 """Factory that limits the number of simultaneous connections.
403 API Stability: Unstable
405 @type connectionCount: C{int}
406 @ivar connectionCount: number of current connections.
407 @type connectionLimit: C{int} or C{None}
408 @cvar connectionLimit: maximum number of connections.
409 @type overflowProtocol: L{Protocol} or C{None}
410 @cvar overflowProtocol: Protocol to use for new connections when
411 connectionLimit is exceeded. If C{None} (the default value), excess
412 connections will be closed immediately.
415 connectionLimit = None
416 overflowProtocol = None
418 def buildProtocol(self, addr):
419 if (self.connectionLimit is None or
420 self.connectionCount < self.connectionLimit):
421 # Build the normal protocol
422 wrappedProtocol = self.protocol()
423 elif self.overflowProtocol is None:
424 # Just drop the connection
427 # Too many connections, so build the overflow protocol
428 wrappedProtocol = self.overflowProtocol()
430 wrappedProtocol.factory = self
431 protocol = ProtocolWrapper(self, wrappedProtocol)
432 self.connectionCount += 1
435 def registerProtocol(self, p):
438 def unregisterProtocol(self, p):
439 self.connectionCount -= 1
443 class TimeoutProtocol(ProtocolWrapper):
445 Protocol that automatically disconnects when the connection is idle.
450 def __init__(self, factory, wrappedProtocol, timeoutPeriod):
454 @param factory: An L{IFactory}.
455 @param wrappedProtocol: A L{Protocol} to wrapp.
456 @param timeoutPeriod: Number of seconds to wait for activity before
459 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
460 self.timeoutCall = None
461 self.setTimeout(timeoutPeriod)
464 def setTimeout(self, timeoutPeriod=None):
468 This will cancel any existing timeouts.
470 @param timeoutPeriod: If not C{None}, change the timeout period.
471 Otherwise, use the existing value.
474 if timeoutPeriod is not None:
475 self.timeoutPeriod = timeoutPeriod
476 self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
479 def cancelTimeout(self):
483 If the timeout was already cancelled, this does nothing.
487 self.timeoutCall.cancel()
488 except error.AlreadyCalled:
490 self.timeoutCall = None
493 def resetTimeout(self):
495 Reset the timeout, usually because some activity just happened.
498 self.timeoutCall.reset(self.timeoutPeriod)
501 def write(self, data):
503 ProtocolWrapper.write(self, data)
506 def writeSequence(self, seq):
508 ProtocolWrapper.writeSequence(self, seq)
511 def dataReceived(self, data):
513 ProtocolWrapper.dataReceived(self, data)
516 def connectionLost(self, reason):
518 ProtocolWrapper.connectionLost(self, reason)
521 def timeoutFunc(self):
523 This method is called when the timeout is triggered.
525 By default it calls L{loseConnection}. Override this if you want
526 something else to happen.
528 self.loseConnection()
532 class TimeoutFactory(WrappingFactory):
534 Factory for TimeoutWrapper.
538 protocol = TimeoutProtocol
541 def __init__(self, wrappedFactory, timeoutPeriod=30*60):
542 self.timeoutPeriod = timeoutPeriod
543 WrappingFactory.__init__(self, wrappedFactory)
546 def buildProtocol(self, addr):
547 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
548 timeoutPeriod=self.timeoutPeriod)
551 def callLater(self, period, func):
553 Wrapper around L{reactor.callLater} for test purpose.
555 return reactor.callLater(period, func)
559 class TrafficLoggingProtocol(ProtocolWrapper):
561 def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
564 @param factory: factory which created this protocol.
565 @type factory: C{protocol.Factory}.
566 @param wrappedProtocol: the underlying protocol.
567 @type wrappedProtocol: C{protocol.Protocol}.
568 @param logfile: file opened for writing used to write log messages.
569 @type logfile: C{file}
570 @param lengthLimit: maximum size of the datareceived logged.
571 @type lengthLimit: C{int}
572 @param number: identifier of the connection.
573 @type number: C{int}.
575 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
576 self.logfile = logfile
577 self.lengthLimit = lengthLimit
578 self._number = number
581 def _log(self, line):
582 self.logfile.write(line + '\n')
586 def _mungeData(self, data):
587 if self.lengthLimit and len(data) > self.lengthLimit:
588 data = data[:self.lengthLimit - 12] + '<... elided>'
593 def connectionMade(self):
595 return ProtocolWrapper.connectionMade(self)
598 def dataReceived(self, data):
599 self._log('C %d: %r' % (self._number, self._mungeData(data)))
600 return ProtocolWrapper.dataReceived(self, data)
603 def connectionLost(self, reason):
604 self._log('C %d: %r' % (self._number, reason))
605 return ProtocolWrapper.connectionLost(self, reason)
609 def write(self, data):
610 self._log('S %d: %r' % (self._number, self._mungeData(data)))
611 return ProtocolWrapper.write(self, data)
614 def writeSequence(self, iovec):
615 self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
616 return ProtocolWrapper.writeSequence(self, iovec)
619 def loseConnection(self):
620 self._log('S %d: *' % (self._number,))
621 return ProtocolWrapper.loseConnection(self)
625 class TrafficLoggingFactory(WrappingFactory):
626 protocol = TrafficLoggingProtocol
630 def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
631 self.logfilePrefix = logfilePrefix
632 self.lengthLimit = lengthLimit
633 WrappingFactory.__init__(self, wrappedFactory)
636 def open(self, name):
637 return file(name, 'w')
640 def buildProtocol(self, addr):
642 logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
643 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
644 logfile, self.lengthLimit, self._counter)
647 def resetCounter(self):
649 Reset the value of the counter used to identify connections.
656 """Mixin for protocols which wish to timeout connections
658 @cvar timeOut: The number of seconds after which to timeout the connection.
664 def callLater(self, period, func):
665 return reactor.callLater(period, func)
668 def resetTimeout(self):
669 """Reset the timeout count down"""
670 if self.__timeoutCall is not None and self.timeOut is not None:
671 self.__timeoutCall.reset(self.timeOut)
673 def setTimeout(self, period):
674 """Change the timeout period
676 @type period: C{int} or C{NoneType}
677 @param period: The period, in seconds, to change the timeout to, or
678 C{None} to disable the timeout.
681 self.timeOut = period
683 if self.__timeoutCall is not None:
685 self.__timeoutCall.cancel()
686 self.__timeoutCall = None
688 self.__timeoutCall.reset(period)
689 elif period is not None:
690 self.__timeoutCall = self.callLater(period, self.__timedOut)
694 def __timedOut(self):
695 self.__timeoutCall = None
696 self.timeoutConnection()
698 def timeoutConnection(self):
699 """Called when the connection times out.
700 Override to define behavior other than dropping the connection.
702 self.transport.loseConnection()