1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 """Resource limiting policies.
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
22 class ProtocolWrapper(Protocol):
23 """Wraps protocol instances and acts as their transport as well."""
27 def __init__(self, factory, wrappedProtocol):
28 self.wrappedProtocol = wrappedProtocol
29 self.factory = factory
31 def makeConnection(self, transport):
32 directlyProvides(self, *providedBy(self) + providedBy(transport))
33 Protocol.makeConnection(self, transport)
37 def write(self, data):
38 self.transport.write(data)
40 def writeSequence(self, data):
41 self.transport.writeSequence(data)
43 def loseConnection(self):
44 self.disconnecting = 1
45 self.transport.loseConnection()
48 return self.transport.getPeer()
51 return self.transport.getHost()
53 def registerProducer(self, producer, streaming):
54 self.transport.registerProducer(producer, streaming)
56 def unregisterProducer(self):
57 self.transport.unregisterProducer()
59 def stopConsuming(self):
60 self.transport.stopConsuming()
62 def __getattr__(self, name):
63 return getattr(self.transport, name)
67 def connectionMade(self):
68 self.factory.registerProtocol(self)
69 self.wrappedProtocol.makeConnection(self)
71 def dataReceived(self, data):
72 self.wrappedProtocol.dataReceived(data)
74 def connectionLost(self, reason):
75 self.factory.unregisterProtocol(self)
76 self.wrappedProtocol.connectionLost(reason)
79 class WrappingFactory(ClientFactory):
80 """Wraps a factory and its protocols, and keeps track of them."""
82 protocol = ProtocolWrapper
84 def __init__(self, wrappedFactory):
85 self.wrappedFactory = wrappedFactory
89 self.wrappedFactory.doStart()
90 ClientFactory.doStart(self)
93 self.wrappedFactory.doStop()
94 ClientFactory.doStop(self)
96 def startedConnecting(self, connector):
97 self.wrappedFactory.startedConnecting(connector)
99 def clientConnectionFailed(self, connector, reason):
100 self.wrappedFactory.clientConnectionFailed(connector, reason)
102 def clientConnectionLost(self, connector, reason):
103 self.wrappedFactory.clientConnectionLost(connector, reason)
105 def buildProtocol(self, addr):
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
108 def registerProtocol(self, p):
109 """Called by protocol to register itself."""
110 self.protocols[p] = 1
112 def unregisterProtocol(self, p):
113 """Called by protocols when they go away."""
114 del self.protocols[p]
117 class ThrottlingProtocol(ProtocolWrapper):
118 """Protocol for ThrottlingFactory."""
120 # wrap API for tracking bandwidth
122 def __init__(self, factory, wrappedProtocol):
123 ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124 self._tempDataBuffer = []
125 self._tempDataLength = 0
126 self.throttled = False
128 def write(self, data):
129 # Check if we can write
130 if (not self.throttled) and self.factory.registerWritten(len(data)):
131 ProtocolWrapper.write(self, data)
133 if hasattr(self, "producer") and self.producer and not self.producer.paused:
134 # Interrupt the flow so that others can can have a chance
135 # We can only do this if it's not already paused otherwise we
136 # risk unpausing something that the Server paused
137 self.producer.pauseProducing()
138 reactor.callLater(0, self.producer.resumeProducing)
140 # Can't write, buffer the data
141 self._tempDataBuffer.append(data)
142 self._tempDataLength += len(data)
143 self._throttleWrites()
145 def writeSequence(self, seq):
146 if not self.throttled:
147 # Write each sequence separately
148 while seq and self.factory.registerWritten(len(seq[0])):
149 ProtocolWrapper.write(self, seq.pop(0))
151 # If there's some left, we must have been throttled
153 self._tempDataBuffer.extend(seq)
154 self._tempDataLength += reduce(operator.add, map(len, seq))
155 self._throttleWrites()
157 def dataReceived(self, data):
158 self.factory.registerRead(len(data))
159 ProtocolWrapper.dataReceived(self, data)
161 def registerProducer(self, producer, streaming):
162 assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
163 self.producer = producer
164 ProtocolWrapper.registerProducer(self, producer, streaming)
166 def unregisterProducer(self):
168 ProtocolWrapper.unregisterProducer(self)
171 def throttleReads(self):
172 self.transport.pauseProducing()
174 def unthrottleReads(self):
175 self.transport.resumeProducing()
177 def _throttleWrites(self):
178 # If we haven't yet, queue for unthrottling
179 if not self.throttled:
180 self.throttled = True
181 self.factory.throttledWrites(self)
183 if hasattr(self, "producer") and self.producer:
184 self.producer.pauseProducing()
186 def unthrottleWrites(self):
188 if self._tempDataBuffer:
189 assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
190 self._tempDataLength -= len(self._tempDataBuffer[0])
191 ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
192 assert self._tempDataLength >= 0
194 # If we wrote it all, start producing more
195 if not self._tempDataBuffer:
196 assert self._tempDataLength == 0
197 self.throttled = False
198 if hasattr(self, "producer") and self.producer:
199 # This might unpause something the Server has also paused, but
200 # it will get paused again on first write anyway
201 reactor.callLater(0, self.producer.resumeProducing)
203 return self._tempDataLength
206 class ThrottlingFactory(WrappingFactory):
207 """Throttles bandwidth and number of connections.
209 Write bandwidth will only be throttled if there is a producer
213 protocol = ThrottlingProtocol
216 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
217 WrappingFactory.__init__(self, wrappedFactory)
218 self.connectionCount = 0
219 self.maxConnectionCount = maxConnectionCount
220 self.readLimit = readLimit # max bytes we should read per second
221 self.writeLimit = writeLimit # max bytes we should write per second
222 self.readThisSecond = 0
223 self.writeAvailable = writeLimit
224 self._writeQueue = []
225 self.unthrottleReadsID = None
226 self.checkReadBandwidthID = None
227 self.unthrottleWritesID = None
228 self.checkWriteBandwidthID = None
230 def registerWritten(self, length):
231 """Called by protocol to tell us more bytes were written."""
232 # Check if there are bytes available to write
233 if self.writeAvailable > 0:
234 self.writeAvailable -= length
239 def throttledWrites(self, p):
240 """Called by the protocol to queue it for later writing."""
241 assert p not in self._writeQueue
242 self._writeQueue.append(p)
244 def registerRead(self, length):
245 """Called by protocol to tell us more bytes were read."""
246 self.readThisSecond += length
248 def checkReadBandwidth(self):
249 """Checks if we've passed bandwidth limits."""
250 if self.readThisSecond > self.readLimit:
252 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
253 self.unthrottleReadsID = reactor.callLater(throttleTime,
254 self.unthrottleReads)
255 self.readThisSecond = 0
256 self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
258 def checkWriteBandwidth(self):
259 """Add some new available bandwidth, and check for protocols to unthrottle."""
260 # Increase the available write bytes, but not higher than the limit
261 self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
263 # Write from the queue until it's empty or we're throttled again
264 while self.writeAvailable > 0 and self._writeQueue:
265 # Get the first queued protocol
266 p = self._writeQueue.pop(0)
267 _tempWriteAvailable = self.writeAvailable
270 # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
271 while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
272 # Unthrottle a single write (from the protocol's buffer)
273 bytesLeft = p.unthrottleWrites()
275 # If the protocol is not done, requeue it
277 self._writeQueue.append(p)
279 self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
281 def throttleReads(self):
282 """Throttle reads on all protocols."""
283 log.msg("Throttling reads on %s" % self)
284 for p in self.protocols.keys():
287 def unthrottleReads(self):
288 """Stop throttling reads on all protocols."""
289 self.unthrottleReadsID = None
290 log.msg("Stopped throttling reads on %s" % self)
291 for p in self.protocols.keys():
294 def buildProtocol(self, addr):
295 if self.connectionCount == 0:
296 if self.readLimit is not None:
297 self.checkReadBandwidth()
298 if self.writeLimit is not None:
299 self.checkWriteBandwidth()
301 if self.connectionCount < self.maxConnectionCount:
302 self.connectionCount += 1
303 return WrappingFactory.buildProtocol(self, addr)
305 log.msg("Max connection count reached!")
308 def unregisterProtocol(self, p):
309 WrappingFactory.unregisterProtocol(self, p)
310 self.connectionCount -= 1
311 if self.connectionCount == 0:
312 if self.unthrottleReadsID is not None:
313 self.unthrottleReadsID.cancel()
314 if self.checkReadBandwidthID is not None:
315 self.checkReadBandwidthID.cancel()
316 if self.unthrottleWritesID is not None:
317 self.unthrottleWritesID.cancel()
318 if self.checkWriteBandwidthID is not None:
319 self.checkWriteBandwidthID.cancel()