# wrap API for tracking bandwidth
+ def __init__(self, factory, wrappedProtocol):
+ ProtocolWrapper.__init__(self, factory, wrappedProtocol)
+ self._tempDataBuffer = []
+ self._tempDataLength = 0
+ self.throttled = False
+
def write(self, data):
- self.factory.registerWritten(len(data))
- ProtocolWrapper.write(self, data)
+ # Check if we can write
+ if (not self.throttled) and self.factory.registerWritten(len(data)):
+ ProtocolWrapper.write(self, data)
+
+ if hasattr(self, "producer") and self.producer and not self.producer.paused:
+ # Interrupt the flow so that others can can have a chance
+ # We can only do this if it's not already paused otherwise we
+ # risk unpausing something that the Server paused
+ self.producer.pauseProducing()
+ reactor.callLater(0, self.producer.resumeProducing)
+ else:
+ # Can't write, buffer the data
+ self._tempDataBuffer.append(data)
+ self._tempDataLength += len(data)
+ self._throttleWrites()
def writeSequence(self, seq):
- self.factory.registerWritten(reduce(operator.add, map(len, seq)))
- ProtocolWrapper.writeSequence(self, seq)
+ if not self.throttled:
+ # Write each sequence separately
+ while seq and self.factory.registerWritten(len(seq[0])):
+ ProtocolWrapper.write(self, seq.pop(0))
+
+ # If there's some left, we must have been throttled
+ if seq:
+ self._tempDataBuffer.extend(seq)
+ self._tempDataLength += reduce(operator.add, map(len, seq))
+ self._throttleWrites()
def dataReceived(self, data):
self.factory.registerRead(len(data))
def unthrottleReads(self):
self.transport.resumeProducing()
- def throttleWrites(self):
- if hasattr(self, "producer"):
+ def _throttleWrites(self):
+ # If we haven't yet, queue for unthrottling
+ if not self.throttled:
+ self.throttled = True
+ self.factory.throttledWrites(self)
+
+ if hasattr(self, "producer") and self.producer:
self.producer.pauseProducing()
def unthrottleWrites(self):
- if hasattr(self, "producer"):
- self.producer.resumeProducing()
+ # Write some data
+ if self._tempDataBuffer:
+ assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
+ self._tempDataLength -= len(self._tempDataBuffer[0])
+ ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
+ assert self._tempDataLength >= 0
+
+ # If we wrote it all, start producing more
+ if not self._tempDataBuffer:
+ assert self._tempDataLength == 0
+ self.throttled = False
+ if hasattr(self, "producer") and self.producer:
+ # This might unpause something the Server has also paused, but
+ # it will get paused again on first write anyway
+ reactor.callLater(0, self.producer.resumeProducing)
+
+ return self._tempDataLength
class ThrottlingFactory(WrappingFactory):
"""
protocol = ThrottlingProtocol
+ CHUNK_SIZE = 4*1024
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
WrappingFactory.__init__(self, wrappedFactory)
self.readLimit = readLimit # max bytes we should read per second
self.writeLimit = writeLimit # max bytes we should write per second
self.readThisSecond = 0
- self.writtenThisSecond = 0
+ self.writeAvailable = writeLimit
+ self._writeQueue = []
self.unthrottleReadsID = None
self.checkReadBandwidthID = None
self.unthrottleWritesID = None
def registerWritten(self, length):
"""Called by protocol to tell us more bytes were written."""
- self.writtenThisSecond += length
+ # Check if there are bytes available to write
+ if self.writeAvailable > 0:
+ self.writeAvailable -= length
+ return True
+
+ return False
+
+ def throttledWrites(self, p):
+ """Called by the protocol to queue it for later writing."""
+ assert p not in self._writeQueue
+ self._writeQueue.append(p)
def registerRead(self, length):
"""Called by protocol to tell us more bytes were read."""
self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
def checkWriteBandwidth(self):
- if self.writtenThisSecond > self.writeLimit:
- self.throttleWrites()
- throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
- self.unthrottleWritesID = reactor.callLater(throttleTime,
- self.unthrottleWrites)
- # reset for next round
- self.writtenThisSecond = 0
+ """Add some new available bandwidth, and check for protocols to unthrottle."""
+ # Increase the available write bytes, but not higher than the limit
+ self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
+
+ # Write from the queue until it's empty or we're throttled again
+ while self.writeAvailable > 0 and self._writeQueue:
+ # Get the first queued protocol
+ p = self._writeQueue.pop(0)
+ _tempWriteAvailable = self.writeAvailable
+ bytesLeft = 1
+
+ # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
+ while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
+ # Unthrottle a single write (from the protocol's buffer)
+ bytesLeft = p.unthrottleWrites()
+
+ # If the protocol is not done, requeue it
+ if bytesLeft > 0:
+ self._writeQueue.append(p)
+
self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
def throttleReads(self):
for p in self.protocols.keys():
p.unthrottleReads()
- def throttleWrites(self):
- """Throttle writes on all protocols."""
- log.msg("Throttling writes on %s" % self)
- for p in self.protocols.keys():
- p.throttleWrites()
-
- def unthrottleWrites(self):
- """Stop throttling writes on all protocols."""
- self.unthrottleWritesID = None
- log.msg("Stopped throttling writes on %s" % self)
- for p in self.protocols.keys():
- p.unthrottleWrites()
-
def buildProtocol(self, addr):
if self.connectionCount == 0:
if self.readLimit is not None: