More strict use of errbacks when using deferreds.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import quote_plus, unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11 from twisted.trial import unittest
12 from twisted.python.filepath import FilePath
13
14 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
15 from apt_p2p_conf import config
16 from apt_p2p_Khashmir.bencode import bencode
17
18 class FileDownloader(static.File):
19     """Modified to make it suitable for apt requests.
20     
21     Tries to find requests in the cache. Found files are first checked for
22     freshness before being sent. Requests for unfound and stale files are
23     forwarded to the main program for downloading.
24     
25     @type manager: L{apt_p2p.AptP2P}
26     @ivar manager: the main program to query 
27     """
28     
29     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
30         self.manager = manager
31         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
32         
33     def renderHTTP(self, req):
34         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
35         resp = super(FileDownloader, self).renderHTTP(req)
36         if isinstance(resp, defer.Deferred):
37             resp.addCallbacks(self._renderHTTP_done, self._renderHTTP_error,
38                               callbackArgs = (req, ), errbackArgs = (req, ))
39         else:
40             resp = self._renderHTTP_done(resp, req)
41         return resp
42         
43     def _renderHTTP_done(self, resp, req):
44         log.msg('Initial response to %s: %r' % (req.uri, resp))
45         
46         if self.manager:
47             path = 'http:/' + req.uri
48             if resp.code >= 200 and resp.code < 400:
49                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
50             
51             log.msg('Not found, trying other methods for %s' % req.uri)
52             return self.manager.get_resp(req, path)
53         
54         return resp
55
56     def _renderHTTP_error(self, err, req):
57         log.msg('Failed to render %s: %r' % (req.uri, err))
58         log.err(err)
59         
60         if self.manager:
61             path = 'http:/' + req.uri
62             return self.manager.get_resp(req, path)
63         
64         return err
65
66     def createSimilarFile(self, path):
67         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
68                               self.processors, self.indexNames[:])
69         
70 class FileUploaderStream(stream.FileStream):
71     """Modified to make it suitable for streaming to peers.
72     
73     Streams the file in small chunks to make it easier to throttle the
74     streaming to peers.
75     
76     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
77     """
78
79     CHUNK_SIZE = 4*1024
80     
81     def read(self, sendfile=False):
82         if self.f is None:
83             return None
84
85         length = self.length
86         if length == 0:
87             self.f = None
88             return None
89         
90         # Remove the SendFileBuffer and mmap use, just use string reads and writes
91
92         readSize = min(length, self.CHUNK_SIZE)
93
94         self.f.seek(self.start)
95         b = self.f.read(readSize)
96         bytesRead = len(b)
97         if not bytesRead:
98             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
99         else:
100             self.length -= bytesRead
101             self.start += bytesRead
102             return b
103
104
105 class FileUploader(static.File):
106     """Modified to make it suitable for peer requests.
107     
108     Uses the modified L{FileUploaderStream} to stream the file for throttling,
109     and doesn't do any listing of directory contents.
110     """
111
112     def render(self, req):
113         if not self.fp.exists():
114             return responsecode.NOT_FOUND
115
116         if self.fp.isdir():
117             # Don't try to render a directory listing
118             return responsecode.NOT_FOUND
119
120         try:
121             f = self.fp.open()
122         except IOError, e:
123             import errno
124             if e[0] == errno.EACCES:
125                 return responsecode.FORBIDDEN
126             elif e[0] == errno.ENOENT:
127                 return responsecode.NOT_FOUND
128             else:
129                 raise
130
131         response = http.Response()
132         # Use the modified FileStream
133         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
134
135         for (header, value) in (
136             ("content-type", self.contentType()),
137             ("content-encoding", self.contentEncoding()),
138         ):
139             if value is not None:
140                 response.headers.setHeader(header, value)
141
142         return response
143
144 class UploadThrottlingProtocol(ThrottlingProtocol):
145     """Protocol for throttling uploads.
146     
147     Determines whether or not to throttle the upload based on the type of stream.
148     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
149     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
150     """
151
152     def __init__(self, factory, wrappedProtocol):
153         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
154         self.throttle = False
155
156     def write(self, data):
157         if self.throttle:
158             ThrottlingProtocol.write(self, data)
159         else:
160             ProtocolWrapper.write(self, data)
161
162     def registerProducer(self, producer, streaming):
163         ThrottlingProtocol.registerProducer(self, producer, streaming)
164         streamType = getattr(producer, 'stream', None)
165         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
166             self.throttle = True
167
168
169 class TopLevel(resource.Resource):
170     """The HTTP server for all requests, both from peers and apt.
171     
172     @type directory: L{twisted.python.filepath.FilePath}
173     @ivar directory: the directory to check for cached files
174     @type db: L{db.DB}
175     @ivar db: the database to use for looking up files and hashes
176     @type manager: L{apt_p2p.AptP2P}
177     @ivar manager: the main program object to send requests to
178     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
179     @ivar factory: the factory to use to serve HTTP requests
180     """
181     
182     addSlash = True
183     
184     def __init__(self, directory, db, manager):
185         """Initialize the instance.
186         
187         @type directory: L{twisted.python.filepath.FilePath}
188         @param directory: the directory to check for cached files
189         @type db: L{db.DB}
190         @param db: the database to use for looking up files and hashes
191         @type manager: L{apt_p2p.AptP2P}
192         @param manager: the main program object to send requests to
193         """
194         self.directory = directory
195         self.db = db
196         self.manager = manager
197         self.uploadLimit = None
198         if config.getint('DEFAULT', 'UPLOAD_LIMIT') > 0:
199             self.uploadLimit = int(config.getint('DEFAULT', 'UPLOAD_LIMIT')*1024)
200         self.factory = None
201
202     def getHTTPFactory(self):
203         """Initialize and get the factory for this HTTP server."""
204         if self.factory is None:
205             self.factory = channel.HTTPFactory(server.Site(self),
206                                                **{'maxPipeline': 10, 
207                                                   'betweenRequestsTimeOut': 60})
208             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
209             self.factory.protocol = UploadThrottlingProtocol
210         return self.factory
211
212     def render(self, ctx):
213         """Render a web page with descriptive statistics."""
214         if self.manager:
215             return http.Response(
216                 200,
217                 {'content-type': http_headers.MimeType('text', 'html')},
218                 self.manager.getStats())
219         else:
220             return http.Response(
221                 200,
222                 {'content-type': http_headers.MimeType('text', 'html')},
223                 '<html><body><p>Some Statistics</body></html>')
224
225     def locateChild(self, request, segments):
226         """Process the incoming request."""
227         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
228         name = segments[0]
229         
230         # If the request is for a shared file (from a peer)
231         if name == '~':
232             if len(segments) != 2:
233                 log.msg('Got a malformed request from %s' % request.remoteAddr)
234                 return None, ()
235             
236             # Find the file in the database
237             # Have to unquote_plus the uri, because the segments are unquoted by twisted
238             hash = unquote_plus(request.uri[3:])
239             files = self.db.lookupHash(hash)
240             if files:
241                 # If it is a file, return it
242                 if 'path' in files[0]:
243                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
244                     return FileUploader(files[0]['path'].path), ()
245                 else:
246                     # It's not for a file, but for a piece string, so return that
247                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
248                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
249             else:
250                 log.msg('Hash could not be found in database: %r' % hash)
251
252         # Only local requests (apt) get past this point
253         if request.remoteAddr.host != "127.0.0.1":
254             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
255             return None, ()
256         
257         # Block access to index .diff files (for now)
258         if 'Packages.diff' in segments or 'Sources.diff' in segments:
259             return None, ()
260          
261         if len(name) > 1:
262             # It's a request from apt
263             return FileDownloader(self.directory.path, self.manager), segments[0:]
264         else:
265             # Will render the statistics page
266             return self, ()
267         
268         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
269         return None, ()
270
271 class TestTopLevel(unittest.TestCase):
272     """Unit tests for the HTTP Server."""
273     
274     client = None
275     pending_calls = []
276     torrent_hash = '\xca \xb8\x0c\x00\xe7\x07\xf8~])+\x9d\xe5_B\xff\x1a\xc4!'
277     torrent = 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
278     file_hash = '\xf8~])+\x9d\xe5_B\xff\x1a\xc4!\xca \xb8\x0c\x00\xe7\x07'
279     
280     def setUp(self):
281         self.client = TopLevel(FilePath('/boot'), self, None)
282         
283     def lookupHash(self, hash):
284         if hash == self.torrent_hash:
285             return [{'pieces': self.torrent}]
286         elif hash == self.file_hash:
287             return [{'path': FilePath('/boot/grub/stage2')}]
288         else:
289             return []
290         
291     def create_request(self, host, path):
292         req = server.Request(None, 'GET', path, (1,1), 0, http_headers.Headers())
293         class addr:
294             host = ''
295             port = 0
296         req.remoteAddr = addr()
297         req.remoteAddr.host = host
298         req.remoteAddr.port = 23456
299         server.Request._parseURL(req)
300         return req
301         
302     def test_unauthorized(self):
303         req = self.create_request('128.0.0.1', '/foo/bar')
304         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
305         
306     def test_Packages_diff(self):
307         req = self.create_request('127.0.0.1',
308                 '/ftp.us.debian.org/debian/dists/unstable/main/binary-i386/Packages.diff/Index')
309         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
310         
311     def test_Statistics(self):
312         req = self.create_request('127.0.0.1', '/')
313         res = req._getChild(None, self.client, req.postpath)
314         self.failIfEqual(res, None)
315         df = defer.maybeDeferred(res.renderHTTP, req)
316         df.addCallback(self.check_resp, 200)
317         return df
318         
319     def test_apt_download(self):
320         req = self.create_request('127.0.0.1',
321                 '/ftp.us.debian.org/debian/dists/stable/Release')
322         res = req._getChild(None, self.client, req.postpath)
323         self.failIfEqual(res, None)
324         self.failUnless(isinstance(res, FileDownloader))
325         df = defer.maybeDeferred(res.renderHTTP, req)
326         df.addCallback(self.check_resp, 404)
327         return df
328         
329     def test_torrent_upload(self):
330         req = self.create_request('123.45.67.89',
331                                   '/~/' + quote_plus(self.torrent_hash))
332         res = req._getChild(None, self.client, req.postpath)
333         self.failIfEqual(res, None)
334         self.failUnless(isinstance(res, static.Data))
335         df = defer.maybeDeferred(res.renderHTTP, req)
336         df.addCallback(self.check_resp, 200)
337         return df
338         
339     def test_file_upload(self):
340         req = self.create_request('123.45.67.89',
341                                   '/~/' + quote_plus(self.file_hash))
342         res = req._getChild(None, self.client, req.postpath)
343         self.failIfEqual(res, None)
344         self.failUnless(isinstance(res, FileUploader))
345         df = defer.maybeDeferred(res.renderHTTP, req)
346         df.addCallback(self.check_resp, 200)
347         return df
348     
349     def test_missing_hash(self):
350         req = self.create_request('123.45.67.89',
351                                   '/~/' + quote_plus('foobar'))
352         self.failUnlessRaises(http.HTTPError, req._getChild, None, self.client, req.postpath)
353
354     def check_resp(self, resp, code):
355         self.failUnlessEqual(resp.code, code)
356         return resp
357         
358     def tearDown(self):
359         for p in self.pending_calls:
360             if p.active():
361                 p.cancel()
362         self.pending_calls = []
363         if self.client:
364             self.client = None
365
366 if __name__ == '__builtin__':
367     # Running from twistd -ny HTTPServer.py
368     # Then test with:
369     #   wget -S 'http://localhost:18080/~/whatever'
370     #   wget -S 'http://localhost:18080/~/pieces'
371
372     import os.path
373     from twisted.python.filepath import FilePath
374     
375     class DB:
376         def lookupHash(self, hash):
377             if hash == 'pieces':
378                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
379             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
380     
381     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
382     factory = t.getHTTPFactory()
383     
384     # Standard twisted application Boilerplate
385     from twisted.application import service, strports
386     application = service.Application("demoserver")
387     s = strports.service('tcp:18080', factory)
388     s.setServiceParent(application)