Clean up all the imports.
[quix0rs-apt-p2p.git] / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from time import time
5 from bisect import bisect_left
6
7 from const import K, HASH_LENGTH, NULL_ID, MAX_FAILURES
8 import khash
9 from node import Node
10
11 class KTable:
12     """local routing table for a kademlia like distributed hash table"""
13     def __init__(self, node):
14         # this is the root node, a.k.a. US!
15         self.node = node
16         self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
17         self.insertNode(node)
18         
19     def _bucketIndexForInt(self, num):
20         """the index of the bucket that should hold int"""
21         return bisect_left(self.buckets, num)
22     
23     def findNodes(self, id):
24         """
25             return K nodes in our own local table closest to the ID.
26         """
27         
28         if isinstance(id, str):
29             num = khash.intify(id)
30         elif isinstance(id, Node):
31             num = id.num
32         elif isinstance(id, int) or isinstance(id, long):
33             num = id
34         else:
35             raise TypeError, "findNodes requires an int, string, or Node"
36             
37         nodes = []
38         i = self._bucketIndexForInt(num)
39         
40         # if this node is already in our table then return it
41         try:
42             index = self.buckets[i].l.index(num)
43         except ValueError:
44             pass
45         else:
46             return [self.buckets[i].l[index]]
47             
48         # don't have the node, get the K closest nodes
49         nodes = nodes + self.buckets[i].l
50         if len(nodes) < K:
51             # need more nodes
52             min = i - 1
53             max = i + 1
54             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
55                 #ASw: note that this requires K be even
56                 if min >= 0:
57                     nodes = nodes + self.buckets[min].l
58                 if max < len(self.buckets):
59                     nodes = nodes + self.buckets[max].l
60                 min = min - 1
61                 max = max + 1
62     
63         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
64         return nodes[:K]
65         
66     def _splitBucket(self, a):
67         diff = (a.max - a.min) / 2
68         b = KBucket([], a.max - diff, a.max)
69         self.buckets.insert(self.buckets.index(a.min) + 1, b)
70         a.max = a.max - diff
71         # transfer nodes to new bucket
72         for anode in a.l[:]:
73             if anode.num >= a.max:
74                 a.l.remove(anode)
75                 b.l.append(anode)
76     
77     def replaceStaleNode(self, stale, new):
78         """this is used by clients to replace a node returned by insertNode after
79         it fails to respond to a Pong message"""
80         i = self._bucketIndexForInt(stale.num)
81         try:
82             it = self.buckets[i].l.index(stale.num)
83         except ValueError:
84             return
85     
86         del(self.buckets[i].l[it])
87         if new:
88             self.buckets[i].l.append(new)
89     
90     def insertNode(self, node, contacted=1):
91         """ 
92         this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
93         the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
94         contacted means that yes, we contacted THEM and we know the node is reachable
95         """
96         assert node.id != NULL_ID
97         if node.id == self.node.id: return
98         # get the bucket for this node
99         i = self. _bucketIndexForInt(node.num)
100         # check to see if node is in the bucket already
101         try:
102             it = self.buckets[i].l.index(node.num)
103         except ValueError:
104             # no
105             pass
106         else:
107             if contacted:
108                 node.updateLastSeen()
109                 # move node to end of bucket
110                 xnode = self.buckets[i].l[it]
111                 del(self.buckets[i].l[it])
112                 # note that we removed the original and replaced it with the new one
113                 # utilizing this nodes new contact info
114                 self.buckets[i].l.append(xnode)
115                 self.buckets[i].touch()
116             return
117         
118         # we don't have this node, check to see if the bucket is full
119         if len(self.buckets[i].l) < K:
120             # no, append this node and return
121             if contacted:
122                 node.updateLastSeen()
123             self.buckets[i].l.append(node)
124             self.buckets[i].touch()
125             return
126             
127         # bucket is full, check to see if self.node is in the bucket
128         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
129             return self.buckets[i].l[0]
130         
131         # this bucket is full and contains our node, split the bucket
132         if len(self.buckets) >= HASH_LENGTH:
133             # our table is FULL, this is really unlikely
134             print "Hash Table is FULL!  Increase K!"
135             return
136             
137         self._splitBucket(self.buckets[i])
138         
139         # now that the bucket is split and balanced, try to insert the node again
140         return self.insertNode(node)
141     
142     def justSeenNode(self, id):
143         """call this any time you get a message from a node
144         it will update it in the table if it's there """
145         try:
146             n = self.findNodes(id)[0]
147         except IndexError:
148             return None
149         else:
150             tstamp = n.lastSeen
151             n.updateLastSeen()
152             return tstamp
153     
154     def invalidateNode(self, n):
155         """
156             forget about node n - use when you know that node is invalid
157         """
158         self.replaceStaleNode(n, None)
159     
160     def nodeFailed(self, node):
161         """ call this when a node fails to respond to a message, to invalidate that node """
162         try:
163             n = self.findNodes(node.num)[0]
164         except IndexError:
165             return None
166         else:
167             if n.msgFailed() >= MAX_FAILURES:
168                 self.invalidateNode(n)
169                         
170 class KBucket:
171     def __init__(self, contents, min, max):
172         self.l = contents
173         self.min = min
174         self.max = max
175         self.lastAccessed = time()
176         
177     def touch(self):
178         self.lastAccessed = time()
179     
180     def getNodeWithInt(self, num):
181         if num in self.l: return num
182         else: raise ValueError
183         
184     def __repr__(self):
185         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
186     
187     ## Comparators    
188     # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
189     # compares integer or node object with the bucket's range
190     def __lt__(self, a):
191         if isinstance(a, Node): a = a.num
192         return self.max <= a
193     def __le__(self, a):
194         if isinstance(a, Node): a = a.num
195         return self.min < a
196     def __gt__(self, a):
197         if isinstance(a, Node): a = a.num
198         return self.min > a
199     def __ge__(self, a):
200         if isinstance(a, Node): a = a.num
201         return self.max >= a
202     def __eq__(self, a):
203         if isinstance(a, Node): a = a.num
204         return self.min <= a and self.max > a
205     def __ne__(self, a):
206         if isinstance(a, Node): a = a.num
207         return self.min >= a or self.max < a
208
209
210 ### UNIT TESTS ###
211 import unittest
212
213 class TestKTable(unittest.TestCase):
214     def setUp(self):
215         self.a = Node().init(khash.newID(), 'localhost', 2002)
216         self.t = KTable(self.a)
217
218     def testAddNode(self):
219         self.b = Node().init(khash.newID(), 'localhost', 2003)
220         self.t.insertNode(self.b)
221         self.assertEqual(len(self.t.buckets[0].l), 1)
222         self.assertEqual(self.t.buckets[0].l[0], self.b)
223
224     def testRemove(self):
225         self.testAddNode()
226         self.t.invalidateNode(self.b)
227         self.assertEqual(len(self.t.buckets[0].l), 0)
228
229     def testFail(self):
230         self.testAddNode()
231         for i in range(MAX_FAILURES - 1):
232             self.t.nodeFailed(self.b)
233             self.assertEqual(len(self.t.buckets[0].l), 1)
234             self.assertEqual(self.t.buckets[0].l[0], self.b)
235             
236         self.t.nodeFailed(self.b)
237         self.assertEqual(len(self.t.buckets[0].l), 0)
238
239
240 if __name__ == "__main__":
241     unittest.main()