]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/HTTPServer.py
Fix a typo in commit e82e704e27.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index c3c64b86b5434dcf21ff1a3d769e28733fffcf12..0c17d3fd8f6c5b34c2e4e70bbf7eadbf41d05a23 100644 (file)
@@ -1,15 +1,19 @@
 
 """Serve local requests from apt and remote requests from peers."""
 
 
 """Serve local requests from apt and remote requests from peers."""
 
-from urllib import unquote_plus
+from urllib import quote_plus, unquote_plus
 from binascii import b2a_hex
 from binascii import b2a_hex
+import operator
 
 from twisted.python import log
 from twisted.internet import defer
 from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
 
 from twisted.python import log
 from twisted.internet import defer
 from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
+from twisted.trial import unittest
+from twisted.python.filepath import FilePath
 
 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
 
 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
+from apt_p2p_conf import config
 from apt_p2p_Khashmir.bencode import bencode
 
 class FileDownloader(static.File):
 from apt_p2p_Khashmir.bencode import bencode
 
 class FileDownloader(static.File):
@@ -26,12 +30,31 @@ class FileDownloader(static.File):
     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
         self.manager = manager
         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
         self.manager = manager
         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
-        
+    
+    def locateChild(self, req, segments):
+        child, segments = super(FileDownloader, self).locateChild(req, segments)
+        # Make sure we always call renderHTTP()
+        if isinstance(child, FileDownloader):
+            return child, segments
+        else:
+            return self, server.StopTraversal
+            
     def renderHTTP(self, req):
         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
     def renderHTTP(self, req):
         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
+        
+        # Make sure the file is in the DB and unchanged
+        if self.manager and not self.manager.db.isUnchanged(self.fp):
+            if self.fp.exists() and self.fp.isfile():
+                self.fp.remove()
+            return self._renderHTTP_done(http.Response(404,
+                        {'content-type': http_headers.MimeType('text', 'html')},
+                        '<html><body><p>File found but it has changed.</body></html>'),
+                        req)
+            
         resp = super(FileDownloader, self).renderHTTP(req)
         if isinstance(resp, defer.Deferred):
         resp = super(FileDownloader, self).renderHTTP(req)
         if isinstance(resp, defer.Deferred):
-            resp.addCallback(self._renderHTTP_done, req)
+            resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
+                              callbackArgs = (req, ), errbackArgs = (req, ))
         else:
             resp = self._renderHTTP_done(resp, req)
         return resp
         else:
             resp = self._renderHTTP_done(resp, req)
         return resp
@@ -42,18 +65,31 @@ class FileDownloader(static.File):
         if self.manager:
             path = 'http:/' + req.uri
             if resp.code >= 200 and resp.code < 400:
         if self.manager:
             path = 'http:/' + req.uri
             if resp.code >= 200 and resp.code < 400:
-                return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
+                return self.manager.get_resp(req, path, resp)
             
             log.msg('Not found, trying other methods for %s' % req.uri)
             return self.manager.get_resp(req, path)
         
         return resp
 
             
             log.msg('Not found, trying other methods for %s' % req.uri)
             return self.manager.get_resp(req, path)
         
         return resp
 
+    def _renderHTTP_error(self, err, req):
+        log.msg('Failed to render %s: %r' % (req.uri, err))
+        log.err(err)
+        
+        if self.manager:
+            path = 'http:/' + req.uri
+            return self.manager.get_resp(req, path)
+        
+        return err
+
     def createSimilarFile(self, path):
         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
                               self.processors, self.indexNames[:])
         
     def createSimilarFile(self, path):
         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
                               self.processors, self.indexNames[:])
         
-class FileUploaderStream(stream.FileStream):
+class UploadStream:
+    """Identifier for streams that are uploaded to peers."""
+    
+class FileUploaderStream(stream.FileStream, UploadStream):
     """Modified to make it suitable for streaming to peers.
     
     Streams the file in small chunks to make it easier to throttle the
     """Modified to make it suitable for streaming to peers.
     
     Streams the file in small chunks to make it easier to throttle the
@@ -87,7 +123,20 @@ class FileUploaderStream(stream.FileStream):
             self.start += bytesRead
             return b
 
             self.start += bytesRead
             return b
 
+class PiecesUploaderStream(stream.MemoryStream, UploadStream):
+    """Modified to identify it for streaming to peers."""
 
 
+class PiecesUploader(static.Data):
+    """Modified to identify it for peer requests.
+    
+    Uses the modified L{PieceUploaderStream} to stream the pieces for throttling.
+    """
+
+    def render(self, req):
+        return http.Response(responsecode.OK,
+                             http_headers.Headers({'content-type': self.contentType()}),
+                             stream=PiecesUploaderStream(self.data))
+        
 class FileUploader(static.File):
     """Modified to make it suitable for peer requests.
     
 class FileUploader(static.File):
     """Modified to make it suitable for peer requests.
     
@@ -134,6 +183,8 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
     """
     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
     """
+    
+    stats = None
 
     def __init__(self, factory, wrappedProtocol):
         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
 
     def __init__(self, factory, wrappedProtocol):
         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
@@ -142,13 +193,23 @@ class UploadThrottlingProtocol(ThrottlingProtocol):
     def write(self, data):
         if self.throttle:
             ThrottlingProtocol.write(self, data)
     def write(self, data):
         if self.throttle:
             ThrottlingProtocol.write(self, data)
+            if self.stats:
+                self.stats.sentBytes(len(data))
         else:
             ProtocolWrapper.write(self, data)
 
         else:
             ProtocolWrapper.write(self, data)
 
