More and better error messages in the DHT.
authorCameron Dale <camrdale@gmail.com>
Fri, 22 Feb 2008 02:40:56 +0000 (18:40 -0800)
committerCameron Dale <camrdale@gmail.com>
Fri, 22 Feb 2008 02:40:56 +0000 (18:40 -0800)
Also added message verification for the incoming message
to prevent abuse.

TODO
apt_dht_Khashmir/actions.py
apt_dht_Khashmir/khashmir.py
apt_dht_Khashmir/krpc.py

diff --git a/TODO b/TODO
index bb5ae3d..fac9201 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,9 +1,3 @@
-Comply with the newly defined protocol on the web page.
-
-Various things need to done to comply with the newly defined protocol:
- - standardize the error messages (especially for a bad token)
-
-
 Reduce the memory footprint by clearing the AptPackages caches.
 
 The memory usage is a little bit high due to keeping the AptPackages
index 6766cd9..7822579 100644 (file)
@@ -31,6 +31,13 @@ class ActionBase:
             return 0
         self.sort = sort
         
+    def actionFailed(self, err, node):
+        log.msg("action %s failed (%s) %s/%s" % (self.__class__.__name__, self.config['PORT'], node.host, node.port))
+        log.err(err)
+        self.caller.table.nodeFailed(node)
+        self.outstanding = self.outstanding - 1
+        self.schedule()
+    
     def goWithNodes(self, t):
         pass
     
@@ -75,7 +82,7 @@ class FindNode(ActionBase):
             if (not self.queried.has_key(node.id)) and node.id != self.caller.node.id:
                 #xxxx t.timeout = time.time() + FIND_NODE_TIMEOUT
                 df = node.findNode(self.target, self.caller.node.id)
-                df.addCallbacks(self.handleGotNodes, self.makeMsgFailed(node))
+                df.addCallbacks(self.handleGotNodes, self.actionFailed, errbackArgs = (node, ))
                 self.outstanding = self.outstanding + 1
                 self.queried[node.id] = 1
             if self.outstanding >= self.config['CONCURRENT_REQS']:
@@ -86,15 +93,6 @@ class FindNode(ActionBase):
             self.finished=1
             reactor.callLater(0, self.callback, l[:self.config['K']])
     
