Clean up the copyrights mentioned in the code.
[quix0rs-apt-p2p.git] / apt_p2p / policies.py
1
2 """
3 Resource limiting policies.
4
5 @seealso: See also L{twisted.protocols.htb} for rate limiting.
6 """
7
8 # system imports
9 import sys, operator
10
11 # twisted imports
12 from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
13 from twisted.internet import reactor, error
14 from twisted.python import log
15 from zope.interface import providedBy, directlyProvides
16
17
18 class ProtocolWrapper(Protocol):
19     """Wraps protocol instances and acts as their transport as well."""
20
21     disconnecting = 0
22
23     def __init__(self, factory, wrappedProtocol):
24         self.wrappedProtocol = wrappedProtocol
25         self.factory = factory
26
27     def makeConnection(self, transport):
28         directlyProvides(self, *providedBy(self) + providedBy(transport))
29         Protocol.makeConnection(self, transport)
30
31     # Transport relaying
32
33     def write(self, data):
34         self.transport.write(data)
35
36     def writeSequence(self, data):
37         self.transport.writeSequence(data)
38
39     def loseConnection(self):
40         self.disconnecting = 1
41         self.transport.loseConnection()
42
43     def getPeer(self):
44         return self.transport.getPeer()
45
46     def getHost(self):
47         return self.transport.getHost()
48
49     def registerProducer(self, producer, streaming):
50         self.transport.registerProducer(producer, streaming)
51
52     def unregisterProducer(self):
53         self.transport.unregisterProducer()
54
55     def stopConsuming(self):
56         self.transport.stopConsuming()
57
58     def __getattr__(self, name):
59         return getattr(self.transport, name)
60
61     # Protocol relaying
62
63     def connectionMade(self):
64         self.factory.registerProtocol(self)
65         self.wrappedProtocol.makeConnection(self)
66
67     def dataReceived(self, data):
68         self.wrappedProtocol.dataReceived(data)
69
70     def connectionLost(self, reason):
71         self.factory.unregisterProtocol(self)
72         self.wrappedProtocol.connectionLost(reason)
73
74
75 class WrappingFactory(ClientFactory):
76     """Wraps a factory and its protocols, and keeps track of them."""
77
78     protocol = ProtocolWrapper
79
80     def __init__(self, wrappedFactory):
81         self.wrappedFactory = wrappedFactory
82         self.protocols = {}
83
84     def doStart(self):
85         self.wrappedFactory.doStart()
86         ClientFactory.doStart(self)
87
88     def doStop(self):
89         self.wrappedFactory.doStop()
90         ClientFactory.doStop(self)
91
92     def startedConnecting(self, connector):
93         self.wrappedFactory.startedConnecting(connector)
94
95     def clientConnectionFailed(self, connector, reason):
96         self.wrappedFactory.clientConnectionFailed(connector, reason)
97
98     def clientConnectionLost(self, connector, reason):
99         self.wrappedFactory.clientConnectionLost(connector, reason)
100
101     def buildProtocol(self, addr):
102         return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
103
104     def registerProtocol(self, p):
105         """Called by protocol to register itself."""
106         self.protocols[p] = 1
107
108     def unregisterProtocol(self, p):
109         """Called by protocols when they go away."""
110         del self.protocols[p]
111
112
113 class ThrottlingProtocol(ProtocolWrapper):
114     """Protocol for ThrottlingFactory."""
115
116     # wrap API for tracking bandwidth
117
118     def __init__(self, factory, wrappedProtocol):
119         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
120         self._tempDataBuffer = []
121         self._tempDataLength = 0
122         self.throttled = False
123
124     def write(self, data):
125         # Check if we can write
126         if not self.throttled:
127             paused = self.factory.registerWritten(len(data))
128             if not paused:
129                 ProtocolWrapper.write(self, data)
130                 
131                 if paused is not None and hasattr(self, "producer") and self.producer and not self.producer.paused:
132                     # Interrupt the flow so that others can can have a chance
133                     # We can only do this if it's not already paused otherwise we
134                     # risk unpausing something that the Server paused
135                     self.producer.pauseProducing()
136                     reactor.callLater(0, self.producer.resumeProducing)
137
138         if self.throttled or paused:
139             # Can't write, buffer the data
140             self._tempDataBuffer.append(data)
141             self._tempDataLength += len(data)
142             self._throttleWrites()
143
144     def writeSequence(self, seq):
145         i = 0
146         if not self.throttled:
147             # Write each sequence separately
148             while i < len(seq) and not self.factory.registerWritten(len(seq[i])):
149                 ProtocolWrapper.write(self, seq[i])
150                 i += 1
151
152         # If there's some left, we must have been paused
153         if i < len(seq):
154             self._tempDataBuffer.extend(seq[i:])
155             self._tempDataLength += reduce(operator.add, map(len, seq[i:]))
156             self._throttleWrites()
157
158     def dataReceived(self, data):
159         self.factory.registerRead(len(data))
160         ProtocolWrapper.dataReceived(self, data)
161
162     def registerProducer(self, producer, streaming):
163         assert streaming, "You can only use the ThrottlingProtocol with streaming (push) producers."
164         self.producer = producer
165         ProtocolWrapper.registerProducer(self, producer, streaming)
166
167     def unregisterProducer(self):
168         del self.producer
169         ProtocolWrapper.unregisterProducer(self)
170
171
172     def throttleReads(self):
173         self.transport.pauseProducing()
174
175     def unthrottleReads(self):
176         self.transport.resumeProducing()
177
178     def _throttleWrites(self):
179         # If we haven't yet, queue for unthrottling
180         if not self.throttled:
181             self.throttled = True
182             self.factory.throttledWrites(self)
183
184         if hasattr(self, "producer") and self.producer:
185             self.producer.pauseProducing()
186
187     def unthrottleWrites(self):
188         # Write some data
189         if self._tempDataBuffer:
190             assert not self.factory.registerWritten(len(self._tempDataBuffer[0]))
191             self._tempDataLength -= len(self._tempDataBuffer[0])
192             ProtocolWrapper.write(self, self._tempDataBuffer.pop(0))
193             assert self._tempDataLength >= 0
194
195         # If we wrote it all, start producing more
196         if not self._tempDataBuffer:
197             assert self._tempDataLength == 0
198             self.throttled = False
199             if hasattr(self, "producer") and self.producer:
200                 # This might unpause something the Server has also paused, but
201                 # it will get paused again on first write anyway
202                 reactor.callLater(0, self.producer.resumeProducing)
203         
204         return self._tempDataLength
205
206
207 class ThrottlingFactory(WrappingFactory):
208     """
209     Throttles bandwidth and number of connections.
210
211     Write bandwidth will only be throttled if there is a producer
212     registered.
213     """
214
215     protocol = ThrottlingProtocol
216     CHUNK_SIZE = 4*1024
217
218     def __init__(self, wrappedFactory, maxConnectionCount=sys.maxint,
219                  readLimit=None, writeLimit=None):
220         WrappingFactory.__init__(self, wrappedFactory)
221         self.connectionCount = 0
222         self.maxConnectionCount = maxConnectionCount
223         self.readLimit = readLimit # max bytes we should read per second
224         self.writeLimit = writeLimit # max bytes we should write per second
225         self.readThisSecond = 0
226         self.writeAvailable = writeLimit
227         self._writeQueue = []
228         self.unthrottleReadsID = None
229         self.checkReadBandwidthID = None
230         self.unthrottleWritesID = None
231         self.checkWriteBandwidthID = None
232
233
234     def callLater(self, period, func):
235         """
236         Wrapper around L{reactor.callLater} for test purpose.
237         """
238         return reactor.callLater(period, func)
239
240
241     def registerWritten(self, length):
242         """
243         Called by protocol to tell us more bytes were written.
244         Returns True if the bytes could not be written and the protocol should pause itself.
245         """
246         # Check if there are bytes available to write
247         if self.writeLimit is None:
248             return None
249         elif self.writeAvailable > 0:
250             self.writeAvailable -= length
251             return False
252         
253         return True
254
255     
256     def throttledWrites(self, p):
257         """
258         Called by the protocol to queue it for later writing.
259         """
260         assert p not in self._writeQueue
261         self._writeQueue.append(p)
262
263
264     def registerRead(self, length):
265         """
266         Called by protocol to tell us more bytes were read.
267         """
268         self.readThisSecond += length
269
270
271     def checkReadBandwidth(self):
272         """
273         Checks if we've passed bandwidth limits.
274         """
275         if self.readThisSecond > self.readLimit:
276             self.throttleReads()
277             throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
278             self.unthrottleReadsID = self.callLater(throttleTime,
279                                                     self.unthrottleReads)
280         self.readThisSecond = 0
281         self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
282
283
284     def checkWriteBandwidth(self):
285         """
286         Add some new available bandwidth, and check for protocols to unthrottle.
287         """
288         # Increase the available write bytes, but not higher than the limit
289         self.writeAvailable = min(self.writeLimit, self.writeAvailable + self.writeLimit)
290         
291         # Write from the queue until it's empty or we're throttled again
292         while self.writeAvailable > 0 and self._writeQueue:
293             # Get the first queued protocol
294             p = self._writeQueue.pop(0)
295             _tempWriteAvailable = self.writeAvailable
296             bytesLeft = 1
297             
298             # Unthrottle writes until CHUNK_SIZE is reached or the protocol is unbuffered
299             while self.writeAvailable > 0 and _tempWriteAvailable - self.writeAvailable < self.CHUNK_SIZE and bytesLeft > 0:
300                 # Unthrottle a single write (from the protocol's buffer)
301                 bytesLeft = p.unthrottleWrites()
302                 
303             # If the protocol is not done, requeue it
304             if bytesLeft > 0:
305                 self._writeQueue.append(p)
306
307         self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
308
309
310     def throttleReads(self):
311         """
312         Throttle reads on all protocols.
313         """
314         log.msg("Throttling reads on %s" % self)
315         for p in self.protocols.keys():
316             p.throttleReads()
317
318
319     def unthrottleReads(self):
320         """
321         Stop throttling reads on all protocols.
322         """
323         self.unthrottleReadsID = None
324         log.msg("Stopped throttling reads on %s" % self)
325         for p in self.protocols.keys():
326             p.unthrottleReads()
327
328
329     def buildProtocol(self, addr):
330         if self.connectionCount == 0:
331             if self.readLimit is not None:
332                 self.checkReadBandwidth()
333             if self.writeLimit is not None:
334                 self.checkWriteBandwidth()
335
336         if self.connectionCount < self.maxConnectionCount:
337             self.connectionCount += 1
338             return WrappingFactory.buildProtocol(self, addr)
339         else:
340             log.msg("Max connection count reached!")
341             return None
342
343
344     def unregisterProtocol(self, p):
345         WrappingFactory.unregisterProtocol(self, p)
346         self.connectionCount -= 1
347         if self.connectionCount == 0:
348             if self.unthrottleReadsID is not None:
349                 self.unthrottleReadsID.cancel()
350             if self.checkReadBandwidthID is not None:
351                 self.checkReadBandwidthID.cancel()
352             if self.unthrottleWritesID is not None:
353                 self.unthrottleWritesID.cancel()
354             if self.checkWriteBandwidthID is not None:
355                 self.checkWriteBandwidthID.cancel()
356
357
358
359 class SpewingProtocol(ProtocolWrapper):
360     def dataReceived(self, data):
361         log.msg("Received: %r" % data)
362         ProtocolWrapper.dataReceived(self,data)
363
364     def write(self, data):
365         log.msg("Sending: %r" % data)
366         ProtocolWrapper.write(self,data)
367
368
369
370 class SpewingFactory(WrappingFactory):
371     protocol = SpewingProtocol
372
373
374
375 class LimitConnectionsByPeer(WrappingFactory):
376     """Stability: Unstable"""
377
378     maxConnectionsPerPeer = 5
379
380     def startFactory(self):
381         self.peerConnections = {}
382
383     def buildProtocol(self, addr):
384         peerHost = addr[0]
385         connectionCount = self.peerConnections.get(peerHost, 0)
386         if connectionCount >= self.maxConnectionsPerPeer:
387             return None
388         self.peerConnections[peerHost] = connectionCount + 1
389         return WrappingFactory.buildProtocol(self, addr)
390
391     def unregisterProtocol(self, p):
392         peerHost = p.getPeer()[1]
393         self.peerConnections[peerHost] -= 1
394         if self.peerConnections[peerHost] == 0:
395             del self.peerConnections[peerHost]
396
397
398 class LimitTotalConnectionsFactory(ServerFactory):
399     """Factory that limits the number of simultaneous connections.
400
401     API Stability: Unstable
402
403     @type connectionCount: C{int}
404     @ivar connectionCount: number of current connections.
405     @type connectionLimit: C{int} or C{None}
406     @cvar connectionLimit: maximum number of connections.
407     @type overflowProtocol: L{Protocol} or C{None}
408     @cvar overflowProtocol: Protocol to use for new connections when
409         connectionLimit is exceeded.  If C{None} (the default value), excess
410         connections will be closed immediately.
411     """
412     connectionCount = 0
413     connectionLimit = None
414     overflowProtocol = None
415
416     def buildProtocol(self, addr):
417         if (self.connectionLimit is None or
418             self.connectionCount < self.connectionLimit):
419                 # Build the normal protocol
420                 wrappedProtocol = self.protocol()
421         elif self.overflowProtocol is None:
422             # Just drop the connection
423             return None
424         else:
425             # Too many connections, so build the overflow protocol
426             wrappedProtocol = self.overflowProtocol()
427
428         wrappedProtocol.factory = self
429         protocol = ProtocolWrapper(self, wrappedProtocol)
430         self.connectionCount += 1
431         return protocol
432
433     def registerProtocol(self, p):
434         pass
435
436     def unregisterProtocol(self, p):
437         self.connectionCount -= 1
438
439
440
441 class TimeoutProtocol(ProtocolWrapper):
442     """
443     Protocol that automatically disconnects when the connection is idle.
444
445     Stability: Unstable
446     """
447
448     def __init__(self, factory, wrappedProtocol, timeoutPeriod):
449         """
450         Constructor.
451
452         @param factory: An L{IFactory}.
453         @param wrappedProtocol: A L{Protocol} to wrapp.
454         @param timeoutPeriod: Number of seconds to wait for activity before
455             timing out.
456         """
457         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
458         self.timeoutCall = None
459         self.setTimeout(timeoutPeriod)
460
461
462     def setTimeout(self, timeoutPeriod=None):
463         """
464         Set a timeout.
465
466         This will cancel any existing timeouts.
467
468         @param timeoutPeriod: If not C{None}, change the timeout period.
469             Otherwise, use the existing value.
470         """
471         self.cancelTimeout()
472         if timeoutPeriod is not None:
473             self.timeoutPeriod = timeoutPeriod
474         self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
475
476
477     def cancelTimeout(self):
478         """
479         Cancel the timeout.
480
481         If the timeout was already cancelled, this does nothing.
482         """
483         if self.timeoutCall:
484             try:
485                 self.timeoutCall.cancel()
486             except error.AlreadyCalled:
487                 pass
488             self.timeoutCall = None
489
490
491     def resetTimeout(self):
492         """
493         Reset the timeout, usually because some activity just happened.
494         """
495         if self.timeoutCall:
496             self.timeoutCall.reset(self.timeoutPeriod)
497
498
499     def write(self, data):
500         self.resetTimeout()
501         ProtocolWrapper.write(self, data)
502
503
504     def writeSequence(self, seq):
505         self.resetTimeout()
506         ProtocolWrapper.writeSequence(self, seq)
507
508
509     def dataReceived(self, data):
510         self.resetTimeout()
511         ProtocolWrapper.dataReceived(self, data)
512
513
514     def connectionLost(self, reason):
515         self.cancelTimeout()
516         ProtocolWrapper.connectionLost(self, reason)
517
518
519     def timeoutFunc(self):
520         """
521         This method is called when the timeout is triggered.
522
523         By default it calls L{loseConnection}.  Override this if you want
524         something else to happen.
525         """
526         self.loseConnection()
527
528
529
530 class TimeoutFactory(WrappingFactory):
531     """
532     Factory for TimeoutWrapper.
533
534     Stability: Unstable
535     """
536     protocol = TimeoutProtocol
537
538
539     def __init__(self, wrappedFactory, timeoutPeriod=30*60):
540         self.timeoutPeriod = timeoutPeriod
541         WrappingFactory.__init__(self, wrappedFactory)
542
543
544     def buildProtocol(self, addr):
545         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
546                              timeoutPeriod=self.timeoutPeriod)
547
548
549     def callLater(self, period, func):
550         """
551         Wrapper around L{reactor.callLater} for test purpose.
552         """
553         return reactor.callLater(period, func)
554
555
556
557 class TrafficLoggingProtocol(ProtocolWrapper):
558
559     def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
560                  number=0):
561         """
562         @param factory: factory which created this protocol.
563         @type factory: C{protocol.Factory}.
564         @param wrappedProtocol: the underlying protocol.
565         @type wrappedProtocol: C{protocol.Protocol}.
566         @param logfile: file opened for writing used to write log messages.
567         @type logfile: C{file}
568         @param lengthLimit: maximum size of the datareceived logged.
569         @type lengthLimit: C{int}
570         @param number: identifier of the connection.
571         @type number: C{int}.
572         """
573         ProtocolWrapper.__init__(self, factory, wrappedProtocol)
574         self.logfile = logfile
575         self.lengthLimit = lengthLimit
576         self._number = number
577
578
579     def _log(self, line):
580         self.logfile.write(line + '\n')
581         self.logfile.flush()
582
583
584     def _mungeData(self, data):
585         if self.lengthLimit and len(data) > self.lengthLimit:
586             data = data[:self.lengthLimit - 12] + '<... elided>'
587         return data
588
589
590     # IProtocol
591     def connectionMade(self):
592         self._log('*')
593         return ProtocolWrapper.connectionMade(self)
594
595
596     def dataReceived(self, data):
597         self._log('C %d: %r' % (self._number, self._mungeData(data)))
598         return ProtocolWrapper.dataReceived(self, data)
599
600
601     def connectionLost(self, reason):
602         self._log('C %d: %r' % (self._number, reason))
603         return ProtocolWrapper.connectionLost(self, reason)
604
605
606     # ITransport
607     def write(self, data):
608         self._log('S %d: %r' % (self._number, self._mungeData(data)))
609         return ProtocolWrapper.write(self, data)
610
611
612     def writeSequence(self, iovec):
613         self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
614         return ProtocolWrapper.writeSequence(self, iovec)
615
616
617     def loseConnection(self):
618         self._log('S %d: *' % (self._number,))
619         return ProtocolWrapper.loseConnection(self)
620
621
622
623 class TrafficLoggingFactory(WrappingFactory):
624     protocol = TrafficLoggingProtocol
625
626     _counter = 0
627
628     def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
629         self.logfilePrefix = logfilePrefix
630         self.lengthLimit = lengthLimit
631         WrappingFactory.__init__(self, wrappedFactory)
632
633
634     def open(self, name):
635         return file(name, 'w')
636
637
638     def buildProtocol(self, addr):
639         self._counter += 1
640         logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
641         return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
642                              logfile, self.lengthLimit, self._counter)
643
644
645     def resetCounter(self):
646         """
647         Reset the value of the counter used to identify connections.
648         """
649         self._counter = 0
650
651
652
653 class TimeoutMixin:
654     """Mixin for protocols which wish to timeout connections
655
656     @cvar timeOut: The number of seconds after which to timeout the connection.
657     """
658     timeOut = None
659
660     __timeoutCall = None
661
662     def callLater(self, period, func):
663         return reactor.callLater(period, func)
664
665
666     def resetTimeout(self):
667         """Reset the timeout count down"""
668         if self.__timeoutCall is not None and self.timeOut is not None:
669             self.__timeoutCall.reset(self.timeOut)
670
671     def setTimeout(self, period):
672         """Change the timeout period
673
674         @type period: C{int} or C{NoneType}
675         @param period: The period, in seconds, to change the timeout to, or
676         C{None} to disable the timeout.
677         """
678         prev = self.timeOut
679         self.timeOut = period
680
681         if self.__timeoutCall is not None:
682             if period is None:
683                 self.__timeoutCall.cancel()
684                 self.__timeoutCall = None
685             else:
686                 self.__timeoutCall.reset(period)
687         elif period is not None:
688             self.__timeoutCall = self.callLater(period, self.__timedOut)
689
690         return prev
691
692     def __timedOut(self):
693         self.__timeoutCall = None
694         self.timeoutConnection()
695
696     def timeoutConnection(self):
697         """Called when the connection times out.
698         Override to define behavior other than dropping the connection.
699         """
700         self.transport.loseConnection()