Improve the creation of nodes and move all to the main khashmir class.
authorCameron Dale <camrdale@gmail.com>
Wed, 9 Jan 2008 03:11:59 +0000 (19:11 -0800)
committerCameron Dale <camrdale@gmail.com>
Wed, 9 Jan 2008 03:11:59 +0000 (19:11 -0800)
Nodes are now initialized with their id, host and port
(optionally stored in a dict) instead of needing 2 steps.

All other classes now call the main khashmir Node() constructor
which adds the udp connection and table.

Previous actions that called insertNode on the table now
call it on the main khashmir class, so that if buckets are
full they can have a chance of being added.

apt_dht_Khashmir/actions.py
apt_dht_Khashmir/khashmir.py
apt_dht_Khashmir/ktable.py
apt_dht_Khashmir/node.py

index 860205c..a99b7ea 100644 (file)
@@ -7,8 +7,8 @@ from khash import intify
 
 class ActionBase:
     """ base class for some long running asynchronous proccesses like finding nodes or values """
-    def __init__(self, table, target, callback, config):
-        self.table = table
+    def __init__(self, caller, target, callback, config):
+        self.caller = caller
         self.target = target
         self.config = config
         self.num = intify(target)
@@ -41,21 +41,16 @@ class FindNode(ActionBase):
     def handleGotNodes(self, dict):
         _krpc_sender = dict['_krpc_sender']
         dict = dict['rsp']
+        n = self.caller.Node(dict["id"], _krpc_sender[0], _krpc_sender[1])
+        self.caller.insertNode(n)
         l = dict["nodes"]
-        sender = {'id' : dict["id"]}
-        sender['port'] = _krpc_sender[1]        
-        sender['host'] = _krpc_sender[0]        
-        sender = self.table.Node().initWithDict(sender)
-        sender.conn = self.table.udp.connectionForAddr((sender.host, sender.port))
-        self.table.table.insertNode(sender)
-        if self.finished or self.answered.has_key(sender.id):
+        if self.finished or self.answered.has_key(dict["id"]):
             # a day late and a dollar short
             return
         self.outstanding = self.outstanding - 1
-        self.answered[sender.id] = 1
+        self.answered[dict["id"]] = 1
         for node in l:
-            n = self.table.Node().initWithDict(node)
-            n.conn = self.table.udp.connectionForAddr((n.host, n.port))
+            n = self.caller.Node(node)
             if not self.found.has_key(n.id):
                 self.found[n.id] = n
         self.schedule()
@@ -72,9 +67,9 @@ class FindNode(ActionBase):
             if node.id == self.target:
                 self.finished=1
                 return self.callback([node])
-            if (not self.queried.has_key(node.id)) and node.id != self.table.node.id:
+            if (not self.queried.has_key(node.id)) and node.id != self.caller.node.id:
                 #xxxx t.timeout = time.time() + FIND_NODE_TIMEOUT
-                df = node.findNode(self.target, self.table.node.id)
+                df = node.findNode(self.target, self.caller.node.id)
                 df.addCallbacks(self.handleGotNodes, self.makeMsgFailed(node))
                 self.outstanding = self.outstanding + 1
                 self.queried[node.id] = 1
@@ -88,8 +83,8 @@ class FindNode(ActionBase):
     
     def makeMsgFailed(self, node):
         def defaultGotNodes(err, self=self, node=node):
-            print ">>> find failed %s/%s" % (node.host, node.port), err
-            self.table.table.nodeFailed(node)
+            print ">>> find failed (%s) %s/%s" % (self.config['PORT'], node.host, node.port), err
+            self.caller.table.nodeFailed(node)
             self.outstanding = self.outstanding - 1
             self.schedule()
         return defaultGotNodes
@@ -100,7 +95,7 @@ class FindNode(ActionBase):
             it's a transaction since we got called from the dispatcher
         """
         for node in nodes:
-            if node.id == self.table.node.id:
+            if node.id == self.caller.node.id:
                 continue
             else:
                 self.found[node.id] = node
@@ -110,31 +105,26 @@ class FindNode(ActionBase):
 
 get_value_timeout = 15
 class GetValue(FindNode):
-    def __init__(self, table, target, callback, config, find="findValue"):
-        FindNode.__init__(self, table, target, callback, config)
+    def __init__(self, caller, target, callback, config, find="findValue"):
+        FindNode.__init__(self, caller, target, callback, config)
         self.findValue = find
             
     """ get value task """
     def handleGotNodes(self, dict):
         _krpc_sender = dict['_krpc_sender']
         dict = dict['rsp']
-        sender = {'id' : dict["id"]}
-        sender['port'] = _krpc_sender[1]
-        sender['host'] = _krpc_sender[0]                
-        sender = self.table.Node().initWithDict(sender)
-        sender.conn = self.table.udp.connectionForAddr((sender.host, sender.port))
-        self.table.table.insertNode(sender)
-        if self.finished or self.answered.has_key(sender.id):
+        n = self.caller.Node(dict["id"], _krpc_sender[0], _krpc_sender[1])
+        self.caller.insertNode(n)
+        if self.finished or self.answered.has_key(dict["id"]):
             # a day late and a dollar short
             return
         self.outstanding = self.outstanding - 1
-        self.answered[sender.id] = 1
+        self.answered[dict["id"]] = 1
         # go through nodes
         # if we have any closer than what we already got, query them
         if dict.has_key('nodes'):
             for node in dict['nodes']:
-                n = self.table.Node().initWithDict(node)
-                n.conn = self.table.udp.connectionForAddr((n.host, n.port))
+                n = self.caller.Node(node)
                 if not self.found.has_key(n.id):
                     self.found[n.id] = n
         elif dict.has_key('values'):
@@ -158,14 +148,14 @@ class GetValue(FindNode):
         l.sort(self.sort)
         
         for node in l[:self.config['K']]:
-            if (not self.queried.has_key(node.id)) and node.id != self.table.node.id:
+            if (not self.queried.has_key(node.id)) and node.id != self.caller.node.id:
                 #xxx t.timeout = time.time() + GET_VALUE_TIMEOUT
                 try:
                     f = getattr(node, self.findValue)
                 except AttributeError:
                     print ">>> findValue %s doesn't have a %s method!" % (node, self.findValue)
                 else:
-                    df = f(self.target, self.table.node.id)
+                    df = f(self.target, self.caller.node.id)
                     df.addCallback(self.handleGotNodes)
                     df.addErrback(self.makeMsgFailed(node))
                     self.outstanding = self.outstanding + 1
@@ -185,7 +175,7 @@ class GetValue(FindNode):
             for n in found:
                 self.results[n] = 1
         for node in nodes:
-            if node.id == self.table.node.id:
+            if node.id == self.caller.node.id:
                 continue
             else:
                 self.found[node.id] = node
@@ -194,15 +184,15 @@ class GetValue(FindNode):
 
 
 class StoreValue(ActionBase):
-    def __init__(self, table, target, value, callback, config, store="storeValue"):
-        ActionBase.__init__(self, table, target, callback, config)
+    def __init__(self, caller, target, value, callback, config, store="storeValue"):
+        ActionBase.__init__(self, caller, target, callback, config)
         self.value = value
         self.stored = []
         self.store = store
         
     def storedValue(self, t, node):
         self.outstanding -= 1
-        self.table.insertNode(node)
+        self.caller.insertNode(node)
         if self.finished:
             return
         self.stored.append(t)
@@ -216,7 +206,7 @@ class StoreValue(ActionBase):
     
     def storeFailed(self, t, node):
         print ">>> store failed %s/%s" % (node.host, node.port)
-        self.table.nodeFailed(node)
+        self.caller.nodeFailed(node)
         self.outstanding -= 1
         if self.finished:
             return t
@@ -237,14 +227,14 @@ class StoreValue(ActionBase):
                     self.finished = 1
                     self.callback(self.target, self.value, self.stored)
             else:
-                if not node.id == self.table.node.id:
+                if not node.id == self.caller.node.id:
                     self.outstanding += 1
                     try:
                         f = getattr(node, self.store)
                     except AttributeError:
                         print ">>> %s doesn't have a %s method!" % (node, self.store)
                     else:
-                        df = f(self.target, self.value, self.table.node.id)
+                        df = f(self.target, self.value, self.caller.node.id)
                         df.addCallback(self.storedValue, node=node)
                         df.addErrback(self.storeFailed, node=node)
                     
index c4d018c..ef1b826 100644 (file)
@@ -42,9 +42,11 @@ class KhashmirBase(protocol.Factory):
         self.refreshTable(force=1)
         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
 
-    def Node(self):
-        n = self._Node()
+    def Node(self, id, host = None, port = None):
+        """Create a new node."""
+        n = self._Node(id, host, port)
         n.table = self.table
+        n.conn = self.udp.connectionForAddr((n.host, n.port))
         return n
     
     def __del__(self):
@@ -54,7 +56,7 @@ class KhashmirBase(protocol.Factory):
         id = self.store.getSelfNode()
         if not id:
             id = newID()
-        return self._Node().init(id, host, port)
+        return self._Node(id, host, port)
         
     def checkpoint(self, auto=0):
         self.store.saveSelfNode(self.node.id)
@@ -72,15 +74,9 @@ class KhashmirBase(protocol.Factory):
         """
         nodes = self.store.getRoutingTable()
         for rec in nodes:
-            n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
-            n.conn = self.udp.connectionForAddr((n.host, n.port))
+            n = self.Node(rec[0], rec[1], int(rec[2]))
             self.table.insertNode(n, contacted=0)
             
-    def _update_node(self, id, host, port):
-        n = self.Node().init(id, host, port)
-        n.conn = self.udp.connectionForAddr((host, port))
-        self.insertNode(n, contacted=0)
-    
 
     #######
     #######  LOCAL INTERFACE    - use these methods!
