Upgrade policies to SVN version and fix a small bug.
authorCameron Dale <camrdale@gmail.com>
Fri, 29 Feb 2008 00:17:22 +0000 (16:17 -0800)
committerCameron Dale <camrdale@gmail.com>
Fri, 29 Feb 2008 00:17:22 +0000 (16:17 -0800)
Fixed the policies usage when writeLimit was None to not do anything.

apt_dht/policies.py

index 3c84c149b3c38dd1317b46031fddf0b602da6e33..e7bae81cbcf236741c58d8e977a6d469a4882536 100644 (file)
@@ -1,10 +1,10 @@
 # -*- test-case-name: twisted.test.test_policies -*-
-# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
+# Copyright (c) 2001-2007 Twisted Matrix Laboratories.
 # See LICENSE for details.
 
-#
 
-"""Resource limiting policies.
+"""
+Resource limiting policies.
 
 @seealso: See also L{twisted.protocols.htb} for rate limiting.
 """
@@ -14,10 +14,10 @@ import sys, operator
 
 # twisted imports
 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
-from twisted.internet.interfaces import ITransport
 from twisted.internet import reactor, error
 from twisted.python import log
-from zope.interface import implements, providedBy, directlyProvides
+from zope.interface import providedBy, directlyProvides
+
 
 class ProtocolWrapper(Protocol):
     """Wraps protocol instances and acts as their transport as well."""
@@ -127,16 +127,19 @@ class ThrottlingProtocol(ProtocolWrapper):
 
     def write(self, data):
         # Check if we can write
-        if (not self.throttled) and self.factory.registerWritten(len(data)):
-            ProtocolWrapper.write(self, data)
-            
-            if hasattr(self, "producer") and self.producer and not self.producer.paused:
-                # Interrupt the flow so that others can can have a chance
-                # We can only do this if it's not already paused otherwise we
-                # risk unpausing something that the Server paused
-                self.producer.pauseProducing()
-                reactor.callLater(0, self.producer.resumeProducing)
-        else:
+        if not self.throttled:
+            paused = self.factory.registerWritten(len(data))
+            if not paused:
+                ProtocolWrapper.write(self, data)
+                
+                if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
+                    # Interrupt the flow so that others can can have a chance
+                    # We can only do this if it's not already paused otherwise we
+                    # risk unpausing something that the Server paused
+                    self.producer.pauseProducing()
+                    reactor.callLater(0, self.producer.resumeProducing)
+
+        if self.throttled or paused:
             # Can't write, buffer the data
             self._tempDataBuffer.append(data)
             self._tempDataLength += len(data)
@@ -145,10 +148,10 @@ class ThrottlingProtocol(ProtocolWrapper):
     def writeSequence(self, seq):
         if not self.throttled:
             # Write each sequence separately
-            while seq and self.factory.registerWritten(len(seq[0])):
+            while seq and not self.factory.registerWritten(len(seq[0])):
                 ProtocolWrapper.write(self, seq.pop(0))
 
-        # If there's some left, we must have been throttled
+        # If there's some left, we must have been paused
         if seq:
             self._tempDataBuffer.extend(seq)
             self._tempDataLength += reduce(operator.add, map(len, seq))
@@ -186,7 +189,7 @@ class ThrottlingProtocol(ProtocolWrapper):
     def unthrottleWrites(self):
         # Write some data
         if self._tempDataBuffer:
-            assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
+            assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
             self._tempDataLength -= len(self._tempDataBuffer[0])
             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
             assert self._tempDataLength >= 0
@@ -204,7 +207,8 @@ class ThrottlingProtocol(ProtocolWrapper):
 
 
 class ThrottlingFactory(WrappingFactory):
-    """Throttles bandwidth and number of connections.
+    """
+    Throttles bandwidth and number of connections.
 
     Write bandwidth will only be throttled if there is a producer
     registered.
@@ -213,7 +217,8 @@ class ThrottlingFactory(WrappingFactory):
     protocol = ThrottlingProtocol
     CHUNK_SIZE = 4*1024
 
-    def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
+    def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
+                 readLimit=None, writeLimit=None):
         WrappingFactory.__init__(self, wrappedFactory)
         self.connectionCount = 0
         self.maxConnectionCount = maxConnectionCount
@@ -227,36 +232,61 @@ class ThrottlingFactory(WrappingFactory):
         self.unthrottleWritesID = None
         self.checkWriteBandwidthID = None
 
+
+    def callLater(self, period, func):
+        """
+        Wrapper around L{reactor.callLater} for test purpose.
+        """
+        return reactor.callLater(period, func)
+
+
     def registerWritten(self, length):
-        """Called by protocol to tell us more bytes were written."""
+        """
+        Called by protocol to tell us more bytes were written.
+        Returns True if the bytes could not be written and the protocol should pause itself.
+        """
         # Check if there are bytes available to write
-        if self.writeAvailable > 0:
+        if self.writeLimit is None:
+            return None
+        elif self.writeAvailable > 0:
             self.writeAvailable -= length
-            return True
+            return False
         
-        return False
+        return True
+
     
     def throttledWrites(self, p):
-        """Called by the protocol to queue it for later writing."""
+        """
+        Called by the protocol to queue it for later writing.
+        """
         assert p not in self._writeQueue
         self._writeQueue.append(p)
 
+
     def registerRead(self, length):
-        """Called by protocol to tell us more bytes were read."""
+        """
+        Called by protocol to tell us more bytes were read.
+        """
         self.readThisSecond += length
 
+
     def checkReadBandwidth(self):
-        """Checks if we've passed bandwidth limits."""
+        """
+        Checks if we've passed bandwidth limits.
+        """
         if self.readThisSecond > self.readLimit:
             self.throttleReads()
             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
