Clean up the DHT config, making K and HASH_LENGTH constants instead.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
index fe117eecca1d36fbece52160673a4baba2f2815d..499c4d852e1e848fb93f30271728e0391a0cf256 100644 (file)
@@ -1,10 +1,14 @@
 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
 # see LICENSE.txt for license information
 
-"""The routing table and buckets for a kademlia-like DHT."""
+"""The routing table and buckets for a kademlia-like DHT.
+
+@var K: the Kademlia "K" constant, this should be an even number
+"""
 
 from datetime import datetime
 from bisect import bisect_left
+from math import log as loge
 
 from twisted.python import log
 from twisted.trial import unittest
@@ -12,6 +16,8 @@ from twisted.trial import unittest
 import khash
 from node import Node, NULL_ID
 
+K = 8
+
 class KTable:
     """Local routing table for a kademlia-like distributed hash table.
     
@@ -35,14 +41,14 @@ class KTable:
         assert node.id != NULL_ID
         self.node = node
         self.config = config
-        self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
+        self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
         
     def _bucketIndexForInt(self, num):
         """Find the index of the bucket that should hold the node's ID number."""
         return bisect_left(self.buckets, num)
     
-    def findNodes(self, id):
-        """Find the K nodes in our own local table closest to the ID.
+    def _nodeNum(self, id):
+        """Takes different types of input and converts to the node ID number.
 
         @type id: C{string} of C{int} or L{node.Node}
         @param id: the ID to find nodes that are close to
@@ -51,34 +57,34 @@ class KTable:
 
         # Get the ID number from the input
         if isinstance(id, str):
-            num = khash.intify(id)
+            return khash.intify(id)
         elif isinstance(id, Node):
-            num = id.num
+            return id.num
         elif isinstance(id, int) or isinstance(id, long):
-            num = id
+            return id
         else:
-            raise TypeError, "findNodes requires an int, string, or Node"
+            raise TypeError, "requires an int, string, or Node input"
             
-        nodes = []
-        i = self._bucketIndexForInt(num)
-        
-        # If this node is already in our table then return it
-        try:
-            index = self.buckets[i].l.index(num)
-        except ValueError:
-            pass
-        else:
-            return [self.buckets[i].l[index]]
+    def findNodes(self, id):
+        """Find the K nodes in our own local table closest to the ID.
+
+        @type id: C{string} of C{int} or L{node.Node}
+        @param id: the ID to find nodes that are close to
+        """
+
+        # Get the ID number from the input
+        num = self._nodeNum(id)
             
-        # Don't have the node, get the K closest nodes from the appropriate bucket
-        nodes = nodes + self.buckets[i].l
+        # Get the K closest nodes from the appropriate bucket
+        i = self._bucketIndexForInt(num)
+        nodes = list(self.buckets[i].l)
         
         # Make sure we have enough
-        if len(nodes) < self.config['K']:
+        if len(nodes) < K:
             # Look in adjoining buckets for nodes
             min = i - 1
             max = i + 1
-            while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
+            while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
                 # Add the adjoining buckets' nodes to the list
                 if min >= 0:
                     nodes = nodes + self.buckets[min].l
@@ -89,7 +95,7 @@ class KTable:
     
         # Sort the found nodes by proximity to the id and return the closest K
         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
-        return nodes[:self.config['K']]
+        return nodes[:K]
         
     def _splitBucket(self, a):
         """Split a bucket in two.
@@ -111,6 +117,41 @@ class KTable:
                 a.l.remove(anode)
                 b.l.append(anode)
     
+    def _mergeBucket(self, i):
+        """Merge unneeded buckets after removing a node.
+        
+        @type i: C{int}
+        @param i: the index of the bucket that lost a node
+        """
+        bucketRange = self.buckets[i].max - self.buckets[i].min
+        otherBucket = None
+
+        # Find if either of the neighbor buckets is the same size
+        # (this will only happen if this or the neighbour has our node ID in its range)
+        if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
+            otherBucket = i-1
+        elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
+            otherBucket = i+1
+            
+        # Decide if we should do a merge
+        if otherBucket is not None and len(self.buckets[i].l) + len(self.buckets[otherBucket].l) <= K:
+            # Remove one bucket and set the other to cover its range as well
+            b = self.buckets[i]
+            a = self.buckets.pop(otherBucket)
+            b.min = min(b.min, a.min)
+            b.max = max(b.max, a.max)
+
+            # Transfer the nodes to the bucket we're keeping, merging the sorting
+            bi = 0
+            for anode in a.l:
+                while bi < len(b.l) and b.l[bi].lastSeen <= anode.lastSeen:
+                    bi += 1
+                b.l.insert(bi, anode)
+                bi += 1
+                
+            # Recurse to check if the neighbour buckets can also be merged
+            self._mergeBucket(min(i, otherBucket))
+    
     def replaceStaleNode(self, stale, new = None):
         """Replace a stale node in a bucket with a new one.
         
