From 58eccda74a195382b889cef4bef5351d7ade115f Mon Sep 17 00:00:00 2001 From: Cameron Dale Date: Thu, 28 Feb 2008 16:17:22 -0800 Subject: [PATCH] Upgrade policies to SVN version and fix a small bug. Fixed the policies usage when writeLimit was None to not do anything. --- apt_dht/policies.py | 450 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 416 insertions(+), 34 deletions(-) diff --git a/apt_dht/policies.py b/apt_dht/policies.py index 3c84c14..e7bae81 100644 --- a/apt_dht/policies.py +++ b/apt_dht/policies.py @@ -1,10 +1,10 @@ # -*- test-case-name: twisted.test.test_policies -*- -# Copyright (c) 2001-2004 Twisted Matrix Laboratories. +# Copyright (c) 2001-2007 Twisted Matrix Laboratories. # See LICENSE for details. -# -"""Resource limiting policies. +""" +Resource limiting policies. @seealso: See also L{twisted.protocols.htb} for rate limiting. """ @@ -14,10 +14,10 @@ import sys, operator # twisted imports from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory -from twisted.internet.interfaces import ITransport from twisted.internet import reactor, error from twisted.python import log -from zope.interface import implements, providedBy, directlyProvides +from zope.interface import providedBy, directlyProvides + class ProtocolWrapper(Protocol): """Wraps protocol instances and acts as their transport as well.""" @@ -127,16 +127,19 @@ class ThrottlingProtocol(ProtocolWrapper): def write(self, data): # Check if we can write - if (not self.throttled) and self.factory.registerWritten(len(data)): - ProtocolWrapper.write(self, data) - - if hasattr(self, "producer") and self.producer and not self.producer.paused: - # Interrupt the flow so that others can can have a chance - # We can only do this if it's not already paused otherwise we - # risk unpausing something that the Server paused - self.producer.pauseProducing() - reactor.callLater(0, self.producer.resumeProducing) - else: + if not self.throttled: + paused = self.factory.registerWritten(len(data)) + if not paused: + ProtocolWrapper.write(self, data) + + if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused: + # Interrupt the flow so that others can can have a chance + # We can only do this if it's not already paused otherwise we + # risk unpausing something that the Server paused + self.producer.pauseProducing() + reactor.callLater(0, self.producer.resumeProducing) + + if self.throttled or paused: # Can't write, buffer the data self._tempDataBuffer.append(data) self._tempDataLength += len(data) @@ -145,10 +148,10 @@ class ThrottlingProtocol(ProtocolWrapper): def writeSequence(self, seq): if not self.throttled: # Write each sequence separately - while seq and self.factory.registerWritten(len(seq[0])): + while seq and not self.factory.registerWritten(len(seq[0])): ProtocolWrapper.write(self, seq.pop(0)) - # If there's some left, we must have been throttled + # If there's some left, we must have been paused if seq: self._tempDataBuffer.extend(seq) self._tempDataLength += reduce(operator.add, map(len, seq)) @@ -186,7 +189,7 @@ class ThrottlingProtocol(ProtocolWrapper): def unthrottleWrites(self): # Write some data if self._tempDataBuffer: - assert self.factory.registerWritten(len(self._tempDataBuffer[0])) + assert not self.factory.registerWritten(len(self._tempDataBuffer[0])) self._tempDataLength -= len(self._tempDataBuffer[0]) ProtocolWrapper.write(self, self._tempDataBuffer.pop(0)) assert self._tempDataLength >= 0 @@ -204,7 +207,8 @@ class ThrottlingProtocol(ProtocolWrapper): class ThrottlingFactory(WrappingFactory): - """Throttles bandwidth and number of connections. + """ + Throttles bandwidth and number of connections. Write bandwidth will only be throttled if there is a producer registered. @@ -213,7 +217,8 @@ class ThrottlingFactory(WrappingFactory): protocol = ThrottlingProtocol CHUNK_SIZE = 4*1024 - def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None): + def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, + readLimit=None, writeLimit=None): WrappingFactory.__init__(self, wrappedFactory) self.connectionCount = 0 self.maxConnectionCount = maxConnectionCount @@ -227,36 +232,61 @@ class ThrottlingFactory(WrappingFactory): self.unthrottleWritesID = None self.checkWriteBandwidthID = None + + def callLater(self, period, func): + """ + Wrapper around L{reactor.callLater} for test purpose. + """ + return reactor.callLater(period, func) + + def registerWritten(self, length): - """Called by protocol to tell us more bytes were written.""" + """ + Called by protocol to tell us more bytes were written. + Returns True if the bytes could not be written and the protocol should pause itself. + """ # Check if there are bytes available to write - if self.writeAvailable > 0: + if self.writeLimit is None: + return None + elif self.writeAvailable > 0: self.writeAvailable -= length - return True + return False - return False + return True + def throttledWrites(self, p): - """Called by the protocol to queue it for later writing.""" + """ + Called by the protocol to queue it for later writing. + """ assert p not in self._writeQueue self._writeQueue.append(p) + def registerRead(self, length): - """Called by protocol to tell us more bytes were read.""" + """ + Called by protocol to tell us more bytes were read. + """ self.readThisSecond += length + def checkReadBandwidth(self): - """Checks if we've passed bandwidth limits.""" + """ + Checks if we've passed bandwidth limits. + """ if self.readThisSecond > self.readLimit: self.throttleReads() throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0 - self.unthrottleReadsID = reactor.callLater(throttleTime, - self.unthrottleReads) + self.unthrottleReadsID = self.callLater(throttleTime, + self.unthrottleReads) self.readThisSecond = 0 - self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth) + self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth) + def checkWriteBandwidth(self): - """Add some new available bandwidth, and check for protocols to unthrottle.""" + """ + Add some new available bandwidth, and check for protocols to unthrottle. + """ # Increase the available write bytes, but not higher than the limit self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit) @@ -276,21 +306,28 @@ class ThrottlingFactory(WrappingFactory): if bytesLeft > 0: self._writeQueue.append(p) - self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth) + self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth) + def throttleReads(self): - """Throttle reads on all protocols.""" + """ + Throttle reads on all protocols. + """ log.msg("Throttling reads on %s" % self) for p in self.protocols.keys(): p.throttleReads() + def unthrottleReads(self): - """Stop throttling reads on all protocols.""" + """ + Stop throttling reads on all protocols. + """ self.unthrottleReadsID = None log.msg("Stopped throttling reads on %s" % self) for p in self.protocols.keys(): p.unthrottleReads() + def buildProtocol(self, addr): if self.connectionCount == 0: if self.readLimit is not None: @@ -305,6 +342,7 @@ class ThrottlingFactory(WrappingFactory): log.msg("Max connection count reached!") return None + def unregisterProtocol(self, p): WrappingFactory.unregisterProtocol(self, p) self.connectionCount -= 1 @@ -318,3 +356,347 @@ class ThrottlingFactory(WrappingFactory): if self.checkWriteBandwidthID is not None: self.checkWriteBandwidthID.cancel() + + +class SpewingProtocol(ProtocolWrapper): + def dataReceived(self, data): + log.msg("Received: %r" % data) + ProtocolWrapper.dataReceived(self,data) + + def write(self, data): + log.msg("Sending: %r" % data) + ProtocolWrapper.write(self,data) + + + +class SpewingFactory(WrappingFactory): + protocol = SpewingProtocol + + + +class LimitConnectionsByPeer(WrappingFactory): + """Stability: Unstable""" + + maxConnectionsPerPeer = 5 + + def startFactory(self): + self.peerConnections = {} + + def buildProtocol(self, addr): + peerHost = addr[0] + connectionCount = self.peerConnections.get(peerHost, 0) + if connectionCount >= self.maxConnectionsPerPeer: + return None + self.peerConnections[peerHost] = connectionCount + 1 + return WrappingFactory.buildProtocol(self, addr) + + def unregisterProtocol(self, p): + peerHost = p.getPeer()[1] + self.peerConnections[peerHost] -= 1 + if self.peerConnections[peerHost] == 0: + del self.peerConnections[peerHost] + + +class LimitTotalConnectionsFactory(ServerFactory): + """Factory that limits the number of simultaneous connections. + + API Stability: Unstable + + @type connectionCount: C{int} + @ivar connectionCount: number of current connections. + @type connectionLimit: C{int} or C{None} + @cvar connectionLimit: maximum number of connections. + @type overflowProtocol: L{Protocol} or C{None} + @cvar overflowProtocol: Protocol to use for new connections when + connectionLimit is exceeded. If C{None} (the default value), excess + connections will be closed immediately. + """ + connectionCount = 0 + connectionLimit = None + overflowProtocol = None + + def buildProtocol(self, addr): + if (self.connectionLimit is None or + self.connectionCount < self.connectionLimit): + # Build the normal protocol + wrappedProtocol = self.protocol() + elif self.overflowProtocol is None: + # Just drop the connection + return None + else: + # Too many connections, so build the overflow protocol + wrappedProtocol = self.overflowProtocol() + + wrappedProtocol.factory = self + protocol = ProtocolWrapper(self, wrappedProtocol) + self.connectionCount += 1 + return protocol + + def registerProtocol(self, p): + pass + + def unregisterProtocol(self, p): + self.connectionCount -= 1 + + + +class TimeoutProtocol(ProtocolWrapper): + """ + Protocol that automatically disconnects when the connection is idle. + + Stability: Unstable + """ + + def __init__(self, factory, wrappedProtocol, timeoutPeriod): + """ + Constructor. + + @param factory: An L{IFactory}. + @param wrappedProtocol: A L{Protocol} to wrapp. + @param timeoutPeriod: Number of seconds to wait for activity before + timing out. + """ + ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self.timeoutCall = None + self.setTimeout(timeoutPeriod) + + + def setTimeout(self, timeoutPeriod=None): + """ + Set a timeout. + + This will cancel any existing timeouts. + + @param timeoutPeriod: If not C{None}, change the timeout period. + Otherwise, use the existing value. + """ + self.cancelTimeout() + if timeoutPeriod is not None: + self.timeoutPeriod = timeoutPeriod + self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc) + + + def cancelTimeout(self): + """ + Cancel the timeout. + + If the timeout was already cancelled, this does nothing. + """ + if self.timeoutCall: + try: + self.timeoutCall.cancel() + except error.AlreadyCalled: + pass + self.timeoutCall = None + + + def resetTimeout(self): + """ + Reset the timeout, usually because some activity just happened. + """ + if self.timeoutCall: + self.timeoutCall.reset(self.timeoutPeriod) + + + def write(self, data): + self.resetTimeout() + ProtocolWrapper.write(self, data) + + + def writeSequence(self, seq): + self.resetTimeout() + ProtocolWrapper.writeSequence(self, seq) + + + def dataReceived(self, data): + self.resetTimeout() + ProtocolWrapper.dataReceived(self, data) + + + def connectionLost(self, reason): + self.cancelTimeout() + ProtocolWrapper.connectionLost(self, reason) + + + def timeoutFunc(self): + """ + This method is called when the timeout is triggered. + + By default it calls L{loseConnection}. Override this if you want + something else to happen. + """ + self.loseConnection() + + + +class TimeoutFactory(WrappingFactory): + """ + Factory for TimeoutWrapper. + + Stability: Unstable + """ + protocol = TimeoutProtocol + + + def __init__(self, wrappedFactory, timeoutPeriod=30*60): + self.timeoutPeriod = timeoutPeriod + WrappingFactory.__init__(self, wrappedFactory) + + + def buildProtocol(self, addr): + return self.protocol(self, self.wrappedFactory.buildProtocol(addr), + timeoutPeriod=self.timeoutPeriod) + + + def callLater(self, period, func): + """ + Wrapper around L{reactor.callLater} for test purpose. + """ + return reactor.callLater(period, func) + + + +class TrafficLoggingProtocol(ProtocolWrapper): + + def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None, + number=0): + """ + @param factory: factory which created this protocol. + @type factory: C{protocol.Factory}. + @param wrappedProtocol: the underlying protocol. + @type wrappedProtocol: C{protocol.Protocol}. + @param logfile: file opened for writing used to write log messages. + @type logfile: C{file} + @param lengthLimit: maximum size of the datareceived logged. + @type lengthLimit: C{int} + @param number: identifier of the connection. + @type number: C{int}. + """ + ProtocolWrapper.__init__(self, factory, wrappedProtocol) + self.logfile = logfile + self.lengthLimit = lengthLimit + self._number = number + + + def _log(self, line): + self.logfile.write(line + '\n') + self.logfile.flush() + + + def _mungeData(self, data): + if self.lengthLimit and len(data) > self.lengthLimit: + data = data[:self.lengthLimit - 12] + '<... elided>' + return data + + + # IProtocol + def connectionMade(self): + self._log('*') + return ProtocolWrapper.connectionMade(self) + + + def dataReceived(self, data): + self._log('C %d: %r' % (self._number, self._mungeData(data))) + return ProtocolWrapper.dataReceived(self, data) + + + def connectionLost(self, reason): + self._log('C %d: %r' % (self._number, reason)) + return ProtocolWrapper.connectionLost(self, reason) + + + # ITransport + def write(self, data): + self._log('S %d: %r' % (self._number, self._mungeData(data))) + return ProtocolWrapper.write(self, data) + + + def writeSequence(self, iovec): + self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec])) + return ProtocolWrapper.writeSequence(self, iovec) + + + def loseConnection(self): + self._log('S %d: *' % (self._number,)) + return ProtocolWrapper.loseConnection(self) + + + +class TrafficLoggingFactory(WrappingFactory): + protocol = TrafficLoggingProtocol + + _counter = 0 + + def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None): + self.logfilePrefix = logfilePrefix + self.lengthLimit = lengthLimit + WrappingFactory.__init__(self, wrappedFactory) + + + def open(self, name): + return file(name, 'w') + + + def buildProtocol(self, addr): + self._counter += 1 + logfile = self.open(self.logfilePrefix + '-' + str(self._counter)) + return self.protocol(self, self.wrappedFactory.buildProtocol(addr), + logfile, self.lengthLimit, self._counter) + + + def resetCounter(self): + """ + Reset the value of the counter used to identify connections. + """ + self._counter = 0 + + + +class TimeoutMixin: + """Mixin for protocols which wish to timeout connections + + @cvar timeOut: The number of seconds after which to timeout the connection. + """ + timeOut = None + + __timeoutCall = None + + def callLater(self, period, func): + return reactor.callLater(period, func) + + + def resetTimeout(self): + """Reset the timeout count down""" + if self.__timeoutCall is not None and self.timeOut is not None: + self.__timeoutCall.reset(self.timeOut) + + def setTimeout(self, period): + """Change the timeout period + + @type period: C{int} or C{NoneType} + @param period: The period, in seconds, to change the timeout to, or + C{None} to disable the timeout. + """ + prev = self.timeOut + self.timeOut = period + + if self.__timeoutCall is not None: + if period is None: + self.__timeoutCall.cancel() + self.__timeoutCall = None + else: + self.__timeoutCall.reset(period) + elif period is not None: + self.__timeoutCall = self.callLater(period, self.__timedOut) + + return prev + + def __timedOut(self): + self.__timeoutCall = None + self.timeoutConnection() + + def timeoutConnection(self): + """Called when the connection times out. + Override to define behavior other than dropping the connection. + """ + self.transport.loseConnection() -- 2.39.5