Only throttle uploads to peers, not to apt.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index 03eea9942863bf7f3dd971341b12cb66019fd44c..b12ca713057c229d604689418298b835a4687885 100644 (file)
@@ -9,7 +9,7 @@ from twisted.internet import defer
 from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
 
-from policies import ThrottlingFactory
+from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
 from apt_p2p_Khashmir.bencode import bencode
 
 class FileDownloader(static.File):
@@ -127,6 +127,31 @@ class FileUploader(static.File):
 
         return response
 
+class UploadThrottlingProtocol(ThrottlingProtocol):
+    """Protocol for throttling uploads.
+    
+    Determines whether or not to throttle the upload based on the type of stream.
+    Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
+    apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
+    """
+
+    def __init__(self, factory, wrappedProtocol):
+        ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
+        self.throttle = False
+
+    def write(self, data):
+        if self.throttle:
+            ThrottlingProtocol.write(self, data)
+        else:
+            ProtocolWrapper.write(self, data)
+
+    def registerProducer(self, producer, streaming):
+        ThrottlingProtocol.registerProducer(self, producer, streaming)
+        streamType = getattr(producer, 'stream', None)
+        if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
+            self.throttle = True
+
+
 class TopLevel(resource.Resource):
     """The HTTP server for all requests, both from peers and apt.
     
@@ -164,6 +189,7 @@ class TopLevel(resource.Resource):
                                                **{'maxPipeline': 10, 
                                                   'betweenRequestsTimeOut': 60})
             self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
+            self.factory.protocol = UploadThrottlingProtocol
         return self.factory
 
     def render(self, ctx):