Add all files to the DB with their hashes.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33     
34     def locateChild(self, req, segments):
35         child, segments = super(FileDownloader, self).locateChild(req, segments)
36         # Make sure we always call renderHTTP()
37         if isinstance(child, FileDownloader):
38             return child, segments
39         else:
40             return self, server.StopTraversal
41             
42     def renderHTTP(self, req):
43         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
44         
45         # Make sure the file is in the DB and unchanged
46         if self.manager and not self.manager.db.isUnchanged(self.fp):
47             if self.fp.exists() and self.fp.isfile():
48                 self.fp.remove()
49             return self._renderHTTP_done(http.Response(404,
50                         {'content-type': http_headers.MimeType('text', 'html')},
51                         '<html><body><p>File found but it has changed.</body></html>'),
52                         req)
53             
54         resp = super(FileDownloader, self).renderHTTP(req)
55         if isinstance(resp, defer.Deferred):
56             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
57                               callbackArgs = (req, ), errbackArgs = (req, ))
58         else:
59             resp = self._renderHTTP_done(resp, req)
60         return resp
61         
62     def _renderHTTP_done(self, resp, req):
63         log.msg('Initial response to %s: %r' % (req.uri, resp))
64         
65         if self.manager:
66             path = 'http:/' + req.uri
67             if resp.code >= 200 and resp.code < 400:
68                 return self.manager.get_resp(req, path, resp)
69             
70             log.msg('Not found, trying other methods for %s' % req.uri)
71             return self.manager.get_resp(req, path)
72         
73         return resp
74
75     def _renderHTTP_error(self, err, req):
76         log.msg('Failed to render %s: %r' % (req.uri, err))
77         log.err(err)
78         
79         if self.manager:
80             path = 'http:/' + req.uri
81             return self.manager.get_resp(req, path)
82         
83         return err
84
85     def createSimilarFile(self, path):
86         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
87                               self.processors, self.indexNames[:])
88         
89 class UploadStream:
90     """Identifier for streams that are uploaded to peers."""
91     
92 class FileUploaderStream(stream.FileStream, UploadStream):
93     """Modified to make it suitable for streaming to peers.
94     
95     Streams the file in small chunks to make it easier to throttle the
96     streaming to peers.
97     
98     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
99     """
100
101     CHUNK_SIZE = 4*1024
102     
103     def read(self, sendfile=False):
104         if self.f is None:
105             return None
106
107         length = self.length
108         if length == 0:
109             self.f = None
110             return None
111         
112         # Remove the SendFileBuffer and mmap use, just use string reads and writes
113
114         readSize = min(length, self.CHUNK_SIZE)
115
116         self.f.seek(self.start)
117         b = self.f.read(readSize)
118         bytesRead = len(b)
119         if not bytesRead:
120             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
121         else:
122             self.length -= bytesRead
123             self.start += bytesRead
124             return b
125
126 class PiecesUploaderStream(stream.MemoryStream, UploadStream):
127     """Modified to identify it for streaming to peers."""
128
129 class PiecesUploader(static.Data):
130     """Modified to identify it for peer requests.
131     
132     Uses the modified L{PieceUploaderStream} to stream the pieces for throttling.
133     """
134
135     def render(self, req):
136         return http.Response(responsecode.OK,
137                              http_headers.Headers({'content-type': self.contentType()}),
138                              stream=PiecesUploaderStream(self.data))
139         
140 class FileUploader(static.File):
141     """Modified to make it suitable for peer requests.
142     
143     Uses the modified L{FileUploaderStream} to stream the file for throttling,
144     and doesn't do any listing of directory contents.
145     """
146
147     def render(self, req):
148         if not self.fp.exists():
149             return responsecode.NOT_FOUND
150
151         if self.fp.isdir():
152             # Don't try to render a directory listing
153             return responsecode.NOT_FOUND
154
155         try:
156             f = self.fp.open()
157         except IOError, e:
158             import errno
159             if e[0] == errno.EACCES:
160                 return responsecode.FORBIDDEN
161             elif e[0] == errno.ENOENT:
162                 return responsecode.NOT_FOUND
163             else:
164                 raise
165
166         response = http.Response()
167         # Use the modified FileStream
168         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
169
170         for (header, value) in (
171             ("content-type", self.contentType()),
172             ("content-encoding", self.contentEncoding()),
173         ):
174             if value is not None:
175                 response.headers.setHeader(header, value)
176
177         return response
178
179 class UploadThrottlingProtocol(ThrottlingProtocol):
180     """Protocol for throttling uploads.
181     
182     Determines whether or not to throttle the upload based on the type of stream.
183     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
184     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
185     """
186     
187     stats = None
188
189     def __init__(self, factory, wrappedProtocol):
190         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
191         self.throttle = False
192
193     def write(self, data):
194         if self.throttle:
195             ThrottlingProtocol.write(self, data)
196             if self.stats:
197                 self.stats.sentBytes(len(data))
198         else:
199             ProtocolWrapper.write(self, data)
200
201     def writeSequence(self, seq):
202         if self.throttle:
203             ThrottlingProtocol.writeSequence(self, seq)
204             if self.stats:
205                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
206         else:
207             ProtocolWrapper.writeSequence(self, seq)
208
209     def registerProducer(self, producer, streaming):
210         ThrottlingProtocol.registerProducer(self, producer, streaming)
211         streamType = getattr(producer, 'stream', None)
212         if isinstance(streamType, UploadStream):
213             self.throttle = True
214
215
216 class TopLevel(resource.Resource):
217     """The HTTP server for all requests, both from peers and apt.
218     
219     @type directory: L{twisted.python.filepath.FilePath}
220     @ivar directory: the directory to check for cached files
221     @type db: L{db.DB}
222     @ivar db: the database to use for looking up files and hashes
223     @type manager: L{apt_p2p.AptP2P}
224     @ivar manager: the main program object to send requests to
225     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
226     @ivar factory: the factory to use to serve HTTP requests
227     """
228     
229     addSlash = True
230     
231     def __init__(self, directory, db, manager):
232         """Initialize the instance.
233         
234         @type directory: L{twisted.python.filepath.FilePath}
235         @param directory: the directory to check for cached files
236         @type db: L{db.DB}
237         @param db: the database to use for looking up files and hashes
238         @type manager: L{apt_p2p.AptP2P}
239         @param manager: the main program object to send requests to
240         """
241         self.directory = directory
242         self.db = db
243         self.manager = manager
244         self.uploadLimit = None
245         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
246             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
247         self.factory = None
248
249     def getHTTPFactory(self):
250         """Initialize and get the factory for this HTTP server."""
251         if self.factory is None:
252             self.factory = channel.HTTPFactory(server.Site(self),
253                                                **{'maxPipeline': 10, 
254                                                   'betweenRequestsTimeOut': 60})
255             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
256             self.factory.protocol = UploadThrottlingProtocol
257             if self.manager:
258                 self.factory.protocol.stats = self.manager.stats
259         return self.factory
260
261     def render(self, ctx):
262         """Render a web page with descriptive statistics."""
263         if self.manager:
264             return http.Response(
265                 200,
266                 {'content-type': http_headers.MimeType('text', 'html')},
267                 self.manager.getStats())
268         else:
269             return http.Response(
270                 200,
271                 {'content-type': http_headers.MimeType('text', 'html')},
272                 '<html><body><p>Some Statistics</body></html>')
273
274     def locateChild(self, request, segments):
275         """Process the incoming request."""
276         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
277         name = segments[0]
278         
279         # If the request is for a shared file (from a peer)
280         if name == '~':
281             if len(segments) != 2:
282                 log.msg('Got a malformed request from %s' % request.remoteAddr)
283                 return None, ()
284             
285             # Find the file in the database
286             # Have to unquote_plus the uri, because the segments are unquoted by twisted
287             hash = unquote_plus(request.uri[3:])
288             files = self.db.lookupHash(hash)
289             if files:
290                 # If it is a file, return it
291                 if 'path' in files[0]:
292                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
293                     return FileUploader(files[0]['path'].path), ()
294                 else:
295                     # It's not for a file, but for a piece string, so return that
296                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
297                     return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
298             else:
299                 log.msg('Hash could not be found in database: %r' % hash)
300                 return None, ()
301
302         if len(name) > 1:
303             # It's a request from apt
304
305             # Only local requests (apt) get past this point
306             if request.remoteAddr.host != "127.0.0.1":
307                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
308                 return None, ()
309
310             # Block access to index .diff files (for now)
311             if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
312                 return None, ()
313              
314             return FileDownloader(self.directory.path, self.manager), segments[0:]
315         else:
316             # Will render the statistics page
317
318             # Only local requests for stats are allowed
319             if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
320                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
321                 return None, ()
322
323             return self, ()
324         
325         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
326         return None, ()
327
328 class TestTopLevel(unittest.TestCase):
329     """Unit tests for the HTTP Server."""
330     
331     client = None
332     pending_calls = []
333     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
334     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
335     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
336     
337     def setUp(self):
338         self.client = TopLevel(FilePath('/boot'), self, None)
339         
340     def lookupHash(self, hash):
341         if hash == self.torrent_hash:
342             return [{'pieces': self.torrent}]
343         elif hash == self.file_hash:
344             return [{'path': FilePath('/boot/grub/stage2')}]
345         else:
346             return []
347         
348     def create_request(self, host, path):
349         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
350         class addr:
351             host = ''
352             port = 0
353         req.remoteAddr = addr()
354         req.remoteAddr.host = host
355         req.remoteAddr.port = 23456
356         server.Request._parseURL(req)
357         return req
358         
359     def test_unauthorized(self):
360         req = self.create_request('128.0.0.1', '/foo/bar')
361         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
362         
363     def test_Packages_diff(self):
364         req = self.create_request('127.0.0.1',
365                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
366         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
367         
368     def test_Statistics(self):
369         req = self.create_request('127.0.0.1', '/')
370         res = req._getChild(None, self.client, req.postpath)
371         self.failIfEqual(res, None)
372         df = defer.maybeDeferred(res.renderHTTP, req)
373         df.addCallback(self.check_resp, 200)
374         return df
375         
376     def test_apt_download(self):
377         req = self.create_request('127.0.0.1',
378                 '/ftp.us.debian.org/debian/dists/stable/Release')
379         res = req._getChild(None, self.client, req.postpath)
380         self.failIfEqual(res, None)
381         self.failUnless(isinstance(res, FileDownloader))
382         df = defer.maybeDeferred(res.renderHTTP, req)
383         df.addCallback(self.check_resp, 404)
384         return df
385         
386     def test_torrent_upload(self):
387         req = self.create_request('123.45.67.89',
388                                   '/~/' + quote_plus(self.torrent_hash))
389         res = req._getChild(None, self.client, req.postpath)
390         self.failIfEqual(res, None)
391         self.failUnless(isinstance(res, static.Data))
392         df = defer.maybeDeferred(res.renderHTTP, req)
393         df.addCallback(self.check_resp, 200)
394         return df
395         
396     def test_file_upload(self):
397         req = self.create_request('123.45.67.89',
398                                   '/~/' + quote_plus(self.file_hash))
399         res = req._getChild(None, self.client, req.postpath)
400         self.failIfEqual(res, None)
401         self.failUnless(isinstance(res, FileUploader))
402         df = defer.maybeDeferred(res.renderHTTP, req)
403         df.addCallback(self.check_resp, 200)
404         return df
405     
406     def test_missing_hash(self):
407         req = self.create_request('123.45.67.89',
408                                   '/~/' + quote_plus('foobar'))
409         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
410
411     def check_resp(self, resp, code):
412         self.failUnlessEqual(resp.code, code)
413         return resp
414         
415     def tearDown(self):
416         for p in self.pending_calls:
417             if p.active():
418                 p.cancel()
419         self.pending_calls = []
420         if self.client:
421             self.client = None
422
423 if __name__ == '__builtin__':
424     # Running from twistd -ny HTTPServer.py
425     # Then test with:
426     #   wget -S 'http://localhost:18080/~/whatever'
427     #   wget -S 'http://localhost:18080/~/pieces'
428
429     import os.path
430     from twisted.python.filepath import FilePath
431     
432     class DB:
433         def lookupHash(self, hash):
434             if hash == 'pieces':
435                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
436             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
437     
438     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
439     factory = t.getHTTPFactory()
440     
441     # Standard twisted application Boilerplate
442     from twisted.application import service, strports
443     application = service.Application("demoserver")
444     s = strports.service('tcp:18080', factory)
445     s.setServiceParent(application)