Make the upload limit a config option.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11
12 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
13 from apt_p2p_Khashmir.bencode import bencode
14
15 class FileDownloader(static.File):
16     """Modified to make it suitable for apt requests.
17     
18     Tries to find requests in the cache. Found files are first checked for
19     freshness before being sent. Requests for unfound and stale files are
20     forwarded to the main program for downloading.
21     
22     @type manager: L{apt_p2p.AptP2P}
23     @ivar manager: the main program to query 
24     """
25     
26     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
27         self.manager = manager
28         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
29         
30     def renderHTTP(self, req):
31         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
32         resp = super(FileDownloader, self).renderHTTP(req)
33         if isinstance(resp, defer.Deferred):
34             resp.addCallback(self._renderHTTP_done, req)
35         else:
36             resp = self._renderHTTP_done(resp, req)
37         return resp
38         
39     def _renderHTTP_done(self, resp, req):
40         log.msg('Initial response to %s: %r' % (req.uri, resp))
41         
42         if self.manager:
43             path = 'http:/' + req.uri
44             if resp.code >= 200 and resp.code < 400:
45                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
46             
47             log.msg('Not found, trying other methods for %s' % req.uri)
48             return self.manager.get_resp(req, path)
49         
50         return resp
51
52     def createSimilarFile(self, path):
53         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
54                               self.processors, self.indexNames[:])
55         
56 class FileUploaderStream(stream.FileStream):
57     """Modified to make it suitable for streaming to peers.
58     
59     Streams the file is small chunks to make it easier to throttle the
60     streaming to peers.
61     
62     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
63     """
64
65     CHUNK_SIZE = 4*1024
66     
67     def read(self, sendfile=False):
68         if self.f is None:
69             return None
70
71         length = self.length
72         if length == 0:
73             self.f = None
74             return None
75         
76         # Remove the SendFileBuffer and mmap use, just use string reads and writes
77
78         readSize = min(length, self.CHUNK_SIZE)
79
80         self.f.seek(self.start)
81         b = self.f.read(readSize)
82         bytesRead = len(b)
83         if not bytesRead:
84             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
85         else:
86             self.length -= bytesRead
87             self.start += bytesRead
88             return b
89
90
91 class FileUploader(static.File):
92     """Modified to make it suitable for peer requests.
93     
94     Uses the modified L{FileUploaderStream} to stream the file for throttling,
95     and doesn't do any listing of directory contents.
96     """
97
98     def render(self, req):
99         if not self.fp.exists():
100             return responsecode.NOT_FOUND
101
102         if self.fp.isdir():
103             # Don't try to render a directory listing
104             return responsecode.NOT_FOUND
105
106         try:
107             f = self.fp.open()
108         except IOError, e:
109             import errno
110             if e[0] == errno.EACCES:
111                 return responsecode.FORBIDDEN
112             elif e[0] == errno.ENOENT:
113                 return responsecode.NOT_FOUND
114             else:
115                 raise
116
117         response = http.Response()
118         # Use the modified FileStream
119         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
120
121         for (header, value) in (
122             ("content-type", self.contentType()),
123             ("content-encoding", self.contentEncoding()),
124         ):
125             if value is not None:
126                 response.headers.setHeader(header, value)
127
128         return response
129
130 class UploadThrottlingProtocol(ThrottlingProtocol):
131     """Protocol for throttling uploads.
132     
133     Determines whether or not to throttle the upload based on the type of stream.
134     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
135     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
136     """
137
138     def __init__(self, factory, wrappedProtocol):
139         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
140         self.throttle = False
141
142     def write(self, data):
143         if self.throttle:
144             ThrottlingProtocol.write(self, data)
145         else:
146             ProtocolWrapper.write(self, data)
147
148     def registerProducer(self, producer, streaming):
149         ThrottlingProtocol.registerProducer(self, producer, streaming)
150         streamType = getattr(producer, 'stream', None)
151         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
152             self.throttle = True
153
154
155 class TopLevel(resource.Resource):
156     """The HTTP server for all requests, both from peers and apt.
157     
158     @type directory: L{twisted.python.filepath.FilePath}
159     @ivar directory: the directory to check for cached files
160     @type db: L{db.DB}
161     @ivar db: the database to use for looking up files and hashes
162     @type manager: L{apt_p2p.AptP2P}
163     @ivar manager: the main program object to send requests to
164     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
165     @ivar factory: the factory to use to serve HTTP requests
166     """
167     
168     addSlash = True
169     
170     def __init__(self, directory, db, manager, uploadLimit):
171         """Initialize the instance.
172         
173         @type directory: L{twisted.python.filepath.FilePath}
174         @param directory: the directory to check for cached files
175         @type db: L{db.DB}
176         @param db: the database to use for looking up files and hashes
177         @type manager: L{apt_p2p.AptP2P}
178         @param manager: the main program object to send requests to
179         """
180         self.directory = directory
181         self.db = db
182         self.manager = manager
183         self.uploadLimit = None
184         if uploadLimit > 0:
185             self.uploadLimit = int(uploadLimit*1024)
186         self.factory = None
187
188     def getHTTPFactory(self):
189         """Initialize and get the factory for this HTTP server."""
190         if self.factory is None:
191             self.factory = channel.HTTPFactory(server.Site(self),
192                                                **{'maxPipeline': 10, 
193                                                   'betweenRequestsTimeOut': 60})
194             self.factory = ThrottlingFactory(self.factory, writeLimit = self.uploadLimit)
195             self.factory.protocol = UploadThrottlingProtocol
196         return self.factory
197
198     def render(self, ctx):
199         """Render a web page with descriptive statistics."""
200         return http.Response(
201             200,
202             {'content-type': http_headers.MimeType('text', 'html')},
203             self.manager.getStats())
204
205     def locateChild(self, request, segments):
206         """Process the incoming request."""
207         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
208         name = segments[0]
209         
210         # If the request is for a shared file (from a peer)
211         if name == '~':
212             if len(segments) != 2:
213                 log.msg('Got a malformed request from %s' % request.remoteAddr)
214                 return None, ()
215             
216             # Find the file in the database
217             hash = unquote_plus(segments[1])
218             files = self.db.lookupHash(hash)
219             if files:
220                 # If it is a file, return it
221                 if 'path' in files[0]:
222                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
223                     return FileUploader(files[0]['path'].path), ()
224                 else:
225                     # It's not for a file, but for a piece string, so return that
226                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
227                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
228             else:
229                 log.msg('Hash could not be found in database: %s' % hash)
230
231         # Only local requests (apt) get past this point
232         if request.remoteAddr.host != "127.0.0.1":
233             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
234             return None, ()
235             
236         if len(name) > 1:
237             # It's a request from apt
238             return FileDownloader(self.directory.path, self.manager), segments[0:]
239         else:
240             # Will render the statistics page
241             return self, ()
242         
243         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
244         return None, ()
245
246 if __name__ == '__builtin__':
247     # Running from twistd -ny HTTPServer.py
248     # Then test with:
249     #   wget -S 'http://localhost:18080/~/whatever'
250     #   wget -S 'http://localhost:18080/~/pieces'
251
252     import os.path
253     from twisted.python.filepath import FilePath
254     
255     class DB:
256         def lookupHash(self, hash):
257             if hash == 'pieces':
258                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
259             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
260     
261     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
262     factory = t.getHTTPFactory()
263     
264     # Standard twisted application Boilerplate
265     from twisted.application import service, strports
266     application = service.Application("demoserver")
267     s = strports.service('tcp:18080', factory)
268     s.setServiceParent(application)