@@ -124,16 +165,22 @@ class KTable:
             not adding any node in the old node's place)
         """
         # Find the stale node's bucket
+        removed = False
         i = self._bucketIndexForInt(stale.num)
         try:
             it = self.buckets[i].l.index(stale.num)
         except ValueError:
-            return
-    
-        # Remove the stale node and insert the new one
-        del(self.buckets[i].l[it])
-        if new:
+            pass
+        else:
+            # Remove the stale node
+            del(self.buckets[i].l[it])
+            removed = True
+        
+        # Insert the new node
+        if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < K:
             self.buckets[i].l.append(new)
+        elif removed:
+            self._mergeBucket(i)
     
     def insertNode(self, node, contacted = True):
         """Try to insert a node in the routing table.
@@ -156,7 +203,7 @@ class KTable:
         if node.id == self.node.id: return
 
         # Get the bucket for this node
-        i = self. _bucketIndexForInt(node.num)
+        i = self._bucketIndexForInt(node.num)
 
         # Check to see if node is in the bucket already
         try:
@@ -177,7 +224,7 @@ class KTable:
             return
         
         # We don't have this node, check to see if the bucket is full
-        if len(self.buckets[i].l) < self.config['K']:
+        if len(self.buckets[i].l) < K:
             # Not full, append this node and return
             if contacted:
                 node.updateLastSeen()
@@ -191,7 +238,7 @@ class KTable:
             return self.buckets[i].l[0]
         
         # Make sure our table isn't FULL, this is really unlikely
-        if len(self.buckets) >= self.config['HASH_LENGTH']:
+        if len(self.buckets) >= (khash.HASH_LENGTH*8):
             log.err("Hash Table is FULL!  Increase K!")
             return
             
@@ -212,13 +259,26 @@ class KTable:
         @rtype: C{datetime.datetime}
         @return: the old lastSeen time of the node, or None if it's not in the table
         """
+        # Get the bucket number
+        num = self._nodeNum(id)
+        i = self._bucketIndexForInt(num)
+
+        # Check to see if node is in the bucket
         try:
-            n = self.findNodes(id)[0]
-        except IndexError:
+            it = self.buckets[i].l.index(num)
+        except ValueError:
             return None
         else:
+            # The node is in the bucket
+            n = self.buckets[i].l[it]
             tstamp = n.lastSeen
             n.updateLastSeen()
+            
+            # Move the node to the end and touch the bucket
+            del(self.buckets[i].l[it])
+            self.buckets[i].l.append(n)
+            self.buckets[i].touch()
+            
             return tstamp
     
     def invalidateNode(self, n):
@@ -230,11 +290,18 @@ class KTable:
     
     def nodeFailed(self, node):
         """Mark a node as having failed once, and remove it if it has failed too much."""
+        # Get the bucket number
+        num = self._nodeNum(node)
+        i = self._bucketIndexForInt(num)
+
+        # Check to see if node is in the bucket
         try:
-            n = self.findNodes(node.num)[0]
-        except IndexError:
+            it = self.buckets[i].l.index(num)
+        except ValueError:
             return None
         else:
+            # The node is in the bucket
+            n = self.buckets[i].l[it]
             if n.msgFailed() >= self.config['MAX_FAILURES']:
                 self.invalidateNode(n)
                         
@@ -270,6 +337,17 @@ class KBucket:
         """Update the L{lastAccessed} time."""
         self.lastAccessed = datetime.now()
     
+    def sort(self):
+        """Sort the nodes in the bucket by their lastSeen time."""
+        def _sort(a, b):
+            """Sort nodes by their lastSeen time."""
+            if a.lastSeen > b.lastSeen:
+                return 1
+            elif a.lastSeen < b.lastSeen:
+                return -1
+            return 0
+        self.l.sort(_sort)
+
     def getNodeWithInt(self, num):
         """Get the node in the bucket with that number.
         
@@ -283,7 +361,8 @@ class KBucket:
         else: raise ValueError
         
     def __repr__(self):
-        return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
+        return "<KBucket %d items (%f to %f, range %d)>" % (
+                len(self.l), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
     
     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
     def __lt__(self, a):
@@ -310,7 +389,7 @@ class TestKTable(unittest.TestCase):
     
     def setUp(self):
         self.a = Node(khash.newID(), '127.0.0.1', 2002)
-        self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
+        self.t = KTable(self.a, {'MAX_FAILURES': 3})
 
     def testAddNode(self):
         self.b = Node(khash.newID(), '127.0.0.1', 2003)
@@ -323,6 +402,16 @@ class TestKTable(unittest.TestCase):
         self.t.invalidateNode(self.b)
         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
 
+    def testMergeBuckets(self):
+        for i in xrange(1000):
+            b = Node(khash.newID(), '127.0.0.1', 2003 + i)
+            self.t.insertNode(b)
+        num = len(self.t.buckets)
+        i = self.t._bucketIndexForInt(self.a.num)
+        for b in self.t.buckets[i].l[:]:
+            self.t.invalidateNode(b)
+        self.failUnlessEqual(len(self.t.buckets), num-1)
+
     def testFail(self):
         self.testAddNode()
         for i in range(self.t.config['MAX_FAILURES'] - 1):