-            self.unthrottleReadsID = reactor.callLater(throttleTime,
-                                                       self.unthrottleReads)
+            self.unthrottleReadsID = self.callLater(throttleTime,
+                                                    self.unthrottleReads)
         self.readThisSecond = 0
-        self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
+        self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
+
 
     def checkWriteBandwidth(self):
-        """Add some new available bandwidth, and check for protocols to unthrottle."""
+        """
+        Add some new available bandwidth, and check for protocols to unthrottle.
+        """
         # Increase the available write bytes, but not higher than the limit
         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
         
@@ -276,21 +306,28 @@ class ThrottlingFactory(WrappingFactory):
             if bytesLeft > 0:
                 self._writeQueue.append(p)
 
-        self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
+        self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
+
 
     def throttleReads(self):
-        """Throttle reads on all protocols."""
+        """
+        Throttle reads on all protocols.
+        """
         log.msg("Throttling reads on %s" % self)
         for p in self.protocols.keys():
             p.throttleReads()
 
+
     def unthrottleReads(self):
-        """Stop throttling reads on all protocols."""
+        """
+        Stop throttling reads on all protocols.
+        """
         self.unthrottleReadsID = None
         log.msg("Stopped throttling reads on %s" % self)
         for p in self.protocols.keys():
             p.unthrottleReads()
 
+
     def buildProtocol(self, addr):
         if self.connectionCount == 0:
             if self.readLimit is not None:
@@ -305,6 +342,7 @@ class ThrottlingFactory(WrappingFactory):
             log.msg("Max connection count reached!")
             return None
 
+
     def unregisterProtocol(self, p):
         WrappingFactory.unregisterProtocol(self, p)
         self.connectionCount -= 1
@@ -318,3 +356,347 @@ class ThrottlingFactory(WrappingFactory):
             if self.checkWriteBandwidthID is not None:
                 self.checkWriteBandwidthID.cancel()
 
+
+
+class SpewingProtocol(ProtocolWrapper):
+    def dataReceived(self, data):
+        log.msg("Received: %r" % data)
+        ProtocolWrapper.dataReceived(self,data)
+
+    def write(self, data):
+        log.msg("Sending: %r" % data)
+        ProtocolWrapper.write(self,data)
+
+
+
+class SpewingFactory(WrappingFactory):
+    protocol = SpewingProtocol
+
+
+
+class LimitConnectionsByPeer(WrappingFactory):
+    """Stability: Unstable"""
+
+    maxConnectionsPerPeer = 5
+
+    def startFactory(self):
+        self.peerConnections = {}
+
+    def buildProtocol(self, addr):
+        peerHost = addr[0]
+        connectionCount = self.peerConnections.get(peerHost, 0)
+        if connectionCount >= self.maxConnectionsPerPeer:
+            return None
+        self.peerConnections[peerHost] = connectionCount + 1
+        return WrappingFactory.buildProtocol(self, addr)
+
+    def unregisterProtocol(self, p):
+        peerHost = p.getPeer()[1]
+        self.peerConnections[peerHost] -= 1
+        if self.peerConnections[peerHost] == 0:
+            del self.peerConnections[peerHost]
+
+
+class LimitTotalConnectionsFactory(ServerFactory):
+    """Factory that limits the number of simultaneous connections.
+
+    API Stability: Unstable
+
+    @type connectionCount: C{int}
+    @ivar connectionCount: number of current connections.
+    @type connectionLimit: C{int} or C{None}
+    @cvar connectionLimit: maximum number of connections.
+    @type overflowProtocol: L{Protocol} or C{None}
+    @cvar overflowProtocol: Protocol to use for new connections when
+        connectionLimit is exceeded.  If C{None} (the default value), excess
+        connections will be closed immediately.
+    """
+    connectionCount = 0
+    connectionLimit = None
+    overflowProtocol = None
+
+    def buildProtocol(self, addr):
+        if (self.connectionLimit is None or
+            self.connectionCount < self.connectionLimit):
+                # Build the normal protocol
+                wrappedProtocol = self.protocol()
+        elif self.overflowProtocol is None:
+            # Just drop the connection
+            return None
+        else:
+            # Too many connections, so build the overflow protocol
+            wrappedProtocol = self.overflowProtocol()
+
+        wrappedProtocol.factory = self
+        protocol = ProtocolWrapper(self, wrappedProtocol)
+        self.connectionCount += 1
+        return protocol
+
+    def registerProtocol(self, p):
+        pass
+
+    def unregisterProtocol(self, p):
+        self.connectionCount -= 1
+
+
+
+class TimeoutProtocol(ProtocolWrapper):
+    """
+    Protocol that automatically disconnects when the connection is idle.
+
+    Stability: Unstable
+    """
+
+    def __init__(self, factory, wrappedProtocol, timeoutPeriod):
+        """
+        Constructor.
+
+        @param factory: An L{IFactory}.
+        @param wrappedProtocol: A L{Protocol} to wrapp.
+        @param timeoutPeriod: Number of seconds to wait for activity before
+            timing out.
+        """
+        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
+        self.timeoutCall = None
+        self.setTimeout(timeoutPeriod)
+
+
+    def setTimeout(self, timeoutPeriod=None):
+        """
+        Set a timeout.
+
+        This will cancel any existing timeouts.
+
+        @param timeoutPeriod: If not C{None}, change the timeout period.
+            Otherwise, use the existing value.
+        """
+        self.cancelTimeout()
+        if timeoutPeriod is not None:
+            self.timeoutPeriod = timeoutPeriod
+        self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
+
+
+    def cancelTimeout(self):
+        """
+        Cancel the timeout.
+
+        If the timeout was already cancelled, this does nothing.
+        """
+        if self.timeoutCall:
+            try:
+                self.timeoutCall.cancel()
+            except error.AlreadyCalled:
+                pass
+            self.timeoutCall = None
+
+
+    def resetTimeout(self):
+        """
+        Reset the timeout, usually because some activity just happened.
+        """
+        if self.timeoutCall:
+            self.timeoutCall.reset(self.timeoutPeriod)
+
+
+    def write(self, data):
+        self.resetTimeout()
+        ProtocolWrapper.write(self, data)
+
+
+    def writeSequence(self, seq):
+        self.resetTimeout()
+        ProtocolWrapper.writeSequence(self, seq)
+
+
+    def dataReceived(self, data):
+        self.resetTimeout()
+        ProtocolWrapper.dataReceived(self, data)
+
+
+    def connectionLost(self, reason):
+        self.cancelTimeout()
+        ProtocolWrapper.connectionLost(self, reason)
+
+
+    def timeoutFunc(self):
+        """
+        This method is called when the timeout is triggered.
+
+        By default it calls L{loseConnection}.  Override this if you want
+        something else to happen.
+        """
+        self.loseConnection()
+
+
+
+class TimeoutFactory(WrappingFactory):
+    """
+    Factory for TimeoutWrapper.
+
+    Stability: Unstable
+    """
+    protocol = TimeoutProtocol
+
+
+    def __init__(self, wrappedFactory, timeoutPeriod=30*60):
+        self.timeoutPeriod = timeoutPeriod
+        WrappingFactory.__init__(self, wrappedFactory)
+
+
+    def buildProtocol(self, addr):
+        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
+                             timeoutPeriod=self.timeoutPeriod)
+
+
+    def callLater(self, period, func):
+        """
+        Wrapper around L{reactor.callLater} for test purpose.
+        """
+        return reactor.callLater(period, func)
+
+
+
+class TrafficLoggingProtocol(ProtocolWrapper):
+
+    def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
+                 number=0):
+        """
+        @param factory: factory which created this protocol.
+        @type factory: C{protocol.Factory}.
+        @param wrappedProtocol: the underlying protocol.
+        @type wrappedProtocol: C{protocol.Protocol}.
+        @param logfile: file opened for writing used to write log messages.
+        @type logfile: C{file}
+        @param lengthLimit: maximum size of the datareceived logged.
+        @type lengthLimit: C{int}
+        @param number: identifier of the connection.
+        @type number: C{int}.
+        """
+        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
+        self.logfile = logfile
+        self.lengthLimit = lengthLimit
+        self._number = number
+
+
+    def _log(self, line):
+        self.logfile.write(line + '\n')
+        self.logfile.flush()
+
+
+    def _mungeData(self, data):
+        if self.lengthLimit and len(data) > self.lengthLimit:
+            data = data[:self.lengthLimit - 12] + '<... elided>'
+        return data
+
+
+    # IProtocol
+    def connectionMade(self):
+        self._log('*')
+        return ProtocolWrapper.connectionMade(self)
+
+
+    def dataReceived(self, data):
+        self._log('C %d: %r' % (self._number, self._mungeData(data)))
+        return ProtocolWrapper.dataReceived(self, data)
+
+
+    def connectionLost(self, reason):
+        self._log('C %d: %r' % (self._number, reason))
+        return ProtocolWrapper.connectionLost(self, reason)
+
+
+    # ITransport
+    def write(self, data):
+        self._log('S %d: %r' % (self._number, self._mungeData(data)))
+        return ProtocolWrapper.write(self, data)
+
+
+    def writeSequence(self, iovec):
+        self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
+        return ProtocolWrapper.writeSequence(self, iovec)
+
+
+    def loseConnection(self):
+        self._log('S %d: *' % (self._number,))
+        return ProtocolWrapper.loseConnection(self)
+
+
+
+class TrafficLoggingFactory(WrappingFactory):
+    protocol = TrafficLoggingProtocol
+
+    _counter = 0
+
+    def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
+        self.logfilePrefix = logfilePrefix
+        self.lengthLimit = lengthLimit
+        WrappingFactory.__init__(self, wrappedFactory)
+
+
+    def open(self, name):
+        return file(name, 'w')
+
+
+    def buildProtocol(self, addr):
+        self._counter += 1
+        logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
+        return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
+                             logfile, self.lengthLimit, self._counter)
+
+
+    def resetCounter(self):
+        """
+        Reset the value of the counter used to identify connections.
+        """
+        self._counter = 0
+
+
+
+class TimeoutMixin:
+    """Mixin for protocols which wish to timeout connections
+
+    @cvar timeOut: The number of seconds after which to timeout the connection.
+    """
+    timeOut = None
+
+    __timeoutCall = None
+
+    def callLater(self, period, func):
+        return reactor.callLater(period, func)
+
+
+    def resetTimeout(self):
+        """Reset the timeout count down"""
+        if self.__timeoutCall is not None and self.timeOut is not None:
+            self.__timeoutCall.reset(self.timeOut)
+
+    def setTimeout(self, period):
+        """Change the timeout period
+
+        @type period: C{int} or C{NoneType}
+        @param period: The period, in seconds, to change the timeout to, or
+        C{None} to disable the timeout.
+        """
+        prev = self.timeOut
+        self.timeOut = period
+
+        if self.__timeoutCall is not None:
+            if period is None:
+                self.__timeoutCall.cancel()
+                self.__timeoutCall = None
+            else:
+                self.__timeoutCall.reset(period)
+        elif period is not None:
+            self.__timeoutCall = self.callLater(period, self.__timedOut)
+
+        return prev
+
+    def __timedOut(self):
+        self.__timeoutCall = None
+        self.timeoutConnection()
+
+    def timeoutConnection(self):
+        """Called when the connection times out.
+        Override to define behavior other than dropping the connection.
+        """
+        self.transport.loseConnection()