Another attempt at throttling, still not working.
[quix0rs-apt-p2p.git] / apt_dht / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2004 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5 #
6
7 """Resource limiting policies.
8
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
10 """
11
12 # system imports
13 import sys, operator
14
15 # twisted imports
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet.interfaces import ITransport
18 from twisted.internet import reactor, error
19 from twisted.python import log
20 from zope.interface import implements, providedBy, directlyProvides
21
22 class ProtocolWrapper(Protocol):
23     """Wraps protocol instances and acts as their transport as well."""
24
25     disconnecting = 0
26
27     def __init__(self, factory, wrappedProtocol):
28         self.wrappedProtocol = wrappedProtocol
29         self.factory = factory
30
31     def makeConnection(self, transport):
32         directlyProvides(self, *providedBy(self) + providedBy(transport))
33         Protocol.makeConnection(self, transport)
34
35     # Transport relaying
36
37     def write(self, data):
38         self.transport.write(data)
39
40     def writeSequence(self, data):
41         self.transport.writeSequence(data)
42
43     def loseConnection(self):
44         self.disconnecting = 1
45         self.transport.loseConnection()
46
47     def getPeer(self):
48         return self.transport.getPeer()
49
50     def getHost(self):
51         return self.transport.getHost()
52
53     def registerProducer(self, producer, streaming):
54         self.transport.registerProducer(producer, streaming)
55
56     def unregisterProducer(self):
57         self.transport.unregisterProducer()
58
59     def stopConsuming(self):
60         self.transport.stopConsuming()
61
62     def __getattr__(self, name):
63         return getattr(self.transport, name)
64
65     # Protocol relaying
66
67     def connectionMade(self):
68         self.factory.registerProtocol(self)
69         self.wrappedProtocol.makeConnection(self)
70
71     def dataReceived(self, data):
72         self.wrappedProtocol.dataReceived(data)
73
74     def connectionLost(self, reason):
75         self.factory.unregisterProtocol(self)
76         self.wrappedProtocol.connectionLost(reason)
77
78
79 class WrappingFactory(ClientFactory):
80     """Wraps a factory and its protocols, and keeps track of them."""
81
82     protocol = ProtocolWrapper
83
84     def __init__(self, wrappedFactory):
85         self.wrappedFactory = wrappedFactory
86         self.protocols = {}
87
88     def doStart(self):
89         self.wrappedFactory.doStart()
90         ClientFactory.doStart(self)
91
92     def doStop(self):
93         self.wrappedFactory.doStop()
94         ClientFactory.doStop(self)
95
96     def startedConnecting(self, connector):
97         self.wrappedFactory.startedConnecting(connector)
98
99     def clientConnectionFailed(self, connector, reason):
100         self.wrappedFactory.clientConnectionFailed(connector, reason)
101
102     def clientConnectionLost(self, connector, reason):
103         self.wrappedFactory.clientConnectionLost(connector, reason)
104
105     def buildProtocol(self, addr):
106         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
107
108     def registerProtocol(self, p):
109         """Called by protocol to register itself."""
110         self.protocols[p] = 1
111
112     def unregisterProtocol(self, p):
113         """Called by protocols when they go away."""
114         del self.protocols[p]
115
116
117 class ThrottlingProtocol(ProtocolWrapper):
118     """Protocol for ThrottlingFactory."""
119
120     # wrap API for tracking bandwidth
121
122     def write(self, data):
123         self.factory.registerWritten(len(data))
124         ProtocolWrapper.write(self, data)
125
126     def writeSequence(self, seq):
127         self.factory.registerWritten(reduce(operator.add, map(len, seq)))
128         ProtocolWrapper.writeSequence(self, seq)
129
130     def dataReceived(self, data):
131         self.factory.registerRead(len(data))
132         ProtocolWrapper.dataReceived(self, data)
133
134     def registerProducer(self, producer, streaming):
135         self.producer = producer
136         ProtocolWrapper.registerProducer(self, producer, streaming)
137
138     def unregisterProducer(self):
139         del self.producer
140         ProtocolWrapper.unregisterProducer(self)
141
142
143     def throttleReads(self):
144         self.transport.pauseProducing()
145
146     def unthrottleReads(self):
147         self.transport.resumeProducing()
148
149     def throttleWrites(self):
150         if hasattr(self, "producer"):
151             self.producer.pauseProducing()
152
153     def unthrottleWrites(self):
154         if hasattr(self, "producer"):
155             self.producer.resumeProducing()
156
157
158 class ThrottlingFactory(WrappingFactory):
159     """Throttles bandwidth and number of connections.
160
161     Write bandwidth will only be throttled if there is a producer
162     registered.
163     """
164
165     protocol = ThrottlingProtocol
166
167     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint, readLimit=None, writeLimit=None):
168         WrappingFactory.__init__(self, wrappedFactory)
169         self.connectionCount = 0
170         self.maxConnectionCount = maxConnectionCount
171         self.readLimit = readLimit # max bytes we should read per second
172         self.writeLimit = writeLimit # max bytes we should write per second
173         self.readThisSecond = 0
174         self.writtenThisSecond = 0
175         self.unthrottleReadsID = None
176         self.checkReadBandwidthID = None
177         self.unthrottleWritesID = None
178         self.checkWriteBandwidthID = None
179
180     def registerWritten(self, length):
181         """Called by protocol to tell us more bytes were written."""
182         self.writtenThisSecond += length
183
184     def registerRead(self, length):
185         """Called by protocol to tell us more bytes were read."""
186         self.readThisSecond += length
187
188     def checkReadBandwidth(self):
189         """Checks if we've passed bandwidth limits."""
190         if self.readThisSecond > self.readLimit:
191             self.throttleReads()
192             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
193             self.unthrottleReadsID = reactor.callLater(throttleTime,
194                                                        self.unthrottleReads)
195         self.readThisSecond = 0
196         self.checkReadBandwidthID = reactor.callLater(1, self.checkReadBandwidth)
197
198     def checkWriteBandwidth(self):
199         if self.writtenThisSecond > self.writeLimit:
200             self.throttleWrites()
201             throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
202             self.unthrottleWritesID = reactor.callLater(throttleTime,
203                                                         self.unthrottleWrites)
204         # reset for next round
205         self.writtenThisSecond = 0
206         self.checkWriteBandwidthID = reactor.callLater(1, self.checkWriteBandwidth)
207
208     def throttleReads(self):
209         """Throttle reads on all protocols."""
210         log.msg("Throttling reads on %s" % self)
211         for p in self.protocols.keys():
212             p.throttleReads()
213
214     def unthrottleReads(self):
215         """Stop throttling reads on all protocols."""
216         self.unthrottleReadsID = None
217         log.msg("Stopped throttling reads on %s" % self)
218         for p in self.protocols.keys():
219             p.unthrottleReads()
220
221     def throttleWrites(self):
222         """Throttle writes on all protocols."""
223         log.msg("Throttling writes on %s" % self)
224         for p in self.protocols.keys():
225             p.throttleWrites()
226
227     def unthrottleWrites(self):
228         """Stop throttling writes on all protocols."""
229         self.unthrottleWritesID = None
230         log.msg("Stopped throttling writes on %s" % self)
231         for p in self.protocols.keys():
232             p.unthrottleWrites()
233
234     def buildProtocol(self, addr):
235         if self.connectionCount == 0:
236             if self.readLimit is not None:
237                 self.checkReadBandwidth()
238             if self.writeLimit is not None:
239                 self.checkWriteBandwidth()
240
241         if self.connectionCount < self.maxConnectionCount:
242             self.connectionCount += 1
243             return WrappingFactory.buildProtocol(self, addr)
244         else:
245             log.msg("Max connection count reached!")
246             return None
247
248     def unregisterProtocol(self, p):
249         WrappingFactory.unregisterProtocol(self, p)
250         self.connectionCount -= 1
251         if self.connectionCount == 0:
252             if self.unthrottleReadsID is not None:
253                 self.unthrottleReadsID.cancel()
254             if self.checkReadBandwidthID is not None:
255                 self.checkReadBandwidthID.cancel()
256             if self.unthrottleWritesID is not None:
257                 self.unthrottleWritesID.cancel()
258             if self.checkWriteBandwidthID is not None:
259                 self.checkWriteBandwidthID.cancel()
260