Minor update to the multiple peer downloading (still not working).
[quix0rs-apt-p2p.git] / apt_p2p / policies.py
1 # -*- test-case-name: twisted.test.test_policies -*-
2 # Copyright (c) 2001-2007 Twisted Matrix Laboratories.
3 # See LICENSE for details.
4
5
6 """
7 Resource limiting policies.
8
9 @seealso: See also L{twisted.protocols.htb} for rate limiting.
10 """
11
12 # system imports
13 import sys, operator
14
15 # twisted imports
16 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
17 from twisted.internet import reactor, error
18 from twisted.python import log
19 from zope.interface import providedBy, directlyProvides
20
21
22 class ProtocolWrapper(Protocol):
23     """Wraps protocol instances and acts as their transport as well."""
24
25     disconnecting = 0
26
27     def __init__(self, factory, wrappedProtocol):
28         self.wrappedProtocol = wrappedProtocol
29         self.factory = factory
30
31     def makeConnection(self, transport):
32         directlyProvides(self, *providedBy(self) + providedBy(transport))
33         Protocol.makeConnection(self, transport)
34
35     # Transport relaying
36
37     def write(self, data):
38         self.transport.write(data)
39
40     def writeSequence(self, data):
41         self.transport.writeSequence(data)
42
43     def loseConnection(self):
44         self.disconnecting = 1
45         self.transport.loseConnection()
46
47     def getPeer(self):
48         return self.transport.getPeer()
49
50     def getHost(self):
51         return self.transport.getHost()
52
53     def registerProducer(self, producer, streaming):
54         self.transport.registerProducer(producer, streaming)
55
56     def unregisterProducer(self):
57         self.transport.unregisterProducer()
58
59     def stopConsuming(self):
60         self.transport.stopConsuming()
61
62     def __getattr__(self, name):
63         return getattr(self.transport, name)
64
65     # Protocol relaying
66
67     def connectionMade(self):
68         self.factory.registerProtocol(self)
69         self.wrappedProtocol.makeConnection(self)
70
71     def dataReceived(self, data):
72         self.wrappedProtocol.dataReceived(data)
73
74     def connectionLost(self, reason):
75         self.factory.unregisterProtocol(self)
76         self.wrappedProtocol.connectionLost(reason)
77
78
79 class WrappingFactory(ClientFactory):
80     """Wraps a factory and its protocols, and keeps track of them."""
81
82     protocol = ProtocolWrapper
83
84     def __init__(self, wrappedFactory):
85         self.wrappedFactory = wrappedFactory
86         self.protocols = {}
87
88     def doStart(self):
89         self.wrappedFactory.doStart()
90         ClientFactory.doStart(self)
91
92     def doStop(self):
93         self.wrappedFactory.doStop()
94         ClientFactory.doStop(self)
95
96     def startedConnecting(self, connector):
97         self.wrappedFactory.startedConnecting(connector)
98
99     def clientConnectionFailed(self, connector, reason):
100         self.wrappedFactory.clientConnectionFailed(connector, reason)
101
102     def clientConnectionLost(self, connector, reason):
103         self.wrappedFactory.clientConnectionLost(connector, reason)
104
105     def buildProtocol(self, addr):
106         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
107
108     def registerProtocol(self, p):
109         """Called by protocol to register itself."""
110         self.protocols[p] = 1
111
112     def unregisterProtocol(self, p):
113         """Called by protocols when they go away."""
114         del self.protocols[p]
115
116
117 class ThrottlingProtocol(ProtocolWrapper):
118     """Protocol for ThrottlingFactory."""
119
120     # wrap API for tracking bandwidth
121
122     def __init__(self, factory, wrappedProtocol):
123         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
124         self._tempDataBuffer = []
125         self._tempDataLength = 0
126         self.throttled = False
127
128     def write(self, data):
129         # Check if we can write
130         if not self.throttled:
131             paused = self.factory.registerWritten(len(data))
132             if not paused:
133                 ProtocolWrapper.write(self, data)
134                 
135                 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
136                     # Interrupt the flow so that others can can have a chance
137                     # We can only do this if it's not already paused otherwise we
138                     # risk unpausing something that the Server paused
139                     self.producer.pauseProducing()
140                     reactor.callLater(0, self.producer.resumeProducing)
141
142         if self.throttled or paused:
143             # Can't write, buffer the data
144             self._tempDataBuffer.append(data)
145             self._tempDataLength += len(data)
146             self._throttleWrites()
147
148     def writeSequence(self, seq):
149         if not self.throttled:
150             # Write each sequence separately
151             while seq and not self.factory.registerWritten(len(seq[0])):
152                 ProtocolWrapper.write(self, seq.pop(0))
153
154         # If there's some left, we must have been paused
155         if seq:
156             self._tempDataBuffer.extend(seq)
157             self._tempDataLength += reduce(operator.add, map(len, seq))
158             self._throttleWrites()
159
160     def dataReceived(self, data):
161         self.factory.registerRead(len(data))
162         ProtocolWrapper.dataReceived(self, data)
163
164     def registerProducer(self, producer, streaming):
165         assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
166         self.producer = producer
167         ProtocolWrapper.registerProducer(self, producer, streaming)
168
169     def unregisterProducer(self):
170         del self.producer
171         ProtocolWrapper.unregisterProducer(self)
172
173
174     def throttleReads(self):
175         self.transport.pauseProducing()
176
177     def unthrottleReads(self):
178         self.transport.resumeProducing()
179
180     def _throttleWrites(self):
181         # If we haven't yet, queue for unthrottling
182         if not self.throttled:
183             self.throttled = True
184             self.factory.throttledWrites(self)
185
186         if hasattr(self, "producer") and self.producer:
187             self.producer.pauseProducing()
188
189     def unthrottleWrites(self):
190         # Write some data
191         if self._tempDataBuffer:
192             assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
193             self._tempDataLength -= len(self._tempDataBuffer[0])
194             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
195             assert self._tempDataLength >= 0
196
197         # If we wrote it all, start producing more
198         if not self._tempDataBuffer:
199             assert self._tempDataLength == 0
200             self.throttled = False
201             if hasattr(self, "producer") and self.producer:
202                 # This might unpause something the Server has also paused, but
203                 # it will get paused again on first write anyway
204                 reactor.callLater(0, self.producer.resumeProducing)
205         
206         return self._tempDataLength
207
208
209 class ThrottlingFactory(WrappingFactory):
210     """
211     Throttles bandwidth and number of connections.
212
213     Write bandwidth will only be throttled if there is a producer
214     registered.
215     """
216
217     protocol = ThrottlingProtocol
218     CHUNK_SIZE = 4*1024
219
220     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
221                  readLimit=None, writeLimit=None):
222         WrappingFactory.__init__(self, wrappedFactory)
223         self.connectionCount = 0
224         self.maxConnectionCount = maxConnectionCount
225         self.readLimit = readLimit # max bytes we should read per second
226         self.writeLimit = writeLimit # max bytes we should write per second
227         self.readThisSecond = 0
228         self.writeAvailable = writeLimit
229         self._writeQueue = []
230         self.unthrottleReadsID = None
231         self.checkReadBandwidthID = None
232         self.unthrottleWritesID = None
233         self.checkWriteBandwidthID = None
234
235
236     def callLater(self, period, func):
237         """
238         Wrapper around L{reactor.callLater} for test purpose.
239         """
240         return reactor.callLater(period, func)
241
242
243     def registerWritten(self, length):
244         """
245         Called by protocol to tell us more bytes were written.
246         Returns True if the bytes could not be written and the protocol should pause itself.
247         """
248         # Check if there are bytes available to write
249         if self.writeLimit is None:
250             return None
251         elif self.writeAvailable > 0:
252             self.writeAvailable -= length
253             return False
254         
255         return True
256
257     
258     def throttledWrites(self, p):
259         """
260         Called by the protocol to queue it for later writing.
261         """
262         assert p not in self._writeQueue
263         self._writeQueue.append(p)
264
265
266     def registerRead(self, length):
267         """
268         Called by protocol to tell us more bytes were read.
269         """
270         self.readThisSecond += length
271
272
273     def checkReadBandwidth(self):
274         """
275         Checks if we've passed bandwidth limits.
276         """
277         if self.readThisSecond > self.readLimit:
278             self.throttleReads()
279             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
280             self.unthrottleReadsID = self.callLater(throttleTime,
281                                                     self.unthrottleReads)
282         self.readThisSecond = 0
283         self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
284
285
286     def checkWriteBandwidth(self):
287         """
288         Add some new available bandwidth, and check for protocols to unthrottle.
289         """
290         # Increase the available write bytes, but not higher than the limit
291         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
292         
293         # Write from the queue until it's empty or we're throttled again
294         while self.writeAvailable > 0 and self._writeQueue:
295             # Get the first queued protocol
296             p = self._writeQueue.pop(0)
297             _tempWriteAvailable = self.writeAvailable
298             bytesLeft = 1
299             
300             # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
301             while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
302                 # Unthrottle a single write (from the protocol's buffer)
303                 bytesLeft = p.unthrottleWrites()
304                 
305             # If the protocol is not done, requeue it
306             if bytesLeft > 0:
307                 self._writeQueue.append(p)
308
309         self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
310
311
312     def throttleReads(self):
313         """
314         Throttle reads on all protocols.
315         """
316         log.msg("Throttling reads on %s" % self)
317         for p in self.protocols.keys():
318             p.throttleReads()
319
320
321     def unthrottleReads(self):
322         """
323         Stop throttling reads on all protocols.
324         """
325         self.unthrottleReadsID = None
326         log.msg("Stopped throttling reads on %s" % self)
327         for p in self.protocols.keys():
328             p.unthrottleReads()
329
330
331     def buildProtocol(self, addr):
332         if self.connectionCount == 0:
333             if self.readLimit is not None:
334                 self.checkReadBandwidth()
335             if self.writeLimit is not None:
336                 self.checkWriteBandwidth()
337
338         if self.connectionCount < self.maxConnectionCount:
339             self.connectionCount += 1
340             return WrappingFactory.buildProtocol(self, addr)
341         else:
342             log.msg("Max connection count reached!")
343             return None
344
345
346     def unregisterProtocol(self, p):
347         WrappingFactory.unregisterProtocol(self, p)
348         self.connectionCount -= 1
349         if self.connectionCount == 0:
350             if self.unthrottleReadsID is not None:
351                 self.unthrottleReadsID.cancel()
352             if self.checkReadBandwidthID is not None:
353                 self.checkReadBandwidthID.cancel()
354             if self.unthrottleWritesID is not None:
355                 self.unthrottleWritesID.cancel()
356             if self.checkWriteBandwidthID is not None:
357                 self.checkWriteBandwidthID.cancel()
358
359
360
361 class SpewingProtocol(ProtocolWrapper):
362     def dataReceived(self, data):
363         log.msg("Received: %r" % data)
364         ProtocolWrapper.dataReceived(self,data)
365
366     def write(self, data):
367         log.msg("Sending: %r" % data)
368         ProtocolWrapper.write(self,data)
369
370
371
372 class SpewingFactory(WrappingFactory):
373     protocol = SpewingProtocol
374
375
376
377 class LimitConnectionsByPeer(WrappingFactory):
378     """Stability: Unstable"""
379
380     maxConnectionsPerPeer = 5
381
382     def startFactory(self):
383         self.peerConnections = {}
384
385     def buildProtocol(self, addr):
386         peerHost = addr[0]
387         connectionCount = self.peerConnections.get(peerHost, 0)
388         if connectionCount >= self.maxConnectionsPerPeer:
389             return None
390         self.peerConnections[peerHost] = connectionCount + 1
391         return WrappingFactory.buildProtocol(self, addr)
392
393     def unregisterProtocol(self, p):
394         peerHost = p.getPeer()[1]
395         self.peerConnections[peerHost] -= 1
396         if self.peerConnections[peerHost] == 0:
397             del self.peerConnections[peerHost]
398
399
400 class LimitTotalConnectionsFactory(ServerFactory):
401     """Factory that limits the number of simultaneous connections.
402
403     API Stability: Unstable
404
405     @type connectionCount: C{int}
406     @ivar connectionCount: number of current connections.
407     @type connectionLimit: C{int} or C{None}
408     @cvar connectionLimit: maximum number of connections.
409     @type overflowProtocol: L{Protocol} or C{None}
410     @cvar overflowProtocol: Protocol to use for new connections when
411         connectionLimit is exceeded.  If C{None} (the default value), excess
412         connections will be closed immediately.
413     """
414     connectionCount = 0
415     connectionLimit = None
416     overflowProtocol = None
417
418     def buildProtocol(self, addr):
419         if (self.connectionLimit is None or
420             self.connectionCount < self.connectionLimit):
421                 # Build the normal protocol
422                 wrappedProtocol = self.protocol()
423         elif self.overflowProtocol is None:
424             # Just drop the connection
425             return None
426         else:
427             # Too many connections, so build the overflow protocol
428             wrappedProtocol = self.overflowProtocol()
429
430         wrappedProtocol.factory = self
431         protocol = ProtocolWrapper(self, wrappedProtocol)
432         self.connectionCount += 1
433         return protocol
434
435     def registerProtocol(self, p):
436         pass
437
438     def unregisterProtocol(self, p):
439         self.connectionCount -= 1
440
441
442
443 class TimeoutProtocol(ProtocolWrapper):
444     """
445     Protocol that automatically disconnects when the connection is idle.
446
447     Stability: Unstable
448     """
449
450     def __init__(self, factory, wrappedProtocol, timeoutPeriod):
451         """
452         Constructor.
453
454         @param factory: An L{IFactory}.
455         @param wrappedProtocol: A L{Protocol} to wrapp.
456         @param timeoutPeriod: Number of seconds to wait for activity before
457             timing out.
458         """
459         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
460         self.timeoutCall = None
461         self.setTimeout(timeoutPeriod)
462
463
464     def setTimeout(self, timeoutPeriod=None):
465         """
466         Set a timeout.
467
468         This will cancel any existing timeouts.
469
470         @param timeoutPeriod: If not C{None}, change the timeout period.
471             Otherwise, use the existing value.
472         """
473         self.cancelTimeout()
474         if timeoutPeriod is not None:
475             self.timeoutPeriod = timeoutPeriod
476         self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
477
478
479     def cancelTimeout(self):
480         """
481         Cancel the timeout.
482
483         If the timeout was already cancelled, this does nothing.
484         """
485         if self.timeoutCall:
486             try:
487                 self.timeoutCall.cancel()
488             except error.AlreadyCalled:
489                 pass
490             self.timeoutCall = None
491
492
493     def resetTimeout(self):
494         """
495         Reset the timeout, usually because some activity just happened.
496         """
497         if self.timeoutCall:
498             self.timeoutCall.reset(self.timeoutPeriod)
499
500
501     def write(self, data):
502         self.resetTimeout()
503         ProtocolWrapper.write(self, data)
504
505
506     def writeSequence(self, seq):
507         self.resetTimeout()
508         ProtocolWrapper.writeSequence(self, seq)
509
510
511     def dataReceived(self, data):
512         self.resetTimeout()
513         ProtocolWrapper.dataReceived(self, data)
514
515
516     def connectionLost(self, reason):
517         self.cancelTimeout()
518         ProtocolWrapper.connectionLost(self, reason)
519
520
521     def timeoutFunc(self):
522         """
523         This method is called when the timeout is triggered.
524
525         By default it calls L{loseConnection}.  Override this if you want
526         something else to happen.
527         """
528         self.loseConnection()
529
530
531
532 class TimeoutFactory(WrappingFactory):
533     """
534     Factory for TimeoutWrapper.
535
536     Stability: Unstable
537     """
538     protocol = TimeoutProtocol
539
540
541     def __init__(self, wrappedFactory, timeoutPeriod=30*60):
542         self.timeoutPeriod = timeoutPeriod
543         WrappingFactory.__init__(self, wrappedFactory)
544
545
546     def buildProtocol(self, addr):
547         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
548                              timeoutPeriod=self.timeoutPeriod)
549
550
551     def callLater(self, period, func):
552         """
553         Wrapper around L{reactor.callLater} for test purpose.
554         """
555         return reactor.callLater(period, func)
556
557
558
559 class TrafficLoggingProtocol(ProtocolWrapper):
560
561     def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
562                  number=0):
563         """
564         @param factory: factory which created this protocol.
565         @type factory: C{protocol.Factory}.
566         @param wrappedProtocol: the underlying protocol.
567         @type wrappedProtocol: C{protocol.Protocol}.
568         @param logfile: file opened for writing used to write log messages.
569         @type logfile: C{file}
570         @param lengthLimit: maximum size of the datareceived logged.
571         @type lengthLimit: C{int}
572         @param number: identifier of the connection.
573         @type number: C{int}.
574         """
575         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
576         self.logfile = logfile
577         self.lengthLimit = lengthLimit
578         self._number = number
579
580
581     def _log(self, line):
582         self.logfile.write(line + '\n')
583         self.logfile.flush()
584
585
586     def _mungeData(self, data):
587         if self.lengthLimit and len(data) > self.lengthLimit:
588             data = data[:self.lengthLimit - 12] + '<... elided>'
589         return data
590
591
592     # IProtocol
593     def connectionMade(self):
594         self._log('*')
595         return ProtocolWrapper.connectionMade(self)
596
597
598     def dataReceived(self, data):
599         self._log('C %d: %r' % (self._number, self._mungeData(data)))
600         return ProtocolWrapper.dataReceived(self, data)
601
602
603     def connectionLost(self, reason):
604         self._log('C %d: %r' % (self._number, reason))
605         return ProtocolWrapper.connectionLost(self, reason)
606
607
608     # ITransport
609     def write(self, data):
610         self._log('S %d: %r' % (self._number, self._mungeData(data)))
611         return ProtocolWrapper.write(self, data)
612
613
614     def writeSequence(self, iovec):
615         self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
616         return ProtocolWrapper.writeSequence(self, iovec)
617
618
619     def loseConnection(self):
620         self._log('S %d: *' % (self._number,))
621         return ProtocolWrapper.loseConnection(self)
622
623
624
625 class TrafficLoggingFactory(WrappingFactory):
626     protocol = TrafficLoggingProtocol
627
628     _counter = 0
629
630     def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
631         self.logfilePrefix = logfilePrefix
632         self.lengthLimit = lengthLimit
633         WrappingFactory.__init__(self, wrappedFactory)
634
635
636     def open(self, name):
637         return file(name, 'w')
638
639
640     def buildProtocol(self, addr):
641         self._counter += 1
642         logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
643         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
644                              logfile, self.lengthLimit, self._counter)
645
646
647     def resetCounter(self):
648         """
649         Reset the value of the counter used to identify connections.
650         """
651         self._counter = 0
652
653
654
655 class TimeoutMixin:
656     """Mixin for protocols which wish to timeout connections
657
658     @cvar timeOut: The number of seconds after which to timeout the connection.
659     """
660     timeOut = None
661
662     __timeoutCall = None
663
664     def callLater(self, period, func):
665         return reactor.callLater(period, func)
666
667
668     def resetTimeout(self):
669         """Reset the timeout count down"""
670         if self.__timeoutCall is not None and self.timeOut is not None:
671             self.__timeoutCall.reset(self.timeOut)
672
673     def setTimeout(self, period):
674         """Change the timeout period
675
676         @type period: C{int} or C{NoneType}
677         @param period: The period, in seconds, to change the timeout to, or
678         C{None} to disable the timeout.
679         """
680         prev = self.timeOut
681         self.timeOut = period
682
683         if self.__timeoutCall is not None:
684             if period is None:
685                 self.__timeoutCall.cancel()
686                 self.__timeoutCall = None
687             else:
688                 self.__timeoutCall.reset(period)
689         elif period is not None:
690             self.__timeoutCall = self.callLater(period, self.__timedOut)
691
692         return prev
693
694     def __timedOut(self):
695         self.__timeoutCall = None
696         self.timeoutConnection()
697
698     def timeoutConnection(self):
699         """Called when the connection times out.
700         Override to define behavior other than dropping the connection.
701         """
702         self.transport.loseConnection()