+    def writeSequence(self, seq):
+        if self.throttle:
+            ThrottlingProtocol.writeSequence(self, seq)
+            if self.stats:
+                self.stats.sentBytes(reduce(operator.add, map(len, seq)))
+        else:
+            ProtocolWrapper.writeSequence(self, seq)
+
     def registerProducer(self, producer, streaming):
         ThrottlingProtocol.registerProducer(self, producer, streaming)
         streamType = getattr(producer, 'stream', None)
     def registerProducer(self, producer, streaming):
         ThrottlingProtocol.registerProducer(self, producer, streaming)
         streamType = getattr(producer, 'stream', None)
-        if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
+        if isinstance(streamType, UploadStream):
             self.throttle = True
 
 
             self.throttle = True
 
 
@@ -167,7 +228,7 @@ class TopLevel(resource.Resource):
     
     addSlash = True
     
     
     addSlash = True
     
-    def __init__(self, directory, db, manager, uploadLimit):
+    def __init__(self, directory, db, manager):
         """Initialize the instance.
         
         @type directory: L{twisted.python.filepath.FilePath}
         """Initialize the instance.
         
         @type directory: L{twisted.python.filepath.FilePath}
@@ -181,8 +242,8 @@ class TopLevel(resource.Resource):
         self.db = db
         self.manager = manager
         self.uploadLimit = None
         self.db = db
         self.manager = manager
         self.uploadLimit = None
-        if uploadLimit > 0:
-            self.uploadLimit = int(uploadLimit*1024)
+        if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
+            self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
         self.factory = None
 
     def getHTTPFactory(self):
         self.factory = None
 
     def getHTTPFactory(self):
@@ -193,14 +254,22 @@ class TopLevel(resource.Resource):
                                                   'betweenRequestsTimeOut': 60})
             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
             self.factory.protocol = UploadThrottlingProtocol
                                                   'betweenRequestsTimeOut': 60})
             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
             self.factory.protocol = UploadThrottlingProtocol
+            if self.manager:
+                self.factory.protocol.stats = self.manager.stats
         return self.factory
 
     def render(self, ctx):
         """Render a web page with descriptive statistics."""
         return self.factory
 
     def render(self, ctx):
         """Render a web page with descriptive statistics."""
-        return http.Response(
-            200,
-            {'content-type': http_headers.MimeType('text', 'html')},
-            self.manager.getStats())
+        if self.manager:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                self.manager.getStats())
+        else:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                '<html><body><p>Some Statistics</body></html>')
 
     def locateChild(self, request, segments):
         """Process the incoming request."""
 
     def locateChild(self, request, segments):
         """Process the incoming request."""
@@ -225,29 +294,132 @@ class TopLevel(resource.Resource):
                 else:
                     # It's not for a file, but for a piece string, so return that
                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
                 else:
                     # It's not for a file, but for a piece string, so return that
                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
-                    return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
+                    return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
             else:
             else:
-                log.msg('Hash could not be found in database: %s' % hash)
+                log.msg('Hash could not be found in database: %r' % hash)
+                return None, ()
 
 
-        # Only local requests (apt) get past this point
-        if request.remoteAddr.host != "127.0.0.1":
-            log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
-            return None, ()
-        
-        # Block access to index .diff files (for now)
-        if 'Packages.diff' in segments or 'Sources.diff' in segments:
-            return None, ()
-         
         if len(name) > 1:
             # It's a request from apt
         if len(name) > 1:
             # It's a request from apt
+
+            # Only local requests (apt) get past this point
+            if request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
+            # Block access to index .diff files (for now)
+            if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
+                return None, ()
+             
             return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             # Will render the statistics page
             return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             # Will render the statistics page
+
+            # Only local requests for stats are allowed
+            if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
             return self, ()
         
         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
         return None, ()
 
             return self, ()
         
         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
         return None, ()
 
+class TestTopLevel(unittest.TestCase):
+    """Unit tests for the HTTP Server."""
+    
+    client = None
+    pending_calls = []
+    torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
+    torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+    file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
+    
+    def setUp(self):
+        self.client = TopLevel(FilePath('/boot'), self, None)
+        
+    def lookupHash(self, hash):
+        if hash == self.torrent_hash:
+            return [{'pieces': self.torrent}]
+        elif hash == self.file_hash:
+            return [{'path': FilePath('/boot/grub/stage2')}]
+        else:
+            return []
+        
+    def create_request(self, host, path):
+        req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
+        class addr:
+            host = ''
+            port = 0
+        req.remoteAddr = addr()
+        req.remoteAddr.host = host
+        req.remoteAddr.port = 23456
+        server.Request._parseURL(req)
+        return req
+        
+    def test_unauthorized(self):
+        req = self.create_request('128.0.0.1', '/foo/bar')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Packages_diff(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Statistics(self):
+        req = self.create_request('127.0.0.1', '/')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_apt_download(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/stable/Release')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileDownloader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 404)
+        return df
+        
+    def test_torrent_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.torrent_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, static.Data))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_file_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.file_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileUploader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+    
+    def test_missing_hash(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus('foobar'))
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+    def check_resp(self, resp, code):
+        self.failUnlessEqual(resp.code, code)
+        return resp
+        
+    def tearDown(self):
+        for p in self.pending_calls:
+            if p.active():
+                p.cancel()
+        self.pending_calls = []
+        if self.client:
+            self.client = None
+
 if __name__ == '__builtin__':
     # Running from twistd -ny HTTPServer.py
     # Then test with:
 if __name__ == '__builtin__':
     # Running from twistd -ny HTTPServer.py
     # Then test with:
@@ -263,7 +435,7 @@ if __name__ == '__builtin__':
                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
     
                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
     
-    t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
+    t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
     factory = t.getHTTPFactory()
     
     # Standard twisted application Boilerplate
     factory = t.getHTTPFactory()
     
     # Standard twisted application Boilerplate