Fix some documentation errors.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index d252a6386aef205ed397e83579559eb2a33139db..c92df92d2baf461651d1dbcd6bab67f60601ec44 100644 (file)
@@ -1,15 +1,20 @@
 
 """Serve local requests from apt and remote requests from peers."""
 
-from urllib import unquote_plus
+from urllib import quote_plus, unquote_plus
 from binascii import b2a_hex
+import operator
 
 from twisted.python import log
 from twisted.internet import defer
 from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
+from twisted.trial import unittest
+from twisted.python.filepath import FilePath
 
-from policies import ThrottlingFactory
+from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
+from Streams import UploadStream, FileUploadStream, PiecesUploadStream
+from apt_p2p_conf import config
 from apt_p2p_Khashmir.bencode import bencode
 
 class FileDownloader(static.File):
@@ -26,12 +31,31 @@ class FileDownloader(static.File):
     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
         self.manager = manager
         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
-        
+    
+    def locateChild(self, req, segments):
+        child, segments = super(FileDownloader, self).locateChild(req, segments)
+        # Make sure we always call renderHTTP()
+        if isinstance(child, FileDownloader):
+            return child, segments
+        else:
+            return self, server.StopTraversal
+            
     def renderHTTP(self, req):
         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
+        
+        # Make sure the file is in the DB and unchanged
+        if self.manager and not self.manager.db.isUnchanged(self.fp):
+            if self.fp.exists() and self.fp.isfile():
+                self.fp.remove()
+            return self._renderHTTP_done(http.Response(404,
+                        {'content-type': http_headers.MimeType('text', 'html')},
+                        '<html><body><p>File found but it has changed.</body></html>'),
+                        req)
+            
         resp = super(FileDownloader, self).renderHTTP(req)
         if isinstance(resp, defer.Deferred):
-            resp.addCallback(self._renderHTTP_done, req)
+            resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
+                              callbackArgs = (req, ), errbackArgs = (req, ))
         else:
             resp = self._renderHTTP_done(resp, req)
         return resp
@@ -42,56 +66,42 @@ class FileDownloader(static.File):
         if self.manager:
             path = 'http:/' + req.uri
             if resp.code >= 200 and resp.code < 400:
-                return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
+                return self.manager.get_resp(req, path, resp)
             
             log.msg('Not found, trying other methods for %s' % req.uri)
             return self.manager.get_resp(req, path)
         
         return resp
 
+    def _renderHTTP_error(self, err, req):
+        log.msg('Failed to render %s: %r' % (req.uri, err))
+        log.err(err)
+        
+        if self.manager:
+            path = 'http:/' + req.uri
+            return self.manager.get_resp(req, path)
+        
+        return err
+
     def createSimilarFile(self, path):
         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
                               self.processors, self.indexNames[:])
         
-class FileUploaderStream(stream.FileStream):
-    """Modified to make it suitable for streaming to peers.
+class PiecesUploader(static.Data):
+    """Modified to identify it for peer requests.
     
-    Streams the file is small chunks to make it easier to throttle the
-    streaming to peers.
-    
-    @ivar CHUNK_SIZE: the size of chunks of data to send at a time
+    Uses the modified L{Streams.PiecesUploadStream} to stream the pieces for throttling.
     """
 
-    CHUNK_SIZE = 4*1024
-    
-    def read(self, sendfile=False):
-        if self.f is None:
-            return None
-
-        length = self.length
-        if length == 0:
-            self.f = None
-            return None
+    def render(self, req):
+        return http.Response(responsecode.OK,
+                             http_headers.Headers({'content-type': self.contentType()}),
+                             stream=PiecesUploadStream(self.data))
         
-        # Remove the SendFileBuffer and mmap use, just use string reads and writes
-
-        readSize = min(length, self.CHUNK_SIZE)
-
-        self.f.seek(self.start)
-        b = self.f.read(readSize)
-        bytesRead = len(b)
-        if not bytesRead:
-            raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
-        else:
-            self.length -= bytesRead
-            self.start += bytesRead
-            return b
-
-
 class FileUploader(static.File):
     """Modified to make it suitable for peer requests.
     
-    Uses the modified L{FileUploaderStream} to stream the file for throttling,
+    Uses the modified L{Streams.FileUploadStream} to stream the file for throttling,
     and doesn't do any listing of directory contents.
     """
 
@@ -116,7 +126,7 @@ class FileUploader(static.File):
 
         response = http.Response()
         # Use the modified FileStream
-        response.stream = FileUploaderStream(f, 0, self.fp.getsize())
+        response.stream = FileUploadStream(f, 0, self.fp.getsize())
 
         for (header, value) in (
             ("content-type", self.contentType()),
@@ -127,6 +137,42 @@ class FileUploader(static.File):
 
         return response
 
+class UploadThrottlingProtocol(ThrottlingProtocol):
+    """Protocol for throttling uploads.
+    
+    Determines whether or not to throttle the upload based on the type of stream.
+    Uploads use instances of L{Streams.UploadStream}.
+    """
+    
+    stats = None
+
+    def __init__(self, factory, wrappedProtocol):
+        ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
+        self.throttle = False
+
+    def write(self, data):
+        if self.throttle:
+            ThrottlingProtocol.write(self, data)
+            if self.stats:
+                self.stats.sentBytes(len(data))
+        else:
+            ProtocolWrapper.write(self, data)
+
+    def writeSequence(self, seq):
+        if self.throttle:
+            ThrottlingProtocol.writeSequence(self, seq)
+            if self.stats:
+                self.stats.sentBytes(reduce(operator.add, map(len, seq)))
+        else:
+            ProtocolWrapper.writeSequence(self, seq)
+
+    def registerProducer(self, producer, streaming):
+        ThrottlingProtocol.registerProducer(self, producer, streaming)
+        streamType = getattr(producer, 'stream', None)
+        if isinstance(streamType, UploadStream):
+            self.throttle = True
+
+
 class TopLevel(resource.Resource):
     """The HTTP server for all requests, both from peers and apt.
     
@@ -137,8 +183,7 @@ class TopLevel(resource.Resource):
     @type manager: L{apt_p2p.AptP2P}
     @ivar manager: the main program object to send requests to
     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
-    @ivar factory: the factory to use to server HTTP requests
-    
+    @ivar factory: the factory to use to serve HTTP requests
     """
     
     addSlash = True
@@ -156,6 +201,9 @@ class TopLevel(resource.Resource):
         self.directory = directory
         self.db = db
         self.manager = manager
+        self.uploadLimit = None
+        if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
+            self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
         self.factory = None
 
     def getHTTPFactory(self):
@@ -164,17 +212,24 @@ class TopLevel(resource.Resource):
             self.factory = channel.HTTPFactory(server.Site(self),
                                                **{'maxPipeline': 10, 
                                                   'betweenRequestsTimeOut': 60})
-            self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
+            self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
+            self.factory.protocol = UploadThrottlingProtocol
+            if self.manager:
+                self.factory.protocol.stats = self.manager.stats
         return self.factory
 
     def render(self, ctx):
         """Render a web page with descriptive statistics."""
-        return http.Response(
-            200,
-            {'content-type': http_headers.MimeType('text', 'html')},
-            """<html><body>
-            <h2>Statistics</h2>
-            <p>TODO: eventually some stats will be shown here.</body></html>""")
+        if self.manager:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                self.manager.getStats())
+        else:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                '<html><body><p>Some Statistics</body></html>')
 
     def locateChild(self, request, segments):
         """Process the incoming request."""
@@ -188,7 +243,8 @@ class TopLevel(resource.Resource):
                 return None, ()
             
             # Find the file in the database
-            hash = unquote_plus(segments[1])
+            # Have to unquote_plus the uri, because the segments are unquoted by twisted
+            hash = unquote_plus(request.uri[3:])
             files = self.db.lookupHash(hash)
             if files:
                 # If it is a file, return it
@@ -198,25 +254,132 @@ class TopLevel(resource.Resource):
                 else:
                     # It's not for a file, but for a piece string, so return that
                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
-                    return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
+                    return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
             else:
-                log.msg('Hash could not be found in database: %s' % hash)
+                log.msg('Hash could not be found in database: %r' % hash)
+                return None, ()
 
-        # Only local requests (apt) get past this point
-        if request.remoteAddr.host != "127.0.0.1":
-            log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
-            return None, ()
-            
         if len(name) > 1:
             # It's a request from apt
+
+            # Only local requests (apt) get past this point
+            if request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
+            # Block access to index .diff files (for now)
+            if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
+                return None, ()
+             
             return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             # Will render the statistics page
+
+            # Only local requests for stats are allowed
+            if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
+                log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+                return None, ()
+
             return self, ()
         
         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
         return None, ()
 
+class TestTopLevel(unittest.TestCase):
+    """Unit tests for the HTTP Server."""
+    
+    client = None
+    pending_calls = []
+    torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
+    torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+    file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
+    
+    def setUp(self):
+        self.client = TopLevel(FilePath('/boot'), self, None)
+        
+    def lookupHash(self, hash):
+        if hash == self.torrent_hash:
+            return [{'pieces': self.torrent}]
+        elif hash == self.file_hash:
+            return [{'path': FilePath('/boot/grub/stage2')}]
+        else:
+            return []
+        
+    def create_request(self, host, path):
+        req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
+        class addr:
+            host = ''
+            port = 0
+        req.remoteAddr = addr()
+        req.remoteAddr.host = host
+        req.remoteAddr.port = 23456
+        server.Request._parseURL(req)
+        return req
+        
+    def test_unauthorized(self):
+        req = self.create_request('128.0.0.1', '/foo/bar')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Packages_diff(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Statistics(self):
+        req = self.create_request('127.0.0.1', '/')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_apt_download(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/stable/Release')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileDownloader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 404)
+        return df
+        
+    def test_torrent_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.torrent_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, static.Data))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_file_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.file_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileUploader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+    
+    def test_missing_hash(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus('foobar'))
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+    def check_resp(self, resp, code):
+        self.failUnlessEqual(resp.code, code)
+        return resp
+        
+    def tearDown(self):
+        for p in self.pending_calls:
+            if p.active():
+                p.cancel()
+        self.pending_calls = []
+        if self.client:
+            self.client = None
+
 if __name__ == '__builtin__':
     # Running from twistd -ny HTTPServer.py
     # Then test with: