From 5f71361ea619c97e77b9c5cd52894d2088d19b08 Mon Sep 17 00:00:00 2001 From: Cameron Dale Date: Mon, 25 Feb 2008 17:10:58 -0800 Subject: [PATCH] Another attempt at throttling, still not working. --- apt_dht/HTTPServer.py | 84 ++++++++++++-- apt_dht/policies.py | 260 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 336 insertions(+), 8 deletions(-) create mode 100644 apt_dht/policies.py diff --git a/apt_dht/HTTPServer.py b/apt_dht/HTTPServer.py index 20f94cd..4d62e5b 100644 --- a/apt_dht/HTTPServer.py +++ b/apt_dht/HTTPServer.py @@ -4,10 +4,11 @@ from urllib import unquote_plus from twisted.python import log from twisted.internet import defer #from twisted.protocols import htb -#from twisted.protocols.policies import ThrottlingFactory -from twisted.web2 import server, http, resource, channel +from twisted.web2 import server, http, resource, channel, stream from twisted.web2 import static, http_headers, responsecode +from policies import ThrottlingFactory + class FileDownloader(static.File): def __init__(self, path, manager, defaultType="text/plain", ignoredExts=(), processors=None, indexNames=None): @@ -40,7 +41,64 @@ class FileDownloader(static.File): return self.__class__(path, self.manager, self.defaultType, self.ignoredExts, self.processors, self.indexNames[:]) - +class FileUploaderStream(stream.FileStream): + + CHUNK_SIZE = 16*1024 + + def read(self, sendfile=False): + if self.f is None: + return None + + length = self.length + if length == 0: + self.f = None + return None + + readSize = min(length, self.CHUNK_SIZE) + + self.f.seek(self.start) + b = self.f.read(readSize) + bytesRead = len(b) + if not bytesRead: + raise RuntimeError("Ran out of data reading file %r, expected %d more bytes" % (self.f, length)) + else: + self.length -= bytesRead + self.start += bytesRead + return b + + +class FileUploader(static.File): + + def render(self, req): + if not self.fp.exists(): + return responsecode.NOT_FOUND + + if self.fp.isdir(): + return responsecode.NOT_FOUND + + try: + f = self.fp.open() + except IOError, e: + import errno + if e[0] == errno.EACCES: + return responsecode.FORBIDDEN + elif e[0] == errno.ENOENT: + return responsecode.NOT_FOUND + else: + raise + + response = http.Response() + response.stream = FileUploaderStream(f, 0, self.fp.getsize()) + + for (header, value) in ( + ("content-type", self.contentType()), + ("content-encoding", self.contentEncoding()), + ): + if value is not None: + response.headers.setHeader(header, value) + + return response + class TopLevel(resource.Resource): addSlash = True @@ -65,7 +123,7 @@ class TopLevel(resource.Resource): # serverFilter.buckets[None] = serverBucket # # self.factory.protocol = htb.ShapedProtocolFactory(self.factory.protocol, serverFilter) -# self.factory = ThrottlingFactory(self.factory, writeLimit = 300*1024) + self.factory = ThrottlingFactory(self.factory, writeLimit = 3*1024) return self.factory def render(self, ctx): @@ -87,7 +145,7 @@ class TopLevel(resource.Resource): files = self.db.lookupHash(hash) if files: log.msg('Sharing %s with %s' % (files[0]['path'].path, request.remoteAddr)) - return static.File(files[0]['path'].path), () + return FileUploader(files[0]['path'].path), () else: log.msg('Hash could not be found in database: %s' % hash) @@ -104,9 +162,19 @@ class TopLevel(resource.Resource): return None, () if __name__ == '__builtin__': - # Running from twistd -y - t = TopLevel('/home', None) - t.setDirectories({'~1': '/tmp', '~2': '/var/log'}) + # Running from twistd -ny HTTPServer.py + # Then test with: + # wget -S 'http://localhost:18080/~/whatever' + # wget -S 'http://localhost:18080/.xsession-errors' + + import os.path + from twisted.python.filepath import FilePath + + class DB: + def lookupHash(self, hash): + return [{'path': FilePath(os.path.expanduser('~/.xsession-errors'))}] + + t = TopLevel(FilePath(os.path.expanduser('~')), DB(), None) factory = t.getHTTPFactory() # Standard twisted application Boilerplate diff --git a/apt_dht/policies.py b/apt_dht/policies.py new file mode 100644 index 0000000..76a81d1 --- /dev/null +++ b/apt_dht/policies.py @@ -0,0 +1,260 @@ +# -*- test-case-name: twisted.test.test_policies -*- +# Copyright (c) 2001-2004 Twisted Matrix Laboratories. +# See LICENSE for details. + +# + +"""Resource limiting policies. + +@seealso: See also L{twisted.protocols.htb} for rate limiting. +""" + +# system imports +import sys, operator + +# twisted imports +from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory +from twisted.internet.interfaces import ITransport +from twisted.internet import reactor, error +from twisted.python import log +from zope.interface import implements, providedBy, directlyProvides + +class ProtocolWrapper(Protocol): + """Wraps protocol instances and acts as their transport as well.""" + + disconnecting = 0 + + def __init__(self, factory, wrappedProtocol): + self.wrappedProtocol = wrappedProtocol + self.factory = factory + + def makeConnection(self, transport): + directlyProvides(self, *providedBy(self) + providedBy(transport)) + Protocol.makeConnection(self, transport) + + # Transport relaying + + def write(self, data): + self.transport.write(data) + + def writeSequence(self, data): + self.transport.writeSequence(data) + + def loseConnection(self): + self.disconnecting = 1 + self.transport.loseConnection() + + def getPeer(self): + return self.transport.getPeer() + + def getHost(self): + return self.transport.getHost() + + def registerProducer(self, producer, streaming): + self.transport.registerProducer(producer, streaming) + + def unregisterProducer(self): + self.transport.unregisterProducer() + + def stopConsuming(self): + self.transport.stopConsuming() + + def __getattr__(self, name): + return getattr(self.transport, name) + + # Protocol relaying + + def connectionMade(self): + self.factory.registerProtocol(self) + self.wrappedProtocol.makeConnection(self) + + def dataReceived(self, data): + self.wrappedProtocol.dataReceived(data) + + def connectionLost(self, reason): + self.factory.unregisterProtocol(self) + self.wrappedProtocol.connectionLost(reason) + + +class WrappingFactory(ClientFactory): + """Wraps a factory and its protocols, and keeps track of them.""" + + protocol = ProtocolWrapper + + def __init__(self, wrappedFactory): + self.wrappedFactory = wrappedFactory + self.protocols = {} + + def doStart(self): + self.wrappedFactory.doStart() + ClientFactory.doStart(self) + + def doStop(self): + self.wrappedFactory.doStop() + ClientFactory.doStop(self) + + def startedConnecting(self, connector): + self.wrappedFactory.startedConnecting(connector) + + def clientConnectionFailed(self, connector, reason): + self.wrappedFactory.clientConnectionFailed(connector, reason) + + def clientConnectionLost(self, connector, reason): + self.wrappedFactory.clientConnectionLost(connector, reason) + + def buildProtocol(self, addr): + return self.protocol(self, self.wrappedFactory.buildProtocol(addr)) + + def registerProtocol(self, p): + """Called by protocol to register itself.""" + self.protocols[p] = 1 + + def unregisterProtocol(self, p): + """Called by protocols when they go away.""" + del self.protocols[p] + + +class ThrottlingProtocol(ProtocolWrapper): + """Protocol for ThrottlingFactory.""" + + # wrap API for tracking bandwidth + + def write(self, data): + self.factory.registerWritten(len(data)) + ProtocolWrapper.write(self, data) + + def writeSequence(self, seq): + self.factory.registerWritten(reduce(operator.add, map(len, seq))) + ProtocolWrapper.writeSequence(self, seq) + + def dataReceived(self, data): + self.factory.registerRead(len(data)) + ProtocolWrapper.dataReceived(self, data) + + def registerProducer(self, producer, streaming): + self.producer = producer + ProtocolWrapper.registerProducer(self, producer, streaming) + + def unregisterProducer(self): + del self.producer + ProtocolWrapper.unregisterProducer(self) + + + def throttleReads(self): + self.transport.pauseProducing() + + def unthrottleReads(self): + self.transport.resumeProducing() + + def throttleWrites(self): + if hasattr(self, "producer"): + self.producer.pauseProducing() + + def unthrottleWrites(self): + if hasattr(self, "producer"): + self.producer.resumeProducing() + + +class ThrottlingFactory(WrappingFactory): + """Throttles bandwidth and number of connections. + + Write bandwidth will only be throttled if there is a producer + registered. + """ + + protocol = ThrottlingProtocol + + def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None): + WrappingFactory.__init__(self, wrappedFactory) + self.connectionCount = 0 + self.maxConnectionCount = maxConnectionCount + self.readLimit = readLimit # max bytes we should read per second + self.writeLimit = writeLimit # max bytes we should write per second + self.readThisSecond = 0 + self.writtenThisSecond = 0 + self.unthrottleReadsID = None + self.checkReadBandwidthID = None + self.unthrottleWritesID = None + self.checkWriteBandwidthID = None + + def registerWritten(self, length): + """Called by protocol to tell us more bytes were written.""" + self.writtenThisSecond += length + + def registerRead(self, length): + """Called by protocol to tell us more bytes were read.""" + self.readThisSecond += length + + def checkReadBandwidth(self): + """Checks if we've passed bandwidth limits.""" + if self.readThisSecond > self.readLimit: + self.throttleReads() + throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0 + self.unthrottleReadsID = reactor.callLater(throttleTime, + self.unthrottleReads) + self.readThisSecond = 0 + self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth) + + def checkWriteBandwidth(self): + if self.writtenThisSecond > self.writeLimit: + self.throttleWrites() + throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0 + self.unthrottleWritesID = reactor.callLater(throttleTime, + self.unthrottleWrites) + # reset for next round + self.writtenThisSecond = 0 + self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth) + + def throttleReads(self): + """Throttle reads on all protocols.""" + log.msg("Throttling reads on %s" % self) + for p in self.protocols.keys(): + p.throttleReads() + + def unthrottleReads(self): + """Stop throttling reads on all protocols.""" + self.unthrottleReadsID = None + log.msg("Stopped throttling reads on %s" % self) + for p in self.protocols.keys(): + p.unthrottleReads() + + def throttleWrites(self): + """Throttle writes on all protocols.""" + log.msg("Throttling writes on %s" % self) + for p in self.protocols.keys(): + p.throttleWrites() + + def unthrottleWrites(self): + """Stop throttling writes on all protocols.""" + self.unthrottleWritesID = None + log.msg("Stopped throttling writes on %s" % self) + for p in self.protocols.keys(): + p.unthrottleWrites() + + def buildProtocol(self, addr): + if self.connectionCount == 0: + if self.readLimit is not None: + self.checkReadBandwidth() + if self.writeLimit is not None: + self.checkWriteBandwidth() + + if self.connectionCount < self.maxConnectionCount: + self.connectionCount += 1 + return WrappingFactory.buildProtocol(self, addr) + else: + log.msg("Max connection count reached!") + return None + + def unregisterProtocol(self, p): + WrappingFactory.unregisterProtocol(self, p) + self.connectionCount -= 1 + if self.connectionCount == 0: + if self.unthrottleReadsID is not None: + self.unthrottleReadsID.cancel() + if self.checkReadBandwidthID is not None: + self.checkReadBandwidthID.cancel() + if self.unthrottleWritesID is not None: + self.unthrottleWritesID.cancel() + if self.checkWriteBandwidthID is not None: + self.checkWriteBandwidthID.cancel() + -- 2.39.5