Fix an error in the actions that allowed for the result to be sent twice.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 """The routing table and buckets for a kademlia-like DHT.
5
6 @var K: the Kademlia "K" constant, this should be an even number
7 """
8
9 from datetime import datetime
10 from bisect import bisect_left
11 from math import log as loge
12
13 from twisted.python import log
14 from twisted.trial import unittest
15
16 import khash
17 from node import Node, NULL_ID
18
19 K = 8
20
21 class KTable:
22     """Local routing table for a kademlia-like distributed hash table.
23     
24     @type node: L{node.Node}
25     @ivar node: the local node
26     @type config: C{dictionary}
27     @ivar config: the configuration parameters for the DHT
28     @type buckets: C{list} of L{KBucket}
29     @ivar buckets: the buckets of nodes in the routing table
30     """
31     
32     def __init__(self, node, config):
33         """Initialize the first empty bucket of everything.
34         
35         @type node: L{node.Node}
36         @param node: the local node
37         @type config: C{dictionary}
38         @param config: the configuration parameters for the DHT
39         """
40         # this is the root node, a.k.a. US!
41         assert node.id != NULL_ID
42         self.node = node
43         self.config = config
44         self.buckets = [KBucket([], 0L, 2L**(khash.HASH_LENGTH*8))]
45         
46     def _bucketIndexForInt(self, num):
47         """Find the index of the bucket that should hold the node's ID number."""
48         return bisect_left(self.buckets, num)
49     
50     def _nodeNum(self, id):
51         """Takes different types of input and converts to the node ID number.
52
53         @type id: C{string} of C{int} or L{node.Node}
54         @param id: the ID to find nodes that are close to
55         @raise TypeError: if id does not properly identify an ID
56         """
57
58         # Get the ID number from the input
59         if isinstance(id, str):
60             return khash.intify(id)
61         elif isinstance(id, Node):
62             return id.num
63         elif isinstance(id, int) or isinstance(id, long):
64             return id
65         else:
66             raise TypeError, "requires an int, string, or Node input"
67             
68     def findNodes(self, id):
69         """Find the K nodes in our own local table closest to the ID.
70
71         @type id: C{string} of C{int} or L{node.Node}
72         @param id: the ID to find nodes that are close to
73         """
74
75         # Get the ID number from the input
76         num = self._nodeNum(id)
77             
78         # Get the K closest nodes from the appropriate bucket
79         i = self._bucketIndexForInt(num)
80         nodes = list(self.buckets[i].l)
81         
82         # Make sure we have enough
83         if len(nodes) < K:
84             # Look in adjoining buckets for nodes
85             min = i - 1
86             max = i + 1
87             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
88                 # Add the adjoining buckets' nodes to the list
89                 if min >= 0:
90                     nodes = nodes + self.buckets[min].l
91                 if max < len(self.buckets):
92                     nodes = nodes + self.buckets[max].l
93                 min = min - 1
94                 max = max + 1
95     
96         # Sort the found nodes by proximity to the id and return the closest K
97         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
98         return nodes[:K]
99         
100     def _splitBucket(self, a):
101         """Split a bucket in two.
102         
103         @type a: L{KBucket}
104         @param a: the bucket to split
105         """
106         # Create a new bucket with half the (upper) range of the current bucket
107         diff = (a.max - a.min) / 2
108         b = KBucket([], a.max - diff, a.max)
109         self.buckets.insert(self.buckets.index(a.min) + 1, b)
110         
111         # Reduce the input bucket's (upper) range 
112         a.max = a.max - diff
113
114         # Transfer nodes to the new bucket
115         for anode in a.l[:]:
116             if anode.num >= a.max:
117                 a.l.remove(anode)
118                 b.l.append(anode)
119     
120     def _mergeBucket(self, i):
121         """Merge unneeded buckets after removing a node.
122         
123         @type i: C{int}
124         @param i: the index of the bucket that lost a node
125         """
126         bucketRange = self.buckets[i].max - self.buckets[i].min
127         otherBucket = None
128
129         # Find if either of the neighbor buckets is the same size
130         # (this will only happen if this or the neighbour has our node ID in its range)
131         if i-1 >= 0 and self.buckets[i-1].max - self.buckets[i-1].min == bucketRange:
132             otherBucket = i-1
133         elif i+1 < len(self.buckets) and self.buckets[i+1].max - self.buckets[i+1].min == bucketRange:
134             otherBucket = i+1
135             
136         # Decide if we should do a merge
137         if otherBucket is not None and len(self.buckets[i].l) + len(self.buckets[otherBucket].l) <= K:
138             # Remove one bucket and set the other to cover its range as well
139             b = self.buckets[i]
140             a = self.buckets.pop(otherBucket)
141             b.min = min(b.min, a.min)
142             b.max = max(b.max, a.max)
143
144             # Transfer the nodes to the bucket we're keeping, merging the sorting
145             bi = 0
146             for anode in a.l:
147                 while bi < len(b.l) and b.l[bi].lastSeen <= anode.lastSeen:
148                     bi += 1
149                 b.l.insert(bi, anode)
150                 bi += 1
151                 
152             # Recurse to check if the neighbour buckets can also be merged
153             self._mergeBucket(min(i, otherBucket))
154     
155     def replaceStaleNode(self, stale, new = None):
156         """Replace a stale node in a bucket with a new one.
157         
158         This is used by clients to replace a node returned by insertNode after
159         it fails to respond to a ping.
160         
161         @type stale: L{node.Node}
162         @param stale: the stale node to remove from the bucket
163         @type new: L{node.Node}
164         @param new: the new node to add in it's place (optional, defaults to
165             not adding any node in the old node's place)
166         """
167         # Find the stale node's bucket
168         removed = False
169         i = self._bucketIndexForInt(stale.num)
170         try:
171             it = self.buckets[i].l.index(stale.num)
172         except ValueError:
173             pass
174         else:
175             # Remove the stale node
176             del(self.buckets[i].l[it])
177             removed = True
178         
179         # Insert the new node
180         if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < K:
181             self.buckets[i].l.append(new)
182         elif removed:
183             self._mergeBucket(i)
184     
185     def insertNode(self, node, contacted = True):
186         """Try to insert a node in the routing table.
187         
188         This inserts the node, returning None if successful, otherwise returns
189         the oldest node in the bucket if it's full. The caller is then
190         responsible for pinging the returned node and calling replaceStaleNode
191         if it doesn't respond. contacted means that yes, we contacted THEM and
192         we know the node is reachable.
193         
194         @type node: L{node.Node}
195         @param node: the new node to try and insert
196         @type contacted: C{boolean}
197         @param contacted: whether the new node is known to be good, i.e.
198             responded to a request (optional, defaults to True)
199         @rtype: L{node.Node}
200         @return: None if successful (the bucket wasn't full), otherwise returns the oldest node in the bucket
201         """
202         assert node.id != NULL_ID
203         if node.id == self.node.id: return
204
205         # Get the bucket for this node
206         i = self._bucketIndexForInt(node.num)
207
208         # Check to see if node is in the bucket already
209         try:
210             it = self.buckets[i].l.index(node.num)
211         except ValueError:
212             pass
213         else:
214             # The node is already in the bucket
215             if contacted:
216                 # It responded, so update it
217                 node.updateLastSeen()
218                 # move node to end of bucket
219                 del(self.buckets[i].l[it])
220                 # note that we removed the original and replaced it with the new one
221                 # utilizing this nodes new contact info
222                 self.buckets[i].l.append(node)
223                 self.buckets[i].touch()
224             return
225         
226         # We don't have this node, check to see if the bucket is full
227         if len(self.buckets[i].l) < K:
228             # Not full, append this node and return
229             if contacted:
230                 node.updateLastSeen()
231             self.buckets[i].l.append(node)
232             self.buckets[i].touch()
233             return
234             
235         # Bucket is full, check to see if the local node is not in the bucket
236         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
237             # Local node not in the bucket, can't split it, return the oldest node
238             return self.buckets[i].l[0]
239         
240         # Make sure our table isn't FULL, this is really unlikely
241         if len(self.buckets) >= (khash.HASH_LENGTH*8):
242             log.err("Hash Table is FULL!  Increase K!")
243             return
244             
245         # This bucket is full and contains our node, split the bucket
246         self._splitBucket(self.buckets[i])
247         
248         # Now that the bucket is split and balanced, try to insert the node again
249         return self.insertNode(node)
250     
251     def justSeenNode(self, id):
252         """Mark a node as just having been seen.
253         
254         Call this any time you get a message from a node, it will update it
255         in the table if it's there.
256
257         @type id: C{string} of C{int} or L{node.Node}
258         @param id: the node ID to mark as just having been seen
259         @rtype: C{datetime.datetime}
260         @return: the old lastSeen time of the node, or None if it's not in the table
261         """
262         # Get the bucket number
263         num = self._nodeNum(id)
264         i = self._bucketIndexForInt(num)
265
266         # Check to see if node is in the bucket
267         try:
268             it = self.buckets[i].l.index(num)
269         except ValueError:
270             return None
271         else:
272             # The node is in the bucket
273             n = self.buckets[i].l[it]
274             tstamp = n.lastSeen
275             n.updateLastSeen()
276             
277             # Move the node to the end and touch the bucket
278             del(self.buckets[i].l[it])
279             self.buckets[i].l.append(n)
280             self.buckets[i].touch()
281             
282             return tstamp
283     
284     def invalidateNode(self, n):
285         """Remove the node from the routing table.
286         
287         Forget about node n. Use this when you know that a node is invalid.
288         """
289         self.replaceStaleNode(n)
290     
291     def nodeFailed(self, node):
292         """Mark a node as having failed once, and remove it if it has failed too much."""
293         # Get the bucket number
294         num = self._nodeNum(node)
295         i = self._bucketIndexForInt(num)
296
297         # Check to see if node is in the bucket
298         try:
299             it = self.buckets[i].l.index(num)
300         except ValueError:
301             return None
302         else:
303             # The node is in the bucket
304             n = self.buckets[i].l[it]
305             if n.msgFailed() >= self.config['MAX_FAILURES']:
306                 self.invalidateNode(n)
307                         
308 class KBucket:
309     """Single bucket of nodes in a kademlia-like routing table.
310     
311     @type l: C{list} of L{node.Node}
312     @ivar l: the nodes that are in this bucket
313     @type min: C{long}
314     @ivar min: the minimum node ID that can be in this bucket
315     @type max: C{long}
316     @ivar max: the maximum node ID that can be in this bucket
317     @type lastAccessed: C{datetime.datetime}
318     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
319     """
320     
321     def __init__(self, contents, min, max):
322         """Initialize the bucket with nodes.
323         
324         @type contents: C{list} of L{node.Node}
325         @param contents: the nodes to store in the bucket
326         @type min: C{long}
327         @param min: the minimum node ID that can be in this bucket
328         @type max: C{long}
329         @param max: the maximum node ID that can be in this bucket
330         """
331         self.l = contents
332         self.min = min
333         self.max = max
334         self.lastAccessed = datetime.now()
335         
336     def touch(self):
337         """Update the L{lastAccessed} time."""
338         self.lastAccessed = datetime.now()
339     
340     def sort(self):
341         """Sort the nodes in the bucket by their lastSeen time."""
342         def _sort(a, b):
343             """Sort nodes by their lastSeen time."""
344             if a.lastSeen > b.lastSeen:
345                 return 1
346             elif a.lastSeen < b.lastSeen:
347                 return -1
348             return 0
349         self.l.sort(_sort)
350
351     def getNodeWithInt(self, num):
352         """Get the node in the bucket with that number.
353         
354         @type num: C{long}
355         @param num: the node ID to look for
356         @raise ValueError: if the node ID is not in the bucket
357         @rtype: L{node.Node}
358         @return: the node
359         """
360         if num in self.l: return num
361         else: raise ValueError
362         
363     def __repr__(self):
364         return "<KBucket %d items (%f to %f, range %d)>" % (
365                 len(self.l), loge(self.min+1)/loge(2), loge(self.max)/loge(2), loge(self.max-self.min)/loge(2))
366     
367     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
368     def __lt__(self, a):
369         if isinstance(a, Node): a = a.num
370         return self.max <= a
371     def __le__(self, a):
372         if isinstance(a, Node): a = a.num
373         return self.min < a
374     def __gt__(self, a):
375         if isinstance(a, Node): a = a.num
376         return self.min > a
377     def __ge__(self, a):
378         if isinstance(a, Node): a = a.num
379         return self.max >= a
380     def __eq__(self, a):
381         if isinstance(a, Node): a = a.num
382         return self.min <= a and self.max > a
383     def __ne__(self, a):
384         if isinstance(a, Node): a = a.num
385         return self.min >= a or self.max < a
386
387 class TestKTable(unittest.TestCase):
388     """Unit tests for the routing table."""
389     
390     def setUp(self):
391         self.a = Node(khash.newID(), '127.0.0.1', 2002)
392         self.t = KTable(self.a, {'MAX_FAILURES': 3})
393
394     def testAddNode(self):
395         self.b = Node(khash.newID(), '127.0.0.1', 2003)
396         self.t.insertNode(self.b)
397         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
398         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
399
400     def testRemove(self):
401         self.testAddNode()
402         self.t.invalidateNode(self.b)
403         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
404
405     def testMergeBuckets(self):
406         for i in xrange(1000):
407             b = Node(khash.newID(), '127.0.0.1', 2003 + i)
408             self.t.insertNode(b)
409         num = len(self.t.buckets)
410         i = self.t._bucketIndexForInt(self.a.num)
411         for b in self.t.buckets[i].l[:]:
412             self.t.invalidateNode(b)
413         self.failUnlessEqual(len(self.t.buckets), num-1)
414
415     def testFail(self):
416         self.testAddNode()
417         for i in range(self.t.config['MAX_FAILURES'] - 1):
418             self.t.nodeFailed(self.b)
419             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
420             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
421             
422         self.t.nodeFailed(self.b)
423         self.failUnlessEqual(len(self.t.buckets[0].l), 0)