Increase the concurrency of DHT requests to 8.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1
2 """The routing table and buckets for a kademlia-like DHT.
3
4 @var K: the Kademlia "K" constant, this should be an even number
5 """
6
7 from datetime import datetime
8 from bisect import bisect_left
9 from math import log as loge
10
11 from twisted.python import log
12 from twisted.trial import unittest
13
14 import khash
15 from node import Node, NULL_ID
16
17 K = 8
18
19 class KTable:
20     """Local routing table for a kademlia-like distributed hash table.
21     
22     @type node: L{node.Node}
23     @ivar node: the local node
24     @type config: C{dictionary}
25     @ivar config: the configuration parameters for the DHT
26     @type buckets: C{list} of L{KBucket}
27     @ivar buckets: the buckets of nodes in the routing table
28     """
29     
30     def __init__(self, node, config):
31         """Initialize the first empty bucket of everything.
32         
33         @type node: L{node.Node}
34         @param node: the local node
35         @type config: C{dictionary}
36         @param config: the configuration parameters for the DHT
37         """
38         # this is the root node, a.k.a. US!
39         assert node.id != NULL_ID
40         self.node = node
41         self.config = config
42         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
43         
44     def _bucketIndexForInt(self, num):
45         """Find the index of the bucket that should hold the node's ID number."""
46         return bisect_left(self.buckets, num)
47     
48     def _nodeNum(self, id):
49         """Takes different types of input and converts to the node ID number.
50
51         @type id: C{string} or C{int} or L{node.Node}
52         @param id: the ID to find nodes that are close to
53         @raise TypeError: if id does not properly identify an ID
54         """
55
56         # Get the ID number from the input
57         if isinstance(id, str):
58             return khash.intify(id)
59         elif isinstance(id, Node):
60             return id.num
61         elif isinstance(id, int) or isinstance(id, long):
62             return id
63         else:
64             raise TypeError, "requires an int, string, or Node input"
65             
66     def findNodes(self, id):
67         """Find the K nodes in our own local table closest to the ID.
68
69         @type id: C{string} or C{int} or L{node.Node}
70         @param id: the ID to find nodes that are close to
71         """
72
73         # Get the ID number from the input
74         num = self._nodeNum(id)
75             
76         # Get the K closest nodes from the appropriate bucket
77         i = self._bucketIndexForInt(num)
78         nodes = self.buckets[i].list()
79         
80         # Make sure we have enough
81         if len(nodes) < K:
82             # Look in adjoining buckets for nodes
83             min = i - 1
84             max = i + 1
85             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
86                 # Add the adjoining buckets' nodes to the list
87                 if min >= 0:
88                     nodes = nodes + self.buckets[min].list()
89                 if max < len(self.buckets):
90                     nodes = nodes + self.buckets[max].list()
91                 min = min - 1
92                 max = max + 1
93     
94         # Sort the found nodes by proximity to the id and return the closest K
95         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
96         return nodes[:K]
97         
98     def _mergeBucket(self, i):
99         """Merge unneeded buckets after removing a node.
100         
101         @type i: C{int}
102         @param i: the index of the bucket that lost a node
103         """
104         bucketRange = self.buckets[i].max - self.buckets[i].min
105         otherBucket = None
106
107         # Find if either of the neighbor buckets is the same size
108         # (this will only happen if this or the neighbour has our node ID in its range)
109         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
110             otherBucket = i-1
111         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
112             otherBucket = i+1
113             
114         # Try and do a merge
115         if otherBucket is not None and self.buckets[i].merge(self.buckets[otherBucket]):
116             # Merge was successful, remove the old bucket
117             self.buckets.pop(otherBucket)
118                 
119             # Recurse to check if the neighbour buckets can also be merged
120             self._mergeBucket(min(i, otherBucket))
121     
122     def replaceStaleNode(self, stale, new = None):
123         """Replace a stale node in a bucket with a new one.
124         
125         This is used by clients to replace a node returned by insertNode after
126         it fails to respond to a ping.
127         
128         @type stale: L{node.Node}
129         @param stale: the stale node to remove from the bucket
130         @type new: L{node.Node}
131         @param new: the new node to add in it's place (optional, defaults to
132             not adding any node in the old node's place)
133         """
134         # Find the stale node's bucket
135         removed = False
136         i = self._bucketIndexForInt(stale.num)
137         try:
138             self.buckets[i].remove(stale.num)
139         except ValueError:
140             pass
141         else:
142             # Removed the stale node
143             removed = True
144             log.msg('Removed node from routing table: %s/%s' % (stale.host, stale.port))
145         
146         # Insert the new node
147         if new and self._bucketIndexForInt(new.num) == i and self.buckets[i].len() < K:
148             self.buckets[i].add(new)
149         elif removed:
150             self._mergeBucket(i)
151     
152     def insertNode(self, node, contacted = True):
153         """Try to insert a node in the routing table.
154         
155         This inserts the node, returning True if successful, False if the
156         node could have been added if it responds to a ping, otherwise returns
157         the oldest node in the bucket if it's full. The caller is then
158         responsible for pinging the returned node and calling replaceStaleNode
159         if it doesn't respond. contacted means that yes, we contacted THEM and
160         we know the node is reachable.
161         
162         @type node: L{node.Node}
163         @param node: the new node to try and insert
164         @type contacted: C{boolean}
165         @param contacted: whether the new node is known to be good, i.e.
166             responded to a request (optional, defaults to True)
167         @rtype: L{node.Node} or C{boolean}
168         @return: True if successful (the bucket wasn't full), False if the
169             node could have been added if it was contacted, otherwise
170             returns the oldest node in the bucket
171         """
172         assert node.id != NULL_ID
173         if node.id == self.node.id: return True
174
175         # Get the bucket for this node
176         i = self._bucketIndexForInt(node.num)
177
178         # Check to see if node is in the bucket already
179         try:
180             self.buckets[i].node(node.num)
181         except ValueError:
182             pass
183         else:
184             # The node is already in the bucket
185             if contacted:
186                 # It responded, so update it
187                 node.updateLastSeen()
188                 # move node to end of bucket
189                 self.buckets[i].remove(node.num)
190                 # note that we removed the original and replaced it with the new one
191                 # utilizing this nodes new contact info
192                 self.buckets[i].add(node)
193                 self.buckets[i].touch()
194             return True
195         
196         # We don't have this node, check to see if the bucket is full
197         if self.buckets[i].len() < K:
198             # Not full, append this node and return
199             if contacted:
200                 node.updateLastSeen()
201                 self.buckets[i].add(node)
202                 self.buckets[i].touch()
203                 log.msg('Added node to routing table: %s/%s' % (node.host, node.port))
204                 return True
205             return False
206             
207         # Bucket is full, check to see if the local node is not in the bucket
208         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
209             # Local node not in the bucket, can't split it, return the oldest node
210             return self.buckets[i].oldest()
211         
212         # Make sure our table isn't FULL, this is really unlikely
213         if len(self.buckets) >= (khash.HASH_LENGTH*8):
214             log.err(RuntimeError("Hash Table is FULL! Increase K!"))
215             return
216             
217         # This bucket is full and contains our node, split the bucket
218         newBucket = self.buckets[i].split()
219         self.buckets.insert(i + 1, newBucket)
220         
221         # Now that the bucket is split and balanced, try to insert the node again
222         return self.insertNode(node)
223     
224     def justSeenNode(self, id):
225         """Mark a node as just having been seen.
226         
227         Call this any time you get a message from a node, it will update it
228         in the table if it's there.
229
230         @type id: C{string} or C{int} or L{node.Node}
231         @param id: the node ID to mark as just having been seen
232         @rtype: C{datetime.datetime}
233         @return: the old lastSeen time of the node, or None if it's not in the table
234         """
235         # Get the bucket number
236         num = self._nodeNum(id)
237         i = self._bucketIndexForInt(num)
238
239         # Check to see if node is in the bucket
240         try:
241             tstamp = self.buckets[i].justSeen(num)
242         except ValueError:
243             return None
244         else:
245             self.buckets[i].touch()
246             return tstamp
247     
248     def invalidateNode(self, n):
249         """Remove the node from the routing table.
250         
251         Forget about node n. Use this when you know that a node is invalid.
252         """
253         self.replaceStaleNode(n)
254     
255     def nodeFailed(self, node):
256         """Mark a node as having failed once, and remove it if it has failed too much."""
257         # Get the bucket number
258         num = self._nodeNum(node)
259         i = self._bucketIndexForInt(num)
260
261         # Check to see if node is in the bucket
262         try:
263             n = self.buckets[i].node(num)
264         except ValueError:
265             return None
266         else:
267             # The node is in the bucket
268             if n.msgFailed() >= self.config['MAX_FAILURES']:
269                 self.invalidateNode(n)
270                         
271 class KBucket:
272     """Single bucket of nodes in a kademlia-like routing table.
273     
274     @type nodes: C{list} of L{node.Node}
275     @ivar nodes: the nodes that are in this bucket
276     @type min: C{long}
277     @ivar min: the minimum node ID that can be in this bucket
278     @type max: C{long}
279     @ivar max: the maximum node ID that can be in this bucket
280     @type lastAccessed: C{datetime.datetime}
281     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
282     """
283     
284     def __init__(self, contents, min, max):
285         """Initialize the bucket with nodes.
286         
287         @type contents: C{list} of L{node.Node}
288         @param contents: the nodes to store in the bucket
289         @type min: C{long}
290         @param min: the minimum node ID that can be in this bucket
291         @type max: C{long}
292         @param max: the maximum node ID that can be in this bucket
293         """
294         self.nodes = contents
295         self.min = min
296         self.max = max
297         self.lastAccessed = datetime.now()
298         
299     def __repr__(self):
300         return "<KBucket %d items (%f to %f, range %d)>" % (
301                 len(self.nodes), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
302     
303     #{ List-like functions
304     def len(self): return len(self.nodes)
305     def list(self): return list(self.nodes)
306     def node(self, num): return self.nodes[self.nodes.index(num)]
307     def remove(self, num): return self.nodes.pop(self.nodes.index(num))
308     def oldest(self): return self.nodes[0]
309
310     def add(self, node):
311         """Add the node in the correct sorted order."""
312         i = len(self.nodes)
313         while i > 0 and node.lastSeen < self.nodes[i-1].lastSeen:
314             i -= 1
315         self.nodes.insert(i, node)
316         
317     def sort(self):
318         """Sort the nodes in the bucket by their lastSeen time."""
319         def _sort(a, b):
320             """Sort nodes by their lastSeen time."""
321             if a.lastSeen > b.lastSeen:
322                 return 1
323             elif a.lastSeen < b.lastSeen:
324                 return -1
325             return 0
326         self.nodes.sort(_sort)
327         
328     #{ Bucket functions
329     def touch(self):
330         """Update the L{lastAccessed} time."""
331         self.lastAccessed = datetime.now()
332     
333     def justSeen(self, num):
334         """Mark a node as having been seen.
335         
336         @param num: the number of the node just seen
337         """
338         i = self.nodes.index(num)
339         
340         # The node is in the bucket
341         n = self.nodes[i]
342         tstamp = n.lastSeen
343         n.updateLastSeen()
344         
345         # Move the node to the end and touch the bucket
346         self.nodes.pop(i)
347         self.nodes.append(n)
348         
349         return tstamp
350
351     def split(self):
352         """Split a bucket in two.
353         
354         @rtype: L{KBucket}
355         @return: the new bucket split from this one
356         """
357         # Create a new bucket with half the (upper) range of the current bucket
358         diff = (self.max - self.min) / 2
359         new = KBucket([], self.max - diff, self.max)
360         
361         # Reduce the input bucket's (upper) range 
362         self.max = self.max - diff
363
364         # Transfer nodes to the new bucket
365         for node in self.nodes[:]:
366             if node.num >= self.max:
367                 self.nodes.remove(node)
368                 new.add(node)
369         return new
370     
371     def merge(self, old):
372         """Try to merge two buckets into one.
373         
374         @type old: L{KBucket}
375         @param old: the bucket to merge into this one
376         @return: whether a merge was done or not
377         """
378         # Decide if we should do a merge
379         if len(self.nodes) + old.len() > K:
380             return False
381
382         # Set the range to cover the other's as well
383         self.min = min(self.min, old.min)
384         self.max = max(self.max, old.max)
385
386         # Transfer the other's nodes to this bucket, merging the sorting
387         i = 0
388         for node in old.list():
389             while i < len(self.nodes) and self.nodes[i].lastSeen <= node.lastSeen:
390                 i += 1
391             self.nodes.insert(i, node)
392             i += 1
393
394         return True
395                 
396     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
397     def __lt__(self, a):
398         if isinstance(a, Node): a = a.num
399         return self.max <= a
400     def __le__(self, a):
401         if isinstance(a, Node): a = a.num
402         return self.min < a
403     def __gt__(self, a):
404         if isinstance(a, Node): a = a.num
405         return self.min > a
406     def __ge__(self, a):
407         if isinstance(a, Node): a = a.num
408         return self.max >= a
409     def __eq__(self, a):
410         if isinstance(a, Node): a = a.num
411         return self.min <= a and self.max > a
412     def __ne__(self, a):
413         if isinstance(a, Node): a = a.num
414         return self.min >= a or self.max < a
415
416 class TestKTable(unittest.TestCase):
417     """Unit tests for the routing table."""
418     
419     def setUp(self):
420         self.a = Node(khash.newID(), '127.0.0.1', 2002)
421         self.t = KTable(self.a, {'MAX_FAILURES': 3})
422
423     def testAddNode(self):
424         self.b = Node(khash.newID(), '127.0.0.1', 2003)
425         self.t.insertNode(self.b)
426         self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
427         self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
428
429     def testRemove(self):
430         self.testAddNode()
431         self.t.invalidateNode(self.b)
432         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)
433
434     def testMergeBuckets(self):
435         for i in xrange(1000):
436             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
437             self.t.insertNode(b)
438         num = len(self.t.buckets)
439         i = self.t._bucketIndexForInt(self.a.num)
440         for b in self.t.buckets[i].nodes[:]:
441             self.t.invalidateNode(b)
442         self.failUnlessEqual(len(self.t.buckets), num-1)
443
444     def testFail(self):
445         self.testAddNode()
446         for i in range(self.t.config['MAX_FAILURES'] - 1):
447             self.t.nodeFailed(self.b)
448             self.failUnlessEqual(len(self.t.buckets[0].nodes), 1)
449             self.failUnlessEqual(self.t.buckets[0].nodes[0], self.b)
450             
451         self.t.nodeFailed(self.b)
452         self.failUnlessEqual(len(self.t.buckets[0].nodes), 0)