Multiple peer downloading is mostly working now.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11
12 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
13 from apt_p2p_Khashmir.bencode import bencode
14
15 class FileDownloader(static.File):
16     """Modified to make it suitable for apt requests.
17     
18     Tries to find requests in the cache. Found files are first checked for
19     freshness before being sent. Requests for unfound and stale files are
20     forwarded to the main program for downloading.
21     
22     @type manager: L{apt_p2p.AptP2P}
23     @ivar manager: the main program to query 
24     """
25     
26     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
27         self.manager = manager
28         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
29         
30     def renderHTTP(self, req):
31         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
32         resp = super(FileDownloader, self).renderHTTP(req)
33         if isinstance(resp, defer.Deferred):
34             resp.addCallback(self._renderHTTP_done, req)
35         else:
36             resp = self._renderHTTP_done(resp, req)
37         return resp
38         
39     def _renderHTTP_done(self, resp, req):
40         log.msg('Initial response to %s: %r' % (req.uri, resp))
41         
42         if self.manager:
43             path = 'http:/' + req.uri
44             if resp.code >= 200 and resp.code < 400:
45                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
46             
47             log.msg('Not found, trying other methods for %s' % req.uri)
48             return self.manager.get_resp(req, path)
49         
50         return resp
51
52     def createSimilarFile(self, path):
53         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
54                               self.processors, self.indexNames[:])
55         
56 class FileUploaderStream(stream.FileStream):
57     """Modified to make it suitable for streaming to peers.
58     
59     Streams the file in small chunks to make it easier to throttle the
60     streaming to peers.
61     
62     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
63     """
64
65     CHUNK_SIZE = 4*1024
66     
67     def read(self, sendfile=False):
68         if self.f is None:
69             return None
70
71         length = self.length
72         if length == 0:
73             self.f = None
74             return None
75         
76         # Remove the SendFileBuffer and mmap use, just use string reads and writes
77
78         readSize = min(length, self.CHUNK_SIZE)
79
80         self.f.seek(self.start)
81         b = self.f.read(readSize)
82         bytesRead = len(b)
83         if not bytesRead:
84             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
85         else:
86             self.length -= bytesRead
87             self.start += bytesRead
88             return b
89
90
91 class FileUploader(static.File):
92     """Modified to make it suitable for peer requests.
93     
94     Uses the modified L{FileUploaderStream} to stream the file for throttling,
95     and doesn't do any listing of directory contents.
96     """
97
98     def render(self, req):
99         if not self.fp.exists():
100             return responsecode.NOT_FOUND
101
102         if self.fp.isdir():
103             # Don't try to render a directory listing
104             return responsecode.NOT_FOUND
105
106         try:
107             f = self.fp.open()
108         except IOError, e:
109             import errno
110             if e[0] == errno.EACCES:
111                 return responsecode.FORBIDDEN
112             elif e[0] == errno.ENOENT:
113                 return responsecode.NOT_FOUND
114             else:
115                 raise
116
117         response = http.Response()
118         # Use the modified FileStream
119         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
120
121         for (header, value) in (
122             ("content-type", self.contentType()),
123             ("content-encoding", self.contentEncoding()),
124         ):
125             if value is not None:
126                 response.headers.setHeader(header, value)
127
128         return response
129
130 class UploadThrottlingProtocol(ThrottlingProtocol):
131     """Protocol for throttling uploads.
132     
133     Determines whether or not to throttle the upload based on the type of stream.
134     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
135     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
136     """
137
138     def __init__(self, factory, wrappedProtocol):
139         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
140         self.throttle = False
141
142     def write(self, data):
143         if self.throttle:
144             ThrottlingProtocol.write(self, data)
145         else:
146             ProtocolWrapper.write(self, data)
147
148     def registerProducer(self, producer, streaming):
149         ThrottlingProtocol.registerProducer(self, producer, streaming)
150         streamType = getattr(producer, 'stream', None)
151         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
152             self.throttle = True
153
154
155 class TopLevel(resource.Resource):
156     """The HTTP server for all requests, both from peers and apt.
157     
158     @type directory: L{twisted.python.filepath.FilePath}
159     @ivar directory: the directory to check for cached files
160     @type db: L{db.DB}
161     @ivar db: the database to use for looking up files and hashes
162     @type manager: L{apt_p2p.AptP2P}
163     @ivar manager: the main program object to send requests to
164     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
165     @ivar factory: the factory to use to serve HTTP requests
166     """
167     
168     addSlash = True
169     
170     def __init__(self, directory, db, manager, uploadLimit):
171         """Initialize the instance.
172         
173         @type directory: L{twisted.python.filepath.FilePath}
174         @param directory: the directory to check for cached files
175         @type db: L{db.DB}
176         @param db: the database to use for looking up files and hashes
177         @type manager: L{apt_p2p.AptP2P}
178         @param manager: the main program object to send requests to
179         """
180         self.directory = directory
181         self.db = db
182         self.manager = manager
183         self.uploadLimit = None
184         if uploadLimit > 0:
185             self.uploadLimit = int(uploadLimit*1024)
186         self.factory = None
187
188     def getHTTPFactory(self):
189         """Initialize and get the factory for this HTTP server."""
190         if self.factory is None:
191             self.factory = channel.HTTPFactory(server.Site(self),
192                                                **{'maxPipeline': 10, 
193                                                   'betweenRequestsTimeOut': 60})
194             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
195             self.factory.protocol = UploadThrottlingProtocol
196         return self.factory
197
198     def render(self, ctx):
199         """Render a web page with descriptive statistics."""
200         return http.Response(
201             200,
202             {'content-type': http_headers.MimeType('text', 'html')},
203             self.manager.getStats())
204
205     def locateChild(self, request, segments):
206         """Process the incoming request."""
207         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
208         name = segments[0]
209         
210         # If the request is for a shared file (from a peer)
211         if name == '~':
212             if len(segments) != 2:
213                 log.msg('Got a malformed request from %s' % request.remoteAddr)
214                 return None, ()
215             
216             # Find the file in the database
217             # Have to unquote_plus the uri, because the segments are unquoted by twisted
218             hash = unquote_plus(request.uri[3:])
219             files = self.db.lookupHash(hash)
220             if files:
221                 # If it is a file, return it
222                 if 'path' in files[0]:
223                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
224                     return FileUploader(files[0]['path'].path), ()
225                 else:
226                     # It's not for a file, but for a piece string, so return that
227                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
228                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
229             else:
230                 log.msg('Hash could not be found in database: %r' % hash)
231
232         # Only local requests (apt) get past this point
233         if request.remoteAddr.host != "127.0.0.1":
234             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
235             return None, ()
236         
237         # Block access to index .diff files (for now)
238         if 'Packages.diff' in segments or 'Sources.diff' in segments:
239             return None, ()
240          
241         if len(name) > 1:
242             # It's a request from apt
243             return FileDownloader(self.directory.path, self.manager), segments[0:]
244         else:
245             # Will render the statistics page
246             return self, ()
247         
248         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
249         return None, ()
250
251 if __name__ == '__builtin__':
252     # Running from twistd -ny HTTPServer.py
253     # Then test with:
254     #   wget -S 'http://localhost:18080/~/whatever'
255     #   wget -S 'http://localhost:18080/~/pieces'
256
257     import os.path
258     from twisted.python.filepath import FilePath
259     
260     class DB:
261         def lookupHash(self, hash):
262             if hash == 'pieces':
263                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
264             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
265     
266     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None, 0)
267     factory = t.getHTTPFactory()
268     
269     # Standard twisted application Boilerplate
270     from twisted.application import service, strports
271     application = service.Application("demoserver")
272     s = strports.service('tcp:18080', factory)
273     s.setServiceParent(application)