Another attempt at throttling, still not working.
authorCameron Dale <camrdale@gmail.com>
Tue, 26 Feb 2008 01:10:58 +0000 (17:10 -0800)
committerCameron Dale <camrdale@gmail.com>
Tue, 26 Feb 2008 01:10:58 +0000 (17:10 -0800)
apt_dht/HTTPServer.py
apt_dht/policies.py [new file with mode: 0644]

index 20f94cd..4d62e5b 100644 (file)
@@ -4,10 +4,11 @@ from urllib import unquote_plus
 from twisted.python import log
 from twisted.internet import defer
 #from twisted.protocols import htb
-#from twisted.protocols.policies import ThrottlingFactory
-from twisted.web2 import server, http, resource, channel
+from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
 
+from policies import ThrottlingFactory
+
 class FileDownloader(static.File):
     
     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
@@ -40,7 +41,64 @@ class FileDownloader(static.File):
         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
                               self.processors, self.indexNames[:])
         
-        
+class FileUploaderStream(stream.FileStream):
+
+    CHUNK_SIZE = 16*1024
+    
+    def read(self, sendfile=False):
+        if self.f is None:
+            return None
+
+        length = self.length
+        if length == 0:
+            self.f = None
+            return None
+
+        readSize = min(length, self.CHUNK_SIZE)
+
+        self.f.seek(self.start)
+        b = self.f.read(readSize)
+        bytesRead = len(b)
+        if not bytesRead:
+            raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
+        else:
+            self.length -= bytesRead
+            self.start += bytesRead
+            return b
+
+
+class FileUploader(static.File):
+
+    def render(self, req):
+        if not self.fp.exists():
+            return responsecode.NOT_FOUND
+
+        if self.fp.isdir():
+            return responsecode.NOT_FOUND
+
+        try:
+            f = self.fp.open()
+        except IOError, e:
+            import errno
+            if e[0] == errno.EACCES:
+                return responsecode.FORBIDDEN
+            elif e[0] == errno.ENOENT:
+                return responsecode.NOT_FOUND
+            else:
+                raise
+
+        response = http.Response()
+        response.stream = FileUploaderStream(f, 0, self.fp.getsize())
+
+        for (header, value) in (
+            ("content-type", self.contentType()),
+            ("content-encoding", self.contentEncoding()),
+        ):
+            if value is not None:
+                response.headers.setHeader(header, value)
+
+        return response
+
 class TopLevel(resource.Resource):
     addSlash = True
     
@@ -65,7 +123,7 @@ class TopLevel(resource.Resource):
 #            serverFilter.buckets[None] = serverBucket
 #
 #            self.factory.protocol = htb.ShapedProtocolFactory(self.factory.protocol, serverFilter)
-#            self.factory = ThrottlingFactory(self.factory, writeLimit = 300*1024)
+            self.factory = ThrottlingFactory(self.factory, writeLimit = 3*1024)
         return self.factory
 
     def render(self, ctx):
@@ -87,7 +145,7 @@ class TopLevel(resource.Resource):
             files = self.db.lookupHash(hash)
             if files:
                 log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
-                return static.File(files[0]['path'].path), ()
+                return FileUploader(files[0]['path'].path), ()
             else:
                 log.msg('Hash could not be found in database: %s' % hash)
         
@@ -104,9 +162,19 @@ class TopLevel(resource.Resource):
         return None, ()
 
 if __name__ == '__builtin__':
-    # Running from twistd -y
-    t = TopLevel('/home', None)
-    t.setDirectories({'~1': '/tmp', '~2': '/var/log'})
+    # Running from twistd -ny HTTPServer.py
+    # Then test with:
+    #   wget -S 'http://localhost:18080/~/whatever'
+    #   wget -S 'http://localhost:18080/.xsession-errors'
+
+    import os.path
+    from twisted.python.filepath import FilePath
+    
+    class DB:
+        def lookupHash(self, hash):
+            return [{'path': FilePath(os.path.expanduser('~/.xsession-errors'))}]
+    
+    t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
     factory = t.getHTTPFactory()
     
     # Standard twisted application Boilerplate
