Clean up the copyrights mentioned in the code.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1
2 """The routing table and buckets for a kademlia-like DHT.
3
4 @var K: the Kademlia "K" constant, this should be an even number
5 """
6
7 from datetime import datetime
8 from bisect import bisect_left
9 from math import log as loge
10
11 from twisted.python import log
12 from twisted.trial import unittest
13
14 import khash
15 from node import Node, NULL_ID
16
17 K = 8
18
19 class KTable:
20     """Local routing table for a kademlia-like distributed hash table.
21     
22     @type node: L{node.Node}
23     @ivar node: the local node
24     @type config: C{dictionary}
25     @ivar config: the configuration parameters for the DHT
26     @type buckets: C{list} of L{KBucket}
27     @ivar buckets: the buckets of nodes in the routing table
28     """
29     
30     def __init__(self, node, config):
31         """Initialize the first empty bucket of everything.
32         
33         @type node: L{node.Node}
34         @param node: the local node
35         @type config: C{dictionary}
36         @param config: the configuration parameters for the DHT
37         """
38         # this is the root node, a.k.a. US!
39         assert node.id != NULL_ID
40         self.node = node
41         self.config = config
42         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
43         
44     def _bucketIndexForInt(self, num):
45         """Find the index of the bucket that should hold the node's ID number."""
46         return bisect_left(self.buckets, num)
47     
48     def _nodeNum(self, id):
49         """Takes different types of input and converts to the node ID number.
50
51         @type id: C{string} of C{int} or L{node.Node}
52         @param id: the ID to find nodes that are close to
53         @raise TypeError: if id does not properly identify an ID
54         """
55
56         # Get the ID number from the input
57         if isinstance(id, str):
58             return khash.intify(id)
59         elif isinstance(id, Node):
60             return id.num
61         elif isinstance(id, int) or isinstance(id, long):
62             return id
63         else:
64             raise TypeError, "requires an int, string, or Node input"
65             
66     def findNodes(self, id):
67         """Find the K nodes in our own local table closest to the ID.
68
69         @type id: C{string} of C{int} or L{node.Node}
70         @param id: the ID to find nodes that are close to
71         """
72
73         # Get the ID number from the input
74         num = self._nodeNum(id)
75             
76         # Get the K closest nodes from the appropriate bucket
77         i = self._bucketIndexForInt(num)
78         nodes = list(self.buckets[i].l)
79         
80         # Make sure we have enough
81         if len(nodes) < K:
82             # Look in adjoining buckets for nodes
83             min = i - 1
84             max = i + 1
85             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
86                 # Add the adjoining buckets' nodes to the list
87                 if min >= 0:
88                     nodes = nodes + self.buckets[min].l
89                 if max < len(self.buckets):
90                     nodes = nodes + self.buckets[max].l
91                 min = min - 1
92                 max = max + 1
93     
94         # Sort the found nodes by proximity to the id and return the closest K
95         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
96         return nodes[:K]
97         
98     def _splitBucket(self, a):
99         """Split a bucket in two.
100         
101         @type a: L{KBucket}
102         @param a: the bucket to split
103         """
104         # Create a new bucket with half the (upper) range of the current bucket
105         diff = (a.max - a.min) / 2
106         b = KBucket([], a.max - diff, a.max)
107         self.buckets.insert(self.buckets.index(a.min) + 1, b)
108         
109         # Reduce the input bucket's (upper) range 
110         a.max = a.max - diff
111
112         # Transfer nodes to the new bucket
113         for anode in a.l[:]:
114             if anode.num >= a.max:
115                 a.l.remove(anode)
116                 b.l.append(anode)
117     
118     def _mergeBucket(self, i):
119         """Merge unneeded buckets after removing a node.
120         
121         @type i: C{int}
122         @param i: the index of the bucket that lost a node
123         """
124         bucketRange = self.buckets[i].max - self.buckets[i].min
125         otherBucket = None
126
127         # Find if either of the neighbor buckets is the same size
128         # (this will only happen if this or the neighbour has our node ID in its range)
129         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
130             otherBucket = i-1
131         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
132             otherBucket = i+1
133             
134         # Decide if we should do a merge
135         if otherBucket is not None and len(self.buckets[i].l) + len(self.buckets[otherBucket].l) <= K:
136             # Remove one bucket and set the other to cover its range as well
137             b = self.buckets[i]
138             a = self.buckets.pop(otherBucket)
139             b.min = min(b.min, a.min)
140             b.max = max(b.max, a.max)
141
142             # Transfer the nodes to the bucket we're keeping, merging the sorting
143             bi = 0
144             for anode in a.l:
145                 while bi < len(b.l) and b.l[bi].lastSeen <= anode.lastSeen:
146                     bi += 1
147                 b.l.insert(bi, anode)
148                 bi += 1
149                 
150             # Recurse to check if the neighbour buckets can also be merged
151             self._mergeBucket(min(i, otherBucket))
152     
153     def replaceStaleNode(self, stale, new = None):
154         """Replace a stale node in a bucket with a new one.
155         
156         This is used by clients to replace a node returned by insertNode after
157         it fails to respond to a ping.
158         
159         @type stale: L{node.Node}
160         @param stale: the stale node to remove from the bucket
161         @type new: L{node.Node}
162         @param new: the new node to add in it's place (optional, defaults to
163             not adding any node in the old node's place)
164         """
165         # Find the stale node's bucket
166         removed = False
167         i = self._bucketIndexForInt(stale.num)
168         try:
169             it = self.buckets[i].l.index(stale.num)
170         except ValueError:
171             pass
172         else:
173             # Remove the stale node
174             del(self.buckets[i].l[it])
175             removed = True
176         
177         # Insert the new node
178         if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < K:
179             self.buckets[i].l.append(new)
180         elif removed:
181             self._mergeBucket(i)
182     
183     def insertNode(self, node, contacted = True):
184         """Try to insert a node in the routing table.
185         
186         This inserts the node, returning None if successful, otherwise returns
187         the oldest node in the bucket if it's full. The caller is then
188         responsible for pinging the returned node and calling replaceStaleNode
189         if it doesn't respond. contacted means that yes, we contacted THEM and
190         we know the node is reachable.
191         
192         @type node: L{node.Node}
193         @param node: the new node to try and insert
194         @type contacted: C{boolean}
195         @param contacted: whether the new node is known to be good, i.e.
196             responded to a request (optional, defaults to True)
197         @rtype: L{node.Node}
198         @return: None if successful (the bucket wasn't full), otherwise returns the oldest node in the bucket
199         """
200         assert node.id != NULL_ID
201         if node.id == self.node.id: return
202
203         # Get the bucket for this node
204         i = self._bucketIndexForInt(node.num)
205
206         # Check to see if node is in the bucket already
207         try:
208             it = self.buckets[i].l.index(node.num)
209         except ValueError:
210             pass
211         else:
212             # The node is already in the bucket
213             if contacted:
214                 # It responded, so update it
215                 node.updateLastSeen()
216                 # move node to end of bucket
217                 del(self.buckets[i].l[it])
218                 # note that we removed the original and replaced it with the new one
219                 # utilizing this nodes new contact info
220                 self.buckets[i].l.append(node)
221                 self.buckets[i].touch()
222             return
223         
224         # We don't have this node, check to see if the bucket is full
225         if len(self.buckets[i].l) < K:
226             # Not full, append this node and return
227             if contacted:
228                 node.updateLastSeen()
229             self.buckets[i].l.append(node)
230             self.buckets[i].touch()
231             return
232             
233         # Bucket is full, check to see if the local node is not in the bucket
234         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
235             # Local node not in the bucket, can't split it, return the oldest node
236             return self.buckets[i].l[0]
237         
238         # Make sure our table isn't FULL, this is really unlikely
239         if len(self.buckets) >= (khash.HASH_LENGTH*8):
240             log.err(RuntimeError("Hash Table is FULL! Increase K!"))
241             return
242             
243         # This bucket is full and contains our node, split the bucket
244         self._splitBucket(self.buckets[i])
245         
246         # Now that the bucket is split and balanced, try to insert the node again
247         return self.insertNode(node)
248     
249     def justSeenNode(self, id):
250         """Mark a node as just having been seen.
251         
252         Call this any time you get a message from a node, it will update it
253         in the table if it's there.
254
255         @type id: C{string} of C{int} or L{node.Node}
256         @param id: the node ID to mark as just having been seen
257         @rtype: C{datetime.datetime}
258         @return: the old lastSeen time of the node, or None if it's not in the table
259         """
260         # Get the bucket number
261         num = self._nodeNum(id)
262         i = self._bucketIndexForInt(num)
263
264         # Check to see if node is in the bucket
265         try:
266             it = self.buckets[i].l.index(num)
267         except ValueError:
268             return None
269         else:
270             # The node is in the bucket
271             n = self.buckets[i].l[it]
272             tstamp = n.lastSeen
273             n.updateLastSeen()
274             
275             # Move the node to the end and touch the bucket
276             del(self.buckets[i].l[it])
277             self.buckets[i].l.append(n)
278             self.buckets[i].touch()
279             
280             return tstamp
281     
282     def invalidateNode(self, n):
283         """Remove the node from the routing table.
284         
285         Forget about node n. Use this when you know that a node is invalid.
286         """
287         self.replaceStaleNode(n)
288     
289     def nodeFailed(self, node):
290         """Mark a node as having failed once, and remove it if it has failed too much."""
291         # Get the bucket number
292         num = self._nodeNum(node)
293         i = self._bucketIndexForInt(num)
294
295         # Check to see if node is in the bucket
296         try:
297             it = self.buckets[i].l.index(num)
298         except ValueError:
299             return None
300         else:
301             # The node is in the bucket
302             n = self.buckets[i].l[it]
303             if n.msgFailed() >= self.config['MAX_FAILURES']:
304                 self.invalidateNode(n)
305                         
306 class KBucket:
307     """Single bucket of nodes in a kademlia-like routing table.
308     
309     @type l: C{list} of L{node.Node}
310     @ivar l: the nodes that are in this bucket
311     @type min: C{long}
312     @ivar min: the minimum node ID that can be in this bucket
313     @type max: C{long}
314     @ivar max: the maximum node ID that can be in this bucket
315     @type lastAccessed: C{datetime.datetime}
316     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
317     """
318     
319     def __init__(self, contents, min, max):
320         """Initialize the bucket with nodes.
321         
322         @type contents: C{list} of L{node.Node}
323         @param contents: the nodes to store in the bucket
324         @type min: C{long}
325         @param min: the minimum node ID that can be in this bucket
326         @type max: C{long}
327         @param max: the maximum node ID that can be in this bucket
328         """
329         self.l = contents
330         self.min = min
331         self.max = max
332         self.lastAccessed = datetime.now()
333         
334     def touch(self):
335         """Update the L{lastAccessed} time."""
336         self.lastAccessed = datetime.now()
337     
338     def sort(self):
339         """Sort the nodes in the bucket by their lastSeen time."""
340         def _sort(a, b):
341             """Sort nodes by their lastSeen time."""
342             if a.lastSeen > b.lastSeen:
343                 return 1
344             elif a.lastSeen < b.lastSeen:
345                 return -1
346             return 0
347         self.l.sort(_sort)
348
349     def getNodeWithInt(self, num):
350         """Get the node in the bucket with that number.
351         
352         @type num: C{long}
353         @param num: the node ID to look for
354         @raise ValueError: if the node ID is not in the bucket
355         @rtype: L{node.Node}
356         @return: the node
357         """
358         if num in self.l: return num
359         else: raise ValueError
360         
361     def __repr__(self):
362         return "<KBucket %d items (%f to %f, range %d)>" % (
363                 len(self.l), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
364     
365     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
366     def __lt__(self, a):
367         if isinstance(a, Node): a = a.num
368         return self.max <= a
369     def __le__(self, a):
370         if isinstance(a, Node): a = a.num
371         return self.min < a
372     def __gt__(self, a):
373         if isinstance(a, Node): a = a.num
374         return self.min > a
375     def __ge__(self, a):
376         if isinstance(a, Node): a = a.num
377         return self.max >= a
378     def __eq__(self, a):
379         if isinstance(a, Node): a = a.num
380         return self.min <= a and self.max > a
381     def __ne__(self, a):
382         if isinstance(a, Node): a = a.num
383         return self.min >= a or self.max < a
384
385 class TestKTable(unittest.TestCase):
386     """Unit tests for the routing table."""
387     
388     def setUp(self):
389         self.a = Node(khash.newID(), '127.0.0.1', 2002)
390         self.t = KTable(self.a, {'MAX_FAILURES': 3})
391
392     def testAddNode(self):
393         self.b = Node(khash.newID(), '127.0.0.1', 2003)
394         self.t.insertNode(self.b)
395         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
396         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
397
398     def testRemove(self):
399         self.testAddNode()
400         self.t.invalidateNode(self.b)
401         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
402
403     def testMergeBuckets(self):
404         for i in xrange(1000):
405             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
406             self.t.insertNode(b)
407         num = len(self.t.buckets)
408         i = self.t._bucketIndexForInt(self.a.num)
409         for b in self.t.buckets[i].l[:]:
410             self.t.invalidateNode(b)
411         self.failUnlessEqual(len(self.t.buckets), num-1)
412
413     def testFail(self):
414         self.testAddNode()
415         for i in range(self.t.config['MAX_FAILURES'] - 1):
416             self.t.nodeFailed(self.b)
417             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
418             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
419             
420         self.t.nodeFailed(self.b)
421         self.failUnlessEqual(len(self.t.buckets[0].l), 0)