Merge small DHT buckets when a node is removed.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 """The routing table and buckets for a kademlia-like DHT."""
5
6 from datetime import datetime
7 from bisect import bisect_left
8 from math import log as loge
9
10 from twisted.python import log
11 from twisted.trial import unittest
12
13 import khash
14 from node import Node, NULL_ID
15
16 class KTable:
17     """Local routing table for a kademlia-like distributed hash table.
18     
19     @type node: L{node.Node}
20     @ivar node: the local node
21     @type config: C{dictionary}
22     @ivar config: the configuration parameters for the DHT
23     @type buckets: C{list} of L{KBucket}
24     @ivar buckets: the buckets of nodes in the routing table
25     """
26     
27     def __init__(self, node, config):
28         """Initialize the first empty bucket of everything.
29         
30         @type node: L{node.Node}
31         @param node: the local node
32         @type config: C{dictionary}
33         @param config: the configuration parameters for the DHT
34         """
35         # this is the root node, a.k.a. US!
36         assert node.id != NULL_ID
37         self.node = node
38         self.config = config
39         self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
40         
41     def _bucketIndexForInt(self, num):
42         """Find the index of the bucket that should hold the node's ID number."""
43         return bisect_left(self.buckets, num)
44     
45     def _nodeNum(self, id):
46         """Takes different types of input and converts to the node ID number.
47
48         @type id: C{string} of C{int} or L{node.Node}
49         @param id: the ID to find nodes that are close to
50         @raise TypeError: if id does not properly identify an ID
51         """
52
53         # Get the ID number from the input
54         if isinstance(id, str):
55             return khash.intify(id)
56         elif isinstance(id, Node):
57             return id.num
58         elif isinstance(id, int) or isinstance(id, long):
59             return id
60         else:
61             raise TypeError, "requires an int, string, or Node input"
62             
63     def findNodes(self, id):
64         """Find the K nodes in our own local table closest to the ID.
65
66         @type id: C{string} of C{int} or L{node.Node}
67         @param id: the ID to find nodes that are close to
68         """
69
70         # Get the ID number from the input
71         num = self._nodeNum(id)
72             
73         # Get the K closest nodes from the appropriate bucket
74         i = self._bucketIndexForInt(num)
75         nodes = list(self.buckets[i].l)
76         
77         # Make sure we have enough
78         if len(nodes) < self.config['K']:
79             # Look in adjoining buckets for nodes
80             min = i - 1
81             max = i + 1
82             while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
83                 # Add the adjoining buckets' nodes to the list
84                 if min >= 0:
85                     nodes = nodes + self.buckets[min].l
86                 if max < len(self.buckets):
87                     nodes = nodes + self.buckets[max].l
88                 min = min - 1
89                 max = max + 1
90     
91         # Sort the found nodes by proximity to the id and return the closest K
92         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
93         return nodes[:self.config['K']]
94         
95     def _splitBucket(self, a):
96         """Split a bucket in two.
97         
98         @type a: L{KBucket}
99         @param a: the bucket to split
100         """
101         # Create a new bucket with half the (upper) range of the current bucket
102         diff = (a.max - a.min) / 2
103         b = KBucket([], a.max - diff, a.max)
104         self.buckets.insert(self.buckets.index(a.min) + 1, b)
105         
106         # Reduce the input bucket's (upper) range 
107         a.max = a.max - diff
108
109         # Transfer nodes to the new bucket
110         for anode in a.l[:]:
111             if anode.num >= a.max:
112                 a.l.remove(anode)
113                 b.l.append(anode)
114     
115     def _mergeBucket(self, i):
116         """Merge unneeded buckets after removing a node.
117         
118         @type i: C{int}
119         @param i: the index of the bucket that lost a node
120         """
121         bucketRange = self.buckets[i].max - self.buckets[i].min
122         otherBucket = None
123
124         # Find if either of the neighbor buckets is the same size
125         # (this will only happen if this or the neighbour has our node ID in its range)
126         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
127             otherBucket = i-1
128         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
129             otherBucket = i+1
130             
131         # Decide if we should do a merge
132         if otherBucket is not None and len(self.buckets[i].l) + len(self.buckets[otherBucket].l) <= self.config['K']:
133             # Remove one bucket and set the other to cover its range as well
134             b = self.buckets[i]
135             a = self.buckets.pop(otherBucket)
136             b.min = min(b.min, a.min)
137             b.max = max(b.max, a.max)
138
139             # Transfer the nodes to the bucket we're keeping, merging the sorting
140             bi = 0
141             for anode in a.l:
142                 while bi < len(b.l) and b.l[bi].lastSeen <= anode.lastSeen:
143                     bi += 1
144                 b.l.insert(bi, anode)
145                 bi += 1
146                 
147             # Recurse to check if the neighbour buckets can also be merged
148             self._mergeBucket(min(i, otherBucket))
149     
150     def replaceStaleNode(self, stale, new = None):
151         """Replace a stale node in a bucket with a new one.
152         
153         This is used by clients to replace a node returned by insertNode after
154         it fails to respond to a ping.
155         
156         @type stale: L{node.Node}
157         @param stale: the stale node to remove from the bucket
158         @type new: L{node.Node}
159         @param new: the new node to add in it's place (optional, defaults to
160             not adding any node in the old node's place)
161         """
162         # Find the stale node's bucket
163         removed = False
164         i = self._bucketIndexForInt(stale.num)
165         try:
166             it = self.buckets[i].l.index(stale.num)
167         except ValueError:
168             pass
169         else:
170             # Remove the stale node
171             del(self.buckets[i].l[it])
172             removed = True
173         
174         # Insert the new node
175         if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < self.config['K']:
176             self.buckets[i].l.append(new)
177         elif removed:
178             self._mergeBucket(i)
179     
180     def insertNode(self, node, contacted = True):
181         """Try to insert a node in the routing table.
182         
183         This inserts the node, returning None if successful, otherwise returns
184         the oldest node in the bucket if it's full. The caller is then
185         responsible for pinging the returned node and calling replaceStaleNode
186         if it doesn't respond. contacted means that yes, we contacted THEM and
187         we know the node is reachable.
188         
189         @type node: L{node.Node}
190         @param node: the new node to try and insert
191         @type contacted: C{boolean}
192         @param contacted: whether the new node is known to be good, i.e.
193             responded to a request (optional, defaults to True)
194         @rtype: L{node.Node}
195         @return: None if successful (the bucket wasn't full), otherwise returns the oldest node in the bucket
196         """
197         assert node.id != NULL_ID
198         if node.id == self.node.id: return
199
200         # Get the bucket for this node
201         i = self._bucketIndexForInt(node.num)
202
203         # Check to see if node is in the bucket already
204         try:
205             it = self.buckets[i].l.index(node.num)
206         except ValueError:
207             pass
208         else:
209             # The node is already in the bucket
210             if contacted:
211                 # It responded, so update it
212                 node.updateLastSeen()
213                 # move node to end of bucket
214                 del(self.buckets[i].l[it])
215                 # note that we removed the original and replaced it with the new one
216                 # utilizing this nodes new contact info
217                 self.buckets[i].l.append(node)
218                 self.buckets[i].touch()
219             return
220         
221         # We don't have this node, check to see if the bucket is full
222         if len(self.buckets[i].l) < self.config['K']:
223             # Not full, append this node and return
224             if contacted:
225                 node.updateLastSeen()
226             self.buckets[i].l.append(node)
227             self.buckets[i].touch()
228             return
229             
230         # Bucket is full, check to see if the local node is not in the bucket
231         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
232             # Local node not in the bucket, can't split it, return the oldest node
233             return self.buckets[i].l[0]
234         
235         # Make sure our table isn't FULL, this is really unlikely
236         if len(self.buckets) >= self.config['HASH_LENGTH']:
237             log.err("Hash Table is FULL!  Increase K!")
238             return
239             
240         # This bucket is full and contains our node, split the bucket
241         self._splitBucket(self.buckets[i])
242         
243         # Now that the bucket is split and balanced, try to insert the node again
244         return self.insertNode(node)
245     
246     def justSeenNode(self, id):
247         """Mark a node as just having been seen.
248         
249         Call this any time you get a message from a node, it will update it
250         in the table if it's there.
251
252         @type id: C{string} of C{int} or L{node.Node}
253         @param id: the node ID to mark as just having been seen
254         @rtype: C{datetime.datetime}
255         @return: the old lastSeen time of the node, or None if it's not in the table
256         """
257         # Get the bucket number
258         num = self._nodeNum(id)
259         i = self._bucketIndexForInt(num)
260
261         # Check to see if node is in the bucket
262         try:
263             it = self.buckets[i].l.index(num)
264         except ValueError:
265             return None
266         else:
267             # The node is in the bucket
268             n = self.buckets[i].l[it]
269             tstamp = n.lastSeen
270             n.updateLastSeen()
271             
272             # Move the node to the end and touch the bucket
273             del(self.buckets[i].l[it])
274             self.buckets[i].l.append(n)
275             self.buckets[i].touch()
276             
277             return tstamp
278     
279     def invalidateNode(self, n):
280         """Remove the node from the routing table.
281         
282         Forget about node n. Use this when you know that a node is invalid.
283         """
284         self.replaceStaleNode(n)
285     
286     def nodeFailed(self, node):
287         """Mark a node as having failed once, and remove it if it has failed too much."""
288         # Get the bucket number
289         num = self._nodeNum(node)
290         i = self._bucketIndexForInt(num)
291
292         # Check to see if node is in the bucket
293         try:
294             it = self.buckets[i].l.index(num)
295         except ValueError:
296             return None
297         else:
298             # The node is in the bucket
299             n = self.buckets[i].l[it]
300             if n.msgFailed() >= self.config['MAX_FAILURES']:
301                 self.invalidateNode(n)
302                         
303 class KBucket:
304     """Single bucket of nodes in a kademlia-like routing table.
305     
306     @type l: C{list} of L{node.Node}
307     @ivar l: the nodes that are in this bucket
308     @type min: C{long}
309     @ivar min: the minimum node ID that can be in this bucket
310     @type max: C{long}
311     @ivar max: the maximum node ID that can be in this bucket
312     @type lastAccessed: C{datetime.datetime}
313     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
314     """
315     
316     def __init__(self, contents, min, max):
317         """Initialize the bucket with nodes.
318         
319         @type contents: C{list} of L{node.Node}
320         @param contents: the nodes to store in the bucket
321         @type min: C{long}
322         @param min: the minimum node ID that can be in this bucket
323         @type max: C{long}
324         @param max: the maximum node ID that can be in this bucket
325         """
326         self.l = contents
327         self.min = min
328         self.max = max
329         self.lastAccessed = datetime.now()
330         
331     def touch(self):
332         """Update the L{lastAccessed} time."""
333         self.lastAccessed = datetime.now()
334     
335     def sort(self):
336         """Sort the nodes in the bucket by their lastSeen time."""
337         def _sort(a, b):
338             """Sort nodes by their lastSeen time."""
339             if a.lastSeen > b.lastSeen:
340                 return 1
341             elif a.lastSeen < b.lastSeen:
342                 return -1
343             return 0
344         self.l.sort(_sort)
345
346     def getNodeWithInt(self, num):
347         """Get the node in the bucket with that number.
348         
349         @type num: C{long}
350         @param num: the node ID to look for
351         @raise ValueError: if the node ID is not in the bucket
352         @rtype: L{node.Node}
353         @return: the node
354         """
355         if num in self.l: return num
356         else: raise ValueError
357         
358     def __repr__(self):
359         return "<KBucket %d items (%f to %f, range %d)>" % (
360                 len(self.l), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
361     
362     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
363     def __lt__(self, a):
364         if isinstance(a, Node): a = a.num
365         return self.max <= a
366     def __le__(self, a):
367         if isinstance(a, Node): a = a.num
368         return self.min < a
369     def __gt__(self, a):
370         if isinstance(a, Node): a = a.num
371         return self.min > a
372     def __ge__(self, a):
373         if isinstance(a, Node): a = a.num
374         return self.max >= a
375     def __eq__(self, a):
376         if isinstance(a, Node): a = a.num
377         return self.min <= a and self.max > a
378     def __ne__(self, a):
379         if isinstance(a, Node): a = a.num
380         return self.min >= a or self.max < a
381
382 class TestKTable(unittest.TestCase):
383     """Unit tests for the routing table."""
384     
385     def setUp(self):
386         self.a = Node(khash.newID(), '127.0.0.1', 2002)
387         self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
388
389     def testAddNode(self):
390         self.b = Node(khash.newID(), '127.0.0.1', 2003)
391         self.t.insertNode(self.b)
392         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
393         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
394
395     def testRemove(self):
396         self.testAddNode()
397         self.t.invalidateNode(self.b)
398         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
399
400     def testMergeBuckets(self):
401         for i in xrange(1000):
402             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
403             self.t.insertNode(b)
404         num = len(self.t.buckets)
405         i = self.t._bucketIndexForInt(self.a.num)
406         for b in self.t.buckets[i].l[:]:
407             self.t.invalidateNode(b)
408         self.failUnlessEqual(len(self.t.buckets), num-1)
409
410     def testFail(self):
411         self.testAddNode()
412         for i in range(self.t.config['MAX_FAILURES'] - 1):
413             self.t.nodeFailed(self.b)
414             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
415             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
416             
417         self.t.nodeFailed(self.b)
418         self.failUnlessEqual(len(self.t.buckets[0].l), 0)