1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
7 """Resource limiting policies.
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
22 class ProtocolWrapper(Protocol):
23 """Wraps protocol instances and acts as their transport as well."""
27 def __init__(self, factory, wrappedProtocol):
28 self.wrappedProtocol = wrappedProtocol
29 self.factory = factory
31 def makeConnection(self, transport):
32 directlyProvides(self, *providedBy(self) + providedBy(transport))
33 Protocol.makeConnection(self, transport)
37 def write(self, data):
38 self.transport.write(data)
40 def writeSequence(self, data):
41 self.transport.writeSequence(data)
43 def loseConnection(self):
44 self.disconnecting = 1
45 self.transport.loseConnection()
48 return self.transport.getPeer()
51 return self.transport.getHost()
53 def registerProducer(self, producer, streaming):
54 self.transport.registerProducer(producer, streaming)
56 def unregisterProducer(self):
57 self.transport.unregisterProducer()
59 def stopConsuming(self):
60 self.transport.stopConsuming()
62 def __getattr__(self, name):
63 return getattr(self.transport, name)
67 def connectionMade(self):
68 self.factory.registerProtocol(self)
69 self.wrappedProtocol.makeConnection(self)
71 def dataReceived(self, data):
72 self.wrappedProtocol.dataReceived(data)
74 def connectionLost(self, reason):
75 self.factory.unregisterProtocol(self)
76 self.wrappedProtocol.connectionLost(reason)
79 class WrappingFactory(ClientFactory):
80 """Wraps a factory and its protocols, and keeps track of them."""
82 protocol = ProtocolWrapper
84 def __init__(self, wrappedFactory):
85 self.wrappedFactory = wrappedFactory
89 self.wrappedFactory.doStart()
90 ClientFactory.doStart(self)
93 self.wrappedFactory.doStop()
94 ClientFactory.doStop(self)
96 def startedConnecting(self, connector):
97 self.wrappedFactory.startedConnecting(connector)
99 def clientConnectionFailed(self, connector, reason):
100 self.wrappedFactory.clientConnectionFailed(connector, reason)
102 def clientConnectionLost(self, connector, reason):
103 self.wrappedFactory.clientConnectionLost(connector, reason)
105 def buildProtocol(self, addr):
106 return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
108 def registerProtocol(self, p):
109 """Called by protocol to register itself."""
110 self.protocols[p] = 1
112 def unregisterProtocol(self, p):
113 """Called by protocols when they go away."""
114 del self.protocols[p]
117 class ThrottlingProtocol(ProtocolWrapper):
118 """Protocol for ThrottlingFactory."""
120 # wrap API for tracking bandwidth
122 def write(self, data):
123 self.factory.registerWritten(len(data))
124 ProtocolWrapper.write(self, data)
126 def writeSequence(self, seq):
127 self.factory.registerWritten(reduce(operator.add, map(len, seq)))
128 ProtocolWrapper.writeSequence(self, seq)
130 def dataReceived(self, data):
131 self.factory.registerRead(len(data))
132 ProtocolWrapper.dataReceived(self, data)
134 def registerProducer(self, producer, streaming):
135 self.producer = producer
136 ProtocolWrapper.registerProducer(self, producer, streaming)
138 def unregisterProducer(self):
140 ProtocolWrapper.unregisterProducer(self)
143 def throttleReads(self):
144 self.transport.pauseProducing()
146 def unthrottleReads(self):
147 self.transport.resumeProducing()
149 def throttleWrites(self):
150 if hasattr(self, "producer"):
151 self.producer.pauseProducing()
153 def unthrottleWrites(self):
154 if hasattr(self, "producer"):
155 self.producer.resumeProducing()
158 class ThrottlingFactory(WrappingFactory):
159 """Throttles bandwidth and number of connections.
161 Write bandwidth will only be throttled if there is a producer
165 protocol = ThrottlingProtocol
167 def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
168 WrappingFactory.__init__(self, wrappedFactory)
169 self.connectionCount = 0
170 self.maxConnectionCount = maxConnectionCount
171 self.readLimit = readLimit # max bytes we should read per second
172 self.writeLimit = writeLimit # max bytes we should write per second
173 self.readThisSecond = 0
174 self.writtenThisSecond = 0
175 self.unthrottleReadsID = None
176 self.checkReadBandwidthID = None
177 self.unthrottleWritesID = None
178 self.checkWriteBandwidthID = None
180 def registerWritten(self, length):
181 """Called by protocol to tell us more bytes were written."""
182 self.writtenThisSecond += length
184 def registerRead(self, length):
185 """Called by protocol to tell us more bytes were read."""
186 self.readThisSecond += length
188 def checkReadBandwidth(self):
189 """Checks if we've passed bandwidth limits."""
190 if self.readThisSecond > self.readLimit:
192 throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
193 self.unthrottleReadsID = reactor.callLater(throttleTime,
194 self.unthrottleReads)
195 self.readThisSecond = 0
196 self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
198 def checkWriteBandwidth(self):
199 if self.writtenThisSecond > self.writeLimit:
200 self.throttleWrites()
201 throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
202 self.unthrottleWritesID = reactor.callLater(throttleTime,
203 self.unthrottleWrites)
204 # reset for next round
205 self.writtenThisSecond = 0
206 self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
208 def throttleReads(self):
209 """Throttle reads on all protocols."""
210 log.msg("Throttling reads on %s" % self)
211 for p in self.protocols.keys():
214 def unthrottleReads(self):
215 """Stop throttling reads on all protocols."""
216 self.unthrottleReadsID = None
217 log.msg("Stopped throttling reads on %s" % self)
218 for p in self.protocols.keys():
221 def throttleWrites(self):
222 """Throttle writes on all protocols."""
223 log.msg("Throttling writes on %s" % self)
224 for p in self.protocols.keys():
227 def unthrottleWrites(self):
228 """Stop throttling writes on all protocols."""
229 self.unthrottleWritesID = None
230 log.msg("Stopped throttling writes on %s" % self)
231 for p in self.protocols.keys():
234 def buildProtocol(self, addr):
235 if self.connectionCount == 0:
236 if self.readLimit is not None:
237 self.checkReadBandwidth()
238 if self.writeLimit is not None:
239 self.checkWriteBandwidth()
241 if self.connectionCount < self.maxConnectionCount:
242 self.connectionCount += 1
243 return WrappingFactory.buildProtocol(self, addr)
245 log.msg("Max connection count reached!")
248 def unregisterProtocol(self, p):
249 WrappingFactory.unregisterProtocol(self, p)
250 self.connectionCount -= 1
251 if self.connectionCount == 0:
252 if self.unthrottleReadsID is not None:
253 self.unthrottleReadsID.cancel()
254 if self.checkReadBandwidthID is not None:
255 self.checkReadBandwidthID.cancel()
256 if self.unthrottleWritesID is not None:
257 self.unthrottleWritesID.cancel()
258 if self.checkWriteBandwidthID is not None:
259 self.checkWriteBandwidthID.cancel()