-    def makeMsgFailed(self, node):
-        def defaultGotNodes(err, self=self, node=node):
-            log.msg("find failed (%s) %s/%s" % (self.config['PORT'], node.host, node.port))
-            log.err(err)
-            self.caller.table.nodeFailed(node)
-            self.outstanding = self.outstanding - 1
-            self.schedule()
-        return defaultGotNodes
-    
     def goWithNodes(self, nodes):
         """
             this starts the process, our argument is a transaction with t.extras being our list of nodes
@@ -163,8 +161,7 @@ class GetValue(FindNode):
                     log.msg("findValue %s doesn't have a %s method!" % (node, self.findValue))
                 else:
                     df = f(self.target, self.caller.node.id)
-                    df.addCallback(self.handleGotNodes)
-                    df.addErrback(self.makeMsgFailed(node))
+                    df.addCallbacks(self.handleGotNodes, self.actionFailed, errbackArgs = (node, ))
                     self.outstanding = self.outstanding + 1
                     self.queried[node.id] = 1
             if self.outstanding >= self.config['CONCURRENT_REQS']:
@@ -211,15 +208,6 @@ class StoreValue(ActionBase):
                 self.schedule()
         return t
     
-    def storeFailed(self, t, node):
-        log.msg("store failed %s/%s" % (node.host, node.port))
-        self.caller.nodeFailed(node)
-        self.outstanding -= 1
-        if self.finished:
-            return t
-        self.schedule()
-        return t
-    
     def schedule(self):
         if self.finished:
             return
@@ -242,8 +230,7 @@ class StoreValue(ActionBase):
                         log.msg("%s doesn't have a %s method!" % (node, self.store))
                     else:
                         df = f(self.target, self.value, node.token, self.caller.node.id)
-                        df.addCallback(self.storedValue, node=node)
-                        df.addErrback(self.storeFailed, node=node)
+                        df.addCallbacks(self.storedValue, self.actionFailed, callbackArgs = (node, ), errbackArgs = (node, ))
                     
     def goWithNodes(self, nodes):
         self.nodes = nodes
index 3f5327a..d3479e6 100644 (file)
@@ -287,8 +287,8 @@ class KhashmirWrite(KhashmirRead):
             this_token = sha(secret + _krpc_sender[0]).digest()
             if token == this_token:
                 self.store.storeValue(key, value)
-                break;
-        return {"id" : self.node.id}
+                return {"id" : self.node.id}
+        raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
 
 # the whole shebang, for testing
 class Khashmir(KhashmirWrite):
index 1428545..63086ed 100644 (file)
@@ -3,8 +3,6 @@
 
 from bencode import bencode, bdecode
 from time import asctime
-import sys
-from traceback import format_exception
 
 from twisted.internet.defer import Deferred
 from twisted.internet import protocol, reactor
@@ -15,10 +13,19 @@ from khash import newID
 
 KRPC_TIMEOUT = 20
 
-KRPC_ERROR = 1
-KRPC_ERROR_METHOD_UNKNOWN = 2
-KRPC_ERROR_RECEIVED_UNKNOWN = 3
-KRPC_ERROR_TIMEOUT = 4
+# Remote node errors
+KRPC_ERROR = 200
+KRPC_ERROR_SERVER_ERROR = 201
+KRPC_ERROR_MALFORMED_PACKET = 202
+KRPC_ERROR_METHOD_UNKNOWN = 203
+KRPC_ERROR_MALFORMED_REQUEST = 204
+KRPC_ERROR_INVALID_TOKEN = 205
+
+# Local errors
+KRPC_ERROR_INTERNAL = 100
+KRPC_ERROR_RECEIVED_UNKNOWN = 101
+KRPC_ERROR_TIMEOUT = 102
+KRPC_ERROR_PROTOCOL_STOPPED = 103
 
 # commands
 TID = 't'
@@ -28,9 +35,53 @@ TYP = 'y'
 ARG = 'a'
 ERR = 'e'
 
-class ProtocolError(Exception):
+class KrpcError(Exception):
     pass
 
+def verifyMessage(msg):
+    """Check received message for corruption and errors.
+    
+    @type msg: C{dictionary}
+    @param msg: the dictionary of information received on the connection
+    @raise KrpcError: if the message is corrupt
+    """
+    
+    if type(msg) != dict:
+        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "not a dictionary")
+    if TYP not in msg:
+        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no message type")
+    if msg[TYP] == REQ:
+        if REQ not in msg:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type not specified")
+        if type(msg[REQ]) != str:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type is not a string")
+        if ARG not in msg:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no arguments for request")
+        if type(msg[ARG]) != dict:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "arguments for request are not in a dictionary")
+    elif msg[TYP] == RSP:
+        if RSP not in msg:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response not specified")
+        if type(msg[RSP]) != dict:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response is not a dictionary")
+    elif msg[TYP] == ERR:
+        if ERR not in msg:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error not specified")
+        if type(msg[ERR]) != list:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a list")
+        if len(msg[ERR]) != 2:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a 2-element list")
+        if type(msg[ERR][0]) not in (int, long):
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error number is not a number")
+        if type(msg[ERR][1]) != str:
+            raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error string is not a string")
+#    else:
+#        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "unknown message type")
+    if TID not in msg:
+        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no transaction ID specified")
+    if type(msg[TID]) != str:
+        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "transaction id is not a string")
+
 class hostbroker(protocol.DatagramProtocol):       
     def __init__(self, server, config):
         self.server = server
@@ -76,73 +127,91 @@ class KRPC:
         self.tids = {}
         self.stopped = False
 
-    def datagramReceived(self, str, addr):
+    def datagramReceived(self, data, addr):
         if self.stopped:
             if self.noisy:
-                log.msg("stopped, dropping message from %r: %s" % (addr, str))
+                log.msg("stopped, dropping message from %r: %s" % (addr, data))
         # bdecode
         try:
-            msg = bdecode(str)
+            msg = bdecode(data)
         except Exception, e:
             if self.noisy:
-                log.msg("response decode error: ")
+                log.msg("krpc bdecode error: ")
                 log.err(e)
-        else:
-            if self.noisy:
-                log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
-            # look at msg type
-            if msg[TYP]  == REQ:
-                ilen = len(str)
-                # if request
-                #      tell factory to handle
-                f = getattr(self.factory ,"krpc_" + msg[REQ], None)
-                msg[ARG]['_krpc_sender'] =  self.addr
-                if f and callable(f):
-                    try:
-                        ret = f(*(), **msg[ARG])
-                    except Exception, e:
-                        olen = self._sendResponse(addr, msg[TID], ERR, `format_exception(type(e), e, sys.exc_info()[2])`)
-                    else:
-                        olen = self._sendResponse(addr, msg[TID], RSP, ret)
-                else:
-                    if self.noisy:
-                        log.msg("don't know about method %s" % msg[REQ])
-                    # unknown method
-                    olen = self._sendResponse(addr, msg[TID], ERR, KRPC_ERROR_METHOD_UNKNOWN)
-                if self.noisy:
-                    log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
-                                                      ilen, msg[REQ], olen))
-            elif msg[TYP] == RSP:
-                # if response
-                #      lookup tid
-                if self.tids.has_key(msg[TID]):
-                    df = self.tids[msg[TID]]
-                    #  callback
-                    del(self.tids[msg[TID]])
-                    df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
-                else:
-                    # no tid, this transaction timed out already...
-                    if self.noisy:
-                        log.msg('timeout: %r' % msg[RSP]['id'])
-            elif msg[TYP] == ERR:
-                # if error
-                #      lookup tid
-                if self.tids.has_key(msg[TID]):
-                    df = self.tids[msg[TID]]
-                    #  callback
-                    df.errback(msg[ERR])
-                    del(self.tids[msg[TID]])
+            return
+
+        try:
+            verifyMessage(msg)
+        except Exception, e:
+            log.msg("krpc message verification error: ")
+            log.err(e)
+            return
+
+        if self.noisy:
+            log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
+        # look at msg type
+        if msg[TYP]  == REQ:
+            ilen = len(data)
+            # if request
+            #  tell factory to handle
+            f = getattr(self.factory ,"krpc_" + msg[REQ], None)
+            msg[ARG]['_krpc_sender'] =  self.addr
+            if f and callable(f):
+                try:
+                    ret = f(*(), **msg[ARG])
+                except KrpcError, e:
+                    olen = self._sendResponse(addr, msg[TID], ERR, [e[0], e[1]])
+                except TypeError, e:
+                    olen = self._sendResponse(addr, msg[TID], ERR,
+                                              [KRPC_ERROR_MALFORMED_REQUEST, str(e)])
+                except Exception, e:
+                    olen = self._sendResponse(addr, msg[TID], ERR,
+                                              [KRPC_ERROR_SERVER_ERROR, str(e)])
                 else:
-                    # day late and dollar short
-                    pass
+                    olen = self._sendResponse(addr, msg[TID], RSP, ret)
             else:
                 if self.noisy:
-                    log.msg("unknown message type: %r" % msg)
-                # unknown message type
+                    log.msg("don't know about method %s" % msg[REQ])
+                # unknown method
+                olen = self._sendResponse(addr, msg[TID], ERR,
+                                          [KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
+            if self.noisy:
+                log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
+                                                  ilen, msg[REQ], olen))
+        elif msg[TYP] == RSP:
+            # if response
+            #  lookup tid
+            if self.tids.has_key(msg[TID]):
                 df = self.tids[msg[TID]]
                 #      callback
-                df.errback(KRPC_ERROR_RECEIVED_UNKNOWN)
                 del(self.tids[msg[TID]])
+                df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
+            else:
+                # no tid, this transaction timed out already...
+                if self.noisy:
+                    log.msg('timeout: %r' % msg[RSP]['id'])
+        elif msg[TYP] == ERR:
+            # if error
+            #  lookup tid
+            if self.tids.has_key(msg[TID]):
+                df = self.tids[msg[TID]]
+                del(self.tids[msg[TID]])
+                # callback
+                df.errback(KrpcError(*msg[ERR]))
+            else:
+                # day late and dollar short, just log it
+                log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
+                pass
+        else:
+            if self.noisy:
+                log.msg("unknown message type: %r" % msg)
+            # unknown message type
+            if msg[TID] in self.tids:
+                df = self.tids[msg[TID]]
+                del(self.tids[msg[TID]])
+                # callback
+                df.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
+                                     "Received an unknown message type: %r" % msg[TYP]))
                 
     def _sendResponse(self, addr, tid, msgType, response):
         if not response:
@@ -159,34 +228,33 @@ class KRPC:
     
     def sendRequest(self, method, args):
         if self.stopped:
-            raise ProtocolError, "connection has been stopped"
+            raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
         # make message
         # send it
         msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
         if self.noisy:
             log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
-        str = bencode(msg)
+        data = bencode(msg)
         d = Deferred()
         self.tids[msg[TID]] = d
-        def timeOut(tids = self.tids, id = msg[TID], msg = msg):
+        def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
             if tids.has_key(id):
                 df = tids[id]
                 del(tids[id])
-                log.msg(">>>>>> KRPC_ERROR_TIMEOUT")
-                df.errback(ProtocolError('timeout waiting for %r' % msg))
+                df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
         def dropTimeOut(dict, later_call = later):
             if later_call.active():
                 later_call.cancel()
             return dict
         d.addBoth(dropTimeOut)
-        self.transport.write(str, self.addr)
+        self.transport.write(data, self.addr)
         return d
     
     def stop(self):
         """Timeout all pending requests."""
         for df in self.tids.values():
-            df.errback(ProtocolError('connection has been closed'))
+            df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
         self.tids = {}
         self.stopped = True
  
@@ -199,8 +267,9 @@ class Receiver(protocol.Factory):
         self.buf = []
     def krpc_store(self, msg, _krpc_sender):
         self.buf += [msg]
+        return {}
     def krpc_echo(self, msg, _krpc_sender):
-        return msg
+        return {'msg': msg}
 
 def make(port):
     af = Receiver()
@@ -240,7 +309,7 @@ class KRPCTests(unittest.TestCase):
     def gotMsg(self, dict, should_be):
         _krpc_sender = dict['_krpc_sender']
         msg = dict['rsp']
-        self.failUnlessEqual(msg, should_be)
+        self.failUnlessEqual(msg['msg'], should_be)
 
     def testManyEcho(self):
         for i in xrange(100):
@@ -277,8 +346,13 @@ class KRPCTests(unittest.TestCase):
 
     def testUnknownMeth(self):
         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
-        df.addErrback(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
+        df.addBoth(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
+        return df
+
+    def testMalformedRequest(self):
+        df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test.", 'foo': 'bar'})
+        df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST)
         return df
 
     def gotErr(self, err, should_be):
-        self.failUnlessEqual(err.value, should_be)
+        self.failUnlessEqual(err.value[0], should_be)