Consume ping errors so they aren't printed in the log.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1
2 """The routing table and buckets for a kademlia-like DHT.
3
4 @var K: the Kademlia "K" constant, this should be an even number
5 """
6
7 from datetime import datetime
8 from bisect import bisect_left
9 from math import log as loge
10
11 from twisted.python import log
12 from twisted.trial import unittest
13
14 import khash
15 from node import Node, NULL_ID
16
17 K = 8
18
19 class KTable:
20     """Local routing table for a kademlia-like distributed hash table.
21     
22     @type node: L{node.Node}
23     @ivar node: the local node
24     @type config: C{dictionary}
25     @ivar config: the configuration parameters for the DHT
26     @type buckets: C{list} of L{KBucket}
27     @ivar buckets: the buckets of nodes in the routing table
28     """
29     
30     def __init__(self, node, config):
31         """Initialize the first empty bucket of everything.
32         
33         @type node: L{node.Node}
34         @param node: the local node
35         @type config: C{dictionary}
36         @param config: the configuration parameters for the DHT
37         """
38         # this is the root node, a.k.a. US!
39         assert node.id != NULL_ID
40         self.node = node
41         self.config = config
42         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
43         
44     def _bucketIndexForInt(self, num):
45         """Find the index of the bucket that should hold the node's ID number."""
46         return bisect_left(self.buckets, num)
47     
48     def _nodeNum(self, id):
49         """Takes different types of input and converts to the node ID number.
50
51         @type id: C{string} or C{int} or L{node.Node}
52         @param id: the ID to find nodes that are close to
53         @raise TypeError: if id does not properly identify an ID
54         """
55
56         # Get the ID number from the input
57         if isinstance(id, str):
58             return khash.intify(id)
59         elif isinstance(id, Node):
60             return id.num
61         elif isinstance(id, int) or isinstance(id, long):
62             return id
63         else:
64             raise TypeError, "requires an int, string, or Node input"
65             
66     def findNodes(self, id):
67         """Find the K nodes in our own local table closest to the ID.
68
69         @type id: C{string} or C{int} or L{node.Node}
70         @param id: the ID to find nodes that are close to
71         """
72
73         # Get the ID number from the input
74         num = self._nodeNum(id)
75             
76         # Get the K closest nodes from the appropriate bucket
77         i = self._bucketIndexForInt(num)
78         nodes = self.buckets[i].list()
79         
80         # Make sure we have enough
81         if len(nodes) < K:
82             # Look in adjoining buckets for nodes
83             min = i - 1
84             max = i + 1
85             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
86                 # Add the adjoining buckets' nodes to the list
87                 if min >= 0:
88                     nodes = nodes + self.buckets[min].list()
89                 if max < len(self.buckets):
90                     nodes = nodes + self.buckets[max].list()
91                 min = min - 1
92                 max = max + 1
93     
94         # Sort the found nodes by proximity to the id and return the closest K
95         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
96         return nodes[:K]
97         
98     def touch(self, id):
99         """Mark a bucket as having been looked up.
100
101         @type id: C{string} or C{int} or L{node.Node}
102         @param id: the ID in the bucket that was accessed
103         """
104         # Get the bucket number from the input
105         num = self._nodeNum(id)
106         i = self._bucketIndexForInt(num)
107         
108         self.buckets[i].touch()
109
110     def _mergeBucket(self, i):
111         """Merge unneeded buckets after removing a node.
112         
113         @type i: C{int}
114         @param i: the index of the bucket that lost a node
115         """
116         bucketRange = self.buckets[i].max - self.buckets[i].min
117         otherBucket = None
118
119         # Find if either of the neighbor buckets is the same size
120         # (this will only happen if this or the neighbour has our node ID in its range)
121         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
122             otherBucket = i-1
123         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
124             otherBucket = i+1
125             
126         # Try and do a merge
127         if otherBucket is not None and self.buckets[i].merge(self.buckets[otherBucket]):
128             # Merge was successful, remove the old bucket
129             self.buckets.pop(otherBucket)
130                 
131             # Recurse to check if the neighbour buckets can also be merged
132             self._mergeBucket(min(i, otherBucket))
133     
134     def replaceStaleNode(self, stale, new = None):
135         """Replace a stale node in a bucket with a new one.
136         
137         This is used by clients to replace a node returned by insertNode after
138         it fails to respond to a ping.
139         
140         @type stale: L{node.Node}
141         @param stale: the stale node to remove from the bucket
142         @type new: L{node.Node}
143         @param new: the new node to add in it's place (optional, defaults to
144             not adding any node in the old node's place)
145         """
146         # Find the stale node's bucket
147         removed = False
148         i = self._bucketIndexForInt(stale.num)
149         try:
150             self.buckets[i].remove(stale.num)
151         except ValueError:
152             pass
153         else:
154             # Removed the stale node
155             removed = True
156             log.msg('Removed node from routing table: %s/%s' % (stale.host, stale.port))
157         
158         # Insert the new node
159         if new and self._bucketIndexForInt(new.num) == i and self.buckets[i].len() < K:
160             self.buckets[i].add(new)
161         elif removed:
162             self._mergeBucket(i)
163     
164     def insertNode(self, node, contacted = True):
165         """Try to insert a node in the routing table.
166         
167         This inserts the node, returning True if successful, False if the
168         node could have been added if it responds to a ping, otherwise returns
169         the oldest node in the bucket if it's full. The caller is then
170         responsible for pinging the returned node and calling replaceStaleNode
171         if it doesn't respond. contacted means that yes, we contacted THEM and
172         we know the node is reachable.
173         
174         @type node: L{node.Node}
175         @param node: the new node to try and insert
176         @type contacted: C{boolean}
177         @param contacted: whether the new node is known to be good, i.e.
178             responded to a request (optional, defaults to True)
179         @rtype: L{node.Node} or C{boolean}
180         @return: True if successful (the bucket wasn't full), False if the
181             node could have been added if it was contacted, otherwise
182             returns the oldest node in the bucket
183         """
184         assert node.id != NULL_ID
185         if node.id == self.node.id: return True
186
187         # Get the bucket for this node
188         i = self._bucketIndexForInt(node.num)
189
190         # Check to see if node is in the bucket already
191         try:
192             self.buckets[i].node(node.num)
193         except ValueError:
194             pass
195         else:
196             # The node is already in the bucket
197             if contacted:
198                 # It responded, so update it
199                 node.updateLastSeen()
200                 # move node to end of bucket
201                 self.buckets[i].remove(node.num)
202                 # note that we removed the original and replaced it with the new one
203                 # utilizing this nodes new contact info
204                 self.buckets[i].add(node)
205             return True
206         
207         # We don't have this node, check to see if the bucket is full
208         if self.buckets[i].len() < K:
209             # Not full, append this node and return
210             if contacted:
211                 node.updateLastSeen()
212                 self.buckets[i].add(node)
213                 log.msg('Added node to routing table: %s/%s' % (node.host, node.port))
214                 return True
215             return False
216             
217         # Bucket is full, check to see if the local node is not in the bucket
218         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
219             # Local node not in the bucket, can't split it, return the oldest node
220             return self.buckets[i].oldest()
221         
222         # Make sure our table isn't FULL, this is really unlikely
223         if len(self.buckets) >= (khash.HASH_LENGTH*8):
224             log.err(RuntimeError("Hash Table is FULL! Increase K!"))
225             return
226             
227         # This bucket is full and contains our node, split the bucket
228         newBucket = self.buckets[i].split()
229         self.buckets.insert(i + 1, newBucket)
230         
231         # Now that the bucket is split and balanced, try to insert the node again
232         return self.insertNode(node)
233     
234     def justSeenNode(self, id):
235         """Mark a node as just having been seen.
236         
237         Call this any time you get a message from a node, it will update it
238         in the table if it's there.
239
240         @type id: C{string} or C{int} or L{node.Node}
241         @param id: the node ID to mark as just having been seen
242         @rtype: C{datetime.datetime}
243         @return: the old lastSeen time of the node, or None if it's not in the table
244         """
245         # Get the bucket number
246         num = self._nodeNum(id)
247         i = self._bucketIndexForInt(num)
248
249         # Check to see if node is in the bucket
250         try:
251             tstamp = self.buckets[i].justSeen(num)
252         except ValueError:
253             return None
254         else:
255             return tstamp
256     
257     def invalidateNode(self, n):
258         """Remove the node from the routing table.
259         
260         Forget about node n. Use this when you know that a node is invalid.
261         """
262         self.replaceStaleNode(n)
263     
264     def nodeFailed(self, node):
265         """Mark a node as having failed once, and remove it if it has failed too much.
266         
267         @return: whether the node is in the routing table
268         """
269         # Get the bucket number
270         num = self._nodeNum(node)
271         i = self._bucketIndexForInt(num)
272
273         # Check to see if node is in the bucket
274         try:
275             n = self.buckets[i].node(num)
276         except ValueError:
277             return False
278         else:
279             # The node is in the bucket
280             if n.msgFailed() >= self.config['MAX_FAILURES']:
281                 self.invalidateNode(n)
282                 return False
283             return True
284                         
285 class KBucket:
286     """Single bucket of nodes in a kademlia-like routing table.
287     
288     @type nodes: C{list} of L{node.Node}
289     @ivar nodes: the nodes that are in this bucket
290     @type min: C{long}
291     @ivar min: the minimum node ID that can be in this bucket
292     @type max: C{long}
293     @ivar max: the maximum node ID that can be in this bucket
294     @type lastAccessed: C{datetime.datetime}
295     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
296     """
297     
298     def __init__(self, contents, min, max):
299         """Initialize the bucket with nodes.
300         
301         @type contents: C{list} of L{node.Node}
302         @param contents: the nodes to store in the bucket
303         @type min: C{long}
304         @param min: the minimum node ID that can be in this bucket
305         @type max: C{long}
306         @param max: the maximum node ID that can be in this bucket
307         """
308         self.nodes = contents
309         self.min = min
310         self.max = max
311         self.lastAccessed = datetime.now()
312         
313     def __repr__(self):
314         return "<KBucket %d items (%f to %f, range %d)>" % (
315                 len(self.nodes), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
316     
317     #{ List-like functions
318     def len(self): return len(self.nodes)
319     def list(self): return list(self.nodes)
320     def node(self, num): return self.nodes[self.nodes.index(num)]
321     def remove(self, num): return self.nodes.pop(self.nodes.index(num))
322     def oldest(self): return self.nodes[0]
323
324     def add(self, node):
325         """Add the node in the correct sorted order."""
326         i = len(self.nodes)
327         while i > 0 and node.lastSeen < self.nodes[i-1].lastSeen:
328             i -= 1
329         self.nodes.insert(i, node)
330         
331     def sort(self):
332         """Sort the nodes in the bucket by their lastSeen time."""
333         def _sort(a, b):
334             """Sort nodes by their lastSeen time."""
335             if a.lastSeen > b.lastSeen:
336                 return 1
337             elif a.lastSeen < b.lastSeen:
338                 return -1
339             return 0
340         self.nodes.sort(_sort)
341         
342     #{ Bucket functions
343     def touch(self):
344         """Update the L{lastAccessed} time."""
345         self.lastAccessed = datetime.now()
346     
347     def justSeen(self, num):
348         """Mark a node as having been seen.
349         
350         @param num: the number of the node just seen
351         """
352         i = self.nodes.index(num)
353         
354         # The node is in the bucket
355         n = self.nodes[i]
356         tstamp = n.lastSeen
357         n.updateLastSeen()
358         
359         # Move the node to the end and touch the bucket
360         self.nodes.pop(i)
361         self.nodes.append(n)
362         
363         return tstamp
364
365     def split(self):
366         """Split a bucket in two.
367         
368         @rtype: L{KBucket}
369         @return: the new bucket split from this one
370         """
371         # Create a new bucket with half the (upper) range of the current bucket
372         diff = (self.max - self.min) / 2
373         new = KBucket([], self.max - diff, self.max)
374         
375         # Reduce the input bucket's (upper) range 
376         self.max = self.max - diff
377
378         # Transfer nodes to the new bucket
379         for node in self.nodes[:]:
380             if node.num >= self.max:
381                 self.nodes.remove(node)
382                 new.add(node)
383         return new
384     
385     def merge(self, old):
386         """Try to merge two buckets into one.
387         
388         @type old: L{KBucket}
389         @param old: the bucket to merge into this one
390         @return: whether a merge was done or not
391         """
392         # Decide if we should do a merge
393         if len(self.nodes) + old.len() > K:
394             return False
395
396         # Set the range to cover the other's as well
397         self.min = min(self.min, old.min)
398         self.max = max(self.max, old.max)
399
400         # Transfer the other's nodes to this bucket, merging the sorting
401         i = 0
402         for node in old.list():
403             while i < len(self.nodes) and self.nodes[i].lastSeen <= node.lastSeen:
404                 i += 1
405             self.nodes.insert(i, node)
406             i += 1
407
408         return True
409                 
410     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
411     def __lt__(self, a):
412         if isinstance(a, Node): a = a.num
413         return self.max <= a
414     def __le__(self, a):
415         if isinstance(a, Node): a = a.num
416         return self.min < a
417     def __gt__(self, a):
418         if isinstance(a, Node): a = a.num
419         return self.min > a
420     def __ge__(self, a):
421         if isinstance(a, Node): a = a.num
422         return self.max >= a
423     def __eq__(self, a):
424         if isinstance(a, Node): a = a.num
425         return self.min <= a and self.max > a
426     def __ne__(self, a):
427         if isinstance(a, Node): a = a.num
428         return self.min >= a or self.max < a
429
430 class TestKTable(unittest.TestCase):
431     """Unit tests for the routing table."""
432     
433     def setUp(self):
434         self.a = Node(khash.newID(), '127.0.0.1', 2002)
435         self.t = KTable(self.a, {'MAX_FAILURES': 3})
436
437     def testAddNode(self):
438         self.b = Node(khash.newID(), '127.0.0.1', 2003)
439         self.t.insertNode(self.b)
440         self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
441         self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
442
443     def testRemove(self):
444         self.testAddNode()
445         self.t.invalidateNode(self.b)
446         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)
447
448     def testMergeBuckets(self):
449         for i in xrange(1000):
450             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
451             self.t.insertNode(b)
452         num = len(self.t.buckets)
453         i = self.t._bucketIndexForInt(self.a.num)
454         for b in self.t.buckets[i].nodes[:]:
455             self.t.invalidateNode(b)
456         self.failUnlessEqual(len(self.t.buckets), num-1)
457
458     def testFail(self):
459         self.testAddNode()
460         for i in range(self.t.config['MAX_FAILURES'] - 1):
461             self.t.nodeFailed(self.b)
462             self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
463             self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
464             
465         self.t.nodeFailed(self.b)
466         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)