"""Serve local requests from apt and remote requests from peers."""
-from urllib import unquote_plus
+from urllib import quote_plus, unquote_plus
from binascii import b2a_hex
+import operator
from twisted.python import log
from twisted.internet import defer
from twisted.web2 import server, http, resource, channel, stream
from twisted.web2 import static, http_headers, responsecode
+from twisted.trial import unittest
+from twisted.python.filepath import FilePath
from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
+from apt_p2p_conf import config
from apt_p2p_Khashmir.bencode import bencode
class FileDownloader(static.File):
log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
resp = super(FileDownloader, self).renderHTTP(req)
if isinstance(resp, defer.Deferred):
- resp.addCallback(self._renderHTTP_done, req)
+ resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
+ callbackArgs = (req, ), errbackArgs = (req, ))
else:
resp = self._renderHTTP_done(resp, req)
return resp
return resp
+ def _renderHTTP_error(self, err, req):
+ log.msg('Failed to render %s: %r' % (req.uri, err))
+ log.err(err)
+
+ if self.manager:
+ path = 'http:/' + req.uri
+ return self.manager.get_resp(req, path)
+
+ return err
+
def createSimilarFile(self, path):
return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
self.processors, self.indexNames[:])
class FileUploaderStream(stream.FileStream):
"""Modified to make it suitable for streaming to peers.
- Streams the file is small chunks to make it easier to throttle the
+ Streams the file in small chunks to make it easier to throttle the
streaming to peers.
@ivar CHUNK_SIZE: the size of chunks of data to send at a time
Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
"""
+
+ stats = None
def __init__(self, factory, wrappedProtocol):
ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
def write(self, data):
if self.throttle:
ThrottlingProtocol.write(self, data)
+ if self.stats:
+ self.stats.sentBytes(len(data))
else:
ProtocolWrapper.write(self, data)
+ def writeSequence(self, seq):
+ if self.throttle:
+ ThrottlingProtocol.writeSequence(self, seq)
+ if self.stats:
+ self.stats.sentBytes(reduce(operator.add, map(len, seq)))
+ else:
+ ProtocolWrapper.writeSequence(self, seq)
+
def registerProducer(self, producer, streaming):
ThrottlingProtocol.registerProducer(self, producer, streaming)
streamType = getattr(producer, 'stream', None)
self.directory = directory
self.db = db
self.manager = manager
+ self.uploadLimit = None
+ if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
+ self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
self.factory = None
def getHTTPFactory(self):
self.factory = channel.HTTPFactory(server.Site(self),
**{'maxPipeline': 10,
'betweenRequestsTimeOut': 60})
- self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
+ self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
self.factory.protocol = UploadThrottlingProtocol
+ if self.manager:
+ self.factory.protocol.stats = self.manager.stats
return self.factory
def render(self, ctx):
"""Render a web page with descriptive statistics."""
- return http.Response(
- 200,
- {'content-type': http_headers.MimeType('text', 'html')},
- self.manager.getStats())
+ if self.manager:
+ return http.Response(
+ 200,
+ {'content-type': http_headers.MimeType('text', 'html')},
+ self.manager.getStats())
+ else:
+ return http.Response(
+ 200,
+ {'content-type': http_headers.MimeType('text', 'html')},
+ '<html><body><p>Some Statistics</body></html>')
def locateChild(self, request, segments):
"""Process the incoming request."""
return None, ()
# Find the file in the database
- hash = unquote_plus(segments[1])
+ # Have to unquote_plus the uri, because the segments are unquoted by twisted
+ hash = unquote_plus(request.uri[3:])
files = self.db.lookupHash(hash)
if files:
# If it is a file, return it
log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
else:
- log.msg('Hash could not be found in database: %s' % hash)
+ log.msg('Hash could not be found in database: %r' % hash)
- # Only local requests (apt) get past this point
- if request.remoteAddr.host != "127.0.0.1":
- log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
- return None, ()
-
if len(name) > 1:
# It's a request from apt
+
+ # Only local requests (apt) get past this point
+ if request.remoteAddr.host != "127.0.0.1":
+ log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+ return None, ()
+
+ # Block access to index .diff files (for now)
+ if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
+ return None, ()
+
return FileDownloader(self.directory.path, self.manager), segments[0:]
else:
# Will render the statistics page
+
+ # Only local requests for stats are allowed
+ if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
+ log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
+ return None, ()
+
return self, ()
log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
return None, ()
+class TestTopLevel(unittest.TestCase):
+ """Unit tests for the HTTP Server."""
+
+ client = None
+ pending_calls = []
+ torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
+ torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+ file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
+
+ def setUp(self):
+ self.client = TopLevel(FilePath('/boot'), self, None)
+
+ def lookupHash(self, hash):
+ if hash == self.torrent_hash:
+ return [{'pieces': self.torrent}]
+ elif hash == self.file_hash:
+ return [{'path': FilePath('/boot/grub/stage2')}]
+ else:
+ return []
+
+ def create_request(self, host, path):
+ req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
+ class addr:
+ host = ''
+ port = 0
+ req.remoteAddr = addr()
+ req.remoteAddr.host = host
+ req.remoteAddr.port = 23456
+ server.Request._parseURL(req)
+ return req
+
+ def test_unauthorized(self):
+ req = self.create_request('128.0.0.1', '/foo/bar')
+ self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+ def test_Packages_diff(self):
+ req = self.create_request('127.0.0.1',
+ '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
+ self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+ def test_Statistics(self):
+ req = self.create_request('127.0.0.1', '/')
+ res = req._getChild(None, self.client, req.postpath)
+ self.failIfEqual(res, None)
+ df = defer.maybeDeferred(res.renderHTTP, req)
+ df.addCallback(self.check_resp, 200)
+ return df
+
+ def test_apt_download(self):
+ req = self.create_request('127.0.0.1',
+ '/ftp.us.debian.org/debian/dists/stable/Release')
+ res = req._getChild(None, self.client, req.postpath)
+ self.failIfEqual(res, None)
+ self.failUnless(isinstance(res, FileDownloader))
+ df = defer.maybeDeferred(res.renderHTTP, req)
+ df.addCallback(self.check_resp, 404)
+ return df
+
+ def test_torrent_upload(self):
+ req = self.create_request('123.45.67.89',
+ '/~/' + quote_plus(self.torrent_hash))
+ res = req._getChild(None, self.client, req.postpath)
+ self.failIfEqual(res, None)
+ self.failUnless(isinstance(res, static.Data))
+ df = defer.maybeDeferred(res.renderHTTP, req)
+ df.addCallback(self.check_resp, 200)
+ return df
+
+ def test_file_upload(self):
+ req = self.create_request('123.45.67.89',
+ '/~/' + quote_plus(self.file_hash))
+ res = req._getChild(None, self.client, req.postpath)
+ self.failIfEqual(res, None)
+ self.failUnless(isinstance(res, FileUploader))
+ df = defer.maybeDeferred(res.renderHTTP, req)
+ df.addCallback(self.check_resp, 200)
+ return df
+
+ def test_missing_hash(self):
+ req = self.create_request('123.45.67.89',
+ '/~/' + quote_plus('foobar'))
+ self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
+
+ def check_resp(self, resp, code):
+ self.failUnlessEqual(resp.code, code)
+ return resp
+
+ def tearDown(self):
+ for p in self.pending_calls:
+ if p.active():
+ p.cancel()
+ self.pending_calls = []
+ if self.client:
+ self.client = None
+
if __name__ == '__builtin__':
# Running from twistd -ny HTTPServer.py
# Then test with: