]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/krpc.py
More work on the TODO.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / krpc.py
index 1458fc25903634bdf87400ef1db4cd584349016c..1428545430c104a20fc729e5d042715436295a5f 100644 (file)
@@ -8,8 +8,11 @@ from traceback import format_exception
 
 from twisted.internet.defer import Deferred
 from twisted.internet import protocol, reactor
+from twisted.python import log
 from twisted.trial import unittest
 
+from khash import newID
+
 KRPC_TIMEOUT = 20
 
 KRPC_ERROR = 1
@@ -25,9 +28,13 @@ TYP = 'y'
 ARG = 'a'
 ERR = 'e'
 
+class ProtocolError(Exception):
+    pass
+
 class hostbroker(protocol.DatagramProtocol):       
-    def __init__(self, server):
+    def __init__(self, server, config):
         self.server = server
+        self.config = config
         # this should be changed to storage that drops old entries
         self.connections = {}
         
@@ -43,7 +50,7 @@ class hostbroker(protocol.DatagramProtocol):
         if addr == self.addr:
             raise Exception
         if not self.connections.has_key(addr):
-            conn = self.protocol(addr, self.server, self.transport)
+            conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
             self.connections[addr] = conn
         else:
             conn = self.connections[addr]
@@ -53,27 +60,36 @@ class hostbroker(protocol.DatagramProtocol):
         protocol.DatagramProtocol.makeConnection(self, transport)
         tup = transport.getHost()
         self.addr = (tup.host, tup.port)
+        
+    def stopProtocol(self):
+        for conn in self.connections.values():
+            conn.stop()
+        protocol.DatagramProtocol.stopProtocol(self)
 
 ## connection
 class KRPC:
-    noisy = 1
-    def __init__(self, addr, server, transport):
+    def __init__(self, addr, server, transport, spew = False):
         self.transport = transport
         self.factory = server
         self.addr = addr
+        self.noisy = spew
         self.tids = {}
-        self.mtid = 0
+        self.stopped = False
 
     def datagramReceived(self, str, addr):
+        if self.stopped:
+            if self.noisy:
+                log.msg("stopped, dropping message from %r: %s" % (addr, str))
         # bdecode
         try:
             msg = bdecode(str)
         except Exception, e:
             if self.noisy:
-                print "response decode error: " + `e`
+                log.msg("response decode error: ")
+                log.err(e)
         else:
-            #if self.noisy:
-            #    print msg
+            if self.noisy:
+                log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
             # look at msg type
             if msg[TYP]  == REQ:
                 ilen = len(str)
@@ -83,32 +99,19 @@ class KRPC:
                 msg[ARG]['_krpc_sender'] =  self.addr
                 if f and callable(f):
                     try:
-                        ret = apply(f, (), msg[ARG])
+                        ret = f(*(), **msg[ARG])
                     except Exception, e:
-                        ## send error
-                        out = bencode({TID:msg[TID], TYP:ERR, ERR :`format_exception(type(e), e, sys.exc_info()[2])`})
-                        olen = len(out)
-                        self.transport.write(out, addr)
+                        olen = self._sendResponse(addr, msg[TID], ERR, `format_exception(type(e), e, sys.exc_info()[2])`)
                     else:
-                        if ret:
-                            #  make response
-                            out = bencode({TID : msg[TID], TYP : RSP, RSP : ret})
-                        else:
-                            out = bencode({TID : msg[TID], TYP : RSP, RSP : {}})
-                        #      send response
-                        olen = len(out)
-                        self.transport.write(out, addr)
-
+                        olen = self._sendResponse(addr, msg[TID], RSP, ret)
                 else:
                     if self.noisy:
-                        print "don't know about method %s" % msg[REQ]
+                        log.msg("don't know about method %s" % msg[REQ])
                     # unknown method
-                    out = bencode({TID:msg[TID], TYP:ERR, ERR : KRPC_ERROR_METHOD_UNKNOWN})
-                    olen = len(out)
-                    self.transport.write(out, addr)
+                    olen = self._sendResponse(addr, msg[TID], ERR, KRPC_ERROR_METHOD_UNKNOWN)
                 if self.noisy:
-                    print "%s %s >>> %s - %s %s %s" % (asctime(), addr, self.factory.node.port, 
-                                                    ilen, msg[REQ], olen)
+                    log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
+                                                      ilen, msg[REQ], olen))
             elif msg[TYP] == RSP:
                 # if response
                 #      lookup tid
