Better identification of peer uploads by the HTTP server.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33         
34     def renderHTTP(self, req):
35         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
36         resp = super(FileDownloader, self).renderHTTP(req)
37         if isinstance(resp, defer.Deferred):
38             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
39                               callbackArgs = (req, ), errbackArgs = (req, ))
40         else:
41             resp = self._renderHTTP_done(resp, req)
42         return resp
43         
44     def _renderHTTP_done(self, resp, req):
45         log.msg('Initial response to %s: %r' % (req.uri, resp))
46         
47         if self.manager:
48             path = 'http:/' + req.uri
49             if resp.code >= 200 and resp.code < 400:
50                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
51             
52             log.msg('Not found, trying other methods for %s' % req.uri)
53             return self.manager.get_resp(req, path)
54         
55         return resp
56
57     def _renderHTTP_error(self, err, req):
58         log.msg('Failed to render %s: %r' % (req.uri, err))
59         log.err(err)
60         
61         if self.manager:
62             path = 'http:/' + req.uri
63             return self.manager.get_resp(req, path)
64         
65         return err
66
67     def createSimilarFile(self, path):
68         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
69                               self.processors, self.indexNames[:])
70         
71 class UploadStream:
72     """Identifier for streams that are uploaded to peers."""
73     
74 class FileUploaderStream(stream.FileStream, UploadStream):
75     """Modified to make it suitable for streaming to peers.
76     
77     Streams the file in small chunks to make it easier to throttle the
78     streaming to peers.
79     
80     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
81     """
82
83     CHUNK_SIZE = 4*1024
84     
85     def read(self, sendfile=False):
86         if self.f is None:
87             return None
88
89         length = self.length
90         if length == 0:
91             self.f = None
92             return None
93         
94         # Remove the SendFileBuffer and mmap use, just use string reads and writes
95
96         readSize = min(length, self.CHUNK_SIZE)
97
98         self.f.seek(self.start)
99         b = self.f.read(readSize)
100         bytesRead = len(b)
101         if not bytesRead:
102             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
103         else:
104             self.length -= bytesRead
105             self.start += bytesRead
106             return b
107
108 class PiecesUploaderStream(stream.MemoryStream, UploadStream):
109     """Modified to identify it for streaming to peers."""
110
111 class PiecesUploader(static.Data):
112     """Modified to identify it for peer requests.
113     
114     Uses the modified L{PieceUploaderStream} to stream the pieces for throttling.
115     """
116
117     def render(self, req):
118         return http.Response(responsecode.OK,
119                              http_headers.Headers({'content-type': self.contentType()}),
120                              stream=PiecesUploaderStream(self.data))
121         
122 class FileUploader(static.File):
123     """Modified to make it suitable for peer requests.
124     
125     Uses the modified L{FileUploaderStream} to stream the file for throttling,
126     and doesn't do any listing of directory contents.
127     """
128
129     def render(self, req):
130         if not self.fp.exists():
131             return responsecode.NOT_FOUND
132
133         if self.fp.isdir():
134             # Don't try to render a directory listing
135             return responsecode.NOT_FOUND
136
137         try:
138             f = self.fp.open()
139         except IOError, e:
140             import errno
141             if e[0] == errno.EACCES:
142                 return responsecode.FORBIDDEN
143             elif e[0] == errno.ENOENT:
144                 return responsecode.NOT_FOUND
145             else:
146                 raise
147
148         response = http.Response()
149         # Use the modified FileStream
150         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
151
152         for (header, value) in (
153             ("content-type", self.contentType()),
154             ("content-encoding", self.contentEncoding()),
155         ):
156             if value is not None:
157                 response.headers.setHeader(header, value)
158
159         return response
160
161 class UploadThrottlingProtocol(ThrottlingProtocol):
162     """Protocol for throttling uploads.
163     
164     Determines whether or not to throttle the upload based on the type of stream.
165     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
166     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
167     """
168     
169     stats = None
170
171     def __init__(self, factory, wrappedProtocol):
172         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
173         self.throttle = False
174
175     def write(self, data):
176         if self.throttle:
177             ThrottlingProtocol.write(self, data)
178             if self.stats:
179                 self.stats.sentBytes(len(data))
180         else:
181             ProtocolWrapper.write(self, data)
182
183     def writeSequence(self, seq):
184         if self.throttle:
185             ThrottlingProtocol.writeSequence(self, seq)
186             if self.stats:
187                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
188         else:
189             ProtocolWrapper.writeSequence(self, seq)
190
191     def registerProducer(self, producer, streaming):
192         ThrottlingProtocol.registerProducer(self, producer, streaming)
193         streamType = getattr(producer, 'stream', None)
194         log.msg('Registered a producer %r with type %r' % (producer, streamType))
195         if isinstance(streamType, UploadStream):
196             log.msg('Throttling')
197             self.throttle = True
198
199
200 class TopLevel(resource.Resource):
201     """The HTTP server for all requests, both from peers and apt.
202     
203     @type directory: L{twisted.python.filepath.FilePath}
204     @ivar directory: the directory to check for cached files
205     @type db: L{db.DB}
206     @ivar db: the database to use for looking up files and hashes
207     @type manager: L{apt_p2p.AptP2P}
208     @ivar manager: the main program object to send requests to
209     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
210     @ivar factory: the factory to use to serve HTTP requests
211     """
212     
213     addSlash = True
214     
215     def __init__(self, directory, db, manager):
216         """Initialize the instance.
217         
218         @type directory: L{twisted.python.filepath.FilePath}
219         @param directory: the directory to check for cached files
220         @type db: L{db.DB}
221         @param db: the database to use for looking up files and hashes
222         @type manager: L{apt_p2p.AptP2P}
223         @param manager: the main program object to send requests to
224         """
225         self.directory = directory
226         self.db = db
227         self.manager = manager
228         self.uploadLimit = None
229         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
230             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
231         self.factory = None
232
233     def getHTTPFactory(self):
234         """Initialize and get the factory for this HTTP server."""
235         if self.factory is None:
236             self.factory = channel.HTTPFactory(server.Site(self),
237                                                **{'maxPipeline': 10, 
238                                                   'betweenRequestsTimeOut': 60})
239             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
240             self.factory.protocol = UploadThrottlingProtocol
241             if self.manager:
242                 self.factory.protocol.stats = self.manager.stats
243         return self.factory
244
245     def render(self, ctx):
246         """Render a web page with descriptive statistics."""
247         if self.manager:
248             return http.Response(
249                 200,
250                 {'content-type': http_headers.MimeType('text', 'html')},
251                 self.manager.getStats())
252         else:
253             return http.Response(
254                 200,
255                 {'content-type': http_headers.MimeType('text', 'html')},
256                 '<html><body><p>Some Statistics</body></html>')
257
258     def locateChild(self, request, segments):
259         """Process the incoming request."""
260         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
261         name = segments[0]
262         
263         # If the request is for a shared file (from a peer)
264         if name == '~':
265             if len(segments) != 2:
266                 log.msg('Got a malformed request from %s' % request.remoteAddr)
267                 return None, ()
268             
269             # Find the file in the database
270             # Have to unquote_plus the uri, because the segments are unquoted by twisted
271             hash = unquote_plus(request.uri[3:])
272             files = self.db.lookupHash(hash)
273             if files:
274                 # If it is a file, return it
275                 if 'path' in files[0]:
276                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
277                     return FileUploader(files[0]['path'].path), ()
278                 else:
279                     # It's not for a file, but for a piece string, so return that
280                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
281                     return PiecesUploader(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
282             else:
283                 log.msg('Hash could not be found in database: %r' % hash)
284
285         if len(name) > 1:
286             # It's a request from apt
287
288             # Only local requests (apt) get past this point
289             if request.remoteAddr.host != "127.0.0.1":
290                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
291                 return None, ()
292
293             # Block access to index .diff files (for now)
294             if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
295                 return None, ()
296              
297             return FileDownloader(self.directory.path, self.manager), segments[0:]
298         else:
299             # Will render the statistics page
300
301             # Only local requests for stats are allowed
302             if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
303                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
304                 return None, ()
305
306             return self, ()
307         
308         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
309         return None, ()
310
311 class TestTopLevel(unittest.TestCase):
312     """Unit tests for the HTTP Server."""
313     
314     client = None
315     pending_calls = []
316     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
317     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
318     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
319     
320     def setUp(self):
321         self.client = TopLevel(FilePath('/boot'), self, None)
322         
323     def lookupHash(self, hash):
324         if hash == self.torrent_hash:
325             return [{'pieces': self.torrent}]
326         elif hash == self.file_hash:
327             return [{'path': FilePath('/boot/grub/stage2')}]
328         else:
329             return []
330         
331     def create_request(self, host, path):
332         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
333         class addr:
334             host = ''
335             port = 0
336         req.remoteAddr = addr()
337         req.remoteAddr.host = host
338         req.remoteAddr.port = 23456
339         server.Request._parseURL(req)
340         return req
341         
342     def test_unauthorized(self):
343         req = self.create_request('128.0.0.1', '/foo/bar')
344         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
345         
346     def test_Packages_diff(self):
347         req = self.create_request('127.0.0.1',
348                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
349         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
350         
351     def test_Statistics(self):
352         req = self.create_request('127.0.0.1', '/')
353         res = req._getChild(None, self.client, req.postpath)
354         self.failIfEqual(res, None)
355         df = defer.maybeDeferred(res.renderHTTP, req)
356         df.addCallback(self.check_resp, 200)
357         return df
358         
359     def test_apt_download(self):
360         req = self.create_request('127.0.0.1',
361                 '/ftp.us.debian.org/debian/dists/stable/Release')
362         res = req._getChild(None, self.client, req.postpath)
363         self.failIfEqual(res, None)
364         self.failUnless(isinstance(res, FileDownloader))
365         df = defer.maybeDeferred(res.renderHTTP, req)
366         df.addCallback(self.check_resp, 404)
367         return df
368         
369     def test_torrent_upload(self):
370         req = self.create_request('123.45.67.89',
371                                   '/~/' + quote_plus(self.torrent_hash))
372         res = req._getChild(None, self.client, req.postpath)
373         self.failIfEqual(res, None)
374         self.failUnless(isinstance(res, static.Data))
375         df = defer.maybeDeferred(res.renderHTTP, req)
376         df.addCallback(self.check_resp, 200)
377         return df
378         
379     def test_file_upload(self):
380         req = self.create_request('123.45.67.89',
381                                   '/~/' + quote_plus(self.file_hash))
382         res = req._getChild(None, self.client, req.postpath)
383         self.failIfEqual(res, None)
384         self.failUnless(isinstance(res, FileUploader))
385         df = defer.maybeDeferred(res.renderHTTP, req)
386         df.addCallback(self.check_resp, 200)
387         return df
388     
389     def test_missing_hash(self):
390         req = self.create_request('123.45.67.89',
391                                   '/~/' + quote_plus('foobar'))
392         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
393
394     def check_resp(self, resp, code):
395         self.failUnlessEqual(resp.code, code)
396         return resp
397         
398     def tearDown(self):
399         for p in self.pending_calls:
400             if p.active():
401                 p.cancel()
402         self.pending_calls = []
403         if self.client:
404             self.client = None
405
406 if __name__ == '__builtin__':
407     # Running from twistd -ny HTTPServer.py
408     # Then test with:
409     #   wget -S 'http://localhost:18080/~/whatever'
410     #   wget -S 'http://localhost:18080/~/pieces'
411
412     import os.path
413     from twisted.python.filepath import FilePath
414     
415     class DB:
416         def lookupHash(self, hash):
417             if hash == 'pieces':
418                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
419             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
420     
421     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
422     factory = t.getHTTPFactory()
423     
424     # Standard twisted application Boilerplate
425     from twisted.application import service, strports
426     application = service.Application("demoserver")
427     s = strports.service('tcp:18080', factory)
428     s.setServiceParent(application)