3d43fa04b48fa24ad1886c02c939203c0db916f0
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33         
34     def renderHTTP(self, req):
35         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
36         resp = super(FileDownloader, self).renderHTTP(req)
37         if isinstance(resp, defer.Deferred):
38             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
39                               callbackArgs = (req, ), errbackArgs = (req, ))
40         else:
41             resp = self._renderHTTP_done(resp, req)
42         return resp
43         
44     def _renderHTTP_done(self, resp, req):
45         log.msg('Initial response to %s: %r' % (req.uri, resp))
46         
47         if self.manager:
48             path = 'http:/' + req.uri
49             if resp.code >= 200 and resp.code < 400:
50                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
51             
52             log.msg('Not found, trying other methods for %s' % req.uri)
53             return self.manager.get_resp(req, path)
54         
55         return resp
56
57     def _renderHTTP_error(self, err, req):
58         log.msg('Failed to render %s: %r' % (req.uri, err))
59         log.err(err)
60         
61         if self.manager:
62             path = 'http:/' + req.uri
63             return self.manager.get_resp(req, path)
64         
65         return err
66
67     def createSimilarFile(self, path):
68         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
69                               self.processors, self.indexNames[:])
70         
71 class FileUploaderStream(stream.FileStream):
72     """Modified to make it suitable for streaming to peers.
73     
74     Streams the file in small chunks to make it easier to throttle the
75     streaming to peers.
76     
77     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
78     """
79
80     CHUNK_SIZE = 4*1024
81     
82     def read(self, sendfile=False):
83         if self.f is None:
84             return None
85
86         length = self.length
87         if length == 0:
88             self.f = None
89             return None
90         
91         # Remove the SendFileBuffer and mmap use, just use string reads and writes
92
93         readSize = min(length, self.CHUNK_SIZE)
94
95         self.f.seek(self.start)
96         b = self.f.read(readSize)
97         bytesRead = len(b)
98         if not bytesRead:
99             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
100         else:
101             self.length -= bytesRead
102             self.start += bytesRead
103             return b
104
105
106 class FileUploader(static.File):
107     """Modified to make it suitable for peer requests.
108     
109     Uses the modified L{FileUploaderStream} to stream the file for throttling,
110     and doesn't do any listing of directory contents.
111     """
112
113     def render(self, req):
114         if not self.fp.exists():
115             return responsecode.NOT_FOUND
116
117         if self.fp.isdir():
118             # Don't try to render a directory listing
119             return responsecode.NOT_FOUND
120
121         try:
122             f = self.fp.open()
123         except IOError, e:
124             import errno
125             if e[0] == errno.EACCES:
126                 return responsecode.FORBIDDEN
127             elif e[0] == errno.ENOENT:
128                 return responsecode.NOT_FOUND
129             else:
130                 raise
131
132         response = http.Response()
133         # Use the modified FileStream
134         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
135
136         for (header, value) in (
137             ("content-type", self.contentType()),
138             ("content-encoding", self.contentEncoding()),
139         ):
140             if value is not None:
141                 response.headers.setHeader(header, value)
142
143         return response
144
145 class UploadThrottlingProtocol(ThrottlingProtocol):
146     """Protocol for throttling uploads.
147     
148     Determines whether or not to throttle the upload based on the type of stream.
149     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
150     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
151     """
152     
153     stats = None
154
155     def __init__(self, factory, wrappedProtocol):
156         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
157         self.throttle = False
158
159     def write(self, data):
160         if self.throttle:
161             ThrottlingProtocol.write(self, data)
162             if stats:
163                 stats.sentBytes(len(data))
164         else:
165             ProtocolWrapper.write(self, data)
166
167     def writeSequence(self, seq):
168         if self.throttle:
169             ThrottlingProtocol.writeSequence(self, seq)
170             if stats:
171                 stats.sentBytes(reduce(operator.add, map(len, seq)))
172         else:
173             ProtocolWrapper.writeSequence(self, seq)
174
175     def registerProducer(self, producer, streaming):
176         ThrottlingProtocol.registerProducer(self, producer, streaming)
177         streamType = getattr(producer, 'stream', None)
178         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
179             self.throttle = True
180
181
182 class TopLevel(resource.Resource):
183     """The HTTP server for all requests, both from peers and apt.
184     
185     @type directory: L{twisted.python.filepath.FilePath}
186     @ivar directory: the directory to check for cached files
187     @type db: L{db.DB}
188     @ivar db: the database to use for looking up files and hashes
189     @type manager: L{apt_p2p.AptP2P}
190     @ivar manager: the main program object to send requests to
191     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
192     @ivar factory: the factory to use to serve HTTP requests
193     """
194     
195     addSlash = True
196     
197     def __init__(self, directory, db, manager):
198         """Initialize the instance.
199         
200         @type directory: L{twisted.python.filepath.FilePath}
201         @param directory: the directory to check for cached files
202         @type db: L{db.DB}
203         @param db: the database to use for looking up files and hashes
204         @type manager: L{apt_p2p.AptP2P}
205         @param manager: the main program object to send requests to
206         """
207         self.directory = directory
208         self.db = db
209         self.manager = manager
210         self.uploadLimit = None
211         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
212             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
213         self.factory = None
214
215     def getHTTPFactory(self):
216         """Initialize and get the factory for this HTTP server."""
217         if self.factory is None:
218             self.factory = channel.HTTPFactory(server.Site(self),
219                                                **{'maxPipeline': 10, 
220                                                   'betweenRequestsTimeOut': 60})
221             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
222             self.factory.protocol = UploadThrottlingProtocol
223             self.factory.protocol.stats = self.manager.stats
224         return self.factory
225
226     def render(self, ctx):
227         """Render a web page with descriptive statistics."""
228         if self.manager:
229             return http.Response(
230                 200,
231                 {'content-type': http_headers.MimeType('text', 'html')},
232                 self.manager.getStats())
233         else:
234             return http.Response(
235                 200,
236                 {'content-type': http_headers.MimeType('text', 'html')},
237                 '<html><body><p>Some Statistics</body></html>')
238
239     def locateChild(self, request, segments):
240         """Process the incoming request."""
241         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
242         name = segments[0]
243         
244         # If the request is for a shared file (from a peer)
245         if name == '~':
246             if len(segments) != 2:
247                 log.msg('Got a malformed request from %s' % request.remoteAddr)
248                 return None, ()
249             
250             # Find the file in the database
251             # Have to unquote_plus the uri, because the segments are unquoted by twisted
252             hash = unquote_plus(request.uri[3:])
253             files = self.db.lookupHash(hash)
254             if files:
255                 # If it is a file, return it
256                 if 'path' in files[0]:
257                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
258                     return FileUploader(files[0]['path'].path), ()
259                 else:
260                     # It's not for a file, but for a piece string, so return that
261                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
262                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
263             else:
264                 log.msg('Hash could not be found in database: %r' % hash)
265
266         # Only local requests (apt) get past this point
267         if request.remoteAddr.host != "127.0.0.1":
268             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
269             return None, ()
270         
271         # Block access to index .diff files (for now)
272         if 'Packages.diff' in segments or 'Sources.diff' in segments:
273             return None, ()
274          
275         if len(name) > 1:
276             # It's a request from apt
277             return FileDownloader(self.directory.path, self.manager), segments[0:]
278         else:
279             # Will render the statistics page
280             return self, ()
281         
282         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
283         return None, ()
284
285 class TestTopLevel(unittest.TestCase):
286     """Unit tests for the HTTP Server."""
287     
288     client = None
289     pending_calls = []
290     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
291     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
292     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
293     
294     def setUp(self):
295         self.client = TopLevel(FilePath('/boot'), self, None)
296         
297     def lookupHash(self, hash):
298         if hash == self.torrent_hash:
299             return [{'pieces': self.torrent}]
300         elif hash == self.file_hash:
301             return [{'path': FilePath('/boot/grub/stage2')}]
302         else:
303             return []
304         
305     def create_request(self, host, path):
306         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
307         class addr:
308             host = ''
309             port = 0
310         req.remoteAddr = addr()
311         req.remoteAddr.host = host
312         req.remoteAddr.port = 23456
313         server.Request._parseURL(req)
314         return req
315         
316     def test_unauthorized(self):
317         req = self.create_request('128.0.0.1', '/foo/bar')
318         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
319         
320     def test_Packages_diff(self):
321         req = self.create_request('127.0.0.1',
322                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
323         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
324         
325     def test_Statistics(self):
326         req = self.create_request('127.0.0.1', '/')
327         res = req._getChild(None, self.client, req.postpath)
328         self.failIfEqual(res, None)
329         df = defer.maybeDeferred(res.renderHTTP, req)
330         df.addCallback(self.check_resp, 200)
331         return df
332         
333     def test_apt_download(self):
334         req = self.create_request('127.0.0.1',
335                 '/ftp.us.debian.org/debian/dists/stable/Release')
336         res = req._getChild(None, self.client, req.postpath)
337         self.failIfEqual(res, None)
338         self.failUnless(isinstance(res, FileDownloader))
339         df = defer.maybeDeferred(res.renderHTTP, req)
340         df.addCallback(self.check_resp, 404)
341         return df
342         
343     def test_torrent_upload(self):
344         req = self.create_request('123.45.67.89',
345                                   '/~/' + quote_plus(self.torrent_hash))
346         res = req._getChild(None, self.client, req.postpath)
347         self.failIfEqual(res, None)
348         self.failUnless(isinstance(res, static.Data))
349         df = defer.maybeDeferred(res.renderHTTP, req)
350         df.addCallback(self.check_resp, 200)
351         return df
352         
353     def test_file_upload(self):
354         req = self.create_request('123.45.67.89',
355                                   '/~/' + quote_plus(self.file_hash))
356         res = req._getChild(None, self.client, req.postpath)
357         self.failIfEqual(res, None)
358         self.failUnless(isinstance(res, FileUploader))
359         df = defer.maybeDeferred(res.renderHTTP, req)
360         df.addCallback(self.check_resp, 200)
361         return df
362     
363     def test_missing_hash(self):
364         req = self.create_request('123.45.67.89',
365                                   '/~/' + quote_plus('foobar'))
366         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
367
368     def check_resp(self, resp, code):
369         self.failUnlessEqual(resp.code, code)
370         return resp
371         
372     def tearDown(self):
373         for p in self.pending_calls:
374             if p.active():
375                 p.cancel()
376         self.pending_calls = []
377         if self.client:
378             self.client = None
379
380 if __name__ == '__builtin__':
381     # Running from twistd -ny HTTPServer.py
382     # Then test with:
383     #   wget -S 'http://localhost:18080/~/whatever'
384     #   wget -S 'http://localhost:18080/~/pieces'
385
386     import os.path
387     from twisted.python.filepath import FilePath
388     
389     class DB:
390         def lookupHash(self, hash):
391             if hash == 'pieces':
392                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
393             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
394     
395     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
396     factory = t.getHTTPFactory()
397     
398     # Standard twisted application Boilerplate
399     from twisted.application import service, strports
400     application = service.Application("demoserver")
401     s = strports.service('tcp:18080', factory)
402     s.setServiceParent(application)