Fixed the ThrottlingFactory to work with web2 static streams from the web server.
authorCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 06:42:23 +0000 (22:42 -0800)
committerCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 06:42:23 +0000 (22:42 -0800)
apt_dht/HTTPServer.py
apt_dht/policies.py

index 4d62e5b..5c6904f 100644 (file)
@@ -43,7 +43,7 @@ class FileDownloader(static.File):
         
 class FileUploaderStream(stream.FileStream):
 
-    CHUNK_SIZE = 16*1024
+    CHUNK_SIZE = 4*1024
     
     def read(self, sendfile=False):
         if self.f is None:
@@ -123,7 +123,7 @@ class TopLevel(resource.Resource):
 #            serverFilter.buckets[None] = serverBucket
 #
 #            self.factory.protocol = htb.ShapedProtocolFactory(self.factory.protocol, serverFilter)
-            self.factory = ThrottlingFactory(self.factory, writeLimit = 3*1024)
+            self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
         return self.factory
 
     def render(self, ctx):
@@ -172,7 +172,7 @@ if __name__ == '__builtin__':
     
     class DB:
         def lookupHash(self, hash):
-            return [{'path': FilePath(os.path.expanduser('~/.xsession-errors'))}]
+            return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
     
     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
     factory = t.getHTTPFactory()
index 76a81d1..8254c9c 100644 (file)
@@ -119,13 +119,40 @@ class ThrottlingProtocol(ProtocolWrapper):
 
     # wrap API for tracking bandwidth
 
+    def __init__(self, factory, wrappedProtocol):
+        ProtocolWrapper.__init__(self, factory, wrappedProtocol)
+        self._tempDataBuffer = []
+        self._tempDataLength = 0
+        self.throttled = False
+
     def write(self, data):
-        self.factory.registerWritten(len(data))
-        ProtocolWrapper.write(self, data)
+        # Check if we can write
+        if (not self.throttled) and self.factory.registerWritten(len(data)):
+            ProtocolWrapper.write(self, data)
+            
+            if hasattr(self, "producer") and self.producer and not self.producer.paused:
+                # Interrupt the flow so that others can can have a chance
+                # We can only do this if it's not already paused otherwise we
+                # risk unpausing something that the Server paused
+                self.producer.pauseProducing()
+                reactor.callLater(0, self.producer.resumeProducing)
+        else:
+            # Can't write, buffer the data
+            self._tempDataBuffer.append(data)
+            self._tempDataLength += len(data)
+            self._throttleWrites()
 
     def writeSequence(self, seq):
-        self.factory.registerWritten(reduce(operator.add, map(len, seq)))
-        ProtocolWrapper.writeSequence(self, seq)
+        if not self.throttled:
+            # Write each sequence separately
+            while seq and self.factory.registerWritten(len(seq[0])):
+                ProtocolWrapper.write(self, seq.pop(0))
+
+        # If there's some left, we must have been throttled
+        if seq:
+            self._tempDataBuffer.extend(seq)
+            self._tempDataLength += reduce(operator.add, map(len, seq))
+            self._throttleWrites()
 
     def dataReceived(self, data):
         self.factory.registerRead(len(data))
@@ -146,13 +173,33 @@ class ThrottlingProtocol(ProtocolWrapper):
     def unthrottleReads(self):
         self.transport.resumeProducing()
 
-    def throttleWrites(self):
-        if hasattr(self, "producer"):
+    def _throttleWrites(self):
+        # If we haven't yet, queue for unthrottling
+        if not self.throttled:
+            self.throttled = True
+            self.factory.throttledWrites(self)
+
+        if hasattr(self, "producer") and self.producer:
             self.producer.pauseProducing()
 
     def unthrottleWrites(self):
-        if hasattr(self, "producer"):
-            self.producer.resumeProducing()
+        # Write some data
+        if self._tempDataBuffer:
+            assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
+            self._tempDataLength -= len(self._tempDataBuffer[0])
+            ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
+            assert self._tempDataLength >= 0
+
+        # If we wrote it all, start producing more
+        if not self._tempDataBuffer:
+            assert self._tempDataLength == 0
+            self.throttled = False
+            if hasattr(self, "producer") and self.producer:
+                # This might unpause something the Server has also paused, but
+                # it will get paused again on first write anyway
+                reactor.callLater(0, self.producer.resumeProducing)
+        
+        return self._tempDataLength
 
 
 class ThrottlingFactory(WrappingFactory):
@@ -163,6 +210,7 @@ class ThrottlingFactory(WrappingFactory):
     """
 
     protocol = ThrottlingProtocol
+    CHUNK_SIZE = 4*1024
 
     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
         WrappingFactory.__init__(self, wrappedFactory)
@@ -171,7 +219,8 @@ class ThrottlingFactory(WrappingFactory):
         self.readLimit = readLimit # max bytes we should read per second
         self.writeLimit = writeLimit # max bytes we should write per second
         self.readThisSecond = 0
-        self.writtenThisSecond = 0
+        self.writeAvailable = writeLimit
+        self._writeQueue = []
         self.unthrottleReadsID = None
         self.checkReadBandwidthID = None
         self.unthrottleWritesID = None
@@ -179,7 +228,17 @@ class ThrottlingFactory(WrappingFactory):
 
     def registerWritten(self, length):
         """Called by protocol to tell us more bytes were written."""
-        self.writtenThisSecond += length
+        # Check if there are bytes available to write
+        if self.writeAvailable > 0:
+            self.writeAvailable -= length
+            return True
+        
+        return False
+    
+    def throttledWrites(self, p):
+        """Called by the protocol to queue it for later writing."""
+        assert p not in self._writeQueue
+        self._writeQueue.append(p)
 
     def registerRead(self, length):
         """Called by protocol to tell us more bytes were read."""
@@ -196,13 +255,26 @@ class ThrottlingFactory(WrappingFactory):
         self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
 
     def checkWriteBandwidth(self):
-        if self.writtenThisSecond > self.writeLimit:
-            self.throttleWrites()
-            throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
-            self.unthrottleWritesID = reactor.callLater(throttleTime,
-                                                        self.unthrottleWrites)
-        # reset for next round
-        self.writtenThisSecond = 0
+        """Add some new available bandwidth, and check for protocols to unthrottle."""
+        # Increase the available write bytes, but not higher than the limit
+        self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
+        
+        # Write from the queue until it's empty or we're throttled again
+        while self.writeAvailable > 0 and self._writeQueue:
+            # Get the first queued protocol
+            p = self._writeQueue.pop(0)
+            _tempWriteAvailable = self.writeAvailable
+            bytesLeft = 1
+            
+            # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
+            while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
+                # Unthrottle a single write (from the protocol's buffer)
+                bytesLeft = p.unthrottleWrites()
+                
+            # If the protocol is not done, requeue it
+            if bytesLeft > 0:
+                self._writeQueue.append(p)
+
         self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
 
     def throttleReads(self):
@@ -218,19 +290,6 @@ class ThrottlingFactory(WrappingFactory):
         for p in self.protocols.keys():
             p.unthrottleReads()
 
-    def throttleWrites(self):
-        """Throttle writes on all protocols."""
-        log.msg("Throttling writes on %s" % self)
-        for p in self.protocols.keys():
-            p.throttleWrites()
-
-    def unthrottleWrites(self):
-        """Stop throttling writes on all protocols."""
-        self.unthrottleWritesID = None
-        log.msg("Stopped throttling writes on %s" % self)
-        for p in self.protocols.keys():
-            p.unthrottleWrites()
-
     def buildProtocol(self, addr):
         if self.connectionCount == 0:
             if self.readLimit is not None: