Fix some documentation errors.
[quix0rs-apt-p2p.git] / apt_p2p / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5
6 """
7 Resource limiting policies.
8
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
10 """
11
12 # system imports
13 import sys, operator
14
15 # twisted imports
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet import reactor, error
18 from twisted.python import log
19 from zope.interface import providedBy, directlyProvides
20
21
22 class ProtocolWrapper(Protocol):
23     """Wraps protocol instances and acts as their transport as well."""
24
25     disconnecting = 0
26
27     def __init__(self, factory, wrappedProtocol):
28         self.wrappedProtocol = wrappedProtocol
29         self.factory = factory
30
31     def makeConnection(self, transport):
32         directlyProvides(self, *providedBy(self) + providedBy(transport))
33         Protocol.makeConnection(self, transport)
34
35     # Transport relaying
36
37     def write(self, data):
38         self.transport.write(data)
39
40     def writeSequence(self, data):
41         self.transport.writeSequence(data)
42
43     def loseConnection(self):
44         self.disconnecting = 1
45         self.transport.loseConnection()
46
47     def getPeer(self):
48         return self.transport.getPeer()
49
50     def getHost(self):
51         return self.transport.getHost()
52
53     def registerProducer(self, producer, streaming):
54         self.transport.registerProducer(producer, streaming)
55
56     def unregisterProducer(self):
57         self.transport.unregisterProducer()
58
59     def stopConsuming(self):
60         self.transport.stopConsuming()
61
62     def __getattr__(self, name):
63         return getattr(self.transport, name)
64
65     # Protocol relaying
66
67     def connectionMade(self):
68         self.factory.registerProtocol(self)
69         self.wrappedProtocol.makeConnection(self)
70
71     def dataReceived(self, data):
72         self.wrappedProtocol.dataReceived(data)
73
74     def connectionLost(self, reason):
75         self.factory.unregisterProtocol(self)
76         self.wrappedProtocol.connectionLost(reason)
77
78
79 class WrappingFactory(ClientFactory):
80     """Wraps a factory and its protocols, and keeps track of them."""
81
82     protocol = ProtocolWrapper
83
84     def __init__(self, wrappedFactory):
85         self.wrappedFactory = wrappedFactory
86         self.protocols = {}
87
88     def doStart(self):
89         self.wrappedFactory.doStart()
90         ClientFactory.doStart(self)
91
92     def doStop(self):
93         self.wrappedFactory.doStop()
94         ClientFactory.doStop(self)
95
96     def startedConnecting(self, connector):
97         self.wrappedFactory.startedConnecting(connector)
98
99     def clientConnectionFailed(self, connector, reason):
100         self.wrappedFactory.clientConnectionFailed(connector, reason)
101
102     def clientConnectionLost(self, connector, reason):
103         self.wrappedFactory.clientConnectionLost(connector, reason)
104
105     def buildProtocol(self, addr):
106         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
107
108     def registerProtocol(self, p):
109         """Called by protocol to register itself."""
110         self.protocols[p] = 1
111
112     def unregisterProtocol(self, p):
113         """Called by protocols when they go away."""
114         del self.protocols[p]
115
116
117 class ThrottlingProtocol(ProtocolWrapper):
118     """Protocol for ThrottlingFactory."""
119
120     # wrap API for tracking bandwidth
121
122     def __init__(self, factory, wrappedProtocol):
123         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124         self._tempDataBuffer = []
125         self._tempDataLength = 0
126         self.throttled = False
127
128     def write(self, data):
129         # Check if we can write
130         if not self.throttled:
131             paused = self.factory.registerWritten(len(data))
132             if not paused:
133                 ProtocolWrapper.write(self, data)
134                 
135                 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
136                     # Interrupt the flow so that others can can have a chance
137                     # We can only do this if it's not already paused otherwise we
138                     # risk unpausing something that the Server paused
139                     self.producer.pauseProducing()
140                     reactor.callLater(0, self.producer.resumeProducing)
141
142         if self.throttled or paused:
143             # Can't write, buffer the data
144             self._tempDataBuffer.append(data)
145             self._tempDataLength += len(data)
146             self._throttleWrites()
147
148     def writeSequence(self, seq):
149         i = 0
150         if not self.throttled:
151             # Write each sequence separately
152             while i < len(seq) and not self.factory.registerWritten(len(seq[i])):
153                 ProtocolWrapper.write(self, seq[i])
154                 i += 1
155
156         # If there's some left, we must have been paused
157         if i < len(seq):
158             self._tempDataBuffer.extend(seq[i:])
159             self._tempDataLength += reduce(operator.add, map(len, seq[i:]))
160             self._throttleWrites()
161
162     def dataReceived(self, data):
163         self.factory.registerRead(len(data))
164         ProtocolWrapper.dataReceived(self, data)
165
166     def registerProducer(self, producer, streaming):
167         assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
168         self.producer = producer
169         ProtocolWrapper.registerProducer(self, producer, streaming)
170
171     def unregisterProducer(self):
172         del self.producer
173         ProtocolWrapper.unregisterProducer(self)
174
175
176     def throttleReads(self):
177         self.transport.pauseProducing()
178
179     def unthrottleReads(self):
180         self.transport.resumeProducing()
181
182     def _throttleWrites(self):
183         # If we haven't yet, queue for unthrottling
184         if not self.throttled:
185             self.throttled = True
186             self.factory.throttledWrites(self)
187
188         if hasattr(self, "producer") and self.producer:
189             self.producer.pauseProducing()
190
191     def unthrottleWrites(self):
192         # Write some data
193         if self._tempDataBuffer:
194             assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
195             self._tempDataLength -= len(self._tempDataBuffer[0])
196             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
197             assert self._tempDataLength >= 0
198
199         # If we wrote it all, start producing more
200         if not self._tempDataBuffer:
201             assert self._tempDataLength == 0
202             self.throttled = False
203             if hasattr(self, "producer") and self.producer:
204                 # This might unpause something the Server has also paused, but
205                 # it will get paused again on first write anyway
206                 reactor.callLater(0, self.producer.resumeProducing)
207         
208         return self._tempDataLength
209
210
211 class ThrottlingFactory(WrappingFactory):
212     """
213     Throttles bandwidth and number of connections.
214
215     Write bandwidth will only be throttled if there is a producer
216     registered.
217     """
218
219     protocol = ThrottlingProtocol
220     CHUNK_SIZE = 4*1024
221
222     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
223                  readLimit=None, writeLimit=None):
224         WrappingFactory.__init__(self, wrappedFactory)
225         self.connectionCount = 0
226         self.maxConnectionCount = maxConnectionCount
227         self.readLimit = readLimit # max bytes we should read per second
228         self.writeLimit = writeLimit # max bytes we should write per second
229         self.readThisSecond = 0
230         self.writeAvailable = writeLimit
231         self._writeQueue = []
232         self.unthrottleReadsID = None
233         self.checkReadBandwidthID = None
234         self.unthrottleWritesID = None
235         self.checkWriteBandwidthID = None
236
237
238     def callLater(self, period, func):
239         """
240         Wrapper around L{reactor.callLater} for test purpose.
241         """
242         return reactor.callLater(period, func)
243
244
245     def registerWritten(self, length):
246         """
247         Called by protocol to tell us more bytes were written.
248         Returns True if the bytes could not be written and the protocol should pause itself.
249         """
250         # Check if there are bytes available to write
251         if self.writeLimit is None:
252             return None
253         elif self.writeAvailable > 0:
254             self.writeAvailable -= length
255             return False
256         
257         return True
258
259     
260     def throttledWrites(self, p):
261         """
262         Called by the protocol to queue it for later writing.
263         """
264         assert p not in self._writeQueue
265         self._writeQueue.append(p)
266
267
268     def registerRead(self, length):
269         """
270         Called by protocol to tell us more bytes were read.
271         """
272         self.readThisSecond += length
273
274
275     def checkReadBandwidth(self):
276         """
277         Checks if we've passed bandwidth limits.
278         """
279         if self.readThisSecond > self.readLimit:
280             self.throttleReads()
281             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
282             self.unthrottleReadsID = self.callLater(throttleTime,
283                                                     self.unthrottleReads)
284         self.readThisSecond = 0
285         self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
286
287
288     def checkWriteBandwidth(self):
289         """
290         Add some new available bandwidth, and check for protocols to unthrottle.
291         """
292         # Increase the available write bytes, but not higher than the limit
293         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
294         
295         # Write from the queue until it's empty or we're throttled again
296         while self.writeAvailable > 0 and self._writeQueue:
297             # Get the first queued protocol
298             p = self._writeQueue.pop(0)
299             _tempWriteAvailable = self.writeAvailable
300             bytesLeft = 1
301             
302             # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
303             while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
304                 # Unthrottle a single write (from the protocol's buffer)
305                 bytesLeft = p.unthrottleWrites()
306                 
307             # If the protocol is not done, requeue it
308             if bytesLeft > 0:
309                 self._writeQueue.append(p)
310
311         self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
312
313
314     def throttleReads(self):
315         """
316         Throttle reads on all protocols.
317         """
318         log.msg("Throttling reads on %s" % self)
319         for p in self.protocols.keys():
320             p.throttleReads()
321
322
323     def unthrottleReads(self):
324         """
325         Stop throttling reads on all protocols.
326         """
327         self.unthrottleReadsID = None
328         log.msg("Stopped throttling reads on %s" % self)
329         for p in self.protocols.keys():
330             p.unthrottleReads()
331
332
333     def buildProtocol(self, addr):
334         if self.connectionCount == 0:
335             if self.readLimit is not None:
336                 self.checkReadBandwidth()
337             if self.writeLimit is not None:
338                 self.checkWriteBandwidth()
339
340         if self.connectionCount < self.maxConnectionCount:
341             self.connectionCount += 1
342             return WrappingFactory.buildProtocol(self, addr)
343         else:
344             log.msg("Max connection count reached!")
345             return None
346
347
348     def unregisterProtocol(self, p):
349         WrappingFactory.unregisterProtocol(self, p)
350         self.connectionCount -= 1
351         if self.connectionCount == 0:
352             if self.unthrottleReadsID is not None:
353                 self.unthrottleReadsID.cancel()
354             if self.checkReadBandwidthID is not None:
355                 self.checkReadBandwidthID.cancel()
356             if self.unthrottleWritesID is not None:
357                 self.unthrottleWritesID.cancel()
358             if self.checkWriteBandwidthID is not None:
359                 self.checkWriteBandwidthID.cancel()
360
361
362
363 class SpewingProtocol(ProtocolWrapper):
364     def dataReceived(self, data):
365         log.msg("Received: %r" % data)
366         ProtocolWrapper.dataReceived(self,data)
367
368     def write(self, data):
369         log.msg("Sending: %r" % data)
370         ProtocolWrapper.write(self,data)
371
372
373
374 class SpewingFactory(WrappingFactory):
375     protocol = SpewingProtocol
376
377
378
379 class LimitConnectionsByPeer(WrappingFactory):
380     """Stability: Unstable"""
381
382     maxConnectionsPerPeer = 5
383
384     def startFactory(self):
385         self.peerConnections = {}
386
387     def buildProtocol(self, addr):
388         peerHost = addr[0]
389         connectionCount = self.peerConnections.get(peerHost, 0)
390         if connectionCount >= self.maxConnectionsPerPeer:
391             return None
392         self.peerConnections[peerHost] = connectionCount + 1
393         return WrappingFactory.buildProtocol(self, addr)
394
395     def unregisterProtocol(self, p):
396         peerHost = p.getPeer()[1]
397         self.peerConnections[peerHost] -= 1
398         if self.peerConnections[peerHost] == 0:
399             del self.peerConnections[peerHost]
400
401
402 class LimitTotalConnectionsFactory(ServerFactory):
403     """Factory that limits the number of simultaneous connections.
404
405     API Stability: Unstable
406
407     @type connectionCount: C{int}
408     @ivar connectionCount: number of current connections.
409     @type connectionLimit: C{int} or C{None}
410     @cvar connectionLimit: maximum number of connections.
411     @type overflowProtocol: L{Protocol} or C{None}
412     @cvar overflowProtocol: Protocol to use for new connections when
413         connectionLimit is exceeded.  If C{None} (the default value), excess
414         connections will be closed immediately.
415     """
416     connectionCount = 0
417     connectionLimit = None
418     overflowProtocol = None
419
420     def buildProtocol(self, addr):
421         if (self.connectionLimit is None or
422             self.connectionCount < self.connectionLimit):
423                 # Build the normal protocol
424                 wrappedProtocol = self.protocol()
425         elif self.overflowProtocol is None:
426             # Just drop the connection
427             return None
428         else:
429             # Too many connections, so build the overflow protocol
430             wrappedProtocol = self.overflowProtocol()
431
432         wrappedProtocol.factory = self
433         protocol = ProtocolWrapper(self, wrappedProtocol)
434         self.connectionCount += 1
435         return protocol
436
437     def registerProtocol(self, p):
438         pass
439
440     def unregisterProtocol(self, p):
441         self.connectionCount -= 1
442
443
444
445 class TimeoutProtocol(ProtocolWrapper):
446     """
447     Protocol that automatically disconnects when the connection is idle.
448
449     Stability: Unstable
450     """
451
452     def __init__(self, factory, wrappedProtocol, timeoutPeriod):
453         """
454         Constructor.
455
456         @param factory: An L{IFactory}.
457         @param wrappedProtocol: A L{Protocol} to wrapp.
458         @param timeoutPeriod: Number of seconds to wait for activity before
459             timing out.
460         """
461         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
462         self.timeoutCall = None
463         self.setTimeout(timeoutPeriod)
464
465
466     def setTimeout(self, timeoutPeriod=None):
467         """
468         Set a timeout.
469
470         This will cancel any existing timeouts.
471
472         @param timeoutPeriod: If not C{None}, change the timeout period.
473             Otherwise, use the existing value.
474         """
475         self.cancelTimeout()
476         if timeoutPeriod is not None:
477             self.timeoutPeriod = timeoutPeriod
478         self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
479
480
481     def cancelTimeout(self):
482         """
483         Cancel the timeout.
484
485         If the timeout was already cancelled, this does nothing.
486         """
487         if self.timeoutCall:
488             try:
489                 self.timeoutCall.cancel()
490             except error.AlreadyCalled:
491                 pass
492             self.timeoutCall = None
493
494
495     def resetTimeout(self):
496         """
497         Reset the timeout, usually because some activity just happened.
498         """
499         if self.timeoutCall:
500             self.timeoutCall.reset(self.timeoutPeriod)
501
502
503     def write(self, data):
504         self.resetTimeout()
505         ProtocolWrapper.write(self, data)
506
507
508     def writeSequence(self, seq):
509         self.resetTimeout()
510         ProtocolWrapper.writeSequence(self, seq)
511
512
513     def dataReceived(self, data):
514         self.resetTimeout()
515         ProtocolWrapper.dataReceived(self, data)
516
517
518     def connectionLost(self, reason):
519         self.cancelTimeout()
520         ProtocolWrapper.connectionLost(self, reason)
521
522
523     def timeoutFunc(self):
524         """
525         This method is called when the timeout is triggered.
526
527         By default it calls L{loseConnection}.  Override this if you want
528         something else to happen.
529         """
530         self.loseConnection()
531
532
533
534 class TimeoutFactory(WrappingFactory):
535     """
536     Factory for TimeoutWrapper.
537
538     Stability: Unstable
539     """
540     protocol = TimeoutProtocol
541
542
543     def __init__(self, wrappedFactory, timeoutPeriod=30*60):
544         self.timeoutPeriod = timeoutPeriod
545         WrappingFactory.__init__(self, wrappedFactory)
546
547
548     def buildProtocol(self, addr):
549         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
550                              timeoutPeriod=self.timeoutPeriod)
551
552
553     def callLater(self, period, func):
554         """
555         Wrapper around L{reactor.callLater} for test purpose.
556         """
557         return reactor.callLater(period, func)
558
559
560
561 class TrafficLoggingProtocol(ProtocolWrapper):
562
563     def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
564                  number=0):
565         """
566         @param factory: factory which created this protocol.
567         @type factory: C{protocol.Factory}.
568         @param wrappedProtocol: the underlying protocol.
569         @type wrappedProtocol: C{protocol.Protocol}.
570         @param logfile: file opened for writing used to write log messages.
571         @type logfile: C{file}
572         @param lengthLimit: maximum size of the datareceived logged.
573         @type lengthLimit: C{int}
574         @param number: identifier of the connection.
575         @type number: C{int}.
576         """
577         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
578         self.logfile = logfile
579         self.lengthLimit = lengthLimit
580         self._number = number
581
582
583     def _log(self, line):
584         self.logfile.write(line + '\n')
585         self.logfile.flush()
586
587
588     def _mungeData(self, data):
589         if self.lengthLimit and len(data) > self.lengthLimit:
590             data = data[:self.lengthLimit - 12] + '<... elided>'
591         return data
592
593
594     # IProtocol
595     def connectionMade(self):
596         self._log('*')
597         return ProtocolWrapper.connectionMade(self)
598
599
600     def dataReceived(self, data):
601         self._log('C %d: %r' % (self._number, self._mungeData(data)))
602         return ProtocolWrapper.dataReceived(self, data)
603
604
605     def connectionLost(self, reason):
606         self._log('C %d: %r' % (self._number, reason))
607         return ProtocolWrapper.connectionLost(self, reason)
608
609
610     # ITransport
611     def write(self, data):
612         self._log('S %d: %r' % (self._number, self._mungeData(data)))
613         return ProtocolWrapper.write(self, data)
614
615
616     def writeSequence(self, iovec):
617         self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
618         return ProtocolWrapper.writeSequence(self, iovec)
619
620
621     def loseConnection(self):
622         self._log('S %d: *' % (self._number,))
623         return ProtocolWrapper.loseConnection(self)
624
625
626
627 class TrafficLoggingFactory(WrappingFactory):
628     protocol = TrafficLoggingProtocol
629
630     _counter = 0
631
632     def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
633         self.logfilePrefix = logfilePrefix
634         self.lengthLimit = lengthLimit
635         WrappingFactory.__init__(self, wrappedFactory)
636
637
638     def open(self, name):
639         return file(name, 'w')
640
641
642     def buildProtocol(self, addr):
643         self._counter += 1
644         logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
645         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
646                              logfile, self.lengthLimit, self._counter)
647
648
649     def resetCounter(self):
650         """
651         Reset the value of the counter used to identify connections.
652         """
653         self._counter = 0
654
655
656
657 class TimeoutMixin:
658     """Mixin for protocols which wish to timeout connections
659
660     @cvar timeOut: The number of seconds after which to timeout the connection.
661     """
662     timeOut = None
663
664     __timeoutCall = None
665
666     def callLater(self, period, func):
667         return reactor.callLater(period, func)
668
669
670     def resetTimeout(self):
671         """Reset the timeout count down"""
672         if self.__timeoutCall is not None and self.timeOut is not None:
673             self.__timeoutCall.reset(self.timeOut)
674
675     def setTimeout(self, period):
676         """Change the timeout period
677
678         @type period: C{int} or C{NoneType}
679         @param period: The period, in seconds, to change the timeout to, or
680         C{None} to disable the timeout.
681         """
682         prev = self.timeOut
683         self.timeOut = period
684
685         if self.__timeoutCall is not None:
686             if period is None:
687                 self.__timeoutCall.cancel()
688                 self.__timeoutCall = None
689             else:
690                 self.__timeoutCall.reset(period)
691         elif period is not None:
692             self.__timeoutCall = self.callLater(period, self.__timedOut)
693
694         return prev
695
696     def __timedOut(self):
697         self.__timeoutCall = None
698         self.timeoutConnection()
699
700     def timeoutConnection(self):
701         """Called when the connection times out.
702         Override to define behavior other than dropping the connection.
703         """
704         self.transport.loseConnection()