87e235f729fea40674086ef826975c12f88ef5c4
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33         
34     def renderHTTP(self, req):
35         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
36         resp = super(FileDownloader, self).renderHTTP(req)
37         if isinstance(resp, defer.Deferred):
38             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
39                               callbackArgs = (req, ), errbackArgs = (req, ))
40         else:
41             resp = self._renderHTTP_done(resp, req)
42         return resp
43         
44     def _renderHTTP_done(self, resp, req):
45         log.msg('Initial response to %s: %r' % (req.uri, resp))
46         
47         if self.manager:
48             path = 'http:/' + req.uri
49             if resp.code >= 200 and resp.code < 400:
50                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
51             
52             log.msg('Not found, trying other methods for %s' % req.uri)
53             return self.manager.get_resp(req, path)
54         
55         return resp
56
57     def _renderHTTP_error(self, err, req):
58         log.msg('Failed to render %s: %r' % (req.uri, err))
59         log.err(err)
60         
61         if self.manager:
62             path = 'http:/' + req.uri
63             return self.manager.get_resp(req, path)
64         
65         return err
66
67     def createSimilarFile(self, path):
68         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
69                               self.processors, self.indexNames[:])
70         
71 class FileUploaderStream(stream.FileStream):
72     """Modified to make it suitable for streaming to peers.
73     
74     Streams the file in small chunks to make it easier to throttle the
75     streaming to peers.
76     
77     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
78     """
79
80     CHUNK_SIZE = 4*1024
81     
82     def read(self, sendfile=False):
83         if self.f is None:
84             return None
85
86         length = self.length
87         if length == 0:
88             self.f = None
89             return None
90         
91         # Remove the SendFileBuffer and mmap use, just use string reads and writes
92
93         readSize = min(length, self.CHUNK_SIZE)
94
95         self.f.seek(self.start)
96         b = self.f.read(readSize)
97         bytesRead = len(b)
98         if not bytesRead:
99             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
100         else:
101             self.length -= bytesRead
102             self.start += bytesRead
103             return b
104
105
106 class FileUploader(static.File):
107     """Modified to make it suitable for peer requests.
108     
109     Uses the modified L{FileUploaderStream} to stream the file for throttling,
110     and doesn't do any listing of directory contents.
111     """
112
113     def render(self, req):
114         if not self.fp.exists():
115             return responsecode.NOT_FOUND
116
117         if self.fp.isdir():
118             # Don't try to render a directory listing
119             return responsecode.NOT_FOUND
120
121         try:
122             f = self.fp.open()
123         except IOError, e:
124             import errno
125             if e[0] == errno.EACCES:
126                 return responsecode.FORBIDDEN
127             elif e[0] == errno.ENOENT:
128                 return responsecode.NOT_FOUND
129             else:
130                 raise
131
132         response = http.Response()
133         # Use the modified FileStream
134         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
135
136         for (header, value) in (
137             ("content-type", self.contentType()),
138             ("content-encoding", self.contentEncoding()),
139         ):
140             if value is not None:
141                 response.headers.setHeader(header, value)
142
143         return response
144
145 class UploadThrottlingProtocol(ThrottlingProtocol):
146     """Protocol for throttling uploads.
147     
148     Determines whether or not to throttle the upload based on the type of stream.
149     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
150     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
151     """
152     
153     stats = None
154
155     def __init__(self, factory, wrappedProtocol):
156         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
157         self.throttle = False
158
159     def write(self, data):
160         if self.throttle:
161             ThrottlingProtocol.write(self, data)
162             if self.stats:
163                 self.stats.sentBytes(len(data))
164         else:
165             ProtocolWrapper.write(self, data)
166
167     def writeSequence(self, seq):
168         if self.throttle:
169             ThrottlingProtocol.writeSequence(self, seq)
170             if self.stats:
171                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
172         else:
173             ProtocolWrapper.writeSequence(self, seq)
174
175     def registerProducer(self, producer, streaming):
176         ThrottlingProtocol.registerProducer(self, producer, streaming)
177         streamType = getattr(producer, 'stream', None)
178         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
179             self.throttle = True
180
181
182 class TopLevel(resource.Resource):
183     """The HTTP server for all requests, both from peers and apt.
184     
185     @type directory: L{twisted.python.filepath.FilePath}
186     @ivar directory: the directory to check for cached files
187     @type db: L{db.DB}
188     @ivar db: the database to use for looking up files and hashes
189     @type manager: L{apt_p2p.AptP2P}
190     @ivar manager: the main program object to send requests to
191     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
192     @ivar factory: the factory to use to serve HTTP requests
193     """
194     
195     addSlash = True
196     
197     def __init__(self, directory, db, manager):
198         """Initialize the instance.
199         
200         @type directory: L{twisted.python.filepath.FilePath}
201         @param directory: the directory to check for cached files
202         @type db: L{db.DB}
203         @param db: the database to use for looking up files and hashes
204         @type manager: L{apt_p2p.AptP2P}
205         @param manager: the main program object to send requests to
206         """
207         self.directory = directory
208         self.db = db
209         self.manager = manager
210         self.uploadLimit = None
211         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
212             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
213         self.factory = None
214
215     def getHTTPFactory(self):
216         """Initialize and get the factory for this HTTP server."""
217         if self.factory is None:
218             self.factory = channel.HTTPFactory(server.Site(self),
219                                                **{'maxPipeline': 10, 
220                                                   'betweenRequestsTimeOut': 60})
221             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
222             self.factory.protocol = UploadThrottlingProtocol
223             if self.manager:
224                 self.factory.protocol.stats = self.manager.stats
225         return self.factory
226
227     def render(self, ctx):
228         """Render a web page with descriptive statistics."""
229         if self.manager:
230             return http.Response(
231                 200,
232                 {'content-type': http_headers.MimeType('text', 'html')},
233                 self.manager.getStats())
234         else:
235             return http.Response(
236                 200,
237                 {'content-type': http_headers.MimeType('text', 'html')},
238                 '<html><body><p>Some Statistics</body></html>')
239
240     def locateChild(self, request, segments):
241         """Process the incoming request."""
242         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
243         name = segments[0]
244         
245         # If the request is for a shared file (from a peer)
246         if name == '~':
247             if len(segments) != 2:
248                 log.msg('Got a malformed request from %s' % request.remoteAddr)
249                 return None, ()
250             
251             # Find the file in the database
252             # Have to unquote_plus the uri, because the segments are unquoted by twisted
253             hash = unquote_plus(request.uri[3:])
254             files = self.db.lookupHash(hash)
255             if files:
256                 # If it is a file, return it
257                 if 'path' in files[0]:
258                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
259                     return FileUploader(files[0]['path'].path), ()
260                 else:
261                     # It's not for a file, but for a piece string, so return that
262                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
263                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
264             else:
265                 log.msg('Hash could not be found in database: %r' % hash)
266
267         # Only local requests (apt) get past this point
268         if request.remoteAddr.host != "127.0.0.1":
269             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
270             return None, ()
271         
272         # Block access to index .diff files (for now)
273         if 'Packages.diff' in segments or 'Sources.diff' in segments:
274             return None, ()
275          
276         if len(name) > 1:
277             # It's a request from apt
278             return FileDownloader(self.directory.path, self.manager), segments[0:]
279         else:
280             # Will render the statistics page
281             return self, ()
282         
283         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
284         return None, ()
285
286 class TestTopLevel(unittest.TestCase):
287     """Unit tests for the HTTP Server."""
288     
289     client = None
290     pending_calls = []
291     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
292     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
293     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
294     
295     def setUp(self):
296         self.client = TopLevel(FilePath('/boot'), self, None)
297         
298     def lookupHash(self, hash):
299         if hash == self.torrent_hash:
300             return [{'pieces': self.torrent}]
301         elif hash == self.file_hash:
302             return [{'path': FilePath('/boot/grub/stage2')}]
303         else:
304             return []
305         
306     def create_request(self, host, path):
307         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
308         class addr:
309             host = ''
310             port = 0
311         req.remoteAddr = addr()
312         req.remoteAddr.host = host
313         req.remoteAddr.port = 23456
314         server.Request._parseURL(req)
315         return req
316         
317     def test_unauthorized(self):
318         req = self.create_request('128.0.0.1', '/foo/bar')
319         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
320         
321     def test_Packages_diff(self):
322         req = self.create_request('127.0.0.1',
323                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
324         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
325         
326     def test_Statistics(self):
327         req = self.create_request('127.0.0.1', '/')
328         res = req._getChild(None, self.client, req.postpath)
329         self.failIfEqual(res, None)
330         df = defer.maybeDeferred(res.renderHTTP, req)
331         df.addCallback(self.check_resp, 200)
332         return df
333         
334     def test_apt_download(self):
335         req = self.create_request('127.0.0.1',
336                 '/ftp.us.debian.org/debian/dists/stable/Release')
337         res = req._getChild(None, self.client, req.postpath)
338         self.failIfEqual(res, None)
339         self.failUnless(isinstance(res, FileDownloader))
340         df = defer.maybeDeferred(res.renderHTTP, req)
341         df.addCallback(self.check_resp, 404)
342         return df
343         
344     def test_torrent_upload(self):
345         req = self.create_request('123.45.67.89',
346                                   '/~/' + quote_plus(self.torrent_hash))
347         res = req._getChild(None, self.client, req.postpath)
348         self.failIfEqual(res, None)
349         self.failUnless(isinstance(res, static.Data))
350         df = defer.maybeDeferred(res.renderHTTP, req)
351         df.addCallback(self.check_resp, 200)
352         return df
353         
354     def test_file_upload(self):
355         req = self.create_request('123.45.67.89',
356                                   '/~/' + quote_plus(self.file_hash))
357         res = req._getChild(None, self.client, req.postpath)
358         self.failIfEqual(res, None)
359         self.failUnless(isinstance(res, FileUploader))
360         df = defer.maybeDeferred(res.renderHTTP, req)
361         df.addCallback(self.check_resp, 200)
362         return df
363     
364     def test_missing_hash(self):
365         req = self.create_request('123.45.67.89',
366                                   '/~/' + quote_plus('foobar'))
367         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
368
369     def check_resp(self, resp, code):
370         self.failUnlessEqual(resp.code, code)
371         return resp
372         
373     def tearDown(self):
374         for p in self.pending_calls:
375             if p.active():
376                 p.cancel()
377         self.pending_calls = []
378         if self.client:
379             self.client = None
380
381 if __name__ == '__builtin__':
382     # Running from twistd -ny HTTPServer.py
383     # Then test with:
384     #   wget -S 'http://localhost:18080/~/whatever'
385     #   wget -S 'http://localhost:18080/~/pieces'
386
387     import os.path
388     from twisted.python.filepath import FilePath
389     
390     class DB:
391         def lookupHash(self, hash):
392             if hash == 'pieces':
393                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
394             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
395     
396     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
397     factory = t.getHTTPFactory()
398     
399     # Standard twisted application Boilerplate
400     from twisted.application import service, strports
401     application = service.Application("demoserver")
402     s = strports.service('tcp:18080', factory)
403     s.setServiceParent(application)