Make sure full timeouts occur even if a resend timeout has not.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / krpc.py
index a4fbacc..b151a55 100644 (file)
@@ -1,9 +1,6 @@
-## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
-# see LICENSE.txt for license information
 
 """The KRPC communication protocol implementation.
 
-@var KRPC_TIMEOUT: the number of seconds after which requests timeout
 @var UDP_PACKET_LIMIT: the maximum number of bytes that can be sent in a
     UDP packet without fragmentation
 
 """
 
 from bencode import bencode, bdecode
-from time import asctime
+from datetime import datetime, timedelta
 from math import ceil
 
-from twisted.internet.defer import Deferred
+from twisted.internet import defer
 from twisted.internet import protocol, reactor
 from twisted.python import log
 from twisted.trial import unittest
 
 from khash import newID
 
-KRPC_TIMEOUT = 20
 UDP_PACKET_LIMIT = 1472
 
 # Remote node errors
@@ -127,6 +123,8 @@ class hostbroker(protocol.DatagramProtocol):
     
     @type server: L{khashmir.Khashmir}
     @ivar server: the main Khashmir program
+    @type stats: L{stats.StatsLogger}
+    @ivar stats: the statistics logger to save transport info
     @type config: C{dictionary}
     @ivar config: the configuration parameters for the DHT
     @type connections: C{dictionary}
@@ -139,15 +137,18 @@ class hostbroker(protocol.DatagramProtocol):
     @ivar addr: the IP address and port of this node
     """
     
-    def __init__(self, server, config):
+    def __init__(self, server, stats, config):
         """Initialize the factory.
         
         @type server: L{khashmir.Khashmir}
         @param server: the main DHT program
+        @type stats: L{stats.StatsLogger}
+        @param stats: the statistics logger to save transport info
         @type config: C{dictionary}
         @param config: the configuration parameters for the DHT
         """
         self.server = server
+        self.stats = stats
         self.config = config
         # this should be changed to storage that drops old entries
         self.connections = {}
@@ -177,7 +178,7 @@ class hostbroker(protocol.DatagramProtocol):
         
         # Create a new protocol object if necessary
         if not self.connections.has_key(addr):
-            conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
+            conn = self.protocol(addr, self.server, self.stats, self.transport, self.config)
             self.connections[addr] = conn
         else:
             conn = self.connections[addr]
@@ -195,16 +196,111 @@ class hostbroker(protocol.DatagramProtocol):
             conn.stop()
         protocol.DatagramProtocol.stopProtocol(self)
 
+class KrpcRequest(defer.Deferred):
+    """An outstanding request to a remote node.
+    
+    @type protocol: L{KRPC}
+    @ivar protocol: the protocol to send data with
+    @ivar tid: the transaction ID of the request
+    @type method: C{string}
+    @ivar method: the name of the method to call on the remote node
+    @type data: C{string}
+    @ivar data: the message to send to the remote node
+    @type config: C{dictionary}
+    @ivar config: the configuration parameters for the DHT
+    @type delay: C{int}
+    @ivar delay: the last timeout delay sent
+    @type start: C{datetime}
+    @ivar start: the time to request was started at
+    @type laterNextTimeout: L{twisted.internet.interfaces.IDelayedCall}
+    @ivar laterNextTimeout: the pending call to timeout the last sent request
+    @type laterFinalTimeout: L{twisted.internet.interfaces.IDelayedCall}
+    @ivar laterFinalTimeout: the pending call to timeout the entire request
+    """
+    
+    def __init__(self, protocol, newTID, method, data, config):
+        """Initialize the request, and send it out.
+        
+        @type protocol: L{KRPC}
+        @param protocol: the protocol to send data with
+        @param newTID: the transaction ID of the request
+        @type method: C{string}
+        @param method: the name of the method to call on the remote node
+        @type data: C{string}
+        @param data: the message to send to the remote node
+        @type config: C{dictionary}
+        @param config: the configuration parameters for the DHT
+        """
+        defer.Deferred.__init__(self)
+        self.protocol = protocol
+        self.tid = newTID
+        self.method = method
+        self.data = data
+        self.config = config
+        self.delay = self.config.get('KRPC_INITIAL_DELAY', 2)
+        self.start = datetime.now()
+        self.laterNextTimeout = None
+        self.laterFinalTimeout = reactor.callLater(self.config.get('KRPC_TIMEOUT', 9), self.finalTimeout)
+        reactor.callLater(0, self.send)
+        
+    def send(self):
+        """Send the request to the remote node."""
+        assert not self.laterNextTimeout, 'There is already a pending request'
+        self.laterNextTimeout = reactor.callLater(self.delay, self.nextTimeout)
+        try:
+            self.protocol.sendData(self.method, self.data)
+        except:
+            log.err()
+
+    def nextTimeout(self):
+        """Check for a unrecoverable timeout, otherwise resend."""
+        self.laterNextTimeout = None
+        if datetime.now() - self.start > timedelta(seconds = self.config.get('KRPC_TIMEOUT', 9)):
+            self.finalTimeout()
+        elif self.protocol.stopped:
+            log.msg('Timeout but can not resend %r, protocol has been stopped' % self.tid)
+        else:
+            self.delay *= 2
+            log.msg('Trying to resend %r now with delay %d sec' % (self.tid, self.delay))
+            reactor.callLater(0, self.send)
+        
+    def finalTimeout(self):
+        """Timeout the request after an unrecoverable timeout."""
+        self.dropTimeOut()
+        delay = datetime.now() - self.start
+        log.msg('%r timed out after %0.2f sec' %
+                (self.tid, delay.seconds + delay.microseconds/1000000.0))
+        self.protocol.timeOut(self.tid, self.method)
+        
+    def callback(self, resp):
+        self.dropTimeOut()
+        defer.Deferred.callback(self, resp)
+        
+    def errback(self, resp):
+        self.dropTimeOut()
+        defer.Deferred.errback(self, resp)
+        
+    def dropTimeOut(self):
+        """Cancel the timeout call when a response is received."""
+        if self.laterFinalTimeout and self.laterFinalTimeout.active():
+            self.laterFinalTimeout.cancel()
+        self.laterFinalTimeout = None
+        if self.laterNextTimeout and self.laterNextTimeout.active():
+            self.laterNextTimeout.cancel()
+        self.laterNextTimeout = None
+
 class KRPC:
     """The KRPC protocol implementation.
     
     @ivar transport: the transport to use for the protocol
     @type factory: L{khashmir.Khashmir}
     @ivar factory: the main Khashmir program
