Only throttle uploads to peers, not to apt.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPServer.py
1
2 """Serve local requests from apt and remote requests from peers."""
3
4 from urllib import unquote_plus
5 from binascii import b2a_hex
6
7 from twisted.python import log
8 from twisted.internet import defer
9 from twisted.web2 import server, http, resource, channel, stream
10 from twisted.web2 import static, http_headers, responsecode
11
12 from policies import ThrottlingFactory, ThrottlingProtocol, ProtocolWrapper
13 from apt_p2p_Khashmir.bencode import bencode
14
15 class FileDownloader(static.File):
16     """Modified to make it suitable for apt requests.
17     
18     Tries to find requests in the cache. Found files are first checked for
19     freshness before being sent. Requests for unfound and stale files are
20     forwarded to the main program for downloading.
21     
22     @type manager: L{apt_p2p.AptP2P}
23     @ivar manager: the main program to query 
24     """
25     
26     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
27         self.manager = manager
28         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
29         
30     def renderHTTP(self, req):
31         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
32         resp = super(FileDownloader, self).renderHTTP(req)
33         if isinstance(resp, defer.Deferred):
34             resp.addCallback(self._renderHTTP_done, req)
35         else:
36             resp = self._renderHTTP_done(resp, req)
37         return resp
38         
39     def _renderHTTP_done(self, resp, req):
40         log.msg('Initial response to %s: %r' % (req.uri, resp))
41         
42         if self.manager:
43             path = 'http:/' + req.uri
44             if resp.code >= 200 and resp.code < 400:
45                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
46             
47             log.msg('Not found, trying other methods for %s' % req.uri)
48             return self.manager.get_resp(req, path)
49         
50         return resp
51
52     def createSimilarFile(self, path):
53         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
54                               self.processors, self.indexNames[:])
55         
56 class FileUploaderStream(stream.FileStream):
57     """Modified to make it suitable for streaming to peers.
58     
59     Streams the file is small chunks to make it easier to throttle the
60     streaming to peers.
61     
62     @ivar CHUNK_SIZE: the size of chunks of data to send at a time
63     """
64
65     CHUNK_SIZE = 4*1024
66     
67     def read(self, sendfile=False):
68         if self.f is None:
69             return None
70
71         length = self.length
72         if length == 0:
73             self.f = None
74             return None
75         
76         # Remove the SendFileBuffer and mmap use, just use string reads and writes
77
78         readSize = min(length, self.CHUNK_SIZE)
79
80         self.f.seek(self.start)
81         b = self.f.read(readSize)
82         bytesRead = len(b)
83         if not bytesRead:
84             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
85         else:
86             self.length -= bytesRead
87             self.start += bytesRead
88             return b
89
90
91 class FileUploader(static.File):
92     """Modified to make it suitable for peer requests.
93     
94     Uses the modified L{FileUploaderStream} to stream the file for throttling,
95     and doesn't do any listing of directory contents.
96     """
97
98     def render(self, req):
99         if not self.fp.exists():
100             return responsecode.NOT_FOUND
101
102         if self.fp.isdir():
103             # Don't try to render a directory listing
104             return responsecode.NOT_FOUND
105
106         try:
107             f = self.fp.open()
108         except IOError, e:
109             import errno
110             if e[0] == errno.EACCES:
111                 return responsecode.FORBIDDEN
112             elif e[0] == errno.ENOENT:
113                 return responsecode.NOT_FOUND
114             else:
115                 raise
116
117         response = http.Response()
118         # Use the modified FileStream
119         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
120
121         for (header, value) in (
122             ("content-type", self.contentType()),
123             ("content-encoding", self.contentEncoding()),
124         ):
125             if value is not None:
126                 response.headers.setHeader(header, value)
127
128         return response
129
130 class UploadThrottlingProtocol(ThrottlingProtocol):
131     """Protocol for throttling uploads.
132     
133     Determines whether or not to throttle the upload based on the type of stream.
134     Uploads use L{FileUploaderStream} or L{twisted.web2.stream.MemorySTream},
135     apt uses L{CacheManager.ProxyFileStream} or L{twisted.web.stream.FileStream}.
136     """
137
138     def __init__(self, factory, wrappedProtocol):
139         ThrottlingProtocol.__init__(self, factory, wrappedProtocol)
140         self.throttle = False
141
142     def write(self, data):
143         if self.throttle:
144             ThrottlingProtocol.write(self, data)
145         else:
146             ProtocolWrapper.write(self, data)
147
148     def registerProducer(self, producer, streaming):
149         ThrottlingProtocol.registerProducer(self, producer, streaming)
150         streamType = getattr(producer, 'stream', None)
151         if isinstance(streamType, FileUploaderStream) or isinstance(streamType, stream.MemoryStream):
152             self.throttle = True
153
154
155 class TopLevel(resource.Resource):
156     """The HTTP server for all requests, both from peers and apt.
157     
158     @type directory: L{twisted.python.filepath.FilePath}
159     @ivar directory: the directory to check for cached files
160     @type db: L{db.DB}
161     @ivar db: the database to use for looking up files and hashes
162     @type manager: L{apt_p2p.AptP2P}
163     @ivar manager: the main program object to send requests to
164     @type factory: L{twisted.web2.channel.HTTPFactory} or L{policies.ThrottlingFactory}
165     @ivar factory: the factory to use to serve HTTP requests
166     """
167     
168     addSlash = True
169     
170     def __init__(self, directory, db, manager):
171         """Initialize the instance.
172         
173         @type directory: L{twisted.python.filepath.FilePath}
174         @param directory: the directory to check for cached files
175         @type db: L{db.DB}
176         @param db: the database to use for looking up files and hashes
177         @type manager: L{apt_p2p.AptP2P}
178         @param manager: the main program object to send requests to
179         """
180         self.directory = directory
181         self.db = db
182         self.manager = manager
183         self.factory = None
184
185     def getHTTPFactory(self):
186         """Initialize and get the factory for this HTTP server."""
187         if self.factory is None:
188             self.factory = channel.HTTPFactory(server.Site(self),
189                                                **{'maxPipeline': 10, 
190                                                   'betweenRequestsTimeOut': 60})
191             self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
192             self.factory.protocol = UploadThrottlingProtocol
193         return self.factory
194
195     def render(self, ctx):
196         """Render a web page with descriptive statistics."""
197         return http.Response(
198             200,
199             {'content-type': http_headers.MimeType('text', 'html')},
200             self.manager.getStats())
201
202     def locateChild(self, request, segments):
203         """Process the incoming request."""
204         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
205         name = segments[0]
206         
207         # If the request is for a shared file (from a peer)
208         if name == '~':
209             if len(segments) != 2:
210                 log.msg('Got a malformed request from %s' % request.remoteAddr)
211                 return None, ()
212             
213             # Find the file in the database
214             hash = unquote_plus(segments[1])
215             files = self.db.lookupHash(hash)
216             if files:
217                 # If it is a file, return it
218                 if 'path' in files[0]:
219                     log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
220                     return FileUploader(files[0]['path'].path), ()
221                 else:
222                     # It's not for a file, but for a piece string, so return that
223                     log.msg('Sending torrent string %s to %s' % (b2a_hex(hash), request.remoteAddr))
224                     return static.Data(bencode({'t': files[0]['pieces']}), 'application/x-bencoded'), ()
225             else:
226                 log.msg('Hash could not be found in database: %s' % hash)
227
228         # Only local requests (apt) get past this point
229         if request.remoteAddr.host != "127.0.0.1":
230             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
231             return None, ()
232             
233         if len(name) > 1:
234             # It's a request from apt
235             return FileDownloader(self.directory.path, self.manager), segments[0:]
236         else:
237             # Will render the statistics page
238             return self, ()
239         
240         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
241         return None, ()
242
243 if __name__ == '__builtin__':
244     # Running from twistd -ny HTTPServer.py
245     # Then test with:
246     #   wget -S 'http://localhost:18080/~/whatever'
247     #   wget -S 'http://localhost:18080/~/pieces'
248
249     import os.path
250     from twisted.python.filepath import FilePath
251     
252     class DB:
253         def lookupHash(self, hash):
254             if hash == 'pieces':
255                 return [{'pieces': 'abcdefghij0123456789\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'}]
256             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
257     
258     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
259     factory = t.getHTTPFactory()
260     
261     # Standard twisted application Boilerplate
262     from twisted.application import service, strports
263     application = service.Application("demoserver")
264     s = strports.service('tcp:18080', factory)
265     s.setServiceParent(application)