@@ -118,8 +121,9 @@ class KRPC:
                     del(self.tids[msg[TID]])
                     df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
                 else:
-                    print 'timeout ' + `msg[RSP]['id']`
                     # no tid, this transaction timed out already...
+                    if self.noisy:
+                        log.msg('timeout: %r' % msg[RSP]['id'])
             elif msg[TYP] == ERR:
                 # if error
                 #      lookup tid
@@ -132,27 +136,44 @@ class KRPC:
                     # day late and dollar short
                     pass
             else:
-                print "unknown message type " + `msg`
+                if self.noisy:
+                    log.msg("unknown message type: %r" % msg)
                 # unknown message type
                 df = self.tids[msg[TID]]
                 #      callback
                 df.errback(KRPC_ERROR_RECEIVED_UNKNOWN)
                 del(self.tids[msg[TID]])
                 
+    def _sendResponse(self, addr, tid, msgType, response):
+        if not response:
+            response = {}
+            
+        msg = {TID : tid, TYP : msgType, msgType : response}
+
+        if self.noisy:
+            log.msg("%d responding to %r: %s" % (self.factory.port, addr, msg))
+
+        out = bencode(msg)
+        self.transport.write(out, addr)
+        return len(out)
+    
     def sendRequest(self, method, args):
+        if self.stopped:
+            raise ProtocolError, "connection has been stopped"
         # make message
         # send it
-        msg = {TID : chr(self.mtid), TYP : REQ,  REQ : method, ARG : args}
-        self.mtid = (self.mtid + 1) % 256
+        msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
+        if self.noisy:
+            log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
         str = bencode(msg)
         d = Deferred()
         self.tids[msg[TID]] = d
-        def timeOut(tids = self.tids, id = msg[TID]):
+        def timeOut(tids = self.tids, id = msg[TID], msg = msg):
             if tids.has_key(id):
                 df = tids[id]
                 del(tids[id])
-                print ">>>>>> KRPC_ERROR_TIMEOUT"
-                df.errback(KRPC_ERROR_TIMEOUT)
+                log.msg(">>>>>> KRPC_ERROR_TIMEOUT")
+                df.errback(ProtocolError('timeout waiting for %r' % msg))
         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
         def dropTimeOut(dict, later_call = later):
             if later_call.active():
@@ -161,6 +182,13 @@ class KRPC:
         d.addBoth(dropTimeOut)
         self.transport.write(str, self.addr)
         return d
+    
+    def stop(self):
+        """Timeout all pending requests."""
+        for df in self.tids.values():
+            df.errback(ProtocolError('connection has been closed'))
+        self.tids = {}
+        self.stopped = True
  
 def connectionForAddr(host, port):
     return host
@@ -176,14 +204,13 @@ class Receiver(protocol.Factory):
 
 def make(port):
     af = Receiver()
-    a = hostbroker(af)
+    a = hostbroker(af, {'SPEW': False})
     a.protocol = KRPC
     p = reactor.listenUDP(port, a)
     return af, a, p
     
 class KRPCTests(unittest.TestCase):
     def setUp(self):
-        KRPC.noisy = 0
         self.af, self.a, self.ap = make(1180)
         self.bf, self.b, self.bp = make(1181)
 
@@ -192,7 +219,7 @@ class KRPCTests(unittest.TestCase):
         self.bp.stopListening()
 
     def bufEquals(self, result, value):
-        self.assertEqual(self.bf.buf, value)
+        self.failUnlessEqual(self.bf.buf, value)
 
     def testSimpleMessage(self):
         d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
@@ -213,7 +240,7 @@ class KRPCTests(unittest.TestCase):
     def gotMsg(self, dict, should_be):
         _krpc_sender = dict['_krpc_sender']
         msg = dict['rsp']
-        self.assertEqual(msg, should_be)
+        self.failUnlessEqual(msg, should_be)
 
     def testManyEcho(self):
         for i in xrange(100):
@@ -254,4 +281,4 @@ class KRPCTests(unittest.TestCase):
         return df
 
     def gotErr(self, err, should_be):
-        self.assertEqual(err.value, should_be)
+        self.failUnlessEqual(err.value, should_be)