]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/HTTPServer.py
WIP on sending multiple KRPC requests before timeout.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index 7f74788d8b1476f1224ef153346def8dd9a5ec3f..3d43fa04b48fa24ad1886c02c939203c0db916f0 100644 (file)
@@ -3,6 +3,7 @@
 
 from urllib import quote_plus, unquote_plus
 from binascii import b2a_hex
+import operator
 
 from twisted.python import log
 from twisted.internet import defer
@@ -12,6 +13,7 @@ from twisted.trial import unittest
 from twisted.python.filepath import FilePath
 
 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
+from apt_p2p_conf import config
 from apt_p2p_Khashmir.bencode import bencode
 
 class FileDownloader(static.File):
@@ -33,7 +35,8 @@ class FileDownloader(static.File):
         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
         resp = super(FileDownloader, self).renderHTTP(req)
         if isinstance(resp, defer.Deferred):
-            resp.addCallback(self._renderHTTP_done, req)
+            resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
+                              callbackArgs = (req, ), errbackArgs = (req, ))
         else:
             resp = self._renderHTTP_done(resp, req)
         return resp
@@ -51,6 +54,16 @@ class FileDownloader(static.File):
         
         return resp
 
+    def _renderHTTP_error(self, err, req):
+        log.msg('Failed to render %s: %r' % (req.uri, err))
+        log.err(err)
+        
+        if self.manager:
+            path = 'http:/' + req.uri
+            return self.manager.get_resp(req, path)
+        
+        return err
+
     def createSimilarFile(self, path):
         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
                               self.processors, self.indexNames[:])
@@ -136,6 +149,8 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
     """
+    
+    stats = None
 
     def __init__(self, factory, wrappedProtocol):
         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
@@ -144,9 +159,19 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     def write(self, data):
         if self.throttle:
             ThrottlingProtocol.write(self, data)
+            if stats:
+                stats.sentBytes(len(data))
         else:
             ProtocolWrapper.write(self, data)
 
+    def writeSequence(self, seq):
+        if self.throttle:
+            ThrottlingProtocol.writeSequence(self, seq)
+            if stats:
+                stats.sentBytes(reduce(operator.add, map(len, seq)))
+        else:
+            ProtocolWrapper.writeSequence(self, seq)
+
     def registerProducer(self, producer, streaming):
         ThrottlingProtocol.registerProducer(self, producer, streaming)
         streamType = getattr(producer, 'stream', None)
@@ -169,7 +194,7 @@ class TopLevel(resource.Resource):
     
     addSlash = True
     
-    def __init__(self, directory, db, manager, uploadLimit):
+    def __init__(self, directory, db, manager):
         """Initialize the instance.
         
         @type directory: L{twisted.python.filepath.FilePath}
@@ -183,8 +208,8 @@ class TopLevel(resource.Resource):
         self.db = db
         self.manager = manager
         self.uploadLimit = None
-        if uploadLimit > 0:
-            self.uploadLimit = int(uploadLimit*1024)
+        if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
+            self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
         self.factory = None
 
     def getHTTPFactory(self):
@@ -195,6 +220,7 @@ class TopLevel(resource.Resource):
                                                   'betweenRequestsTimeOut': 60})
             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
             self.factory.protocol = UploadThrottlingProtocol
+            self.factory.protocol.stats = self.manager.stats
         return self.factory
 
     def render(self, ctx):
@@ -266,7 +292,7 @@ class TestTopLevel(unittest.TestCase):
     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
     
     def setUp(self):
-        self.client = TopLevel(FilePath('/boot'), self, None, 0)
+        self.client = TopLevel(FilePath('/boot'), self, None)
         
     def lookupHash(self, hash):
         if hash == self.torrent_hash: