Remove some unnecessary log messages and use better Exceptions.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33         
34     def renderHTTP(self, req):
35         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
36         resp = super(FileDownloader, self).renderHTTP(req)
37         if isinstance(resp, defer.Deferred):
38             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
39                               callbackArgs = (req, ), errbackArgs = (req, ))
40         else:
41             resp = self._renderHTTP_done(resp, req)
42         return resp
43         
44     def _renderHTTP_done(self, resp, req):
45         log.msg('Initial response to %s: %r' % (req.uri, resp))
46         
47         if self.manager:
48             path = 'http:/' + req.uri
49             if resp.code >= 200 and resp.code < 400:
50                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
51             
52             log.msg('Not found, trying other methods for %s' % req.uri)
53             return self.manager.get_resp(req, path)
54         
55         return resp
56
57     def _renderHTTP_error(self, err, req):
58         log.msg('Failed to render %s: %r' % (req.uri, err))
59         log.err(err)
60         
61         if self.manager:
62             path = 'http:/' + req.uri
63             return self.manager.get_resp(req, path)
64         
65         return err
66
67     def createSimilarFile(self, path):
68         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
69                               self.processors, self.indexNames[:])
70         
71 class UploadStream:
72     """Identifier for streams that are uploaded to peers."""
73     
74 class FileUploaderStream(stream.FileStream, UploadStream):
75     """Modified to make it suitable for streaming to peers.
76     
77     Streams the file in small chunks to make it easier to throttle the
78     streaming to peers.
79     
80     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
81     """
82
83     CHUNK_SIZE = 4*1024
84     
85     def read(self, sendfile=False):
86         if self.f is None:
87             return None
88
89         length = self.length
90         if length == 0:
91             self.f = None
92             return None
93         
94         # Remove the SendFileBuffer and mmap use, just use string reads and writes
95
96         readSize = min(length, self.CHUNK_SIZE)
97
98         self.f.seek(self.start)
99         b = self.f.read(readSize)
100         bytesRead = len(b)
101         if not bytesRead:
102             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
103         else:
104             self.length -= bytesRead
105             self.start += bytesRead
106             return b
107
108 class PiecesUploaderStream(stream.MemoryStream, UploadStream):
109     """Modified to identify it for streaming to peers."""
110
111 class PiecesUploader(static.Data):
112     """Modified to identify it for peer requests.
113     
114     Uses the modified L{PieceUploaderStream} to stream the pieces for throttling.
115     """
116
117     def render(self, req):
118         return http.Response(responsecode.OK,
119                              http_headers.Headers({'content-type': self.contentType()}),
120                              stream=PiecesUploaderStream(self.data))
121         
122 class FileUploader(static.File):
123     """Modified to make it suitable for peer requests.
124     
125     Uses the modified L{FileUploaderStream} to stream the file for throttling,
126     and doesn't do any listing of directory contents.
127     """
128
129     def render(self, req):
130         if not self.fp.exists():
131             return responsecode.NOT_FOUND
132
133         if self.fp.isdir():
134             # Don't try to render a directory listing
135             return responsecode.NOT_FOUND
136
137         try:
138             f = self.fp.open()
139         except IOError, e:
140             import errno
141             if e[0] == errno.EACCES:
142                 return responsecode.FORBIDDEN
143             elif e[0] == errno.ENOENT:
144                 return responsecode.NOT_FOUND
145             else:
146                 raise
147
148         response = http.Response()
149         # Use the modified FileStream
150         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
151
152         for (header, value) in (
153             ("content-type", self.contentType()),
154             ("content-encoding", self.contentEncoding()),
155         ):
156             if value is not None:
157                 response.headers.setHeader(header, value)
158
159         return response
160
161 class UploadThrottlingProtocol(ThrottlingProtocol):
162     """Protocol for throttling uploads.
163     
164     Determines whether or not to throttle the upload based on the type of stream.
165     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
166     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
167     """
168     
169     stats = None
170
171     def __init__(self, factory, wrappedProtocol):
172         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
173         self.throttle = False
174
175     def write(self, data):
176         if self.throttle:
177             ThrottlingProtocol.write(self, data)
178             if self.stats:
179                 self.stats.sentBytes(len(data))
180         else:
181             ProtocolWrapper.write(self, data)
182
183     def writeSequence(self, seq):
184         if self.throttle:
185             ThrottlingProtocol.writeSequence(self, seq)
186             if self.stats:
187                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
188         else:
189             ProtocolWrapper.writeSequence(self, seq)
190
191     def registerProducer(self, producer, streaming):
192         ThrottlingProtocol.registerProducer(self, producer, streaming)
193         streamType = getattr(producer, 'stream', None)
194         if isinstance(streamType, UploadStream):
195             self.throttle = True
196
197
198 class TopLevel(resource.Resource):
199     """The HTTP server for all requests, both from peers and apt.
200     
201     @type directory: L{twisted.python.filepath.FilePath}
202     @ivar directory: the directory to check for cached files
203     @type db: L{db.DB}
204     @ivar db: the database to use for looking up files and hashes
205     @type manager: L{apt_p2p.AptP2P}
206     @ivar manager: the main program object to send requests to
207     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
208     @ivar factory: the factory to use to serve HTTP requests
209     """
210     
211     addSlash = True
212     
213     def __init__(self, directory, db, manager):
214         """Initialize the instance.
215         
216         @type directory: L{twisted.python.filepath.FilePath}
217         @param directory: the directory to check for cached files
218         @type db: L{db.DB}
219         @param db: the database to use for looking up files and hashes
220         @type manager: L{apt_p2p.AptP2P}
221         @param manager: the main program object to send requests to
222         """
223         self.directory = directory
224         self.db = db
225         self.manager = manager
226         self.uploadLimit = None
227         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
228             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
229         self.factory = None
230
231     def getHTTPFactory(self):
232         """Initialize and get the factory for this HTTP server."""
233         if self.factory is None:
234             self.factory = channel.HTTPFactory(server.Site(self),
235                                                **{'maxPipeline': 10, 
236                                                   'betweenRequestsTimeOut': 60})
237             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
238             self.factory.protocol = UploadThrottlingProtocol
239             if self.manager:
240                 self.factory.protocol.stats = self.manager.stats
241         return self.factory
242
243     def render(self, ctx):
244         """Render a web page with descriptive statistics."""
245         if self.manager:
246             return http.Response(
247                 200,
248                 {'content-type': http_headers.MimeType('text', 'html')},
249                 self.manager.getStats())
250         else:
251             return http.Response(
252                 200,
253                 {'content-type': http_headers.MimeType('text', 'html')},
254                 '<html><body><p>Some Statistics</body></html>')
255
256     def locateChild(self, request, segments):
257         """Process the incoming request."""
258         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
259         name = segments[0]
260         
261         # If the request is for a shared file (from a peer)
262         if name == '~':
263             if len(segments) != 2:
264                 log.msg('Got a malformed request from %s' % request.remoteAddr)
265                 return None, ()
266             
267             # Find the file in the database
268             # Have to unquote_plus the uri, because the segments are unquoted by twisted
269             hash = unquote_plus(request.uri[3:])
270             files = self.db.lookupHash(hash)
271             if files:
272                 # If it is a file, return it
273                 if 'path' in files[0]:
274                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
275                     return FileUploader(files[0]['path'].path), ()
276                 else:
277                     # It's not for a file, but for a piece string, so return that
278                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
279                     return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
280             else:
281                 log.msg('Hash could not be found in database: %r' % hash)
282
283         if len(name) > 1:
284             # It's a request from apt
285
286             # Only local requests (apt) get past this point
287             if request.remoteAddr.host != "127.0.0.1":
288                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
289                 return None, ()
290
291             # Block access to index .diff files (for now)
292             if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
293                 return None, ()
294              
295             return FileDownloader(self.directory.path, self.manager), segments[0:]
296         else:
297             # Will render the statistics page
298
299             # Only local requests for stats are allowed
300             if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
301                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
302                 return None, ()
303
304             return self, ()
305         
306         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
307         return None, ()
308
309 class TestTopLevel(unittest.TestCase):
310     """Unit tests for the HTTP Server."""
311     
312     client = None
313     pending_calls = []
314     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
315     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
316     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
317     
318     def setUp(self):
319         self.client = TopLevel(FilePath('/boot'), self, None)
320         
321     def lookupHash(self, hash):
322         if hash == self.torrent_hash:
323             return [{'pieces': self.torrent}]
324         elif hash == self.file_hash:
325             return [{'path': FilePath('/boot/grub/stage2')}]
326         else:
327             return []
328         
329     def create_request(self, host, path):
330         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
331         class addr:
332             host = ''
333             port = 0
334         req.remoteAddr = addr()
335         req.remoteAddr.host = host
336         req.remoteAddr.port = 23456
337         server.Request._parseURL(req)
338         return req
339         
340     def test_unauthorized(self):
341         req = self.create_request('128.0.0.1', '/foo/bar')
342         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
343         
344     def test_Packages_diff(self):
345         req = self.create_request('127.0.0.1',
346                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
347         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
348         
349     def test_Statistics(self):
350         req = self.create_request('127.0.0.1', '/')
351         res = req._getChild(None, self.client, req.postpath)
352         self.failIfEqual(res, None)
353         df = defer.maybeDeferred(res.renderHTTP, req)
354         df.addCallback(self.check_resp, 200)
355         return df
356         
357     def test_apt_download(self):
358         req = self.create_request('127.0.0.1',
359                 '/ftp.us.debian.org/debian/dists/stable/Release')
360         res = req._getChild(None, self.client, req.postpath)
361         self.failIfEqual(res, None)
362         self.failUnless(isinstance(res, FileDownloader))
363         df = defer.maybeDeferred(res.renderHTTP, req)
364         df.addCallback(self.check_resp, 404)
365         return df
366         
367     def test_torrent_upload(self):
368         req = self.create_request('123.45.67.89',
369                                   '/~/' + quote_plus(self.torrent_hash))
370         res = req._getChild(None, self.client, req.postpath)
371         self.failIfEqual(res, None)
372         self.failUnless(isinstance(res, static.Data))
373         df = defer.maybeDeferred(res.renderHTTP, req)
374         df.addCallback(self.check_resp, 200)
375         return df
376         
377     def test_file_upload(self):
378         req = self.create_request('123.45.67.89',
379                                   '/~/' + quote_plus(self.file_hash))
380         res = req._getChild(None, self.client, req.postpath)
381         self.failIfEqual(res, None)
382         self.failUnless(isinstance(res, FileUploader))
383         df = defer.maybeDeferred(res.renderHTTP, req)
384         df.addCallback(self.check_resp, 200)
385         return df
386     
387     def test_missing_hash(self):
388         req = self.create_request('123.45.67.89',
389                                   '/~/' + quote_plus('foobar'))
390         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
391
392     def check_resp(self, resp, code):
393         self.failUnlessEqual(resp.code, code)
394         return resp
395         
396     def tearDown(self):
397         for p in self.pending_calls:
398             if p.active():
399                 p.cancel()
400         self.pending_calls = []
401         if self.client:
402             self.client = None
403
404 if __name__ == '__builtin__':
405     # Running from twistd -ny HTTPServer.py
406     # Then test with:
407     #   wget -S 'http://localhost:18080/~/whatever'
408     #   wget -S 'http://localhost:18080/~/pieces'
409
410     import os.path
411     from twisted.python.filepath import FilePath
412     
413     class DB:
414         def lookupHash(self, hash):
415             if hash == 'pieces':
416                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
417             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
418     
419     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
420     factory = t.getHTTPFactory()
421     
422     # Standard twisted application Boilerplate
423     from twisted.application import service, strports
424     application = service.Application("demoserver")
425     s = strports.service('tcp:18080', factory)
426     s.setServiceParent(application)