+    @type stats: L{stats.StatsLogger}
+    @ivar stats: the statistics logger to save transport info
     @type addr: (C{string}, C{int})
     @ivar addr: the IP address and port of the source node
-    @type noisy: C{boolean}
-    @ivar noisy: whether to log additional details of the protocol
+    @type config: C{dictionary}
+    @ivar config: the configuration parameters for the DHT
     @type tids: C{dictionary}
     @ivar tids: the transaction IDs outstanding for requests, keys are the
         transaction ID of the request, values are the deferreds to call with
@@ -213,22 +309,25 @@ class KRPC:
     @ivar stopped: whether the protocol has been stopped
     """
     
-    def __init__(self, addr, server, transport, spew = False):
+    def __init__(self, addr, server, stats, transport, config = {}):
         """Initialize the protocol.
         
         @type addr: (C{string}, C{int})
         @param addr: the IP address and port of the source node
         @type server: L{khashmir.Khashmir}
         @param server: the main Khashmir program
+        @type stats: L{stats.StatsLogger}
+        @param stats: the statistics logger to save transport info
         @param transport: the transport to use for the protocol
-        @type spew: C{boolean}
-        @param spew: whether to log additional details of the protocol
-            (optional, defaults to False)
+        @type config: C{dictionary}
+        @param config: the configuration parameters for the DHT
+            (optional, defaults to using defaults)
         """
         self.transport = transport
         self.factory = server
+        self.stats = stats
         self.addr = addr
-        self.noisy = spew
+        self.config = config
         self.tids = {}
         self.stopped = False
 
@@ -240,15 +339,16 @@ class KRPC:
         @type addr: (C{string}, C{int})
         @param addr: source IP address and port of datagram.
         """
+        self.stats.receivedBytes(len(data))
         if self.stopped:
-            if self.noisy:
+            if self.config.get('SPEW', False):
                 log.msg("stopped, dropping message from %r: %s" % (addr, data))
 
         # Bdecode the message
         try:
             msg = bdecode(data)
         except Exception, e:
-            if self.noisy:
+            if self.config.get('SPEW', False):
                 log.msg("krpc bdecode error: ")
                 log.err(e)
             return
@@ -261,7 +361,7 @@ class KRPC:
             log.err(e)
             return
 
-        if self.noisy:
+        if self.config.get('SPEW', False):
             log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
 
         # Process it based on its type
@@ -272,68 +372,77 @@ class KRPC:
             f = getattr(self.factory ,"krpc_" + msg[REQ], None)
             msg[ARG]['_krpc_sender'] =  self.addr
             if f and callable(f):
+                self.stats.receivedAction(msg[REQ])
                 try:
                     ret = f(*(), **msg[ARG])
                 except KrpcError, e:
                     log.msg('Got a Krpc error while running: krpc_%s' % msg[REQ])
