1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 """Resource limiting policies.
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
22 class ProtocolWrapper(Protocol):
23 """Wraps protocol instances and acts as their transport as well."""
27 def __init__(self, factory, wrappedProtocol):
28 self.wrappedProtocol = wrappedProtocol
29 self.factory = factory
31 def makeConnection(self, transport):
32 directlyProvides(self, *providedBy(self) + providedBy(transport))
33 Protocol.makeConnection(self, transport)
37 def write(self, data):
38 self.transport.write(data)
40 def writeSequence(self, data):
41 self.transport.writeSequence(data)
43 def loseConnection(self):
44 self.disconnecting = 1
45 self.transport.loseConnection()
48 return self.transport.getPeer()
51 return self.transport.getHost()
53 def registerProducer(self, producer, streaming):
54 self.transport.registerProducer(producer, streaming)
56 def unregisterProducer(self):
57 self.transport.unregisterProducer()
59 def stopConsuming(self):
60 self.transport.stopConsuming()
62 def __getattr__(self, name):
63 return getattr(self.transport, name)
67 def connectionMade(self):
68 self.factory.registerProtocol(self)
69 self.wrappedProtocol.makeConnection(self)
71 def dataReceived(self, data):
72 self.wrappedProtocol.dataReceived(data)
74 def connectionLost(self, reason):
75 self.factory.unregisterProtocol(self)
76 self.wrappedProtocol.connectionLost(reason)
79 class WrappingFactory(ClientFactory):
80 """Wraps a factory and its protocols, and keeps track of them."""
82 protocol = ProtocolWrapper
84 def __init__(self, wrappedFactory):
85 self.wrappedFactory = wrappedFactory
89 self.wrappedFactory.doStart()
90 ClientFactory.doStart(self)
93 self.wrappedFactory.doStop()
94 ClientFactory.doStop(self)
96 def startedConnecting(self, connector):
97 self.wrappedFactory.startedConnecting(connector)
99 def clientConnectionFailed(self, connector, reason):
100 self.wrappedFactory.clientConnectionFailed(connector, reason)
102 def clientConnectionLost(self, connector, reason):
103 self.wrappedFactory.clientConnectionLost(connector, reason)
105 def buildProtocol(self, addr):
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
108 def registerProtocol(self, p):
109 """Called by protocol to register itself."""
110 self.protocols[p] = 1
112 def unregisterProtocol(self, p):
113 """Called by protocols when they go away."""
114 del self.protocols[p]
117 class ThrottlingProtocol(ProtocolWrapper):
118 """Protocol for ThrottlingFactory."""
120 # wrap API for tracking bandwidth
122 def __init__(self, factory, wrappedProtocol):
123 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124 self._tempDataBuffer = []
125 self._tempDataLength = 0
126 self.throttled = False
128 def write(self, data):
129 # Check if we can write
130 if (not self.throttled) and self.factory.registerWritten(len(data)):
131 ProtocolWrapper.write(self, data)
133 if hasattr(self, "producer") and self.producer and not self.producer.paused:
134 # Interrupt the flow so that others can can have a chance
135 # We can only do this if it's not already paused otherwise we
136 # risk unpausing something that the Server paused
137 self.producer.pauseProducing()
138 reactor.callLater(0, self.producer.resumeProducing)
140 # Can't write, buffer the data
141 self._tempDataBuffer.append(data)
142 self._tempDataLength += len(data)
143 self._throttleWrites()
145 def writeSequence(self, seq):
146 if not self.throttled:
147 # Write each sequence separately
148 while seq and self.factory.registerWritten(len(seq[0])):
149 ProtocolWrapper.write(self, seq.pop(0))
151 # If there's some left, we must have been throttled
153 self._tempDataBuffer.extend(seq)
154 self._tempDataLength += reduce(operator.add, map(len, seq))
155 self._throttleWrites()
157 def dataReceived(self, data):
158 self.factory.registerRead(len(data))
159 ProtocolWrapper.dataReceived(self, data)
161 def registerProducer(self, producer, streaming):
162 self.producer = producer
163 ProtocolWrapper.registerProducer(self, producer, streaming)
165 def unregisterProducer(self):
167 ProtocolWrapper.unregisterProducer(self)
170 def throttleReads(self):
171 self.transport.pauseProducing()
173 def unthrottleReads(self):
174 self.transport.resumeProducing()
176 def _throttleWrites(self):
177 # If we haven't yet, queue for unthrottling
178 if not self.throttled:
179 self.throttled = True
180 self.factory.throttledWrites(self)
182 if hasattr(self, "producer") and self.producer:
183 self.producer.pauseProducing()
185 def unthrottleWrites(self):
187 if self._tempDataBuffer:
188 assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
189 self._tempDataLength -= len(self._tempDataBuffer[0])
190 ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
191 assert self._tempDataLength >= 0
193 # If we wrote it all, start producing more
194 if not self._tempDataBuffer:
195 assert self._tempDataLength == 0
196 self.throttled = False
197 if hasattr(self, "producer") and self.producer:
198 # This might unpause something the Server has also paused, but
199 # it will get paused again on first write anyway
200 reactor.callLater(0, self.producer.resumeProducing)
202 return self._tempDataLength
205 class ThrottlingFactory(WrappingFactory):
206 """Throttles bandwidth and number of connections.
208 Write bandwidth will only be throttled if there is a producer
212 protocol = ThrottlingProtocol
215 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
216 WrappingFactory.__init__(self, wrappedFactory)
217 self.connectionCount = 0
218 self.maxConnectionCount = maxConnectionCount
219 self.readLimit = readLimit # max bytes we should read per second
220 self.writeLimit = writeLimit # max bytes we should write per second
221 self.readThisSecond = 0
222 self.writeAvailable = writeLimit
223 self._writeQueue = []
224 self.unthrottleReadsID = None
225 self.checkReadBandwidthID = None
226 self.unthrottleWritesID = None
227 self.checkWriteBandwidthID = None
229 def registerWritten(self, length):
230 """Called by protocol to tell us more bytes were written."""
231 # Check if there are bytes available to write
232 if self.writeAvailable > 0:
233 self.writeAvailable -= length
238 def throttledWrites(self, p):
239 """Called by the protocol to queue it for later writing."""
240 assert p not in self._writeQueue
241 self._writeQueue.append(p)
243 def registerRead(self, length):
244 """Called by protocol to tell us more bytes were read."""
245 self.readThisSecond += length
247 def checkReadBandwidth(self):
248 """Checks if we've passed bandwidth limits."""
249 if self.readThisSecond > self.readLimit:
251 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
252 self.unthrottleReadsID = reactor.callLater(throttleTime,
253 self.unthrottleReads)
254 self.readThisSecond = 0
255 self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
257 def checkWriteBandwidth(self):
258 """Add some new available bandwidth, and check for protocols to unthrottle."""
259 # Increase the available write bytes, but not higher than the limit
260 self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
262 # Write from the queue until it's empty or we're throttled again
263 while self.writeAvailable > 0 and self._writeQueue:
264 # Get the first queued protocol
265 p = self._writeQueue.pop(0)
266 _tempWriteAvailable = self.writeAvailable
269 # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
270 while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
271 # Unthrottle a single write (from the protocol's buffer)
272 bytesLeft = p.unthrottleWrites()
274 # If the protocol is not done, requeue it
276 self._writeQueue.append(p)
278 self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
280 def throttleReads(self):
281 """Throttle reads on all protocols."""
282 log.msg("Throttling reads on %s" % self)
283 for p in self.protocols.keys():
286 def unthrottleReads(self):
287 """Stop throttling reads on all protocols."""
288 self.unthrottleReadsID = None
289 log.msg("Stopped throttling reads on %s" % self)
290 for p in self.protocols.keys():
293 def buildProtocol(self, addr):
294 if self.connectionCount == 0:
295 if self.readLimit is not None:
296 self.checkReadBandwidth()
297 if self.writeLimit is not None:
298 self.checkWriteBandwidth()
300 if self.connectionCount < self.maxConnectionCount:
301 self.connectionCount += 1
302 return WrappingFactory.buildProtocol(self, addr)
304 log.msg("Max connection count reached!")
307 def unregisterProtocol(self, p):
308 WrappingFactory.unregisterProtocol(self, p)
309 self.connectionCount -= 1
310 if self.connectionCount == 0:
311 if self.unthrottleReadsID is not None:
312 self.unthrottleReadsID.cancel()
313 if self.checkReadBandwidthID is not None:
314 self.checkReadBandwidthID.cancel()
315 if self.unthrottleWritesID is not None:
316 self.unthrottleWritesID.cancel()
317 if self.checkWriteBandwidthID is not None:
318 self.checkWriteBandwidthID.cancel()