Updated and added a lot of unittests.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
index f3d6de72525281df73660f35adc252f26d630414..cc6233ba955b192a31b15b94a46c620ede753bda 100644 (file)
@@ -1,13 +1,15 @@
 
 """Serve local requests from apt and remote requests from peers."""
 
-from urllib import unquote_plus
+from urllib import quote_plus, unquote_plus
 from binascii import b2a_hex
 
 from twisted.python import log
 from twisted.internet import defer
 from twisted.web2 import server, http, resource, channel, stream
 from twisted.web2 import static, http_headers, responsecode
+from twisted.trial import unittest
+from twisted.python.filepath import FilePath
 
 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
 from apt_p2p_Khashmir.bencode import bencode
@@ -197,10 +199,16 @@ class TopLevel(resource.Resource):
 
     def render(self, ctx):
         """Render a web page with descriptive statistics."""
-        return http.Response(
-            200,
-            {'content-type': http_headers.MimeType('text', 'html')},
-            self.manager.getStats())
+        if self.manager:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                self.manager.getStats())
+        else:
+            return http.Response(
+                200,
+                {'content-type': http_headers.MimeType('text', 'html')},
+                '<html><body><p>Some Statistics</body></html>')
 
     def locateChild(self, request, segments):
         """Process the incoming request."""
@@ -248,6 +256,101 @@ class TopLevel(resource.Resource):
         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
         return None, ()
 
+class TestTopLevel(unittest.TestCase):
+    """Unit tests for the HTTP Server."""
+    
+    client = None
+    pending_calls = []
+    torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
+    torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+    file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
+    
+    def setUp(self):
+        self.client = TopLevel(FilePath('/boot'), self, None, 0)
+        
+    def lookupHash(self, hash):
+        if hash == self.torrent_hash:
+            return [{'pieces': self.torrent}]
+        elif hash == self.file_hash:
+            return [{'path': FilePath('/boot/initrd')}]
+        else:
+            return []
+        
+    def create_request(self, host, path):
+        req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
+        class addr:
+            host = ''
+            port = 0
+        req.remoteAddr = addr()
+        req.remoteAddr.host = host
+        req.remoteAddr.port = 23456
+        server.Request._parseURL(req)
+        return req
+        
+    def test_unauthorized(self):
+        req = self.create_request('128.0.0.1', '/foo/bar')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Packages_diff(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+        
+    def test_Statistics(self):
+        req = self.create_request('127.0.0.1', '/')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_apt_download(self):
+        req = self.create_request('127.0.0.1',
+                '/ftp.us.debian.org/debian/dists/stable/Release')
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileDownloader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 404)
+        return df
+        
+    def test_torrent_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.torrent_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, static.Data))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+        
+    def test_file_upload(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus(self.file_hash))
+        res = req._getChild(None, self.client, req.postpath)
+        self.failIfEqual(res, None)
+        self.failUnless(isinstance(res, FileUploader))
+        df = defer.maybeDeferred(res.renderHTTP, req)
+        df.addCallback(self.check_resp, 200)
+        return df
+    
+    def test_missing_hash(self):
+        req = self.create_request('123.45.67.89',
+                                  '/~/' + quote_plus('foobar'))
+        self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+    def check_resp(self, resp, code):
+        self.failUnlessEqual(resp.code, code)
+        return resp
+        
+    def tearDown(self):
+        for p in self.pending_calls:
+            if p.active():
+                p.cancel()
+        self.pending_calls = []
+        if self.client:
+            self.client = None
+
 if __name__ == '__builtin__':
     # Running from twistd -ny HTTPServer.py
     # Then test with: