Only add nodes to the routing table that have responded to a request.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1
2 """The routing table and buckets for a kademlia-like DHT.
3
4 @var K: the Kademlia "K" constant, this should be an even number
5 """
6
7 from datetime import datetime
8 from bisect import bisect_left
9 from math import log as loge
10
11 from twisted.python import log
12 from twisted.trial import unittest
13
14 import khash
15 from node import Node, NULL_ID
16
17 K = 8
18
19 class KTable:
20     """Local routing table for a kademlia-like distributed hash table.
21     
22     @type node: L{node.Node}
23     @ivar node: the local node
24     @type config: C{dictionary}
25     @ivar config: the configuration parameters for the DHT
26     @type buckets: C{list} of L{KBucket}
27     @ivar buckets: the buckets of nodes in the routing table
28     """
29     
30     def __init__(self, node, config):
31         """Initialize the first empty bucket of everything.
32         
33         @type node: L{node.Node}
34         @param node: the local node
35         @type config: C{dictionary}
36         @param config: the configuration parameters for the DHT
37         """
38         # this is the root node, a.k.a. US!
39         assert node.id != NULL_ID
40         self.node = node
41         self.config = config
42         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
43         
44     def _bucketIndexForInt(self, num):
45         """Find the index of the bucket that should hold the node's ID number."""
46         return bisect_left(self.buckets, num)
47     
48     def _nodeNum(self, id):
49         """Takes different types of input and converts to the node ID number.
50
51         @type id: C{string} of C{int} or L{node.Node}
52         @param id: the ID to find nodes that are close to
53         @raise TypeError: if id does not properly identify an ID
54         """
55
56         # Get the ID number from the input
57         if isinstance(id, str):
58             return khash.intify(id)
59         elif isinstance(id, Node):
60             return id.num
61         elif isinstance(id, int) or isinstance(id, long):
62             return id
63         else:
64             raise TypeError, "requires an int, string, or Node input"
65             
66     def findNodes(self, id):
67         """Find the K nodes in our own local table closest to the ID.
68
69         @type id: C{string} of C{int} or L{node.Node}
70         @param id: the ID to find nodes that are close to
71         """
72
73         # Get the ID number from the input
74         num = self._nodeNum(id)
75             
76         # Get the K closest nodes from the appropriate bucket
77         i = self._bucketIndexForInt(num)
78         nodes = list(self.buckets[i].l)
79         
80         # Make sure we have enough
81         if len(nodes) < K:
82             # Look in adjoining buckets for nodes
83             min = i - 1
84             max = i + 1
85             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
86                 # Add the adjoining buckets' nodes to the list
87                 if min >= 0:
88                     nodes = nodes + self.buckets[min].l
89                 if max < len(self.buckets):
90                     nodes = nodes + self.buckets[max].l
91                 min = min - 1
92                 max = max + 1
93     
94         # Sort the found nodes by proximity to the id and return the closest K
95         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
96         return nodes[:K]
97         
98     def _splitBucket(self, a):
99         """Split a bucket in two.
100         
101         @type a: L{KBucket}
102         @param a: the bucket to split
103         """
104         # Create a new bucket with half the (upper) range of the current bucket
105         diff = (a.max - a.min) / 2
106         b = KBucket([], a.max - diff, a.max)
107         self.buckets.insert(self.buckets.index(a.min) + 1, b)
108         
109         # Reduce the input bucket's (upper) range 
110         a.max = a.max - diff
111
112         # Transfer nodes to the new bucket
113         for anode in a.l[:]:
114             if anode.num >= a.max:
115                 a.l.remove(anode)
116                 b.l.append(anode)
117     
118     def _mergeBucket(self, i):
119         """Merge unneeded buckets after removing a node.
120         
121         @type i: C{int}
122         @param i: the index of the bucket that lost a node
123         """
124         bucketRange = self.buckets[i].max - self.buckets[i].min
125         otherBucket = None
126
127         # Find if either of the neighbor buckets is the same size
128         # (this will only happen if this or the neighbour has our node ID in its range)
129         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
130             otherBucket = i-1
131         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
132             otherBucket = i+1
133             
134         # Decide if we should do a merge
135         if otherBucket is not None and len(self.buckets[i].l) + len(self.buckets[otherBucket].l) <= K:
136             # Remove one bucket and set the other to cover its range as well
137             b = self.buckets[i]
138             a = self.buckets.pop(otherBucket)
139             b.min = min(b.min, a.min)
140             b.max = max(b.max, a.max)
141
142             # Transfer the nodes to the bucket we're keeping, merging the sorting
143             bi = 0
144             for anode in a.l:
145                 while bi < len(b.l) and b.l[bi].lastSeen <= anode.lastSeen:
146                     bi += 1
147                 b.l.insert(bi, anode)
148                 bi += 1
149                 
150             # Recurse to check if the neighbour buckets can also be merged
151             self._mergeBucket(min(i, otherBucket))
152     
153     def replaceStaleNode(self, stale, new = None):
154         """Replace a stale node in a bucket with a new one.
155         
156         This is used by clients to replace a node returned by insertNode after
157         it fails to respond to a ping.
158         
159         @type stale: L{node.Node}
160         @param stale: the stale node to remove from the bucket
161         @type new: L{node.Node}
162         @param new: the new node to add in it's place (optional, defaults to
163             not adding any node in the old node's place)
164         """
165         # Find the stale node's bucket
166         removed = False
167         i = self._bucketIndexForInt(stale.num)
168         try:
169             it = self.buckets[i].l.index(stale.num)
170         except ValueError:
171             pass
172         else:
173             # Remove the stale node
174             del(self.buckets[i].l[it])
175             removed = True
176             log.msg('Removed node from routing table: %s/%s' % (stale.host, stale.port))
177         
178         # Insert the new node
179         if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < K:
180             self.buckets[i].l.append(new)
181         elif removed:
182             self._mergeBucket(i)
183     
184     def insertNode(self, node, contacted = True):
185         """Try to insert a node in the routing table.
186         
187         This inserts the node, returning True if successful, False if the
188         node could have been added if it responds to a ping, otherwise returns
189         the oldest node in the bucket if it's full. The caller is then
190         responsible for pinging the returned node and calling replaceStaleNode
191         if it doesn't respond. contacted means that yes, we contacted THEM and
192         we know the node is reachable.
193         
194         @type node: L{node.Node}
195         @param node: the new node to try and insert
196         @type contacted: C{boolean}
197         @param contacted: whether the new node is known to be good, i.e.
198             responded to a request (optional, defaults to True)
199         @rtype: L{node.Node} or C{boolean}
200         @return: True if successful (the bucket wasn't full), False if the
201             node could have been added if it was contacted, otherwise
202             returns the oldest node in the bucket
203         """
204         assert node.id != NULL_ID
205         if node.id == self.node.id: return True
206
207         # Get the bucket for this node
208         i = self._bucketIndexForInt(node.num)
209
210         # Check to see if node is in the bucket already
211         try:
212             it = self.buckets[i].l.index(node.num)
213         except ValueError:
214             pass
215         else:
216             # The node is already in the bucket
217             if contacted:
218                 # It responded, so update it
219                 node.updateLastSeen()
220                 # move node to end of bucket
221                 del(self.buckets[i].l[it])
222                 # note that we removed the original and replaced it with the new one
223                 # utilizing this nodes new contact info
224                 self.buckets[i].l.append(node)
225                 self.buckets[i].touch()
226             return True
227         
228         # We don't have this node, check to see if the bucket is full
229         if len(self.buckets[i].l) < K:
230             # Not full, append this node and return
231             if contacted:
232                 node.updateLastSeen()
233                 self.buckets[i].l.append(node)
234                 self.buckets[i].touch()
235                 log.msg('Added node to routing table: %s/%s' % (node.host, node.port))
236                 return True
237             return False
238             
239         # Bucket is full, check to see if the local node is not in the bucket
240         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
241             # Local node not in the bucket, can't split it, return the oldest node
242             return self.buckets[i].l[0]
243         
244         # Make sure our table isn't FULL, this is really unlikely
245         if len(self.buckets) >= (khash.HASH_LENGTH*8):
246             log.err(RuntimeError("Hash Table is FULL! Increase K!"))
247             return
248             
249         # This bucket is full and contains our node, split the bucket
250         self._splitBucket(self.buckets[i])
251         
252         # Now that the bucket is split and balanced, try to insert the node again
253         return self.insertNode(node)
254     
255     def justSeenNode(self, id):
256         """Mark a node as just having been seen.
257         
258         Call this any time you get a message from a node, it will update it
259         in the table if it's there.
260
261         @type id: C{string} of C{int} or L{node.Node}
262         @param id: the node ID to mark as just having been seen
263         @rtype: C{datetime.datetime}
264         @return: the old lastSeen time of the node, or None if it's not in the table
265         """
266         # Get the bucket number
267         num = self._nodeNum(id)
268         i = self._bucketIndexForInt(num)
269
270         # Check to see if node is in the bucket
271         try:
272             it = self.buckets[i].l.index(num)
273         except ValueError:
274             return None
275         else:
276             # The node is in the bucket
277             n = self.buckets[i].l[it]
278             tstamp = n.lastSeen
279             n.updateLastSeen()
280             
281             # Move the node to the end and touch the bucket
282             del(self.buckets[i].l[it])
283             self.buckets[i].l.append(n)
284             self.buckets[i].touch()
285             
286             return tstamp
287     
288     def invalidateNode(self, n):
289         """Remove the node from the routing table.
290         
291         Forget about node n. Use this when you know that a node is invalid.
292         """
293         self.replaceStaleNode(n)
294     
295     def nodeFailed(self, node):
296         """Mark a node as having failed once, and remove it if it has failed too much."""
297         # Get the bucket number
298         num = self._nodeNum(node)
299         i = self._bucketIndexForInt(num)
300
301         # Check to see if node is in the bucket
302         try:
303             it = self.buckets[i].l.index(num)
304         except ValueError:
305             return None
306         else:
307             # The node is in the bucket
308             n = self.buckets[i].l[it]
309             if n.msgFailed() >= self.config['MAX_FAILURES']:
310                 self.invalidateNode(n)
311                         
312 class KBucket:
313     """Single bucket of nodes in a kademlia-like routing table.
314     
315     @type l: C{list} of L{node.Node}
316     @ivar l: the nodes that are in this bucket
317     @type min: C{long}
318     @ivar min: the minimum node ID that can be in this bucket
319     @type max: C{long}
320     @ivar max: the maximum node ID that can be in this bucket
321     @type lastAccessed: C{datetime.datetime}
322     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
323     """
324     
325     def __init__(self, contents, min, max):
326         """Initialize the bucket with nodes.
327         
328         @type contents: C{list} of L{node.Node}
329         @param contents: the nodes to store in the bucket
330         @type min: C{long}
331         @param min: the minimum node ID that can be in this bucket
332         @type max: C{long}
333         @param max: the maximum node ID that can be in this bucket
334         """
335         self.l = contents
336         self.min = min
337         self.max = max
338         self.lastAccessed = datetime.now()
339         
340     def touch(self):
341         """Update the L{lastAccessed} time."""
342         self.lastAccessed = datetime.now()
343     
344     def sort(self):
345         """Sort the nodes in the bucket by their lastSeen time."""
346         def _sort(a, b):
347             """Sort nodes by their lastSeen time."""
348             if a.lastSeen > b.lastSeen:
349                 return 1
350             elif a.lastSeen < b.lastSeen:
351                 return -1
352             return 0
353         self.l.sort(_sort)
354
355     def getNodeWithInt(self, num):
356         """Get the node in the bucket with that number.
357         
358         @type num: C{long}
359         @param num: the node ID to look for
360         @raise ValueError: if the node ID is not in the bucket
361         @rtype: L{node.Node}
362         @return: the node
363         """
364         if num in self.l: return num
365         else: raise ValueError
366         
367     def __repr__(self):
368         return "<KBucket %d items (%f to %f, range %d)>" % (
369                 len(self.l), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
370     
371     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
372     def __lt__(self, a):
373         if isinstance(a, Node): a = a.num
374         return self.max <= a
375     def __le__(self, a):
376         if isinstance(a, Node): a = a.num
377         return self.min < a
378     def __gt__(self, a):
379         if isinstance(a, Node): a = a.num
380         return self.min > a
381     def __ge__(self, a):
382         if isinstance(a, Node): a = a.num
383         return self.max >= a
384     def __eq__(self, a):
385         if isinstance(a, Node): a = a.num
386         return self.min <= a and self.max > a
387     def __ne__(self, a):
388         if isinstance(a, Node): a = a.num
389         return self.min >= a or self.max < a
390
391 class TestKTable(unittest.TestCase):
392     """Unit tests for the routing table."""
393     
394     def setUp(self):
395         self.a = Node(khash.newID(), '127.0.0.1', 2002)
396         self.t = KTable(self.a, {'MAX_FAILURES': 3})
397
398     def testAddNode(self):
399         self.b = Node(khash.newID(), '127.0.0.1', 2003)
400         self.t.insertNode(self.b)
401         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
402         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
403
404     def testRemove(self):
405         self.testAddNode()
406         self.t.invalidateNode(self.b)
407         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
408
409     def testMergeBuckets(self):
410         for i in xrange(1000):
411             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
412             self.t.insertNode(b)
413         num = len(self.t.buckets)
414         i = self.t._bucketIndexForInt(self.a.num)
415         for b in self.t.buckets[i].l[:]:
416             self.t.invalidateNode(b)
417         self.failUnlessEqual(len(self.t.buckets), num-1)
418
419     def testFail(self):
420         self.testAddNode()
421         for i in range(self.t.config['MAX_FAILURES'] - 1):
422             self.t.nodeFailed(self.b)
423             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
424             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
425             
426         self.t.nodeFailed(self.b)
427         self.failUnlessEqual(len(self.t.buckets[0].l), 0)