cc6233ba955b192a31b15b94a46c620ede753bda
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11 from twisted.trial import unittest
12 from twisted.python.filepath import FilePath
13
14 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
15 from apt_p2p_Khashmir.bencode import bencode
16
17 class FileDownloader(static.File):
18     """Modified to make it suitable for apt requests.
19     
20     Tries to find requests in the cache. Found files are first checked for
21     freshness before being sent. Requests for unfound and stale files are
22     forwarded to the main program for downloading.
23     
24     @type manager: L{apt_p2p.AptP2P}
25     @ivar manager: the main program to query 
26     """
27     
28     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
29         self.manager = manager
30         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
31         
32     def renderHTTP(self, req):
33         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
34         resp = super(FileDownloader, self).renderHTTP(req)
35         if isinstance(resp, defer.Deferred):
36             resp.addCallback(self._renderHTTP_done, req)
37         else:
38             resp = self._renderHTTP_done(resp, req)
39         return resp
40         
41     def _renderHTTP_done(self, resp, req):
42         log.msg('Initial response to %s: %r' % (req.uri, resp))
43         
44         if self.manager:
45             path = 'http:/' + req.uri
46             if resp.code >= 200 and resp.code < 400:
47                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
48             
49             log.msg('Not found, trying other methods for %s' % req.uri)
50             return self.manager.get_resp(req, path)
51         
52         return resp
53
54     def createSimilarFile(self, path):
55         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
56                               self.processors, self.indexNames[:])
57         
58 class FileUploaderStream(stream.FileStream):
59     """Modified to make it suitable for streaming to peers.
60     
61     Streams the file in small chunks to make it easier to throttle the
62     streaming to peers.
63     
64     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
65     """
66
67     CHUNK_SIZE = 4*1024
68     
69     def read(self, sendfile=False):
70         if self.f is None:
71             return None
72
73         length = self.length
74         if length == 0:
75             self.f = None
76             return None
77         
78         # Remove the SendFileBuffer and mmap use, just use string reads and writes
79
80         readSize = min(length, self.CHUNK_SIZE)
81
82         self.f.seek(self.start)
83         b = self.f.read(readSize)
84         bytesRead = len(b)
85         if not bytesRead:
86             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
87         else:
88             self.length -= bytesRead
89             self.start += bytesRead
90             return b
91
92
93 class FileUploader(static.File):
94     """Modified to make it suitable for peer requests.
95     
96     Uses the modified L{FileUploaderStream} to stream the file for throttling,
97     and doesn't do any listing of directory contents.
98     """
99
100     def render(self, req):
101         if not self.fp.exists():
102             return responsecode.NOT_FOUND
103
104         if self.fp.isdir():
105             # Don't try to render a directory listing
106             return responsecode.NOT_FOUND
107
108         try:
109             f = self.fp.open()
110         except IOError, e:
111             import errno
112             if e[0] == errno.EACCES:
113                 return responsecode.FORBIDDEN
114             elif e[0] == errno.ENOENT:
115                 return responsecode.NOT_FOUND
116             else:
117                 raise
118
119         response = http.Response()
120         # Use the modified FileStream
121         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
122
123         for (header, value) in (
124             ("content-type", self.contentType()),
125             ("content-encoding", self.contentEncoding()),
126         ):
127             if value is not None:
128                 response.headers.setHeader(header, value)
129
130         return response
131
132 class UploadThrottlingProtocol(ThrottlingProtocol):
133     """Protocol for throttling uploads.
134     
135     Determines whether or not to throttle the upload based on the type of stream.
136     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
137     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
138     """
139
140     def __init__(self, factory, wrappedProtocol):
141         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
142         self.throttle = False
143
144     def write(self, data):
145         if self.throttle:
146             ThrottlingProtocol.write(self, data)
147         else:
148             ProtocolWrapper.write(self, data)
149
150     def registerProducer(self, producer, streaming):
151         ThrottlingProtocol.registerProducer(self, producer, streaming)
152         streamType = getattr(producer, 'stream', None)
153         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
154             self.throttle = True
155
156
157 class TopLevel(resource.Resource):
158     """The HTTP server for all requests, both from peers and apt.
159     
160     @type directory: L{twisted.python.filepath.FilePath}
161     @ivar directory: the directory to check for cached files
162     @type db: L{db.DB}
163     @ivar db: the database to use for looking up files and hashes
164     @type manager: L{apt_p2p.AptP2P}
165     @ivar manager: the main program object to send requests to
166     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
167     @ivar factory: the factory to use to serve HTTP requests
168     """
169     
170     addSlash = True
171     
172     def __init__(self, directory, db, manager, uploadLimit):
173         """Initialize the instance.
174         
175         @type directory: L{twisted.python.filepath.FilePath}
176         @param directory: the directory to check for cached files
177         @type db: L{db.DB}
178         @param db: the database to use for looking up files and hashes
179         @type manager: L{apt_p2p.AptP2P}
180         @param manager: the main program object to send requests to
181         """
182         self.directory = directory
183         self.db = db
184         self.manager = manager
185         self.uploadLimit = None
186         if uploadLimit > 0:
187             self.uploadLimit = int(uploadLimit*1024)
188         self.factory = None
189
190     def getHTTPFactory(self):
191         """Initialize and get the factory for this HTTP server."""
192         if self.factory is None:
193             self.factory = channel.HTTPFactory(server.Site(self),
194                                                **{'maxPipeline': 10, 
195                                                   'betweenRequestsTimeOut': 60})
196             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
197             self.factory.protocol = UploadThrottlingProtocol
198         return self.factory
199
200     def render(self, ctx):
201         """Render a web page with descriptive statistics."""
202         if self.manager:
203             return http.Response(
204                 200,
205                 {'content-type': http_headers.MimeType('text', 'html')},
206                 self.manager.getStats())
207         else:
208             return http.Response(
209                 200,
210                 {'content-type': http_headers.MimeType('text', 'html')},
211                 '<html><body><p>Some Statistics</body></html>')
212
213     def locateChild(self, request, segments):
214         """Process the incoming request."""
215         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
216         name = segments[0]
217         
218         # If the request is for a shared file (from a peer)
219         if name == '~':
220             if len(segments) != 2:
221                 log.msg('Got a malformed request from %s' % request.remoteAddr)
222                 return None, ()
223             
224             # Find the file in the database
225             # Have to unquote_plus the uri, because the segments are unquoted by twisted
226             hash = unquote_plus(request.uri[3:])
227             files = self.db.lookupHash(hash)
228             if files:
229                 # If it is a file, return it
230                 if 'path' in files[0]:
231                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
232                     return FileUploader(files[0]['path'].path), ()
233                 else:
234                     # It's not for a file, but for a piece string, so return that
235                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
236                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
237             else:
238                 log.msg('Hash could not be found in database: %r' % hash)
239
240         # Only local requests (apt) get past this point
241         if request.remoteAddr.host != "127.0.0.1":
242             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
243             return None, ()
244         
245         # Block access to index .diff files (for now)
246         if 'Packages.diff' in segments or 'Sources.diff' in segments:
247             return None, ()
248          
249         if len(name) > 1:
250             # It's a request from apt
251             return FileDownloader(self.directory.path, self.manager), segments[0:]
252         else:
253             # Will render the statistics page
254             return self, ()
255         
256         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
257         return None, ()
258
259 class TestTopLevel(unittest.TestCase):
260     """Unit tests for the HTTP Server."""
261     
262     client = None
263     pending_calls = []
264     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
265     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
266     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
267     
268     def setUp(self):
269         self.client = TopLevel(FilePath('/boot'), self, None, 0)
270         
271     def lookupHash(self, hash):
272         if hash == self.torrent_hash:
273             return [{'pieces': self.torrent}]
274         elif hash == self.file_hash:
275             return [{'path': FilePath('/boot/initrd')}]
276         else:
277             return []
278         
279     def create_request(self, host, path):
280         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
281         class addr:
282             host = ''
283             port = 0
284         req.remoteAddr = addr()
285         req.remoteAddr.host = host
286         req.remoteAddr.port = 23456
287         server.Request._parseURL(req)
288         return req
289         
290     def test_unauthorized(self):
291         req = self.create_request('128.0.0.1', '/foo/bar')
292         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
293         
294     def test_Packages_diff(self):
295         req = self.create_request('127.0.0.1',
296                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
297         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
298         
299     def test_Statistics(self):
300         req = self.create_request('127.0.0.1', '/')
301         res = req._getChild(None, self.client, req.postpath)
302         self.failIfEqual(res, None)
303         df = defer.maybeDeferred(res.renderHTTP, req)
304         df.addCallback(self.check_resp, 200)
305         return df
306         
307     def test_apt_download(self):
308         req = self.create_request('127.0.0.1',
309                 '/ftp.us.debian.org/debian/dists/stable/Release')
310         res = req._getChild(None, self.client, req.postpath)
311         self.failIfEqual(res, None)
312         self.failUnless(isinstance(res, FileDownloader))
313         df = defer.maybeDeferred(res.renderHTTP, req)
314         df.addCallback(self.check_resp, 404)
315         return df
316         
317     def test_torrent_upload(self):
318         req = self.create_request('123.45.67.89',
319                                   '/~/' + quote_plus(self.torrent_hash))
320         res = req._getChild(None, self.client, req.postpath)
321         self.failIfEqual(res, None)
322         self.failUnless(isinstance(res, static.Data))
323         df = defer.maybeDeferred(res.renderHTTP, req)
324         df.addCallback(self.check_resp, 200)
325         return df
326         
327     def test_file_upload(self):
328         req = self.create_request('123.45.67.89',
329                                   '/~/' + quote_plus(self.file_hash))
330         res = req._getChild(None, self.client, req.postpath)
331         self.failIfEqual(res, None)
332         self.failUnless(isinstance(res, FileUploader))
333         df = defer.maybeDeferred(res.renderHTTP, req)
334         df.addCallback(self.check_resp, 200)
335         return df
336     
337     def test_missing_hash(self):
338         req = self.create_request('123.45.67.89',
339                                   '/~/' + quote_plus('foobar'))
340         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
341
342     def check_resp(self, resp, code):
343         self.failUnlessEqual(resp.code, code)
344         return resp
345         
346     def tearDown(self):
347         for p in self.pending_calls:
348             if p.active():
349                 p.cancel()
350         self.pending_calls = []
351         if self.client:
352             self.client = None
353
354 if __name__ == '__builtin__':
355     # Running from twistd -ny HTTPServer.py
356     # Then test with:
357     #   wget -S 'http://localhost:18080/~/whatever'
358     #   wget -S 'http://localhost:18080/~/pieces'
359
360     import os.path
361     from twisted.python.filepath import FilePath
362     
363     class DB:
364         def lookupHash(self, hash):
365             if hash == 'pieces':
366                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
367             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
368     
369     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
370     factory = t.getHTTPFactory()
371     
372     # Standard twisted application Boilerplate
373     from twisted.application import service, strports
374     application = service.Application("demoserver")
375     s = strports.service('tcp:18080', factory)
376     s.setServiceParent(application)