New TODO for hashing files with detached GPG signatures.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from Streams import UploadStream, FileUploadStream, PiecesUploadStream
17 from apt_p2p_conf import config
18 from apt_p2p_Khashmir.bencode import bencode
19
20 class FileDownloader(static.File):
21     """Modified to make it suitable for apt requests.
22     
23     Tries to find requests in the cache. Found files are first checked for
24     freshness before being sent. Requests for unfound and stale files are
25     forwarded to the main program for downloading.
26     
27     @type manager: L{apt_p2p.AptP2P}
28     @ivar manager: the main program to query 
29     """
30     
31     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
32         self.manager = manager
33         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
34     
35     def locateChild(self, req, segments):
36         child, segments = super(FileDownloader, self).locateChild(req, segments)
37         # Make sure we always call renderHTTP()
38         if isinstance(child, FileDownloader):
39             return child, segments
40         else:
41             return self, server.StopTraversal
42             
43     def renderHTTP(self, req):
44         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
45         
46         # Make sure the file is in the DB and unchanged
47         if self.manager and not self.manager.db.isUnchanged(self.fp):
48             if self.fp.exists() and self.fp.isfile():
49                 self.fp.remove()
50             return self._renderHTTP_done(http.Response(404,
51                         {'content-type': http_headers.MimeType('text', 'html')},
52                         '<html><body><p>File found but it has changed.</body></html>'),
53                         req)
54             
55         resp = super(FileDownloader, self).renderHTTP(req)
56         if isinstance(resp, defer.Deferred):
57             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
58                               callbackArgs = (req, ), errbackArgs = (req, ))
59         else:
60             resp = self._renderHTTP_done(resp, req)
61         return resp
62         
63     def _renderHTTP_done(self, resp, req):
64         log.msg('Initial response to %s: %r' % (req.uri, resp))
65         
66         if self.manager:
67             path = 'http:/' + req.uri
68             if resp.code >= 200 and resp.code < 400:
69                 return self.manager.get_resp(req, path, resp)
70             
71             log.msg('Not found, trying other methods for %s' % req.uri)
72             return self.manager.get_resp(req, path)
73         
74         return resp
75
76     def _renderHTTP_error(self, err, req):
77         log.msg('Failed to render %s: %r' % (req.uri, err))
78         log.err(err)
79         
80         if self.manager:
81             path = 'http:/' + req.uri
82             return self.manager.get_resp(req, path)
83         
84         return err
85
86     def createSimilarFile(self, path):
87         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
88                               self.processors, self.indexNames[:])
89         
90 class PiecesUploader(static.Data):
91     """Modified to identify it for peer requests.
92     
93     Uses the modified L{Streams.PiecesUploadStream} to stream the pieces for throttling.
94     """
95
96     def render(self, req):
97         return http.Response(responsecode.OK,
98                              http_headers.Headers({'content-type': self.contentType()}),
99                              stream=PiecesUploadStream(self.data))
100         
101 class FileUploader(static.File):
102     """Modified to make it suitable for peer requests.
103     
104     Uses the modified L{Streams.FileUploadStream} to stream the file for throttling,
105     and doesn't do any listing of directory contents.
106     """
107
108     def render(self, req):
109         if not self.fp.exists():
110             return responsecode.NOT_FOUND
111
112         if self.fp.isdir():
113             # Don't try to render a directory listing
114             return responsecode.NOT_FOUND
115
116         try:
117             f = self.fp.open()
118         except IOError, e:
119             import errno
120             if e[0] == errno.EACCES:
121                 return responsecode.FORBIDDEN
122             elif e[0] == errno.ENOENT:
123                 return responsecode.NOT_FOUND
124             else:
125                 raise
126
127         response = http.Response()
128         # Use the modified FileStream
129         response.stream = FileUploadStream(f, 0, self.fp.getsize())
130
131         for (header, value) in (
132             ("content-type", self.contentType()),
133             ("content-encoding", self.contentEncoding()),
134         ):
135             if value is not None:
136                 response.headers.setHeader(header, value)
137
138         return response
139
140 class UploadThrottlingProtocol(ThrottlingProtocol):
141     """Protocol for throttling uploads.
142     
143     Determines whether or not to throttle the upload based on the type of stream.
144     Uploads use instances of L{Streams.UploadStream}.
145     """
146     
147     stats = None
148
149     def __init__(self, factory, wrappedProtocol):
150         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
151         self.throttle = False
152
153     def write(self, data):
154         if self.throttle:
155             ThrottlingProtocol.write(self, data)
156             if self.stats:
157                 self.stats.sentBytes(len(data))
158         else:
159             ProtocolWrapper.write(self, data)
160
161     def writeSequence(self, seq):
162         if self.throttle:
163             ThrottlingProtocol.writeSequence(self, seq)
164             if self.stats:
165                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
166         else:
167             ProtocolWrapper.writeSequence(self, seq)
168
169     def registerProducer(self, producer, streaming):
170         ThrottlingProtocol.registerProducer(self, producer, streaming)
171         streamType = getattr(producer, 'stream', None)
172         if isinstance(streamType, UploadStream):
173             self.throttle = True
174
175
176 class TopLevel(resource.Resource):
177     """The HTTP server for all requests, both from peers and apt.
178     
179     @type directory: L{twisted.python.filepath.FilePath}
180     @ivar directory: the directory to check for cached files
181     @type db: L{db.DB}
182     @ivar db: the database to use for looking up files and hashes
183     @type manager: L{apt_p2p.AptP2P}
184     @ivar manager: the main program object to send requests to
185     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
186     @ivar factory: the factory to use to serve HTTP requests
187     """
188     
189     addSlash = True
190     
191     def __init__(self, directory, db, manager):
192         """Initialize the instance.
193         
194         @type directory: L{twisted.python.filepath.FilePath}
195         @param directory: the directory to check for cached files
196         @type db: L{db.DB}
197         @param db: the database to use for looking up files and hashes
198         @type manager: L{apt_p2p.AptP2P}
199         @param manager: the main program object to send requests to
200         """
201         self.directory = directory
202         self.db = db
203         self.manager = manager
204         self.uploadLimit = None
205         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
206             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
207         self.factory = None
208
209     def getHTTPFactory(self):
210         """Initialize and get the factory for this HTTP server."""
211         if self.factory is None:
212             self.factory = channel.HTTPFactory(server.Site(self),
213                                                **{'maxPipeline': 10, 
214                                                   'betweenRequestsTimeOut': 60})
215             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
216             self.factory.protocol = UploadThrottlingProtocol
217             if self.manager:
218                 self.factory.protocol.stats = self.manager.stats
219         return self.factory
220
221     def render(self, ctx):
222         """Render a web page with descriptive statistics."""
223         if self.manager:
224             return http.Response(
225                 200,
226                 {'content-type': http_headers.MimeType('text', 'html')},
227                 self.manager.getStats())
228         else:
229             return http.Response(
230                 200,
231                 {'content-type': http_headers.MimeType('text', 'html')},
232                 '<html><body><p>Some Statistics</body></html>')
233
234     def locateChild(self, request, segments):
235         """Process the incoming request."""
236         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
237         name = segments[0]
238         
239         # If the request is for a shared file (from a peer)
240         if name == '~':
241             if len(segments) != 2:
242                 log.msg('Got a malformed request from %s' % request.remoteAddr)
243                 return None, ()
244             
245             # Find the file in the database
246             # Have to unquote_plus the uri, because the segments are unquoted by twisted
247             hash = unquote_plus(request.uri[3:])
248             files = self.db.lookupHash(hash)
249             if files:
250                 # If it is a file, return it
251                 if 'path' in files[0]:
252                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
253                     return FileUploader(files[0]['path'].path), ()
254                 else:
255                     # It's not for a file, but for a piece string, so return that
256                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
257                     return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
258             else:
259                 log.msg('Hash could not be found in database: %r' % hash)
260                 return None, ()
261
262         if len(name) > 1:
263             # It's a request from apt
264
265             # Only local requests (apt) get past this point
266             if request.remoteAddr.host != "127.0.0.1":
267                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
268                 return None, ()
269
270             # Block access to index .diff files (for now)
271             if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
272                 return None, ()
273              
274             return FileDownloader(self.directory.path, self.manager), segments[0:]
275         else:
276             # Will render the statistics page
277
278             # Only local requests for stats are allowed
279             if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
280                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
281                 return None, ()
282
283             return self, ()
284         
285         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
286         return None, ()
287
288 class TestTopLevel(unittest.TestCase):
289     """Unit tests for the HTTP Server."""
290     
291     client = None
292     pending_calls = []
293     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
294     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
295     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
296     
297     def setUp(self):
298         self.client = TopLevel(FilePath('/boot'), self, None)
299         
300     def lookupHash(self, hash):
301         if hash == self.torrent_hash:
302             return [{'pieces': self.torrent}]
303         elif hash == self.file_hash:
304             return [{'path': FilePath('/boot/grub/stage2')}]
305         else:
306             return []
307         
308     def create_request(self, host, path):
309         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
310         class addr:
311             host = ''
312             port = 0
313         req.remoteAddr = addr()
314         req.remoteAddr.host = host
315         req.remoteAddr.port = 23456
316         server.Request._parseURL(req)
317         return req
318         
319     def test_unauthorized(self):
320         req = self.create_request('128.0.0.1', '/foo/bar')
321         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
322         
323     def test_Packages_diff(self):
324         req = self.create_request('127.0.0.1',
325                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
326         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
327         
328     def test_Statistics(self):
329         req = self.create_request('127.0.0.1', '/')
330         res = req._getChild(None, self.client, req.postpath)
331         self.failIfEqual(res, None)
332         df = defer.maybeDeferred(res.renderHTTP, req)
333         df.addCallback(self.check_resp, 200)
334         return df
335         
336     def test_apt_download(self):
337         req = self.create_request('127.0.0.1',
338                 '/ftp.us.debian.org/debian/dists/stable/Release')
339         res = req._getChild(None, self.client, req.postpath)
340         self.failIfEqual(res, None)
341         self.failUnless(isinstance(res, FileDownloader))
342         df = defer.maybeDeferred(res.renderHTTP, req)
343         df.addCallback(self.check_resp, 404)
344         return df
345         
346     def test_torrent_upload(self):
347         req = self.create_request('123.45.67.89',
348                                   '/~/' + quote_plus(self.torrent_hash))
349         res = req._getChild(None, self.client, req.postpath)
350         self.failIfEqual(res, None)
351         self.failUnless(isinstance(res, static.Data))
352         df = defer.maybeDeferred(res.renderHTTP, req)
353         df.addCallback(self.check_resp, 200)
354         return df
355         
356     def test_file_upload(self):
357         req = self.create_request('123.45.67.89',
358                                   '/~/' + quote_plus(self.file_hash))
359         res = req._getChild(None, self.client, req.postpath)
360         self.failIfEqual(res, None)
361         self.failUnless(isinstance(res, FileUploader))
362         df = defer.maybeDeferred(res.renderHTTP, req)
363         df.addCallback(self.check_resp, 200)
364         return df
365     
366     def test_missing_hash(self):
367         req = self.create_request('123.45.67.89',
368                                   '/~/' + quote_plus('foobar'))
369         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
370
371     def check_resp(self, resp, code):
372         self.failUnlessEqual(resp.code, code)
373         return resp
374         
375     def tearDown(self):
376         for p in self.pending_calls:
377             if p.active():
378                 p.cancel()
379         self.pending_calls = []
380         if self.client:
381             self.client = None
382
383 if __name__ == '__builtin__':
384     # Running from twistd -ny HTTPServer.py
385     # Then test with:
386     #   wget -S 'http://localhost:18080/~/whatever'
387     #   wget -S 'http://localhost:18080/~/pieces'
388
389     import os.path
390     from twisted.python.filepath import FilePath
391     
392     class DB:
393         def lookupHash(self, hash):
394             if hash == 'pieces':
395                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
396             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
397     
398     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
399     factory = t.getHTTPFactory()
400     
401     # Standard twisted application Boilerplate
402     from twisted.application import service, strports
403     application = service.Application("demoserver")
404     s = strports.service('tcp:18080', factory)
405     s.setServiceParent(application)