]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht/policies.py
Fixed the ThrottlingFactory to work with web2 static streams from the web server.
[quix0rs-apt-p2p.git] / apt_dht / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 #
6
7 """Resource limiting policies.
8
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
10 """
11
12 # system imports
13 import sys, operator
14
15 # twisted imports
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
21
22 class ProtocolWrapper(Protocol):
23     """Wraps protocol instances and acts as their transport as well."""
24
25     disconnecting = 0
26
27     def __init__(self, factory, wrappedProtocol):
28         self.wrappedProtocol = wrappedProtocol
29         self.factory = factory
30
31     def makeConnection(self, transport):
32         directlyProvides(self, *providedBy(self) + providedBy(transport))
33         Protocol.makeConnection(self, transport)
34
35     # Transport relaying
36
37     def write(self, data):
38         self.transport.write(data)
39
40     def writeSequence(self, data):
41         self.transport.writeSequence(data)
42
43     def loseConnection(self):
44         self.disconnecting = 1
45         self.transport.loseConnection()
46
47     def getPeer(self):
48         return self.transport.getPeer()
49
50     def getHost(self):
51         return self.transport.getHost()
52
53     def registerProducer(self, producer, streaming):
54         self.transport.registerProducer(producer, streaming)
55
56     def unregisterProducer(self):
57         self.transport.unregisterProducer()
58
59     def stopConsuming(self):
60         self.transport.stopConsuming()
61
62     def __getattr__(self, name):
63         return getattr(self.transport, name)
64
65     # Protocol relaying
66
67     def connectionMade(self):
68         self.factory.registerProtocol(self)
69         self.wrappedProtocol.makeConnection(self)
70
71     def dataReceived(self, data):
72         self.wrappedProtocol.dataReceived(data)
73
74     def connectionLost(self, reason):
75         self.factory.unregisterProtocol(self)
76         self.wrappedProtocol.connectionLost(reason)
77
78
79 class WrappingFactory(ClientFactory):
80     """Wraps a factory and its protocols, and keeps track of them."""
81
82     protocol = ProtocolWrapper
83
84     def __init__(self, wrappedFactory):
85         self.wrappedFactory = wrappedFactory
86         self.protocols = {}
87
88     def doStart(self):
89         self.wrappedFactory.doStart()
90         ClientFactory.doStart(self)
91
92     def doStop(self):
93         self.wrappedFactory.doStop()
94         ClientFactory.doStop(self)
95
96     def startedConnecting(self, connector):
97         self.wrappedFactory.startedConnecting(connector)
98
99     def clientConnectionFailed(self, connector, reason):
100         self.wrappedFactory.clientConnectionFailed(connector, reason)
101
102     def clientConnectionLost(self, connector, reason):
103         self.wrappedFactory.clientConnectionLost(connector, reason)
104
105     def buildProtocol(self, addr):
106         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
107
108     def registerProtocol(self, p):
109         """Called by protocol to register itself."""
110         self.protocols[p] = 1
111
112     def unregisterProtocol(self, p):
113         """Called by protocols when they go away."""
114         del self.protocols[p]
115
116
117 class ThrottlingProtocol(ProtocolWrapper):
118     """Protocol for ThrottlingFactory."""
119
120     # wrap API for tracking bandwidth
121
122     def __init__(self, factory, wrappedProtocol):
123         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124         self._tempDataBuffer = []
125         self._tempDataLength = 0
126         self.throttled = False
127
128     def write(self, data):
129         # Check if we can write
130         if (not self.throttled) and self.factory.registerWritten(len(data)):
131             ProtocolWrapper.write(self, data)
132             
133             if hasattr(self, "producer") and self.producer and not self.producer.paused:
134                 # Interrupt the flow so that others can can have a chance
135                 # We can only do this if it's not already paused otherwise we
136                 # risk unpausing something that the Server paused
137                 self.producer.pauseProducing()
138                 reactor.callLater(0, self.producer.resumeProducing)
139         else:
140             # Can't write, buffer the data
141             self._tempDataBuffer.append(data)
142             self._tempDataLength += len(data)
143             self._throttleWrites()
144
145     def writeSequence(self, seq):
146         if not self.throttled:
147             # Write each sequence separately
148             while seq and self.factory.registerWritten(len(seq[0])):
149                 ProtocolWrapper.write(self, seq.pop(0))
150
151         # If there's some left, we must have been throttled
152         if seq:
153             self._tempDataBuffer.extend(seq)
154             self._tempDataLength += reduce(operator.add, map(len, seq))
155             self._throttleWrites()
156
157     def dataReceived(self, data):
158         self.factory.registerRead(len(data))
159         ProtocolWrapper.dataReceived(self, data)
160
161     def registerProducer(self, producer, streaming):
162         self.producer = producer
163         ProtocolWrapper.registerProducer(self, producer, streaming)
164
165     def unregisterProducer(self):
166         del self.producer
167         ProtocolWrapper.unregisterProducer(self)
168
169
170     def throttleReads(self):
171         self.transport.pauseProducing()
172
173     def unthrottleReads(self):
174         self.transport.resumeProducing()
175
176     def _throttleWrites(self):
177         # If we haven't yet, queue for unthrottling
178         if not self.throttled:
179             self.throttled = True
180             self.factory.throttledWrites(self)
181
182         if hasattr(self, "producer") and self.producer:
183             self.producer.pauseProducing()
184
185     def unthrottleWrites(self):
186         # Write some data
187         if self._tempDataBuffer:
188             assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
189             self._tempDataLength -= len(self._tempDataBuffer[0])
190             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
191             assert self._tempDataLength >= 0
192
193         # If we wrote it all, start producing more
194         if not self._tempDataBuffer:
195             assert self._tempDataLength == 0
196             self.throttled = False
197             if hasattr(self, "producer") and self.producer:
198                 # This might unpause something the Server has also paused, but
199                 # it will get paused again on first write anyway
200                 reactor.callLater(0, self.producer.resumeProducing)
201         
202         return self._tempDataLength
203
204
205 class ThrottlingFactory(WrappingFactory):
206     """Throttles bandwidth and number of connections.
207
208     Write bandwidth will only be throttled if there is a producer
209     registered.
210     """
211
212     protocol = ThrottlingProtocol
213     CHUNK_SIZE = 4*1024
214
215     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
216         WrappingFactory.__init__(self, wrappedFactory)
217         self.connectionCount = 0
218         self.maxConnectionCount = maxConnectionCount
219         self.readLimit = readLimit # max bytes we should read per second
220         self.writeLimit = writeLimit # max bytes we should write per second
221         self.readThisSecond = 0
222         self.writeAvailable = writeLimit
223         self._writeQueue = []
224         self.unthrottleReadsID = None
225         self.checkReadBandwidthID = None
226         self.unthrottleWritesID = None
227         self.checkWriteBandwidthID = None
228
229     def registerWritten(self, length):
230         """Called by protocol to tell us more bytes were written."""
231         # Check if there are bytes available to write
232         if self.writeAvailable > 0:
233             self.writeAvailable -= length
234             return True
235         
236         return False
237     
238     def throttledWrites(self, p):
239         """Called by the protocol to queue it for later writing."""
240         assert p not in self._writeQueue
241         self._writeQueue.append(p)
242
243     def registerRead(self, length):
244         """Called by protocol to tell us more bytes were read."""
245         self.readThisSecond += length
246
247     def checkReadBandwidth(self):
248         """Checks if we've passed bandwidth limits."""
249         if self.readThisSecond > self.readLimit:
250             self.throttleReads()
251             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
252             self.unthrottleReadsID = reactor.callLater(throttleTime,
253                                                        self.unthrottleReads)
254         self.readThisSecond = 0
255         self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
256
257     def checkWriteBandwidth(self):
258         """Add some new available bandwidth, and check for protocols to unthrottle."""
259         # Increase the available write bytes, but not higher than the limit
260         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
261         
262         # Write from the queue until it's empty or we're throttled again
263         while self.writeAvailable > 0 and self._writeQueue:
264             # Get the first queued protocol
265             p = self._writeQueue.pop(0)
266             _tempWriteAvailable = self.writeAvailable
267             bytesLeft = 1
268             
269             # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
270             while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
271                 # Unthrottle a single write (from the protocol's buffer)
272                 bytesLeft = p.unthrottleWrites()
273                 
274             # If the protocol is not done, requeue it
275             if bytesLeft > 0:
276                 self._writeQueue.append(p)
277
278         self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
279
280     def throttleReads(self):
281         """Throttle reads on all protocols."""
282         log.msg("Throttling reads on %s" % self)
283         for p in self.protocols.keys():
284             p.throttleReads()
285
286     def unthrottleReads(self):
287         """Stop throttling reads on all protocols."""
288         self.unthrottleReadsID = None
289         log.msg("Stopped throttling reads on %s" % self)
290         for p in self.protocols.keys():
291             p.unthrottleReads()
292
293     def buildProtocol(self, addr):
294         if self.connectionCount == 0:
295             if self.readLimit is not None:
296                 self.checkReadBandwidth()
297             if self.writeLimit is not None:
298                 self.checkWriteBandwidth()
299
300         if self.connectionCount < self.maxConnectionCount:
301             self.connectionCount += 1
302             return WrappingFactory.buildProtocol(self, addr)
303         else:
304             log.msg("Max connection count reached!")
305             return None
306
307     def unregisterProtocol(self, p):
308         WrappingFactory.unregisterProtocol(self, p)
309         self.connectionCount -= 1
310         if self.connectionCount == 0:
311             if self.unthrottleReadsID is not None:
312                 self.unthrottleReadsID.cancel()
313             if self.checkReadBandwidthID is not None:
314                 self.checkReadBandwidthID.cancel()
315             if self.unthrottleWritesID is not None:
316                 self.unthrottleWritesID.cancel()
317             if self.checkWriteBandwidthID is not None:
318                 self.checkWriteBandwidthID.cancel()
319