-                    log.err(e)
-                    olen = self._sendResponse(addr, msg[TID], ERR, [e[0], e[1]])
+                    if e[0] != KRPC_ERROR_INVALID_TOKEN:
+                        log.err(e)
+                    self.stats.errorAction(msg[REQ])
+                    olen = self._sendResponse(msg[REQ], addr, msg[TID], ERR,
+                                              [e[0], e[1]])
                 except TypeError, e:
                     log.msg('Got a malformed request for: krpc_%s' % msg[REQ])
                     log.err(e)
-                    olen = self._sendResponse(addr, msg[TID], ERR,
+                    self.stats.errorAction(msg[REQ])
+                    olen = self._sendResponse(msg[REQ], addr, msg[TID], ERR,
                                               [KRPC_ERROR_MALFORMED_REQUEST, str(e)])
                 except Exception, e:
                     log.msg('Got an unknown error while running: krpc_%s' % msg[REQ])
                     log.err(e)
-                    olen = self._sendResponse(addr, msg[TID], ERR,
+                    self.stats.errorAction(msg[REQ])
+                    olen = self._sendResponse(msg[REQ], addr, msg[TID], ERR,
                                               [KRPC_ERROR_SERVER_ERROR, str(e)])
                 else:
-                    olen = self._sendResponse(addr, msg[TID], RSP, ret)
+                    olen = self._sendResponse(msg[REQ], addr, msg[TID], RSP, ret)
             else:
                 # Request for unknown method
                 log.msg("ERROR: don't know about method %s" % msg[REQ])
-                olen = self._sendResponse(addr, msg[TID], ERR,
+                self.stats.receivedAction('unknown')
+                olen = self._sendResponse(msg[REQ], addr, msg[TID], ERR,
                                           [KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
-            if self.noisy:
-                log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
-                                                  ilen, msg[REQ], olen))
+
+            log.msg("%s >>> %s %s %s" % (addr, ilen, msg[REQ], olen))
         elif msg[TYP] == RSP:
             # Responses get processed by their TID's deferred
             if self.tids.has_key(msg[TID]):
-                df = self.tids[msg[TID]]
+                req = self.tids[msg[TID]]
                 #      callback
                 del(self.tids[msg[TID]])
-                df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
+                msg[RSP]['_krpc_sender'] = addr
+                req.callback(msg[RSP])
             else:
-                # no tid, this transaction timed out already...
-                if self.noisy:
-                    log.msg('timeout: %r' % msg[RSP]['id'])
+                # no tid, this transaction was finished already...
+                if self.config.get('SPEW', False):
+                    log.msg('received response from %r for completed request: %r' %
+                            (msg[RSP]['id'], msg[TID]))
         elif msg[TYP] == ERR:
             # Errors get processed by their TID's deferred's errback
             if self.tids.has_key(msg[TID]):
-                df = self.tids[msg[TID]]
+                req = self.tids[msg[TID]]
                 del(self.tids[msg[TID]])
                 # callback
-                df.errback(KrpcError(*msg[ERR]))
+                req.errback(KrpcError(*msg[ERR]))
             else:
-                # day late and dollar short, just log it
-                log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
-                pass
+                # no tid, this transaction was finished already...
+                log.msg('received an error %r from %r for completed request: %r' %
+                        (msg[ERR], msg[RSP]['id'], msg[TID]))
         else:
             # Received an unknown message type
-            if self.noisy:
+            if self.config.get('SPEW', False):
                 log.msg("unknown message type: %r" % msg)
             if msg[TID] in self.tids:
-                df = self.tids[msg[TID]]
+                req = self.tids[msg[TID]]
                 del(self.tids[msg[TID]])
                 # callback
-                df.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
-                                     "Received an unknown message type: %r" % msg[TYP]))
+                req.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
+                                      "Received an unknown message type: %r" % msg[TYP]))
                 
-    def _sendResponse(self, addr, tid, msgType, response):
+    def _sendResponse(self, request, addr, tid, msgType, response):
         """Helper function for sending responses to nodes.
-        
+
+        @param request: the name of the requested method
         @type addr: (C{string}, C{int})
         @param addr: source IP address and port of datagram.
         @param tid: the transaction ID of the request
@@ -347,7 +456,7 @@ class KRPC:
             # Create the response message
             msg = {TID : tid, TYP : msgType, msgType : response}
     
-            if self.noisy:
+            if self.config.get('SPEW', False):
                 log.msg("%d responding to %r: %s" % (self.factory.port, addr, msg))
     
             out = bencode(msg)
@@ -388,14 +497,17 @@ class KRPC:
                 else:
                     # Too long a response, send an error
                     log.msg('Could not send response, too long: %d bytes' % len(out))
+                    self.stats.errorAction(request)
                     msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_RESPONSE_TOO_LONG, "response was %d bytes" % len(out)]}
                     out = bencode(msg)
 
         except Exception, e:
             # Unknown error, send an error message
+            self.stats.errorAction(request)
             msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_SERVER_ERROR, "unknown error sending response: %s" % str(e)]}
             out = bencode(msg)
                     
+        self.stats.sentBytes(len(out))
         self.transport.write(out, addr)
         return len(out)
     
@@ -403,46 +515,63 @@ class KRPC:
         """Send a request to the remote node.
         
         @type method: C{string}