@@ -88,8 +84,7 @@ class KhashmirBase(protocol.Factory):
         """
             ping this node and add the contact info to the table on pong!
         """
-        n =self.Node().init(NULL_ID, host, port) 
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
+        n = self.Node(NULL_ID, host, port)
         self.sendPing(n, callback=callback)
 
     ## this call is async!
@@ -144,15 +139,9 @@ class KhashmirBase(protocol.Factory):
         """
         df = node.ping(self.node.id)
         ## these are the callbacks we use when we issue a PING
-        def _pongHandler(dict, node=node, table=self.table, callback=callback):
-            _krpc_sender = dict['_krpc_sender']
-            dict = dict['rsp']
-            sender = {'id' : dict['id']}
-            sender['host'] = _krpc_sender[0]
-            sender['port'] = _krpc_sender[1]
-            n = self.Node().initWithDict(sender)
-            n.conn = self.udp.connectionForAddr((n.host, n.port))
-            table.insertNode(n)
+        def _pongHandler(dict, node=node, self=self, callback=callback):
+            n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+            self.insertNode(n)
             if callback:
                 callback()
         def _defaultPong(err, node=node, table=self.table, callback=callback):
@@ -206,11 +195,13 @@ class KhashmirBase(protocol.Factory):
 
     #### Remote Interface - called by remote nodes
     def krpc_ping(self, id, _krpc_sender):
-        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
         return {"id" : self.node.id}
         
     def krpc_find_node(self, target, id, _krpc_sender):
-        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
         nodes = self.table.findNodes(target)
         nodes = map(lambda node: node.senderDict(), nodes)
         return {"nodes" : nodes, "id" : self.node.id}
@@ -243,7 +234,8 @@ class KhashmirRead(KhashmirBase):
 
     #### Remote Interface - called by remote nodes
     def krpc_find_value(self, key, id, _krpc_sender):
-        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
     
         l = self.store.retrieveValues(key)
         if len(l) > 0:
@@ -277,7 +269,8 @@ class KhashmirWrite(KhashmirRead):
                     
     #### Remote Interface - called by remote nodes
     def krpc_store_value(self, key, value, id, _krpc_sender):
-        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
         self.store.storeValue(key, value)
         return {"id" : self.node.id}
 
index 85abe4a..7ffde39 100644 (file)
@@ -13,10 +13,10 @@ class KTable:
     """local routing table for a kademlia like distributed hash table"""
     def __init__(self, node, config):
         # this is the root node, a.k.a. US!
+        assert node.id != NULL_ID
         self.node = node
         self.config = config
         self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
-        self.insertNode(node)
         
     def _bucketIndexForInt(self, num):
         """the index of the bucket that should hold int"""
@@ -210,11 +210,11 @@ class KBucket:
 
 class TestKTable(unittest.TestCase):
     def setUp(self):
-        self.a = Node().init(khash.newID(), 'localhost', 2002)
+        self.a = Node(khash.newID(), 'localhost', 2002)
         self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
 
     def testAddNode(self):
-        self.b = Node().init(khash.newID(), 'localhost', 2003)
+        self.b = Node(khash.newID(), 'localhost', 2003)
         self.t.insertNode(self.b)
         self.assertEqual(len(self.t.buckets[0].l), 1)
         self.assertEqual(self.t.buckets[0].l[0], self.b)
index 5f11e51..609e666 100644 (file)
@@ -13,26 +13,23 @@ NULL_ID = 20 * '\0'
 
 class Node:
     """encapsulate contact info"""
-    def __init__(self):
+    def __init__(self, id, host = None, port = None):
         self.fails = 0
         self.lastSeen = datetime(MINYEAR, 1, 1)
-        self.id = self.host = self.port = ''
-    
-    def init(self, id, host, port):
+
+        # Alternate method, init Node from dictionary
+        if isinstance(id, dict):
+            host = id['host']
+            port = id['port']
+            id = id['id']
+
+        assert(isinstance(id, str))
+        assert(isinstance(host, str))
         self.id = id
         self.num = khash.intify(id)
         self.host = host
-        self.port = port
+        self.port = int(port)
         self._senderDict = {'id': self.id, 'port' : self.port, 'host' : self.host}
-        return self
-    
-    def initWithDict(self, dict):
-        self._senderDict = dict
-        self.id = dict['id']
-        self.num = khash.intify(self.id)
-        self.port = dict['port']
-        self.host = dict['host']
-        return self
     
     def updateLastSeen(self):
         self.lastSeen = datetime.now()
@@ -77,7 +74,7 @@ class Node:
 
 class TestNode(unittest.TestCase):
     def setUp(self):
-        self.node = Node().init(khash.newID(), 'localhost', 2002)
+        self.node = Node(khash.newID(), 'localhost', 2002)
     def testUpdateLastSeen(self):
         t = self.node.lastSeen
         self.node.updateLastSeen()