1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 Resource limiting policies.
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet import reactor, error
18 from twisted.python import log
19 from zope.interface import providedBy, directlyProvides
22 class ProtocolWrapper(Protocol):
23 """Wraps protocol instances and acts as their transport as well."""
27 def __init__(self, factory, wrappedProtocol):
28 self.wrappedProtocol = wrappedProtocol
29 self.factory = factory
31 def makeConnection(self, transport):
32 directlyProvides(self, *providedBy(self) + providedBy(transport))
33 Protocol.makeConnection(self, transport)
37 def write(self, data):
38 self.transport.write(data)
40 def writeSequence(self, data):
41 self.transport.writeSequence(data)
43 def loseConnection(self):
44 self.disconnecting = 1
45 self.transport.loseConnection()
48 return self.transport.getPeer()
51 return self.transport.getHost()
53 def registerProducer(self, producer, streaming):
54 self.transport.registerProducer(producer, streaming)
56 def unregisterProducer(self):
57 self.transport.unregisterProducer()
59 def stopConsuming(self):
60 self.transport.stopConsuming()
62 def __getattr__(self, name):
63 return getattr(self.transport, name)
67 def connectionMade(self):
68 self.factory.registerProtocol(self)
69 self.wrappedProtocol.makeConnection(self)
71 def dataReceived(self, data):
72 self.wrappedProtocol.dataReceived(data)
74 def connectionLost(self, reason):
75 self.factory.unregisterProtocol(self)
76 self.wrappedProtocol.connectionLost(reason)
79 class WrappingFactory(ClientFactory):
80 """Wraps a factory and its protocols, and keeps track of them."""
82 protocol = ProtocolWrapper
84 def __init__(self, wrappedFactory):
85 self.wrappedFactory = wrappedFactory
89 self.wrappedFactory.doStart()
90 ClientFactory.doStart(self)
93 self.wrappedFactory.doStop()
94 ClientFactory.doStop(self)
96 def startedConnecting(self, connector):
97 self.wrappedFactory.startedConnecting(connector)
99 def clientConnectionFailed(self, connector, reason):
100 self.wrappedFactory.clientConnectionFailed(connector, reason)
102 def clientConnectionLost(self, connector, reason):
103 self.wrappedFactory.clientConnectionLost(connector, reason)
105 def buildProtocol(self, addr):
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
108 def registerProtocol(self, p):
109 """Called by protocol to register itself."""
110 self.protocols[p] = 1
112 def unregisterProtocol(self, p):
113 """Called by protocols when they go away."""
114 del self.protocols[p]
117 class ThrottlingProtocol(ProtocolWrapper):
118 """Protocol for ThrottlingFactory."""
120 # wrap API for tracking bandwidth
122 def __init__(self, factory, wrappedProtocol):
123 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124 self._tempDataBuffer = []
125 self._tempDataLength = 0
126 self.throttled = False
128 def write(self, data):
129 # Check if we can write
130 if not self.throttled:
131 paused = self.factory.registerWritten(len(data))
133 ProtocolWrapper.write(self, data)
135 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
136 # Interrupt the flow so that others can can have a chance
137 # We can only do this if it's not already paused otherwise we
138 # risk unpausing something that the Server paused
139 self.producer.pauseProducing()
140 reactor.callLater(0, self.producer.resumeProducing)
142 if self.throttled or paused:
143 # Can't write, buffer the data
144 self._tempDataBuffer.append(data)
145 self._tempDataLength += len(data)
146 self._throttleWrites()
148 def writeSequence(self, seq):
150 if not self.throttled:
151 # Write each sequence separately
152 while i < len(seq) and not self.factory.registerWritten(len(seq[i])):
153 ProtocolWrapper.write(self, seq[i])
156 # If there's some left, we must have been paused
158 self._tempDataBuffer.extend(seq[i:])
159 self._tempDataLength += reduce(operator.add, map(len, seq[i:]))
160 self._throttleWrites()
162 def dataReceived(self, data):
163 self.factory.registerRead(len(data))
164 ProtocolWrapper.dataReceived(self, data)
166 def registerProducer(self, producer, streaming):
167 assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
168 self.producer = producer
169 ProtocolWrapper.registerProducer(self, producer, streaming)
171 def unregisterProducer(self):
173 ProtocolWrapper.unregisterProducer(self)
176 def throttleReads(self):
177 self.transport.pauseProducing()
179 def unthrottleReads(self):
180 self.transport.resumeProducing()
182 def _throttleWrites(self):
183 # If we haven't yet, queue for unthrottling
184 if not self.throttled:
185 self.throttled = True
186 self.factory.throttledWrites(self)
188 if hasattr(self, "producer") and self.producer:
189 self.producer.pauseProducing()
191 def unthrottleWrites(self):
193 if self._tempDataBuffer:
194 assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
195 self._tempDataLength -= len(self._tempDataBuffer[0])
196 ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
197 assert self._tempDataLength >= 0
199 # If we wrote it all, start producing more
200 if not self._tempDataBuffer:
201 assert self._tempDataLength == 0
202 self.throttled = False
203 if hasattr(self, "producer") and self.producer:
204 # This might unpause something the Server has also paused, but
205 # it will get paused again on first write anyway
206 reactor.callLater(0, self.producer.resumeProducing)
208 return self._tempDataLength
211 class ThrottlingFactory(WrappingFactory):
213 Throttles bandwidth and number of connections.
215 Write bandwidth will only be throttled if there is a producer
219 protocol = ThrottlingProtocol
222 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
223 readLimit=None, writeLimit=None):
224 WrappingFactory.__init__(self, wrappedFactory)
225 self.connectionCount = 0
226 self.maxConnectionCount = maxConnectionCount
227 self.readLimit = readLimit # max bytes we should read per second
228 self.writeLimit = writeLimit # max bytes we should write per second
229 self.readThisSecond = 0
230 self.writeAvailable = writeLimit
231 self._writeQueue = []
232 self.unthrottleReadsID = None
233 self.checkReadBandwidthID = None
234 self.unthrottleWritesID = None
235 self.checkWriteBandwidthID = None
238 def callLater(self, period, func):
240 Wrapper around L{reactor.callLater} for test purpose.
242 return reactor.callLater(period, func)
245 def registerWritten(self, length):
247 Called by protocol to tell us more bytes were written.
248 Returns True if the bytes could not be written and the protocol should pause itself.
250 # Check if there are bytes available to write
251 if self.writeLimit is None:
253 elif self.writeAvailable > 0:
254 self.writeAvailable -= length
260 def throttledWrites(self, p):
262 Called by the protocol to queue it for later writing.
264 assert p not in self._writeQueue
265 self._writeQueue.append(p)
268 def registerRead(self, length):
270 Called by protocol to tell us more bytes were read.
272 self.readThisSecond += length
275 def checkReadBandwidth(self):
277 Checks if we've passed bandwidth limits.
279 if self.readThisSecond > self.readLimit:
281 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
282 self.unthrottleReadsID = self.callLater(throttleTime,
283 self.unthrottleReads)
284 self.readThisSecond = 0
285 self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
288 def checkWriteBandwidth(self):
290 Add some new available bandwidth, and check for protocols to unthrottle.
292 # Increase the available write bytes, but not higher than the limit
293 self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
295 # Write from the queue until it's empty or we're throttled again
296 while self.writeAvailable > 0 and self._writeQueue:
297 # Get the first queued protocol
298 p = self._writeQueue.pop(0)
299 _tempWriteAvailable = self.writeAvailable
302 # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
303 while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
304 # Unthrottle a single write (from the protocol's buffer)
305 bytesLeft = p.unthrottleWrites()
307 # If the protocol is not done, requeue it
309 self._writeQueue.append(p)
311 self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
314 def throttleReads(self):
316 Throttle reads on all protocols.
318 log.msg("Throttling reads on %s" % self)
319 for p in self.protocols.keys():
323 def unthrottleReads(self):
325 Stop throttling reads on all protocols.
327 self.unthrottleReadsID = None
328 log.msg("Stopped throttling reads on %s" % self)
329 for p in self.protocols.keys():
333 def buildProtocol(self, addr):
334 if self.connectionCount == 0:
335 if self.readLimit is not None:
336 self.checkReadBandwidth()
337 if self.writeLimit is not None:
338 self.checkWriteBandwidth()
340 if self.connectionCount < self.maxConnectionCount:
341 self.connectionCount += 1
342 return WrappingFactory.buildProtocol(self, addr)
344 log.msg("Max connection count reached!")
348 def unregisterProtocol(self, p):
349 WrappingFactory.unregisterProtocol(self, p)
350 self.connectionCount -= 1
351 if self.connectionCount == 0:
352 if self.unthrottleReadsID is not None:
353 self.unthrottleReadsID.cancel()
354 if self.checkReadBandwidthID is not None:
355 self.checkReadBandwidthID.cancel()
356 if self.unthrottleWritesID is not None:
357 self.unthrottleWritesID.cancel()
358 if self.checkWriteBandwidthID is not None:
359 self.checkWriteBandwidthID.cancel()
363 class SpewingProtocol(ProtocolWrapper):
364 def dataReceived(self, data):
365 log.msg("Received: %r" % data)
366 ProtocolWrapper.dataReceived(self,data)
368 def write(self, data):
369 log.msg("Sending: %r" % data)
370 ProtocolWrapper.write(self,data)
374 class SpewingFactory(WrappingFactory):
375 protocol = SpewingProtocol
379 class LimitConnectionsByPeer(WrappingFactory):
380 """Stability: Unstable"""
382 maxConnectionsPerPeer = 5
384 def startFactory(self):
385 self.peerConnections = {}
387 def buildProtocol(self, addr):
389 connectionCount = self.peerConnections.get(peerHost, 0)
390 if connectionCount >= self.maxConnectionsPerPeer:
392 self.peerConnections[peerHost] = connectionCount + 1
393 return WrappingFactory.buildProtocol(self, addr)
395 def unregisterProtocol(self, p):
396 peerHost = p.getPeer()[1]
397 self.peerConnections[peerHost] -= 1
398 if self.peerConnections[peerHost] == 0:
399 del self.peerConnections[peerHost]
402 class LimitTotalConnectionsFactory(ServerFactory):
403 """Factory that limits the number of simultaneous connections.
405 API Stability: Unstable
407 @type connectionCount: C{int}
408 @ivar connectionCount: number of current connections.
409 @type connectionLimit: C{int} or C{None}
410 @cvar connectionLimit: maximum number of connections.
411 @type overflowProtocol: L{Protocol} or C{None}
412 @cvar overflowProtocol: Protocol to use for new connections when
413 connectionLimit is exceeded. If C{None} (the default value), excess
414 connections will be closed immediately.
417 connectionLimit = None
418 overflowProtocol = None
420 def buildProtocol(self, addr):
421 if (self.connectionLimit is None or
422 self.connectionCount < self.connectionLimit):
423 # Build the normal protocol
424 wrappedProtocol = self.protocol()
425 elif self.overflowProtocol is None:
426 # Just drop the connection
429 # Too many connections, so build the overflow protocol
430 wrappedProtocol = self.overflowProtocol()
432 wrappedProtocol.factory = self
433 protocol = ProtocolWrapper(self, wrappedProtocol)
434 self.connectionCount += 1
437 def registerProtocol(self, p):
440 def unregisterProtocol(self, p):
441 self.connectionCount -= 1
445 class TimeoutProtocol(ProtocolWrapper):
447 Protocol that automatically disconnects when the connection is idle.
452 def __init__(self, factory, wrappedProtocol, timeoutPeriod):
456 @param factory: An L{IFactory}.
457 @param wrappedProtocol: A L{Protocol} to wrapp.
458 @param timeoutPeriod: Number of seconds to wait for activity before
461 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
462 self.timeoutCall = None
463 self.setTimeout(timeoutPeriod)
466 def setTimeout(self, timeoutPeriod=None):
470 This will cancel any existing timeouts.
472 @param timeoutPeriod: If not C{None}, change the timeout period.
473 Otherwise, use the existing value.
476 if timeoutPeriod is not None:
477 self.timeoutPeriod = timeoutPeriod
478 self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
481 def cancelTimeout(self):
485 If the timeout was already cancelled, this does nothing.
489 self.timeoutCall.cancel()
490 except error.AlreadyCalled:
492 self.timeoutCall = None
495 def resetTimeout(self):
497 Reset the timeout, usually because some activity just happened.
500 self.timeoutCall.reset(self.timeoutPeriod)
503 def write(self, data):
505 ProtocolWrapper.write(self, data)
508 def writeSequence(self, seq):
510 ProtocolWrapper.writeSequence(self, seq)
513 def dataReceived(self, data):
515 ProtocolWrapper.dataReceived(self, data)
518 def connectionLost(self, reason):
520 ProtocolWrapper.connectionLost(self, reason)
523 def timeoutFunc(self):
525 This method is called when the timeout is triggered.
527 By default it calls L{loseConnection}. Override this if you want
528 something else to happen.
530 self.loseConnection()
534 class TimeoutFactory(WrappingFactory):
536 Factory for TimeoutWrapper.
540 protocol = TimeoutProtocol
543 def __init__(self, wrappedFactory, timeoutPeriod=30*60):
544 self.timeoutPeriod = timeoutPeriod
545 WrappingFactory.__init__(self, wrappedFactory)
548 def buildProtocol(self, addr):
549 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
550 timeoutPeriod=self.timeoutPeriod)
553 def callLater(self, period, func):
555 Wrapper around L{reactor.callLater} for test purpose.
557 return reactor.callLater(period, func)
561 class TrafficLoggingProtocol(ProtocolWrapper):
563 def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
566 @param factory: factory which created this protocol.
567 @type factory: C{protocol.Factory}.
568 @param wrappedProtocol: the underlying protocol.
569 @type wrappedProtocol: C{protocol.Protocol}.
570 @param logfile: file opened for writing used to write log messages.
571 @type logfile: C{file}
572 @param lengthLimit: maximum size of the datareceived logged.
573 @type lengthLimit: C{int}
574 @param number: identifier of the connection.
575 @type number: C{int}.
577 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
578 self.logfile = logfile
579 self.lengthLimit = lengthLimit
580 self._number = number
583 def _log(self, line):
584 self.logfile.write(line + '\n')
588 def _mungeData(self, data):
589 if self.lengthLimit and len(data) > self.lengthLimit:
590 data = data[:self.lengthLimit - 12] + '<... elided>'
595 def connectionMade(self):
597 return ProtocolWrapper.connectionMade(self)
600 def dataReceived(self, data):
601 self._log('C %d: %r' % (self._number, self._mungeData(data)))
602 return ProtocolWrapper.dataReceived(self, data)
605 def connectionLost(self, reason):
606 self._log('C %d: %r' % (self._number, reason))
607 return ProtocolWrapper.connectionLost(self, reason)
611 def write(self, data):
612 self._log('S %d: %r' % (self._number, self._mungeData(data)))
613 return ProtocolWrapper.write(self, data)
616 def writeSequence(self, iovec):
617 self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
618 return ProtocolWrapper.writeSequence(self, iovec)
621 def loseConnection(self):
622 self._log('S %d: *' % (self._number,))
623 return ProtocolWrapper.loseConnection(self)
627 class TrafficLoggingFactory(WrappingFactory):
628 protocol = TrafficLoggingProtocol
632 def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
633 self.logfilePrefix = logfilePrefix
634 self.lengthLimit = lengthLimit
635 WrappingFactory.__init__(self, wrappedFactory)
638 def open(self, name):
639 return file(name, 'w')
642 def buildProtocol(self, addr):
644 logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
645 return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
646 logfile, self.lengthLimit, self._counter)
649 def resetCounter(self):
651 Reset the value of the counter used to identify connections.
658 """Mixin for protocols which wish to timeout connections
660 @cvar timeOut: The number of seconds after which to timeout the connection.
666 def callLater(self, period, func):
667 return reactor.callLater(period, func)
670 def resetTimeout(self):
671 """Reset the timeout count down"""
672 if self.__timeoutCall is not None and self.timeOut is not None:
673 self.__timeoutCall.reset(self.timeOut)
675 def setTimeout(self, period):
676 """Change the timeout period
678 @type period: C{int} or C{NoneType}
679 @param period: The period, in seconds, to change the timeout to, or
680 C{None} to disable the timeout.
683 self.timeOut = period
685 if self.__timeoutCall is not None:
687 self.__timeoutCall.cancel()
688 self.__timeoutCall = None
690 self.__timeoutCall.reset(period)
691 elif period is not None:
692 self.__timeoutCall = self.callLater(period, self.__timedOut)
696 def __timedOut(self):
697 self.__timeoutCall = None
698 self.timeoutConnection()
700 def timeoutConnection(self):
701 """Called when the connection times out.
702 Override to define behavior other than dropping the connection.
704 self.transport.loseConnection()