3c84c149b3c38dd1317b46031fddf0b602da6e33
[quix0rs-apt-p2p.git] / apt_dht / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 #
6
7 """Resource limiting policies.
8
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
10 """
11
12 # system imports
13 import sys, operator
14
15 # twisted imports
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
21
22 class ProtocolWrapper(Protocol):
23     """Wraps protocol instances and acts as their transport as well."""
24
25     disconnecting = 0
26
27     def __init__(self, factory, wrappedProtocol):
28         self.wrappedProtocol = wrappedProtocol
29         self.factory = factory
30
31     def makeConnection(self, transport):
32         directlyProvides(self, *providedBy(self) + providedBy(transport))
33         Protocol.makeConnection(self, transport)
34
35     # Transport relaying
36
37     def write(self, data):
38         self.transport.write(data)
39
40     def writeSequence(self, data):
41         self.transport.writeSequence(data)
42
43     def loseConnection(self):
44         self.disconnecting = 1
45         self.transport.loseConnection()
46
47     def getPeer(self):
48         return self.transport.getPeer()
49
50     def getHost(self):
51         return self.transport.getHost()
52
53     def registerProducer(self, producer, streaming):
54         self.transport.registerProducer(producer, streaming)
55
56     def unregisterProducer(self):
57         self.transport.unregisterProducer()
58
59     def stopConsuming(self):
60         self.transport.stopConsuming()
61
62     def __getattr__(self, name):
63         return getattr(self.transport, name)
64
65     # Protocol relaying
66
67     def connectionMade(self):
68         self.factory.registerProtocol(self)
69         self.wrappedProtocol.makeConnection(self)
70
71     def dataReceived(self, data):
72         self.wrappedProtocol.dataReceived(data)
73
74     def connectionLost(self, reason):
75         self.factory.unregisterProtocol(self)
76         self.wrappedProtocol.connectionLost(reason)
77
78
79 class WrappingFactory(ClientFactory):
80     """Wraps a factory and its protocols, and keeps track of them."""
81
82     protocol = ProtocolWrapper
83
84     def __init__(self, wrappedFactory):
85         self.wrappedFactory = wrappedFactory
86         self.protocols = {}
87
88     def doStart(self):
89         self.wrappedFactory.doStart()
90         ClientFactory.doStart(self)
91
92     def doStop(self):
93         self.wrappedFactory.doStop()
94         ClientFactory.doStop(self)
95
96     def startedConnecting(self, connector):
97         self.wrappedFactory.startedConnecting(connector)
98
99     def clientConnectionFailed(self, connector, reason):
100         self.wrappedFactory.clientConnectionFailed(connector, reason)
101
102     def clientConnectionLost(self, connector, reason):
103         self.wrappedFactory.clientConnectionLost(connector, reason)
104
105     def buildProtocol(self, addr):
106         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
107
108     def registerProtocol(self, p):
109         """Called by protocol to register itself."""
110         self.protocols[p] = 1
111
112     def unregisterProtocol(self, p):
113         """Called by protocols when they go away."""
114         del self.protocols[p]
115
116
117 class ThrottlingProtocol(ProtocolWrapper):
118     """Protocol for ThrottlingFactory."""
119
120     # wrap API for tracking bandwidth
121
122     def __init__(self, factory, wrappedProtocol):
123         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124         self._tempDataBuffer = []
125         self._tempDataLength = 0
126         self.throttled = False
127
128     def write(self, data):
129         # Check if we can write
130         if (not self.throttled) and self.factory.registerWritten(len(data)):
131             ProtocolWrapper.write(self, data)
132             
133             if hasattr(self, "producer") and self.producer and not self.producer.paused:
134                 # Interrupt the flow so that others can can have a chance
135                 # We can only do this if it's not already paused otherwise we
136                 # risk unpausing something that the Server paused
137                 self.producer.pauseProducing()
138                 reactor.callLater(0, self.producer.resumeProducing)
139         else:
140             # Can't write, buffer the data
141             self._tempDataBuffer.append(data)
142             self._tempDataLength += len(data)
143             self._throttleWrites()
144
145     def writeSequence(self, seq):
146         if not self.throttled:
147             # Write each sequence separately
148             while seq and self.factory.registerWritten(len(seq[0])):
149                 ProtocolWrapper.write(self, seq.pop(0))
150
151         # If there's some left, we must have been throttled
152         if seq:
153             self._tempDataBuffer.extend(seq)
154             self._tempDataLength += reduce(operator.add, map(len, seq))
155             self._throttleWrites()
156
157     def dataReceived(self, data):
158         self.factory.registerRead(len(data))
159         ProtocolWrapper.dataReceived(self, data)
160
161     def registerProducer(self, producer, streaming):
162         assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
163         self.producer = producer
164         ProtocolWrapper.registerProducer(self, producer, streaming)
165
166     def unregisterProducer(self):
167         del self.producer
168         ProtocolWrapper.unregisterProducer(self)
169
170
171     def throttleReads(self):
172         self.transport.pauseProducing()
173
174     def unthrottleReads(self):
175         self.transport.resumeProducing()
176
177     def _throttleWrites(self):
178         # If we haven't yet, queue for unthrottling
179         if not self.throttled:
180             self.throttled = True
181             self.factory.throttledWrites(self)
182
183         if hasattr(self, "producer") and self.producer:
184             self.producer.pauseProducing()
185
186     def unthrottleWrites(self):
187         # Write some data
188         if self._tempDataBuffer:
189             assert self.factory.registerWritten(len(self._tempDataBuffer[0]))
190             self._tempDataLength -= len(self._tempDataBuffer[0])
191             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
192             assert self._tempDataLength >= 0
193
194         # If we wrote it all, start producing more
195         if not self._tempDataBuffer:
196             assert self._tempDataLength == 0
197             self.throttled = False
198             if hasattr(self, "producer") and self.producer:
199                 # This might unpause something the Server has also paused, but
200                 # it will get paused again on first write anyway
201                 reactor.callLater(0, self.producer.resumeProducing)
202         
203         return self._tempDataLength
204
205
206 class ThrottlingFactory(WrappingFactory):
207     """Throttles bandwidth and number of connections.
208
209     Write bandwidth will only be throttled if there is a producer
210     registered.
211     """
212
213     protocol = ThrottlingProtocol
214     CHUNK_SIZE = 4*1024
215
216     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
217         WrappingFactory.__init__(self, wrappedFactory)
218         self.connectionCount = 0
219         self.maxConnectionCount = maxConnectionCount
220         self.readLimit = readLimit # max bytes we should read per second
221         self.writeLimit = writeLimit # max bytes we should write per second
222         self.readThisSecond = 0
223         self.writeAvailable = writeLimit
224         self._writeQueue = []
225         self.unthrottleReadsID = None
226         self.checkReadBandwidthID = None
227         self.unthrottleWritesID = None
228         self.checkWriteBandwidthID = None
229
230     def registerWritten(self, length):
231         """Called by protocol to tell us more bytes were written."""
232         # Check if there are bytes available to write
233         if self.writeAvailable > 0:
234             self.writeAvailable -= length
235             return True
236         
237         return False
238     
239     def throttledWrites(self, p):
240         """Called by the protocol to queue it for later writing."""
241         assert p not in self._writeQueue
242         self._writeQueue.append(p)
243
244     def registerRead(self, length):
245         """Called by protocol to tell us more bytes were read."""
246         self.readThisSecond += length
247
248     def checkReadBandwidth(self):
249         """Checks if we've passed bandwidth limits."""
250         if self.readThisSecond > self.readLimit:
251             self.throttleReads()
252             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
253             self.unthrottleReadsID = reactor.callLater(throttleTime,
254                                                        self.unthrottleReads)
255         self.readThisSecond = 0
256         self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
257
258     def checkWriteBandwidth(self):
259         """Add some new available bandwidth, and check for protocols to unthrottle."""
260         # Increase the available write bytes, but not higher than the limit
261         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
262         
263         # Write from the queue until it's empty or we're throttled again
264         while self.writeAvailable > 0 and self._writeQueue:
265             # Get the first queued protocol
266             p = self._writeQueue.pop(0)
267             _tempWriteAvailable = self.writeAvailable
268             bytesLeft = 1
269             
270             # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
271             while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
272                 # Unthrottle a single write (from the protocol's buffer)
273                 bytesLeft = p.unthrottleWrites()
274                 
275             # If the protocol is not done, requeue it
276             if bytesLeft > 0:
277                 self._writeQueue.append(p)
278
279         self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
280
281     def throttleReads(self):
282         """Throttle reads on all protocols."""
283         log.msg("Throttling reads on %s" % self)
284         for p in self.protocols.keys():
285             p.throttleReads()
286
287     def unthrottleReads(self):
288         """Stop throttling reads on all protocols."""
289         self.unthrottleReadsID = None
290         log.msg("Stopped throttling reads on %s" % self)
291         for p in self.protocols.keys():
292             p.unthrottleReads()
293
294     def buildProtocol(self, addr):
295         if self.connectionCount == 0:
296             if self.readLimit is not None:
297                 self.checkReadBandwidth()
298             if self.writeLimit is not None:
299                 self.checkWriteBandwidth()
300
301         if self.connectionCount < self.maxConnectionCount:
302             self.connectionCount += 1
303             return WrappingFactory.buildProtocol(self, addr)
304         else:
305             log.msg("Max connection count reached!")
306             return None
307
308     def unregisterProtocol(self, p):
309         WrappingFactory.unregisterProtocol(self, p)
310         self.connectionCount -= 1
311         if self.connectionCount == 0:
312             if self.unthrottleReadsID is not None:
313                 self.unthrottleReadsID.cancel()
314             if self.checkReadBandwidthID is not None:
315                 self.checkReadBandwidthID.cancel()
316             if self.unthrottleWritesID is not None:
317                 self.unthrottleWritesID.cancel()
318             if self.checkWriteBandwidthID is not None:
319                 self.checkWriteBandwidthID.cancel()
320