-        @param method: the methiod name to call on the remote node
+        @param method: the method name to call on the remote node
         @param args: the arguments to send to the remote node's method
         """
         if self.stopped:
-            raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
+            return defer.fail(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED,
+                                        "cannot send, connection has been stopped"))
 
         # Create the request message
-        msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
-        if self.noisy:
+        newTID = newID()
+        msg = {TID : newTID, TYP : REQ,  REQ : method, ARG : args}
+        if self.config.get('SPEW', False):
             log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
         data = bencode(msg)
         
-        # Create the deferred and save it with the TID
-        d = Deferred()
-        self.tids[msg[TID]] = d
-
-        # Schedule a later timeout call
-        def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
-            """Call the deferred's errback if a timeout occurs."""
-            if tids.has_key(id):
-                df = tids[id]
-                del(tids[id])
-                df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
-        later = reactor.callLater(KRPC_TIMEOUT, timeOut)
+        # Create the request object and save it with the TID
+        req = KrpcRequest(self, newTID, method, data, self.config)
+        self.tids[newTID] = req
         
-        # Cancel the timeout call if a response is received
-        def dropTimeOut(dict, later_call = later):
-            """Cancel the timeout call when a response is received."""
-            if later_call.active():
-                later_call.cancel()
-            return dict
-        d.addBoth(dropTimeOut)
+        # Save the conclusion of the action
+        req.addCallbacks(self.stats.responseAction, self.stats.failedAction,
+                         callbackArgs = (method, datetime.now()),
+                         errbackArgs = (method, datetime.now()))
+
+        return req
+    
+    def sendData(self, method, data):
+        """Write a request to the transport and save the stats.
         
+        @type method: C{string}
+        @param method: the name of the method to call on the remote node
+        @type data: C{string}
+        @param data: the message to send to the remote node
+        """
         self.transport.write(data, self.addr)
-        return d
-    
+        self.stats.sentAction(method)
+        self.stats.sentBytes(len(data))
+        
+    def timeOut(self, badTID, method):
+        """Call the deferred's errback if a timeout occurs.
+        
+        @param badTID: the transaction ID of the request
+        @type method: C{string}
+        @param method: the name of the method that timed out on the remote node
+        """
+        if badTID in self.tids:
+            req = self.tids[badTID]
+            del(self.tids[badTID])
+            req.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % 
+                                                      (method, self.addr)))
+        else:
+            log.msg('Received a timeout for an unknown request for %s from %r' % (method, self.addr))
+        
     def stop(self):
-        """Timeout all pending requests."""
-        for df in self.tids.values():
-            df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
+        """Cancel all pending requests."""
+        for req in self.tids.values():
+            req.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED,
+                                  'connection has been stopped while waiting for response'))
         self.tids = {}
         self.stopped = True
 
@@ -463,8 +592,10 @@ class Receiver(protocol.Factory):
         return {'values': ['1'*length]*num}
 
 def make(port):
+    from stats import StatsLogger
     af = Receiver()
-    a = hostbroker(af, {'SPEW': False})
+    a = hostbroker(af, StatsLogger(None, None),
+                   {'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2, 'SPEW': False})
     a.protocol = KRPC
     p = reactor.listenUDP(port, a)
     return af, a, p
@@ -501,8 +632,7 @@ class KRPCTests(unittest.TestCase):
 
     def gotMsg(self, dict, should_be):
         _krpc_sender = dict['_krpc_sender']
-        msg = dict['rsp']
-        self.failUnlessEqual(msg['msg'], should_be)
+        self.failUnlessEqual(dict['msg'], should_be)
 
     def testManyEcho(self):
         for i in xrange(100):
@@ -539,16 +669,20 @@ class KRPCTests(unittest.TestCase):
 
     def testUnknownMeth(self):
         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
+        df = self.failUnlessFailure(df, KrpcError)
         df.addBoth(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
         return df
 
     def testMalformedRequest(self):
         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test.", 'foo': 'bar'})
-        df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST)
+        df = self.failUnlessFailure(df, KrpcError)
+        df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST, TypeError)
         return df
 
-    def gotErr(self, err, should_be):
-        self.failUnlessEqual(err.value[0], should_be)
+    def gotErr(self, value, should_be, *errorTypes):
+        self.failUnlessEqual(value[0], should_be)
+        if errorTypes:
+            self.flushLoggedErrors(*errorTypes)
         
     def testLongPackets(self):
         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('values', {'length' : 1, 'num': 2000})