199c6266f445ac603cf8bbcf96a6193c6cf8f3fe
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1
2 """The routing table and buckets for a kademlia-like DHT.
3
4 @var K: the Kademlia "K" constant, this should be an even number
5 """
6
7 from datetime import datetime
8 from bisect import bisect_left
9 from math import log as loge
10
11 from twisted.python import log
12 from twisted.trial import unittest
13
14 import khash
15 from node import Node, NULL_ID
16
17 K = 8
18
19 class KTable:
20     """Local routing table for a kademlia-like distributed hash table.
21     
22     @type node: L{node.Node}
23     @ivar node: the local node
24     @type config: C{dictionary}
25     @ivar config: the configuration parameters for the DHT
26     @type buckets: C{list} of L{KBucket}
27     @ivar buckets: the buckets of nodes in the routing table
28     """
29     
30     def __init__(self, node, config):
31         """Initialize the first empty bucket of everything.
32         
33         @type node: L{node.Node}
34         @param node: the local node
35         @type config: C{dictionary}
36         @param config: the configuration parameters for the DHT
37         """
38         # this is the root node, a.k.a. US!
39         assert node.id != NULL_ID
40         self.node = node
41         self.config = config
42         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
43         
44     def _bucketIndexForInt(self, num):
45         """Find the index of the bucket that should hold the node's ID number."""
46         return bisect_left(self.buckets, num)
47     
48     def _nodeNum(self, id):
49         """Takes different types of input and converts to the node ID number.
50
51         @type id: C{string} or C{int} or L{node.Node}
52         @param id: the ID to find nodes that are close to
53         @raise TypeError: if id does not properly identify an ID
54         """
55
56         # Get the ID number from the input
57         if isinstance(id, str):
58             return khash.intify(id)
59         elif isinstance(id, Node):
60             return id.num
61         elif isinstance(id, int) or isinstance(id, long):
62             return id
63         else:
64             raise TypeError, "requires an int, string, or Node input"
65             
66     def findNodes(self, id):
67         """Find the K nodes in our own local table closest to the ID.
68
69         @type id: C{string} or C{int} or L{node.Node}
70         @param id: the ID to find nodes that are close to
71         """
72
73         # Get the ID number from the input
74         num = self._nodeNum(id)
75             
76         # Get the K closest nodes from the appropriate bucket
77         i = self._bucketIndexForInt(num)
78         nodes = self.buckets[i].list()
79         
80         # Make sure we have enough
81         if len(nodes) < K:
82             # Look in adjoining buckets for nodes
83             min = i - 1
84             max = i + 1
85             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
86                 # Add the adjoining buckets' nodes to the list
87                 if min >= 0:
88                     nodes = nodes + self.buckets[min].list()
89                 if max < len(self.buckets):
90                     nodes = nodes + self.buckets[max].list()
91                 min = min - 1
92                 max = max + 1
93     
94         # Sort the found nodes by proximity to the id and return the closest K
95         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
96         return nodes[:K]
97         
98     def touch(self, id):
99         """Mark a bucket as having been looked up.
100
101         @type id: C{string} or C{int} or L{node.Node}
102         @param id: the ID in the bucket that was accessed
103         """
104         # Get the bucket number from the input
105         num = self._nodeNum(id)
106         i = self._bucketIndexForInt(num)
107         
108         self.buckets[i].touch()
109
110     def _mergeBucket(self, i):
111         """Merge unneeded buckets after removing a node.
112         
113         @type i: C{int}
114         @param i: the index of the bucket that lost a node
115         """
116         bucketRange = self.buckets[i].max - self.buckets[i].min
117         otherBucket = None
118
119         # Find if either of the neighbor buckets is the same size
120         # (this will only happen if this or the neighbour has our node ID in its range)
121         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
122             otherBucket = i-1
123         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
124             otherBucket = i+1
125             
126         # Try and do a merge
127         if otherBucket is not None and self.buckets[i].merge(self.buckets[otherBucket]):
128             # Merge was successful, remove the old bucket
129             self.buckets.pop(otherBucket)
130                 
131             # Recurse to check if the neighbour buckets can also be merged
132             self._mergeBucket(min(i, otherBucket))
133     
134     def replaceStaleNode(self, stale, new = None):
135         """Replace a stale node in a bucket with a new one.
136         
137         This is used by clients to replace a node returned by insertNode after
138         it fails to respond to a ping.
139         
140         @type stale: L{node.Node}
141         @param stale: the stale node to remove from the bucket
142         @type new: L{node.Node}
143         @param new: the new node to add in it's place (optional, defaults to
144             not adding any node in the old node's place)
145         """
146         # Find the stale node's bucket
147         removed = False
148         i = self._bucketIndexForInt(stale.num)
149         try:
150             self.buckets[i].remove(stale.num)
151         except ValueError:
152             pass
153         else:
154             # Removed the stale node
155             removed = True
156             log.msg('Removed node from routing table: %s/%s' % (stale.host, stale.port))
157         
158         # Insert the new node
159         if new and self._bucketIndexForInt(new.num) == i and self.buckets[i].len() < K:
160             self.buckets[i].add(new)
161         elif removed:
162             self._mergeBucket(i)
163     
164     def insertNode(self, node, contacted = True):
165         """Try to insert a node in the routing table.
166         
167         This inserts the node, returning True if successful, False if the
168         node could have been added if it responds to a ping, otherwise returns
169         the oldest node in the bucket if it's full. The caller is then
170         responsible for pinging the returned node and calling replaceStaleNode
171         if it doesn't respond. contacted means that yes, we contacted THEM and
172         we know the node is reachable.
173         
174         @type node: L{node.Node}
175         @param node: the new node to try and insert
176         @type contacted: C{boolean}
177         @param contacted: whether the new node is known to be good, i.e.
178             responded to a request (optional, defaults to True)
179         @rtype: L{node.Node} or C{boolean}
180         @return: True if successful (the bucket wasn't full), False if the
181             node could have been added if it was contacted, otherwise
182             returns the oldest node in the bucket
183         """
184         assert node.id != NULL_ID
185         if node.id == self.node.id: return True
186
187         # Get the bucket for this node
188         i = self._bucketIndexForInt(node.num)
189
190         # Check to see if node is in the bucket already
191         try:
192             self.buckets[i].node(node.num)
193         except ValueError:
194             pass
195         else:
196             # The node is already in the bucket
197             if contacted:
198                 # It responded, so update it
199                 node.updateLastSeen()
200                 # move node to end of bucket
201                 self.buckets[i].remove(node.num)
202                 # note that we removed the original and replaced it with the new one
203                 # utilizing this nodes new contact info
204                 self.buckets[i].add(node)
205             return True
206         
207         # We don't have this node, check to see if the bucket is full
208         if self.buckets[i].len() < K:
209             # Not full, append this node and return
210             if contacted:
211                 node.updateLastSeen()
212                 self.buckets[i].add(node)
213                 log.msg('Added node to routing table: %s/%s' % (node.host, node.port))
214                 return True
215             return False
216             
217         # Bucket is full, check to see if the local node is not in the bucket
218         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
219             # Local node not in the bucket, can't split it, return the oldest node
220             return self.buckets[i].oldest()
221         
222         # Make sure our table isn't FULL, this is really unlikely
223         if len(self.buckets) >= (khash.HASH_LENGTH*8):
224             log.err(RuntimeError("Hash Table is FULL! Increase K!"))
225             return
226             
227         # This bucket is full and contains our node, split the bucket
228         newBucket = self.buckets[i].split()
229         self.buckets.insert(i + 1, newBucket)
230         
231         # Now that the bucket is split and balanced, try to insert the node again
232         return self.insertNode(node)
233     
234     def justSeenNode(self, id):
235         """Mark a node as just having been seen.
236         
237         Call this any time you get a message from a node, it will update it
238         in the table if it's there.
239
240         @type id: C{string} or C{int} or L{node.Node}
241         @param id: the node ID to mark as just having been seen
242         @rtype: C{datetime.datetime}
243         @return: the old lastSeen time of the node, or None if it's not in the table
244         """
245         # Get the bucket number
246         num = self._nodeNum(id)
247         i = self._bucketIndexForInt(num)
248
249         # Check to see if node is in the bucket
250         try:
251             tstamp = self.buckets[i].justSeen(num)
252         except ValueError:
253             return None
254         else:
255             return tstamp
256     
257     def invalidateNode(self, n):
258         """Remove the node from the routing table.
259         
260         Forget about node n. Use this when you know that a node is invalid.
261         """
262         self.replaceStaleNode(n)
263     
264     def nodeFailed(self, node):
265         """Mark a node as having failed once, and remove it if it has failed too much."""
266         # Get the bucket number
267         num = self._nodeNum(node)
268         i = self._bucketIndexForInt(num)
269
270         # Check to see if node is in the bucket
271         try:
272             n = self.buckets[i].node(num)
273         except ValueError:
274             return None
275         else:
276             # The node is in the bucket
277             if n.msgFailed() >= self.config['MAX_FAILURES']:
278                 self.invalidateNode(n)
279                         
280 class KBucket:
281     """Single bucket of nodes in a kademlia-like routing table.
282     
283     @type nodes: C{list} of L{node.Node}
284     @ivar nodes: the nodes that are in this bucket
285     @type min: C{long}
286     @ivar min: the minimum node ID that can be in this bucket
287     @type max: C{long}
288     @ivar max: the maximum node ID that can be in this bucket
289     @type lastAccessed: C{datetime.datetime}
290     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
291     """
292     
293     def __init__(self, contents, min, max):
294         """Initialize the bucket with nodes.
295         
296         @type contents: C{list} of L{node.Node}
297         @param contents: the nodes to store in the bucket
298         @type min: C{long}
299         @param min: the minimum node ID that can be in this bucket
300         @type max: C{long}
301         @param max: the maximum node ID that can be in this bucket
302         """
303         self.nodes = contents
304         self.min = min
305         self.max = max
306         self.lastAccessed = datetime.now()
307         
308     def __repr__(self):
309         return "<KBucket %d items (%f to %f, range %d)>" % (
310                 len(self.nodes), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
311     
312     #{ List-like functions
313     def len(self): return len(self.nodes)
314     def list(self): return list(self.nodes)
315     def node(self, num): return self.nodes[self.nodes.index(num)]
316     def remove(self, num): return self.nodes.pop(self.nodes.index(num))
317     def oldest(self): return self.nodes[0]
318
319     def add(self, node):
320         """Add the node in the correct sorted order."""
321         i = len(self.nodes)
322         while i > 0 and node.lastSeen < self.nodes[i-1].lastSeen:
323             i -= 1
324         self.nodes.insert(i, node)
325         
326     def sort(self):
327         """Sort the nodes in the bucket by their lastSeen time."""
328         def _sort(a, b):
329             """Sort nodes by their lastSeen time."""
330             if a.lastSeen > b.lastSeen:
331                 return 1
332             elif a.lastSeen < b.lastSeen:
333                 return -1
334             return 0
335         self.nodes.sort(_sort)
336         
337     #{ Bucket functions
338     def touch(self):
339         """Update the L{lastAccessed} time."""
340         self.lastAccessed = datetime.now()
341     
342     def justSeen(self, num):
343         """Mark a node as having been seen.
344         
345         @param num: the number of the node just seen
346         """
347         i = self.nodes.index(num)
348         
349         # The node is in the bucket
350         n = self.nodes[i]
351         tstamp = n.lastSeen
352         n.updateLastSeen()
353         
354         # Move the node to the end and touch the bucket
355         self.nodes.pop(i)
356         self.nodes.append(n)
357         
358         return tstamp
359
360     def split(self):
361         """Split a bucket in two.
362         
363         @rtype: L{KBucket}
364         @return: the new bucket split from this one
365         """
366         # Create a new bucket with half the (upper) range of the current bucket
367         diff = (self.max - self.min) / 2
368         new = KBucket([], self.max - diff, self.max)
369         
370         # Reduce the input bucket's (upper) range 
371         self.max = self.max - diff
372
373         # Transfer nodes to the new bucket
374         for node in self.nodes[:]:
375             if node.num >= self.max:
376                 self.nodes.remove(node)
377                 new.add(node)
378         return new
379     
380     def merge(self, old):
381         """Try to merge two buckets into one.
382         
383         @type old: L{KBucket}
384         @param old: the bucket to merge into this one
385         @return: whether a merge was done or not
386         """
387         # Decide if we should do a merge
388         if len(self.nodes) + old.len() > K:
389             return False
390
391         # Set the range to cover the other's as well
392         self.min = min(self.min, old.min)
393         self.max = max(self.max, old.max)
394
395         # Transfer the other's nodes to this bucket, merging the sorting
396         i = 0
397         for node in old.list():
398             while i < len(self.nodes) and self.nodes[i].lastSeen <= node.lastSeen:
399                 i += 1
400             self.nodes.insert(i, node)
401             i += 1
402
403         return True
404                 
405     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
406     def __lt__(self, a):
407         if isinstance(a, Node): a = a.num
408         return self.max <= a
409     def __le__(self, a):
410         if isinstance(a, Node): a = a.num
411         return self.min < a
412     def __gt__(self, a):
413         if isinstance(a, Node): a = a.num
414         return self.min > a
415     def __ge__(self, a):
416         if isinstance(a, Node): a = a.num
417         return self.max >= a
418     def __eq__(self, a):
419         if isinstance(a, Node): a = a.num
420         return self.min <= a and self.max > a
421     def __ne__(self, a):
422         if isinstance(a, Node): a = a.num
423         return self.min >= a or self.max < a
424
425 class TestKTable(unittest.TestCase):
426     """Unit tests for the routing table."""
427     
428     def setUp(self):
429         self.a = Node(khash.newID(), '127.0.0.1', 2002)
430         self.t = KTable(self.a, {'MAX_FAILURES': 3})
431
432     def testAddNode(self):
433         self.b = Node(khash.newID(), '127.0.0.1', 2003)
434         self.t.insertNode(self.b)
435         self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
436         self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
437
438     def testRemove(self):
439         self.testAddNode()
440         self.t.invalidateNode(self.b)
441         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)
442
443     def testMergeBuckets(self):
444         for i in xrange(1000):
445             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
446             self.t.insertNode(b)
447         num = len(self.t.buckets)
448         i = self.t._bucketIndexForInt(self.a.num)
449         for b in self.t.buckets[i].nodes[:]:
450             self.t.invalidateNode(b)
451         self.failUnlessEqual(len(self.t.buckets), num-1)
452
453     def testFail(self):
454         self.testAddNode()
455         for i in range(self.t.config['MAX_FAILURES'] - 1):
456             self.t.nodeFailed(self.b)
457             self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
458             self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
459             
460         self.t.nodeFailed(self.b)
461         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)