Workaround old sqlite not having 'select count(distinct key)'.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 """The routing table and buckets for a kademlia-like DHT."""
5
6 from datetime import datetime
7 from bisect import bisect_left
8
9 from twisted.python import log
10 from twisted.trial import unittest
11
12 import khash
13 from node import Node, NULL_ID
14
15 class KTable:
16     """Local routing table for a kademlia-like distributed hash table.
17     
18     @type node: L{node.Node}
19     @ivar node: the local node
20     @type config: C{dictionary}
21     @ivar config: the configuration parameters for the DHT
22     @type buckets: C{list} of L{KBucket}
23     @ivar buckets: the buckets of nodes in the routing table
24     """
25     
26     def __init__(self, node, config):
27         """Initialize the first empty bucket of everything.
28         
29         @type node: L{node.Node}
30         @param node: the local node
31         @type config: C{dictionary}
32         @param config: the configuration parameters for the DHT
33         """
34         # this is the root node, a.k.a. US!
35         assert node.id != NULL_ID
36         self.node = node
37         self.config = config
38         self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
39         
40     def _bucketIndexForInt(self, num):
41         """Find the index of the bucket that should hold the node's ID number."""
42         return bisect_left(self.buckets, num)
43     
44     def _nodeNum(self, id):
45         """Takes different types of input and converts to the node ID number.
46
47         @type id: C{string} of C{int} or L{node.Node}
48         @param id: the ID to find nodes that are close to
49         @raise TypeError: if id does not properly identify an ID
50         """
51
52         # Get the ID number from the input
53         if isinstance(id, str):
54             return khash.intify(id)
55         elif isinstance(id, Node):
56             return id.num
57         elif isinstance(id, int) or isinstance(id, long):
58             return id
59         else:
60             raise TypeError, "requires an int, string, or Node input"
61             
62     def findNodes(self, id):
63         """Find the K nodes in our own local table closest to the ID.
64
65         @type id: C{string} of C{int} or L{node.Node}
66         @param id: the ID to find nodes that are close to
67         """
68
69         # Get the ID number from the input
70         num = self._nodeNum(id)
71             
72         # Get the K closest nodes from the appropriate bucket
73         i = self._bucketIndexForInt(num)
74         nodes = list(self.buckets[i].l)
75         
76         # Make sure we have enough
77         if len(nodes) < self.config['K']:
78             # Look in adjoining buckets for nodes
79             min = i - 1
80             max = i + 1
81             while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
82                 # Add the adjoining buckets' nodes to the list
83                 if min >= 0:
84                     nodes = nodes + self.buckets[min].l
85                 if max < len(self.buckets):
86                     nodes = nodes + self.buckets[max].l
87                 min = min - 1
88                 max = max + 1
89     
90         # Sort the found nodes by proximity to the id and return the closest K
91         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
92         return nodes[:self.config['K']]
93         
94     def _splitBucket(self, a):
95         """Split a bucket in two.
96         
97         @type a: L{KBucket}
98         @param a: the bucket to split
99         """
100         # Create a new bucket with half the (upper) range of the current bucket
101         diff = (a.max - a.min) / 2
102         b = KBucket([], a.max - diff, a.max)
103         self.buckets.insert(self.buckets.index(a.min) + 1, b)
104         
105         # Reduce the input bucket's (upper) range 
106         a.max = a.max - diff
107
108         # Transfer nodes to the new bucket
109         for anode in a.l[:]:
110             if anode.num >= a.max:
111                 a.l.remove(anode)
112                 b.l.append(anode)
113     
114     def replaceStaleNode(self, stale, new = None):
115         """Replace a stale node in a bucket with a new one.
116         
117         This is used by clients to replace a node returned by insertNode after
118         it fails to respond to a ping.
119         
120         @type stale: L{node.Node}
121         @param stale: the stale node to remove from the bucket
122         @type new: L{node.Node}
123         @param new: the new node to add in it's place (optional, defaults to
124             not adding any node in the old node's place)
125         """
126         # Find the stale node's bucket
127         i = self._bucketIndexForInt(stale.num)
128         try:
129             it = self.buckets[i].l.index(stale.num)
130         except ValueError:
131             pass
132         else:
133             # Remove the stale node
134             del(self.buckets[i].l[it])
135         
136         # Insert the new node
137         if new and self._bucketIndexForInt(new.num) == i and len(self.buckets[i].l) < self.config['K']:
138             self.buckets[i].l.append(new)
139     
140     def insertNode(self, node, contacted = True):
141         """Try to insert a node in the routing table.
142         
143         This inserts the node, returning None if successful, otherwise returns
144         the oldest node in the bucket if it's full. The caller is then
145         responsible for pinging the returned node and calling replaceStaleNode
146         if it doesn't respond. contacted means that yes, we contacted THEM and
147         we know the node is reachable.
148         
149         @type node: L{node.Node}
150         @param node: the new node to try and insert
151         @type contacted: C{boolean}
152         @param contacted: whether the new node is known to be good, i.e.
153             responded to a request (optional, defaults to True)
154         @rtype: L{node.Node}
155         @return: None if successful (the bucket wasn't full), otherwise returns the oldest node in the bucket
156         """
157         assert node.id != NULL_ID
158         if node.id == self.node.id: return
159
160         # Get the bucket for this node
161         i = self._bucketIndexForInt(node.num)
162
163         # Check to see if node is in the bucket already
164         try:
165             it = self.buckets[i].l.index(node.num)
166         except ValueError:
167             pass
168         else:
169             # The node is already in the bucket
170             if contacted:
171                 # It responded, so update it
172                 node.updateLastSeen()
173                 # move node to end of bucket
174                 del(self.buckets[i].l[it])
175                 # note that we removed the original and replaced it with the new one
176                 # utilizing this nodes new contact info
177                 self.buckets[i].l.append(node)
178                 self.buckets[i].touch()
179             return
180         
181         # We don't have this node, check to see if the bucket is full
182         if len(self.buckets[i].l) < self.config['K']:
183             # Not full, append this node and return
184             if contacted:
185                 node.updateLastSeen()
186             self.buckets[i].l.append(node)
187             self.buckets[i].touch()
188             return
189             
190         # Bucket is full, check to see if the local node is not in the bucket
191         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
192             # Local node not in the bucket, can't split it, return the oldest node
193             return self.buckets[i].l[0]
194         
195         # Make sure our table isn't FULL, this is really unlikely
196         if len(self.buckets) >= self.config['HASH_LENGTH']:
197             log.err("Hash Table is FULL!  Increase K!")
198             return
199             
200         # This bucket is full and contains our node, split the bucket
201         self._splitBucket(self.buckets[i])
202         
203         # Now that the bucket is split and balanced, try to insert the node again
204         return self.insertNode(node)
205     
206     def justSeenNode(self, id):
207         """Mark a node as just having been seen.
208         
209         Call this any time you get a message from a node, it will update it
210         in the table if it's there.
211
212         @type id: C{string} of C{int} or L{node.Node}
213         @param id: the node ID to mark as just having been seen
214         @rtype: C{datetime.datetime}
215         @return: the old lastSeen time of the node, or None if it's not in the table
216         """
217         # Get the bucket number
218         num = self._nodeNum(id)
219         i = self._bucketIndexForInt(num)
220
221         # Check to see if node is in the bucket
222         try:
223             it = self.buckets[i].l.index(num)
224         except ValueError:
225             return None
226         else:
227             # The node is in the bucket
228             n = self.buckets[i].l[it]
229             tstamp = n.lastSeen
230             n.updateLastSeen()
231             
232             # Move the node to the end and touch the bucket
233             del(self.buckets[i].l[it])
234             self.buckets[i].l.append(n)
235             self.buckets[i].touch()
236             
237             return tstamp
238     
239     def invalidateNode(self, n):
240         """Remove the node from the routing table.
241         
242         Forget about node n. Use this when you know that a node is invalid.
243         """
244         self.replaceStaleNode(n)
245     
246     def nodeFailed(self, node):
247         """Mark a node as having failed once, and remove it if it has failed too much."""
248         # Get the bucket number
249         num = self._nodeNum(node)
250         i = self._bucketIndexForInt(num)
251
252         # Check to see if node is in the bucket
253         try:
254             it = self.buckets[i].l.index(num)
255         except ValueError:
256             return None
257         else:
258             # The node is in the bucket
259             n = self.buckets[i].l[it]
260             if n.msgFailed() >= self.config['MAX_FAILURES']:
261                 self.invalidateNode(n)
262                         
263 class KBucket:
264     """Single bucket of nodes in a kademlia-like routing table.
265     
266     @type l: C{list} of L{node.Node}
267     @ivar l: the nodes that are in this bucket
268     @type min: C{long}
269     @ivar min: the minimum node ID that can be in this bucket
270     @type max: C{long}
271     @ivar max: the maximum node ID that can be in this bucket
272     @type lastAccessed: C{datetime.datetime}
273     @ivar lastAccessed: the last time a node in this bucket was successfully contacted
274     """
275     
276     def __init__(self, contents, min, max):
277         """Initialize the bucket with nodes.
278         
279         @type contents: C{list} of L{node.Node}
280         @param contents: the nodes to store in the bucket
281         @type min: C{long}
282         @param min: the minimum node ID that can be in this bucket
283         @type max: C{long}
284         @param max: the maximum node ID that can be in this bucket
285         """
286         self.l = contents
287         self.min = min
288         self.max = max
289         self.lastAccessed = datetime.now()
290         
291     def touch(self):
292         """Update the L{lastAccessed} time."""
293         self.lastAccessed = datetime.now()
294     
295     def getNodeWithInt(self, num):
296         """Get the node in the bucket with that number.
297         
298         @type num: C{long}
299         @param num: the node ID to look for
300         @raise ValueError: if the node ID is not in the bucket
301         @rtype: L{node.Node}
302         @return: the node
303         """
304         if num in self.l: return num
305         else: raise ValueError
306         
307     def __repr__(self):
308         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
309     
310     #{ Comparators to bisect/index a list of buckets (by their range) with either a node or a long
311     def __lt__(self, a):
312         if isinstance(a, Node): a = a.num
313         return self.max <= a
314     def __le__(self, a):
315         if isinstance(a, Node): a = a.num
316         return self.min < a
317     def __gt__(self, a):
318         if isinstance(a, Node): a = a.num
319         return self.min > a
320     def __ge__(self, a):
321         if isinstance(a, Node): a = a.num
322         return self.max >= a
323     def __eq__(self, a):
324         if isinstance(a, Node): a = a.num
325         return self.min <= a and self.max > a
326     def __ne__(self, a):
327         if isinstance(a, Node): a = a.num
328         return self.min >= a or self.max < a
329
330 class TestKTable(unittest.TestCase):
331     """Unit tests for the routing table."""
332     
333     def setUp(self):
334         self.a = Node(khash.newID(), '127.0.0.1', 2002)
335         self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
336
337     def testAddNode(self):
338         self.b = Node(khash.newID(), '127.0.0.1', 2003)
339         self.t.insertNode(self.b)
340         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
341         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
342
343     def testRemove(self):
344         self.testAddNode()
345         self.t.invalidateNode(self.b)
346         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
347
348     def testFail(self):
349         self.testAddNode()
350         for i in range(self.t.config['MAX_FAILURES'] - 1):
351             self.t.nodeFailed(self.b)
352             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
353             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
354             
355         self.t.nodeFailed(self.b)
356         self.failUnlessEqual(len(self.t.buckets[0].l), 0)