]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/HTTPServer.py
Always try and find the mirror site, even if no updated files have been saved.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index ee314ac66710e54424069e207fa41b98f2a088e9..1d1af488fdabddb289bb4cc9e8bfd430078b624e 100644 (file)
@@ -3,6 +3,7 @@
 
 from urllib import quote_plus, unquote_plus
 from binascii import b2a_hex
+import operator
 
 from twisted.python import log
 from twisted.internet import defer
@@ -148,6 +149,8 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
     """
+    
+    stats = None
 
     def __init__(self, factory, wrappedProtocol):
         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
@@ -156,9 +159,19 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     def write(self, data):
         if self.throttle:
             ThrottlingProtocol.write(self, data)
+            if self.stats:
+                self.stats.sentBytes(len(data))
         else:
             ProtocolWrapper.write(self, data)
 
+    def writeSequence(self, seq):
+        if self.throttle:
+            ThrottlingProtocol.writeSequence(self, seq)
+            if self.stats:
+                self.stats.sentBytes(reduce(operator.add, map(len, seq)))
+        else:
+            ProtocolWrapper.writeSequence(self, seq)
+
     def registerProducer(self, producer, streaming):
         ThrottlingProtocol.registerProducer(self, producer, streaming)
         streamType = getattr(producer, 'stream', None)
@@ -207,6 +220,8 @@ class TopLevel(resource.Resource):
                                                   'betweenRequestsTimeOut': 60})
             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
             self.factory.protocol = UploadThrottlingProtocol
+            if self.manager:
+                self.factory.protocol.stats = self.manager.stats
         return self.factory
 
     def render(self, ctx):
@@ -249,20 +264,27 @@ class TopLevel(resource.Resource):
             else:
                 log.msg('Hash could not be found in database: %r' % hash)
 
-        # Only local requests (apt) get past this point
-        if request.remoteAddr.host != "127.0.0.1":
-            log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
-            return None, ()
-        
-        # Block access to index .diff files (for now)
-        if 'Packages.diff' in segments or 'Sources.diff' in segments:
-            return None, ()
-         
         if len(name) > 1:
             # It's a request from apt
+
+            # Only local requests (apt) get past this point
+            if request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
+            # Block access to index .diff files (for now)
+            if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
+                return None, ()
+             
             return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             # Will render the statistics page
+
+            # Only local requests for stats are allowed
+            if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
             return self, ()
         
         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
@@ -378,7 +400,7 @@ if __name__ == '__builtin__':
                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
     
-    t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
+    t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
     factory = t.getHTTPFactory()
     
     # Standard twisted application Boilerplate