Added 2 (commented) failed attempts to throttle the web server.
[quix0rs-apt-p2p.git] / apt_dht / HTTPServer.py
index 181da4e4575f20a1ab66a6fba71f70dd62b71db2..20f94cdfa5e3088193591455b269a1fbd6963b15 100644 (file)
@@ -1,7 +1,10 @@
-import os.path, time
+
+from urllib import unquote_plus
 
 from twisted.python import log
 from twisted.internet import defer
+#from twisted.protocols import htb
+#from twisted.protocols.policies import ThrottlingFactory
 from twisted.web2 import server, http, resource, channel
 from twisted.web2 import static, http_headers, responsecode
 
@@ -26,10 +29,10 @@ class FileDownloader(static.File):
         if self.manager:
             path = 'http:/' + req.uri
             if resp.code >= 200 and resp.code < 400:
-                return self.manager.check_freshness(path, resp.headers.getHeader('Last-Modified'), resp)
+                return self.manager.check_freshness(req, path, resp.headers.getHeader('Last-Modified'), resp)
             
             log.msg('Not found, trying other methods for %s' % req.uri)
-            return self.manager.get_resp(path)
+            return self.manager.get_resp(req, path)
         
         return resp
 
@@ -41,20 +44,30 @@ class FileDownloader(static.File):
 class TopLevel(resource.Resource):
     addSlash = True
     
-    def __init__(self, directory, manager):
+    def __init__(self, directory, db, manager):
         self.directory = directory
+        self.db = db
         self.manager = manager
-        self.subdirs = []
+        self.factory = None
+
+    def getHTTPFactory(self):
+        if self.factory is None:
+            self.factory = channel.HTTPFactory(server.Site(self),
+                                               **{'maxPipeline': 10, 
+                                                  'betweenRequestsTimeOut': 60})
+#            serverFilter = htb.HierarchicalBucketFilter()
+#            serverBucket = htb.Bucket()
+#
+#            # Cap total server traffic at 20 kB/s
+#            serverBucket.maxburst = 20000
+#            serverBucket.rate = 20000
+#
+#            serverFilter.buckets[None] = serverBucket
+#
+#            self.factory.protocol = htb.ShapedProtocolFactory(self.factory.protocol, serverFilter)
+#            self.factory = ThrottlingFactory(self.factory, writeLimit = 300*1024)
+        return self.factory
 
-    def addDirectory(self, directory):
-        path = "~" + str(len(self.subdirs))
-        self.subdirs.append(directory)
-        return path
-    
-    def removeDirectory(self, directory):
-        loc = self.subdirs.index(directory)
-        self.subdirs[loc] = ''
-        
     def render(self, ctx):
         return http.Response(
             200,
@@ -64,39 +77,40 @@ class TopLevel(resource.Resource):
             <p>TODO: eventually some stats will be shown here.</body></html>""")
 
     def locateChild(self, request, segments):
+        log.msg('Got HTTP request for %s from %s' % (request.uri, request.remoteAddr))
         name = segments[0]
-        if len(name) > 1 and name[0] == '~':
-            try:
-                loc = int(name[1:])
-            except:
-                log.msg('Not found: %s from %s' % (request.uri, request.remoteAddr))
+        if name == '~':
+            if len(segments) != 2:
+                log.msg('Got a malformed request from %s' % request.remoteAddr)
                 return None, ()
-            
-            if loc >= 0 and loc < len(self.subdirs) and self.subdirs[loc]:
-                log.msg('Sharing %s with %s' % (request.uri, request.remoteAddr))
-                return static.File(self.subdirs[loc]), segments[1:]
+            hash = unquote_plus(segments[1])
+            files = self.db.lookupHash(hash)
+            if files:
+                log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr))
+                return static.File(files[0]['path'].path), ()
             else:
-                log.msg('Not found: %s from %s' % (request.uri, request.remoteAddr))
-                return None, ()
+                log.msg('Hash could not be found in database: %s' % hash)
         
         if request.remoteAddr.host != "127.0.0.1":
             log.msg('Blocked illegal access to %s from %s' % (request.uri, request.remoteAddr))
             return None, ()
             
         if len(name) > 1:
-            return FileDownloader(self.directory, self.manager), segments[0:]
+            return FileDownloader(self.directory.path, self.manager), segments[0:]
         else:
             return self, ()
         
+        log.msg('Got a malformed request for "%s" from %s' % (request.uri, request.remoteAddr))
+        return None, ()
+
 if __name__ == '__builtin__':
     # Running from twistd -y
     t = TopLevel('/home', None)
-    t.addDirectory('/tmp')
-    t.addDirectory('/var/log')
-    site = server.Site(t)
+    t.setDirectories({'~1': '/tmp', '~2': '/var/log'})
+    factory = t.getHTTPFactory()
     
     # Standard twisted application Boilerplate
     from twisted.application import service, strports
     application = service.Application("demoserver")
-    s = strports.service('tcp:18080', channel.HTTPFactory(site))
+    s = strports.service('tcp:18080', factory)
     s.setServiceParent(application)