Block favicon.ico and allow remote stats requests (configurable).
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6 import operator
7
8 from twisted.python import log
9 from twisted.internet import defer
10 from twisted.web2 import server, http, resource, channel, stream
11 from twisted.web2 import static, http_headers, responsecode
12 from twisted.trial import unittest
13 from twisted.python.filepath import FilePath
14
15 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
16 from apt_p2p_conf import config
17 from apt_p2p_Khashmir.bencode import bencode
18
19 class FileDownloader(static.File):
20     """Modified to make it suitable for apt requests.
21     
22     Tries to find requests in the cache. Found files are first checked for
23     freshness before being sent. Requests for unfound and stale files are
24     forwarded to the main program for downloading.
25     
26     @type manager: L{apt_p2p.AptP2P}
27     @ivar manager: the main program to query 
28     """
29     
30     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
31         self.manager = manager
32         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
33         
34     def renderHTTP(self, req):
35         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
36         resp = super(FileDownloader, self).renderHTTP(req)
37         if isinstance(resp, defer.Deferred):
38             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
39                               callbackArgs = (req, ), errbackArgs = (req, ))
40         else:
41             resp = self._renderHTTP_done(resp, req)
42         return resp
43         
44     def _renderHTTP_done(self, resp, req):
45         log.msg('Initial response to %s: %r' % (req.uri, resp))
46         
47         if self.manager:
48             path = 'http:/' + req.uri
49             if resp.code >= 200 and resp.code < 400:
50                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
51             
52             log.msg('Not found, trying other methods for %s' % req.uri)
53             return self.manager.get_resp(req, path)
54         
55         return resp
56
57     def _renderHTTP_error(self, err, req):
58         log.msg('Failed to render %s: %r' % (req.uri, err))
59         log.err(err)
60         
61         if self.manager:
62             path = 'http:/' + req.uri
63             return self.manager.get_resp(req, path)
64         
65         return err
66
67     def createSimilarFile(self, path):
68         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
69                               self.processors, self.indexNames[:])
70         
71 class FileUploaderStream(stream.FileStream):
72     """Modified to make it suitable for streaming to peers.
73     
74     Streams the file in small chunks to make it easier to throttle the
75     streaming to peers.
76     
77     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
78     """
79
80     CHUNK_SIZE = 4*1024
81     
82     def read(self, sendfile=False):
83         if self.f is None:
84             return None
85
86         length = self.length
87         if length == 0:
88             self.f = None
89             return None
90         
91         # Remove the SendFileBuffer and mmap use, just use string reads and writes
92
93         readSize = min(length, self.CHUNK_SIZE)
94
95         self.f.seek(self.start)
96         b = self.f.read(readSize)
97         bytesRead = len(b)
98         if not bytesRead:
99             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
100         else:
101             self.length -= bytesRead
102             self.start += bytesRead
103             return b
104
105
106 class FileUploader(static.File):
107     """Modified to make it suitable for peer requests.
108     
109     Uses the modified L{FileUploaderStream} to stream the file for throttling,
110     and doesn't do any listing of directory contents.
111     """
112
113     def render(self, req):
114         if not self.fp.exists():
115             return responsecode.NOT_FOUND
116
117         if self.fp.isdir():
118             # Don't try to render a directory listing
119             return responsecode.NOT_FOUND
120
121         try:
122             f = self.fp.open()
123         except IOError, e:
124             import errno
125             if e[0] == errno.EACCES:
126                 return responsecode.FORBIDDEN
127             elif e[0] == errno.ENOENT:
128                 return responsecode.NOT_FOUND
129             else:
130                 raise
131
132         response = http.Response()
133         # Use the modified FileStream
134         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
135
136         for (header, value) in (
137             ("content-type", self.contentType()),
138             ("content-encoding", self.contentEncoding()),
139         ):
140             if value is not None:
141                 response.headers.setHeader(header, value)
142
143         return response
144
145 class UploadThrottlingProtocol(ThrottlingProtocol):
146     """Protocol for throttling uploads.
147     
148     Determines whether or not to throttle the upload based on the type of stream.
149     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
150     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
151     """
152     
153     stats = None
154
155     def __init__(self, factory, wrappedProtocol):
156         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
157         self.throttle = False
158
159     def write(self, data):
160         if self.throttle:
161             ThrottlingProtocol.write(self, data)
162             if self.stats:
163                 self.stats.sentBytes(len(data))
164         else:
165             ProtocolWrapper.write(self, data)
166
167     def writeSequence(self, seq):
168         if self.throttle:
169             ThrottlingProtocol.writeSequence(self, seq)
170             if self.stats:
171                 self.stats.sentBytes(reduce(operator.add, map(len, seq)))
172         else:
173             ProtocolWrapper.writeSequence(self, seq)
174
175     def registerProducer(self, producer, streaming):
176         ThrottlingProtocol.registerProducer(self, producer, streaming)
177         streamType = getattr(producer, 'stream', None)
178         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
179             self.throttle = True
180
181
182 class TopLevel(resource.Resource):
183     """The HTTP server for all requests, both from peers and apt.
184     
185     @type directory: L{twisted.python.filepath.FilePath}
186     @ivar directory: the directory to check for cached files
187     @type db: L{db.DB}
188     @ivar db: the database to use for looking up files and hashes
189     @type manager: L{apt_p2p.AptP2P}
190     @ivar manager: the main program object to send requests to
191     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
192     @ivar factory: the factory to use to serve HTTP requests
193     """
194     
195     addSlash = True
196     
197     def __init__(self, directory, db, manager):
198         """Initialize the instance.
199         
200         @type directory: L{twisted.python.filepath.FilePath}
201         @param directory: the directory to check for cached files
202         @type db: L{db.DB}
203         @param db: the database to use for looking up files and hashes
204         @type manager: L{apt_p2p.AptP2P}
205         @param manager: the main program object to send requests to
206         """
207         self.directory = directory
208         self.db = db
209         self.manager = manager
210         self.uploadLimit = None
211         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
212             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
213         self.factory = None
214
215     def getHTTPFactory(self):
216         """Initialize and get the factory for this HTTP server."""
217         if self.factory is None:
218             self.factory = channel.HTTPFactory(server.Site(self),
219                                                **{'maxPipeline': 10, 
220                                                   'betweenRequestsTimeOut': 60})
221             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
222             self.factory.protocol = UploadThrottlingProtocol
223             if self.manager:
224                 self.factory.protocol.stats = self.manager.stats
225         return self.factory
226
227     def render(self, ctx):
228         """Render a web page with descriptive statistics."""
229         if self.manager:
230             return http.Response(
231                 200,
232                 {'content-type': http_headers.MimeType('text', 'html')},
233                 self.manager.getStats())
234         else:
235             return http.Response(
236                 200,
237                 {'content-type': http_headers.MimeType('text', 'html')},
238                 '<html><body><p>Some Statistics</body></html>')
239
240     def locateChild(self, request, segments):
241         """Process the incoming request."""
242         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
243         name = segments[0]
244         
245         # If the request is for a shared file (from a peer)
246         if name == '~':
247             if len(segments) != 2:
248                 log.msg('Got a malformed request from %s' % request.remoteAddr)
249                 return None, ()
250             
251             # Find the file in the database
252             # Have to unquote_plus the uri, because the segments are unquoted by twisted
253             hash = unquote_plus(request.uri[3:])
254             files = self.db.lookupHash(hash)
255             if files:
256                 # If it is a file, return it
257                 if 'path' in files[0]:
258                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
259                     return FileUploader(files[0]['path'].path), ()
260                 else:
261                     # It's not for a file, but for a piece string, so return that
262                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
263                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
264             else:
265                 log.msg('Hash could not be found in database: %r' % hash)
266
267         if len(name) > 1:
268             # It's a request from apt
269
270             # Only local requests (apt) get past this point
271             if request.remoteAddr.host != "127.0.0.1":
272                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
273                 return None, ()
274
275             # Block access to index .diff files (for now)
276             if 'Packages.diff' in segments or 'Sources.diff' in segments or name == 'favicon.ico':
277                 return None, ()
278              
279             return FileDownloader(self.directory.path, self.manager), segments[0:]
280         else:
281             # Will render the statistics page
282
283             # Only local requests for stats are allowed
284             if not config.getboolean('DEFAULT', 'REMOTE_STATS') and request.remoteAddr.host != "127.0.0.1":
285                 log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
286                 return None, ()
287
288             return self, ()
289         
290         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
291         return None, ()
292
293 class TestTopLevel(unittest.TestCase):
294     """Unit tests for the HTTP Server."""
295     
296     client = None
297     pending_calls = []
298     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
299     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
300     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
301     
302     def setUp(self):
303         self.client = TopLevel(FilePath('/boot'), self, None)
304         
305     def lookupHash(self, hash):
306         if hash == self.torrent_hash:
307             return [{'pieces': self.torrent}]
308         elif hash == self.file_hash:
309             return [{'path': FilePath('/boot/grub/stage2')}]
310         else:
311             return []
312         
313     def create_request(self, host, path):
314         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
315         class addr:
316             host = ''
317             port = 0
318         req.remoteAddr = addr()
319         req.remoteAddr.host = host
320         req.remoteAddr.port = 23456
321         server.Request._parseURL(req)
322         return req
323         
324     def test_unauthorized(self):
325         req = self.create_request('128.0.0.1', '/foo/bar')
326         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
327         
328     def test_Packages_diff(self):
329         req = self.create_request('127.0.0.1',
330                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
331         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
332         
333     def test_Statistics(self):
334         req = self.create_request('127.0.0.1', '/')
335         res = req._getChild(None, self.client, req.postpath)
336         self.failIfEqual(res, None)
337         df = defer.maybeDeferred(res.renderHTTP, req)
338         df.addCallback(self.check_resp, 200)
339         return df
340         
341     def test_apt_download(self):
342         req = self.create_request('127.0.0.1',
343                 '/ftp.us.debian.org/debian/dists/stable/Release')
344         res = req._getChild(None, self.client, req.postpath)
345         self.failIfEqual(res, None)
346         self.failUnless(isinstance(res, FileDownloader))
347         df = defer.maybeDeferred(res.renderHTTP, req)
348         df.addCallback(self.check_resp, 404)
349         return df
350         
351     def test_torrent_upload(self):
352         req = self.create_request('123.45.67.89',
353                                   '/~/' + quote_plus(self.torrent_hash))
354         res = req._getChild(None, self.client, req.postpath)
355         self.failIfEqual(res, None)
356         self.failUnless(isinstance(res, static.Data))
357         df = defer.maybeDeferred(res.renderHTTP, req)
358         df.addCallback(self.check_resp, 200)
359         return df
360         
361     def test_file_upload(self):
362         req = self.create_request('123.45.67.89',
363                                   '/~/' + quote_plus(self.file_hash))
364         res = req._getChild(None, self.client, req.postpath)
365         self.failIfEqual(res, None)
366         self.failUnless(isinstance(res, FileUploader))
367         df = defer.maybeDeferred(res.renderHTTP, req)
368         df.addCallback(self.check_resp, 200)
369         return df
370     
371     def test_missing_hash(self):
372         req = self.create_request('123.45.67.89',
373                                   '/~/' + quote_plus('foobar'))
374         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
375
376     def check_resp(self, resp, code):
377         self.failUnlessEqual(resp.code, code)
378         return resp
379         
380     def tearDown(self):
381         for p in self.pending_calls:
382             if p.active():
383                 p.cancel()
384         self.pending_calls = []
385         if self.client:
386             self.client = None
387
388 if __name__ == '__builtin__':
389     # Running from twistd -ny HTTPServer.py
390     # Then test with:
391     #   wget -S 'http://localhost:18080/~/whatever'
392     #   wget -S 'http://localhost:18080/~/pieces'
393
394     import os.path
395     from twisted.python.filepath import FilePath
396     
397     class DB:
398         def lookupHash(self, hash):
399             if hash == 'pieces':
400                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
401             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
402     
403     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
404     factory = t.getHTTPFactory()
405     
406     # Standard twisted application Boilerplate
407     from twisted.application import service, strports
408     application = service.Application("demoserver")
409     s = strports.service('tcp:18080', factory)
410     s.setServiceParent(application)