Fixed the ThrottlingFactory to work with web2 static streams from the web server.
[quix0rs-apt-p2p.git] / apt_dht / HTTPServer.py
1
2 from urllib import unquote_plus
3
4 from twisted.python import log
5 from twisted.internet import defer
6 #from twisted.protocols import htb
7 from twisted.web2 import server, http, resource, channel, stream
8 from twisted.web2 import static, http_headers, responsecode
9
10 from policies import ThrottlingFactory
11
12 class FileDownloader(static.File):
13     
14     def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None):
15         self.manager = manager
16         super(FileDownloader, self).__init__(path, defaultType, ignoredExts, processors, indexNames)
17         
18     def renderHTTP(self, req):
19         log.msg('Got request for %s from %s' % (req.uri, req.remoteAddr))
20         resp = super(FileDownloader, self).renderHTTP(req)
21         if isinstance(resp, defer.Deferred):
22             resp.addCallback(self._renderHTTP_done, req)
23         else:
24             resp = self._renderHTTP_done(resp, req)
25         return resp
26         
27     def _renderHTTP_done(self, resp, req):
28         log.msg('Initial response to %s: %r' % (req.uri, resp))
29         
30         if self.manager:
31             path = 'http:/' + req.uri
32             if resp.code >= 200 and resp.code < 400:
33                 return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
34             
35             log.msg('Not found, trying other methods for %s' % req.uri)
36             return self.manager.get_resp(req, path)
37         
38         return resp
39
40     def createSimilarFile(self, path):
41         return self.__class__(path, self.manager, self.defaultType, self.ignoredExts,
42                               self.processors, self.indexNames[:])
43         
44 class FileUploaderStream(stream.FileStream):
45
46     CHUNK_SIZE = 4*1024
47     
48     def read(self, sendfile=False):
49         if self.f is None:
50             return None
51
52         length = self.length
53         if length == 0:
54             self.f = None
55             return None
56
57         readSize = min(length, self.CHUNK_SIZE)
58
59         self.f.seek(self.start)
60         b = self.f.read(readSize)
61         bytesRead = len(b)
62         if not bytesRead:
63             raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length))
64         else:
65             self.length -= bytesRead
66             self.start += bytesRead
67             return b
68
69
70 class FileUploader(static.File):
71
72     def render(self, req):
73         if not self.fp.exists():
74             return responsecode.NOT_FOUND
75
76         if self.fp.isdir():
77             return responsecode.NOT_FOUND
78
79         try:
80             f = self.fp.open()
81         except IOError, e:
82             import errno
83             if e[0] == errno.EACCES:
84                 return responsecode.FORBIDDEN
85             elif e[0] == errno.ENOENT:
86                 return responsecode.NOT_FOUND
87             else:
88                 raise
89
90         response = http.Response()
91         response.stream = FileUploaderStream(f, 0, self.fp.getsize())
92
93         for (header, value) in (
94             ("content-type", self.contentType()),
95             ("content-encoding", self.contentEncoding()),
96         ):
97             if value is not None:
98                 response.headers.setHeader(header, value)
99
100         return response
101
102 class TopLevel(resource.Resource):
103     addSlash = True
104     
105     def __init__(self, directory, db, manager):
106         self.directory = directory
107         self.db = db
108         self.manager = manager
109         self.factory = None
110
111     def getHTTPFactory(self):
112         if self.factory is None:
113             self.factory = channel.HTTPFactory(server.Site(self),
114                                                **{'maxPipeline': 10, 
115                                                   'betweenRequestsTimeOut': 60})
116 #            serverFilter = htb.HierarchicalBucketFilter()
117 #            serverBucket = htb.Bucket()
118 #
119 #            # Cap total server traffic at 20 kB/s
120 #            serverBucket.maxburst = 20000
121 #            serverBucket.rate = 20000
122 #
123 #            serverFilter.buckets[None] = serverBucket
124 #
125 #            self.factory.protocol = htb.ShapedProtocolFactory(self.factory.protocol, serverFilter)
126             self.factory = ThrottlingFactory(self.factory, writeLimit = 30*1024)
127         return self.factory
128
129     def render(self, ctx):
130         return http.Response(
131             200,
132             {'content-type': http_headers.MimeType('text', 'html')},
133             """<html><body>
134             <h2>Statistics</h2>
135             <p>TODO: eventually some stats will be shown here.</body></html>""")
136
137     def locateChild(self, request, segments):
138         log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
139         name = segments[0]
140         if name == '~':
141             if len(segments) != 2:
142                 log.msg('Got a malformed request from %s' % request.remoteAddr)
143                 return None, ()
144             hash = unquote_plus(segments[1])
145             files = self.db.lookupHash(hash)
146             if files:
147                 log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
148                 return FileUploader(files[0]['path'].path), ()
149             else:
150                 log.msg('Hash could not be found in database: %s' % hash)
151         
152         if request.remoteAddr.host != "127.0.0.1":
153             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
154             return None, ()
155             
156         if len(name) > 1:
157             return FileDownloader(self.directory.path, self.manager), segments[0:]
158         else:
159             return self, ()
160         
161         log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
162         return None, ()
163
164 if __name__ == '__builtin__':
165     # Running from twistd -ny HTTPServer.py
166     # Then test with:
167     #   wget -S 'http://localhost:18080/~/whatever'
168     #   wget -S 'http://localhost:18080/.xsession-errors'
169
170     import os.path
171     from twisted.python.filepath import FilePath
172     
173     class DB:
174         def lookupHash(self, hash):
175             return [{'path': FilePath(os.path.expanduser('~/school/optout'))}]
176     
177     t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None)
178     factory = t.getHTTPFactory()
179     
180     # Standard twisted application Boilerplate
181     from twisted.application import service, strports
182     application = service.Application("demoserver")
183     s = strports.service('tcp:18080', factory)
184     s.setServiceParent(application)