diff --git a/apt_dht/policies.py b/apt_dht/policies.py
new file mode 100644 (file)
index 0000000..76a81d1
--- /dev/null
@@ -0,0 +1,260 @@
+# -*- test-case-name: twisted.test.test_policies -*-
+# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
+# See LICENSE for details.
+
+#
+
+"""Resource limiting policies.
+
+@seealso: See also L{twisted.protocols.htb} for rate limiting.
+"""
+
+# system imports
+import sys, operator
+
+# twisted imports
+from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
+from twisted.internet.interfaces import ITransport
+from twisted.internet import reactor, error
+from twisted.python import log
+from zope.interface import implements, providedBy, directlyProvides
+
+class ProtocolWrapper(Protocol):
+    """Wraps protocol instances and acts as their transport as well."""
+
+    disconnecting = 0
+
+    def __init__(self, factory, wrappedProtocol):
+        self.wrappedProtocol = wrappedProtocol
+        self.factory = factory
+
+    def makeConnection(self, transport):
+        directlyProvides(self, *providedBy(self) + providedBy(transport))
+        Protocol.makeConnection(self, transport)
+
+    # Transport relaying
+
+    def write(self, data):
+        self.transport.write(data)
+
+    def writeSequence(self, data):
+        self.transport.writeSequence(data)
+
+    def loseConnection(self):
+        self.disconnecting = 1
+        self.transport.loseConnection()
+
+    def getPeer(self):
+        return self.transport.getPeer()
+
+    def getHost(self):
+        return self.transport.getHost()
+
+    def registerProducer(self, producer, streaming):
+        self.transport.registerProducer(producer, streaming)
+
+    def unregisterProducer(self):
+        self.transport.unregisterProducer()
+
+    def stopConsuming(self):
+        self.transport.stopConsuming()
+
+    def __getattr__(self, name):
+        return getattr(self.transport, name)
+
+    # Protocol relaying
+
+    def connectionMade(self):
+        self.factory.registerProtocol(self)
+        self.wrappedProtocol.makeConnection(self)
+
+    def dataReceived(self, data):
+        self.wrappedProtocol.dataReceived(data)
+
+    def connectionLost(self, reason):
+        self.factory.unregisterProtocol(self)
+        self.wrappedProtocol.connectionLost(reason)
+
+
+class WrappingFactory(ClientFactory):
+    """Wraps a factory and its protocols, and keeps track of them."""
+
+    protocol = ProtocolWrapper
+
+    def __init__(self, wrappedFactory):
+        self.wrappedFactory = wrappedFactory
+        self.protocols = {}
+
+    def doStart(self):
+        self.wrappedFactory.doStart()
+        ClientFactory.doStart(self)
+
+    def doStop(self):
+        self.wrappedFactory.doStop()
+        ClientFactory.doStop(self)
+
+    def startedConnecting(self, connector):
+        self.wrappedFactory.startedConnecting(connector)
+
+    def clientConnectionFailed(self, connector, reason):
+        self.wrappedFactory.clientConnectionFailed(connector, reason)
+
+    def clientConnectionLost(self, connector, reason):
+        self.wrappedFactory.clientConnectionLost(connector, reason)
+
+    def buildProtocol(self, addr):
+        return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
+
+    def registerProtocol(self, p):
+        """Called by protocol to register itself."""
+        self.protocols[p] = 1
+
+    def unregisterProtocol(self, p):
+        """Called by protocols when they go away."""
+        del self.protocols[p]
+
+
+class ThrottlingProtocol(ProtocolWrapper):
+    """Protocol for ThrottlingFactory."""
+
+    # wrap API for tracking bandwidth
+
+    def write(self, data):
+        self.factory.registerWritten(len(data))
+        ProtocolWrapper.write(self, data)
+
+    def writeSequence(self, seq):
+        self.factory.registerWritten(reduce(operator.add, map(len, seq)))
+        ProtocolWrapper.writeSequence(self, seq)
+
+    def dataReceived(self, data):
+        self.factory.registerRead(len(data))
+        ProtocolWrapper.dataReceived(self, data)
+
+    def registerProducer(self, producer, streaming):
+        self.producer = producer
+        ProtocolWrapper.registerProducer(self, producer, streaming)
+
+    def unregisterProducer(self):
+        del self.producer
+        ProtocolWrapper.unregisterProducer(self)
+
+
+    def throttleReads(self):
+        self.transport.pauseProducing()
+
+    def unthrottleReads(self):
+        self.transport.resumeProducing()
+
+    def throttleWrites(self):
+        if hasattr(self, "producer"):
+            self.producer.pauseProducing()
+
+    def unthrottleWrites(self):
+        if hasattr(self, "producer"):
+            self.producer.resumeProducing()
+
+
+class ThrottlingFactory(WrappingFactory):
+    """Throttles bandwidth and number of connections.
+
+    Write bandwidth will only be throttled if there is a producer
+    registered.
+    """
+
+    protocol = ThrottlingProtocol
+
+    def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
+        WrappingFactory.__init__(self, wrappedFactory)
+        self.connectionCount = 0
+        self.maxConnectionCount = maxConnectionCount
+        self.readLimit = readLimit # max bytes we should read per second
+        self.writeLimit = writeLimit # max bytes we should write per second
+        self.readThisSecond = 0
+        self.writtenThisSecond = 0
+        self.unthrottleReadsID = None
+        self.checkReadBandwidthID = None
+        self.unthrottleWritesID = None
+        self.checkWriteBandwidthID = None
+
+    def registerWritten(self, length):
+        """Called by protocol to tell us more bytes were written."""
+        self.writtenThisSecond += length
+
+    def registerRead(self, length):
+        """Called by protocol to tell us more bytes were read."""
+        self.readThisSecond += length
+
+    def checkReadBandwidth(self):
+        """Checks if we've passed bandwidth limits."""
+        if self.readThisSecond > self.readLimit:
+            self.throttleReads()
+            throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
+            self.unthrottleReadsID = reactor.callLater(throttleTime,
+                                                       self.unthrottleReads)
+        self.readThisSecond = 0
+        self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
+
+    def checkWriteBandwidth(self):
+        if self.writtenThisSecond > self.writeLimit:
+            self.throttleWrites()
+            throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
+            self.unthrottleWritesID = reactor.callLater(throttleTime,
+                                                        self.unthrottleWrites)
+        # reset for next round
+        self.writtenThisSecond = 0
+        self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
+
+    def throttleReads(self):
+        """Throttle reads on all protocols."""
+        log.msg("Throttling reads on %s" % self)
+        for p in self.protocols.keys():
+            p.throttleReads()
+
+    def unthrottleReads(self):
+        """Stop throttling reads on all protocols."""
+        self.unthrottleReadsID = None
+        log.msg("Stopped throttling reads on %s" % self)
+        for p in self.protocols.keys():
+            p.unthrottleReads()
+
+    def throttleWrites(self):
+        """Throttle writes on all protocols."""
+        log.msg("Throttling writes on %s" % self)
+        for p in self.protocols.keys():
+            p.throttleWrites()
+
+    def unthrottleWrites(self):
+        """Stop throttling writes on all protocols."""
+        self.unthrottleWritesID = None
+        log.msg("Stopped throttling writes on %s" % self)
+        for p in self.protocols.keys():
+            p.unthrottleWrites()
+
+    def buildProtocol(self, addr):
+        if self.connectionCount == 0:
+            if self.readLimit is not None:
+                self.checkReadBandwidth()
+            if self.writeLimit is not None:
+                self.checkWriteBandwidth()
+
+        if self.connectionCount < self.maxConnectionCount:
+            self.connectionCount += 1
+            return WrappingFactory.buildProtocol(self, addr)
+        else:
+            log.msg("Max connection count reached!")
+            return None
+
+    def unregisterProtocol(self, p):
+        WrappingFactory.unregisterProtocol(self, p)
+        self.connectionCount -= 1
+        if self.connectionCount == 0:
+            if self.unthrottleReadsID is not None:
+                self.unthrottleReadsID.cancel()
+            if self.checkReadBandwidthID is not None:
+                self.checkReadBandwidthID.cancel()
+            if self.unthrottleWritesID is not None:
+                self.unthrottleWritesID.cancel()
+            if self.checkWriteBandwidthID is not None:
+                self.checkWriteBandwidthID.cancel()
+