Use the apt_p2p_conf config import rather than passing parameters around.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11 from twisted.trial import unittest
12 from twisted.python.filepath import FilePath
13
14 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
15 from apt_p2p_conf import config
16 from apt_p2p_Khashmir.bencode import bencode
17
18 class FileDownloader(static.File):
19     """Modified to make it suitable for apt requests.
20     
21     Tries to find requests in the cache. Found files are first checked for
22     freshness before being sent. Requests for unfound and stale files are
23     forwarded to the main program for downloading.
24     
25     @type manager: L{apt_p2p.AptP2P}
26     @ivar manager: the main program to query 
27     """
28     
29     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
30         self.manager = manager
31         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
32         
33     def renderHTTP(self, req):
34         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
35         resp = super(FileDownloader, self).renderHTTP(req)
36         if isinstance(resp, defer.Deferred):
37             resp.addCallback(self._renderHTTP_done, req)
38         else:
39             resp = self._renderHTTP_done(resp, req)
40         return resp
41         
42     def _renderHTTP_done(self, resp, req):
43         log.msg('Initial response to %s: %r' % (req.uri, resp))
44         
45         if self.manager:
46             path = 'http:/' + req.uri
47             if resp.code >= 200 and resp.code < 400:
48                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
49             
50             log.msg('Not found, trying other methods for %s' % req.uri)
51             return self.manager.get_resp(req, path)
52         
53         return resp
54
55     def createSimilarFile(self, path):
56         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
57                               self.processors, self.indexNames[:])
58         
59 class FileUploaderStream(stream.FileStream):
60     """Modified to make it suitable for streaming to peers.
61     
62     Streams the file in small chunks to make it easier to throttle the
63     streaming to peers.
64     
65     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
66     """
67
68     CHUNK_SIZE = 4*1024
69     
70     def read(self, sendfile=False):
71         if self.f is None:
72             return None
73
74         length = self.length
75         if length == 0:
76             self.f = None
77             return None
78         
79         # Remove the SendFileBuffer and mmap use, just use string reads and writes
80
81         readSize = min(length, self.CHUNK_SIZE)
82
83         self.f.seek(self.start)
84         b = self.f.read(readSize)
85         bytesRead = len(b)
86         if not bytesRead:
87             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
88         else:
89             self.length -= bytesRead
90             self.start += bytesRead
91             return b
92
93
94 class FileUploader(static.File):
95     """Modified to make it suitable for peer requests.
96     
97     Uses the modified L{FileUploaderStream} to stream the file for throttling,
98     and doesn't do any listing of directory contents.
99     """
100
101     def render(self, req):
102         if not self.fp.exists():
103             return responsecode.NOT_FOUND
104
105         if self.fp.isdir():
106             # Don't try to render a directory listing
107             return responsecode.NOT_FOUND
108
109         try:
110             f = self.fp.open()
111         except IOError, e:
112             import errno
113             if e[0] == errno.EACCES:
114                 return responsecode.FORBIDDEN
115             elif e[0] == errno.ENOENT:
116                 return responsecode.NOT_FOUND
117             else:
118                 raise
119
120         response = http.Response()
121         # Use the modified FileStream
122         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
123
124         for (header, value) in (
125             ("content-type", self.contentType()),
126             ("content-encoding", self.contentEncoding()),
127         ):
128             if value is not None:
129                 response.headers.setHeader(header, value)
130
131         return response
132
133 class UploadThrottlingProtocol(ThrottlingProtocol):
134     """Protocol for throttling uploads.
135     
136     Determines whether or not to throttle the upload based on the type of stream.
137     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
138     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
139     """
140
141     def __init__(self, factory, wrappedProtocol):
142         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
143         self.throttle = False
144
145     def write(self, data):
146         if self.throttle:
147             ThrottlingProtocol.write(self, data)
148         else:
149             ProtocolWrapper.write(self, data)
150
151     def registerProducer(self, producer, streaming):
152         ThrottlingProtocol.registerProducer(self, producer, streaming)
153         streamType = getattr(producer, 'stream', None)
154         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
155             self.throttle = True
156
157
158 class TopLevel(resource.Resource):
159     """The HTTP server for all requests, both from peers and apt.
160     
161     @type directory: L{twisted.python.filepath.FilePath}
162     @ivar directory: the directory to check for cached files
163     @type db: L{db.DB}
164     @ivar db: the database to use for looking up files and hashes
165     @type manager: L{apt_p2p.AptP2P}
166     @ivar manager: the main program object to send requests to
167     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
168     @ivar factory: the factory to use to serve HTTP requests
169     """
170     
171     addSlash = True
172     
173     def __init__(self, directory, db, manager):
174         """Initialize the instance.
175         
176         @type directory: L{twisted.python.filepath.FilePath}
177         @param directory: the directory to check for cached files
178         @type db: L{db.DB}
179         @param db: the database to use for looking up files and hashes
180         @type manager: L{apt_p2p.AptP2P}
181         @param manager: the main program object to send requests to
182         """
183         self.directory = directory
184         self.db = db
185         self.manager = manager
186         self.uploadLimit = None
187         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
188             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
189         self.factory = None
190
191     def getHTTPFactory(self):
192         """Initialize and get the factory for this HTTP server."""
193         if self.factory is None:
194             self.factory = channel.HTTPFactory(server.Site(self),
195                                                **{'maxPipeline': 10, 
196                                                   'betweenRequestsTimeOut': 60})
197             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
198             self.factory.protocol = UploadThrottlingProtocol
199         return self.factory
200
201     def render(self, ctx):
202         """Render a web page with descriptive statistics."""
203         if self.manager:
204             return http.Response(
205                 200,
206                 {'content-type': http_headers.MimeType('text', 'html')},
207                 self.manager.getStats())
208         else:
209             return http.Response(
210                 200,
211                 {'content-type': http_headers.MimeType('text', 'html')},
212                 '<html><body><p>Some Statistics</body></html>')
213
214     def locateChild(self, request, segments):
215         """Process the incoming request."""
216         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
217         name = segments[0]
218         
219         # If the request is for a shared file (from a peer)
220         if name == '~':
221             if len(segments) != 2:
222                 log.msg('Got a malformed request from %s' % request.remoteAddr)
223                 return None, ()
224             
225             # Find the file in the database
226             # Have to unquote_plus the uri, because the segments are unquoted by twisted
227             hash = unquote_plus(request.uri[3:])
228             files = self.db.lookupHash(hash)
229             if files:
230                 # If it is a file, return it
231                 if 'path' in files[0]:
232                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
233                     return FileUploader(files[0]['path'].path), ()
234                 else:
235                     # It's not for a file, but for a piece string, so return that
236                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
237                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
238             else:
239                 log.msg('Hash could not be found in database: %r' % hash)
240
241         # Only local requests (apt) get past this point
242         if request.remoteAddr.host != "127.0.0.1":
243             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
244             return None, ()
245         
246         # Block access to index .diff files (for now)
247         if 'Packages.diff' in segments or 'Sources.diff' in segments:
248             return None, ()
249          
250         if len(name) > 1:
251             # It's a request from apt
252             return FileDownloader(self.directory.path, self.manager), segments[0:]
253         else:
254             # Will render the statistics page
255             return self, ()
256         
257         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
258         return None, ()
259
260 class TestTopLevel(unittest.TestCase):
261     """Unit tests for the HTTP Server."""
262     
263     client = None
264     pending_calls = []
265     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
266     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
267     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
268     
269     def setUp(self):
270         self.client = TopLevel(FilePath('/boot'), self, None)
271         
272     def lookupHash(self, hash):
273         if hash == self.torrent_hash:
274             return [{'pieces': self.torrent}]
275         elif hash == self.file_hash:
276             return [{'path': FilePath('/boot/grub/stage2')}]
277         else:
278             return []
279         
280     def create_request(self, host, path):
281         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
282         class addr:
283             host = ''
284             port = 0
285         req.remoteAddr = addr()
286         req.remoteAddr.host = host
287         req.remoteAddr.port = 23456
288         server.Request._parseURL(req)
289         return req
290         
291     def test_unauthorized(self):
292         req = self.create_request('128.0.0.1', '/foo/bar')
293         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
294         
295     def test_Packages_diff(self):
296         req = self.create_request('127.0.0.1',
297                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
298         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
299         
300     def test_Statistics(self):
301         req = self.create_request('127.0.0.1', '/')
302         res = req._getChild(None, self.client, req.postpath)
303         self.failIfEqual(res, None)
304         df = defer.maybeDeferred(res.renderHTTP, req)
305         df.addCallback(self.check_resp, 200)
306         return df
307         
308     def test_apt_download(self):
309         req = self.create_request('127.0.0.1',
310                 '/ftp.us.debian.org/debian/dists/stable/Release')
311         res = req._getChild(None, self.client, req.postpath)
312         self.failIfEqual(res, None)
313         self.failUnless(isinstance(res, FileDownloader))
314         df = defer.maybeDeferred(res.renderHTTP, req)
315         df.addCallback(self.check_resp, 404)
316         return df
317         
318     def test_torrent_upload(self):
319         req = self.create_request('123.45.67.89',
320                                   '/~/' + quote_plus(self.torrent_hash))
321         res = req._getChild(None, self.client, req.postpath)
322         self.failIfEqual(res, None)
323         self.failUnless(isinstance(res, static.Data))
324         df = defer.maybeDeferred(res.renderHTTP, req)
325         df.addCallback(self.check_resp, 200)
326         return df
327         
328     def test_file_upload(self):
329         req = self.create_request('123.45.67.89',
330                                   '/~/' + quote_plus(self.file_hash))
331         res = req._getChild(None, self.client, req.postpath)
332         self.failIfEqual(res, None)
333         self.failUnless(isinstance(res, FileUploader))
334         df = defer.maybeDeferred(res.renderHTTP, req)
335         df.addCallback(self.check_resp, 200)
336         return df
337     
338     def test_missing_hash(self):
339         req = self.create_request('123.45.67.89',
340                                   '/~/' + quote_plus('foobar'))
341         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
342
343     def check_resp(self, resp, code):
344         self.failUnlessEqual(resp.code, code)
345         return resp
346         
347     def tearDown(self):
348         for p in self.pending_calls:
349             if p.active():
350                 p.cancel()
351         self.pending_calls = []
352         if self.client:
353             self.client = None
354
355 if __name__ == '__builtin__':
356     # Running from twistd -ny HTTPServer.py
357     # Then test with:
358     #   wget -S 'http://localhost:18080/~/whatever'
359     #   wget -S 'http://localhost:18080/~/pieces'
360
361     import os.path
362     from twisted.python.filepath import FilePath
363     
364     class DB:
365         def lookupHash(self, hash):
366             if hash == 'pieces':
367                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
368             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
369     
370     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
371     factory = t.getHTTPFactory()
372     
373     # Standard twisted application Boilerplate
374     from twisted.application import service, strports
375     application = service.Application("demoserver")
376     s = strports.service('tcp:18080', factory)
377     s.setServiceParent(application)