Ignore the pyc and eclipse project files.
[quix0rs-apt-p2p.git] / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import time
5 from bisect import *
6 from types import *
7
8 import khash as hash
9 import const
10 from const import K, HASH_LENGTH, NULL_ID
11 from node import Node
12
13 class KTable:
14     """local routing table for a kademlia like distributed hash table"""
15     def __init__(self, node):
16         # this is the root node, a.k.a. US!
17         self.node = node
18         self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
19         self.insertNode(node)
20         
21     def _bucketIndexForInt(self, num):
22         """the index of the bucket that should hold int"""
23         return bisect_left(self.buckets, num)
24     
25     def findNodes(self, id):
26         """
27             return K nodes in our own local table closest to the ID.
28         """
29         
30         if isinstance(id, str):
31             num = hash.intify(id)
32         elif isinstance(id, Node):
33             num = id.num
34         elif isinstance(id, int) or isinstance(id, long):
35             num = id
36         else:
37             raise TypeError, "findNodes requires an int, string, or Node"
38             
39         nodes = []
40         i = self._bucketIndexForInt(num)
41         
42         # if this node is already in our table then return it
43         try:
44             index = self.buckets[i].l.index(num)
45         except ValueError:
46             pass
47         else:
48             return [self.buckets[i].l[index]]
49             
50         # don't have the node, get the K closest nodes
51         nodes = nodes + self.buckets[i].l
52         if len(nodes) < K:
53             # need more nodes
54             min = i - 1
55             max = i + 1
56             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
57                 #ASw: note that this requires K be even
58                 if min >= 0:
59                     nodes = nodes + self.buckets[min].l
60                 if max < len(self.buckets):
61                     nodes = nodes + self.buckets[max].l
62                 min = min - 1
63                 max = max + 1
64     
65         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
66         return nodes[:K]
67         
68     def _splitBucket(self, a):
69         diff = (a.max - a.min) / 2
70         b = KBucket([], a.max - diff, a.max)
71         self.buckets.insert(self.buckets.index(a.min) + 1, b)
72         a.max = a.max - diff
73         # transfer nodes to new bucket
74         for anode in a.l[:]:
75             if anode.num >= a.max:
76                 a.l.remove(anode)
77                 b.l.append(anode)
78     
79     def replaceStaleNode(self, stale, new):
80         """this is used by clients to replace a node returned by insertNode after
81         it fails to respond to a Pong message"""
82         i = self._bucketIndexForInt(stale.num)
83         try:
84             it = self.buckets[i].l.index(stale.num)
85         except ValueError:
86             return
87     
88         del(self.buckets[i].l[it])
89         if new:
90             self.buckets[i].l.append(new)
91     
92     def insertNode(self, node, contacted=1):
93         """ 
94         this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
95         the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
96         contacted means that yes, we contacted THEM and we know the node is reachable
97         """
98         assert node.id != NULL_ID
99         if node.id == self.node.id: return
100         # get the bucket for this node
101         i = self. _bucketIndexForInt(node.num)
102         # check to see if node is in the bucket already
103         try:
104             it = self.buckets[i].l.index(node.num)
105         except ValueError:
106             # no
107             pass
108         else:
109             if contacted:
110                 node.updateLastSeen()
111                 # move node to end of bucket
112                 xnode = self.buckets[i].l[it]
113                 del(self.buckets[i].l[it])
114                 # note that we removed the original and replaced it with the new one
115                 # utilizing this nodes new contact info
116                 self.buckets[i].l.append(xnode)
117                 self.buckets[i].touch()
118             return
119         
120         # we don't have this node, check to see if the bucket is full
121         if len(self.buckets[i].l) < K:
122             # no, append this node and return
123             if contacted:
124                 node.updateLastSeen()
125             self.buckets[i].l.append(node)
126             self.buckets[i].touch()
127             return
128             
129         # bucket is full, check to see if self.node is in the bucket
130         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
131             return self.buckets[i].l[0]
132         
133         # this bucket is full and contains our node, split the bucket
134         if len(self.buckets) >= HASH_LENGTH:
135             # our table is FULL, this is really unlikely
136             print "Hash Table is FULL!  Increase K!"
137             return
138             
139         self._splitBucket(self.buckets[i])
140         
141         # now that the bucket is split and balanced, try to insert the node again
142         return self.insertNode(node)
143     
144     def justSeenNode(self, id):
145         """call this any time you get a message from a node
146         it will update it in the table if it's there """
147         try:
148             n = self.findNodes(id)[0]
149         except IndexError:
150             return None
151         else:
152             tstamp = n.lastSeen
153             n.updateLastSeen()
154             return tstamp
155     
156     def invalidateNode(self, n):
157         """
158             forget about node n - use when you know that node is invalid
159         """
160         self.replaceStaleNode(n, None)
161     
162     def nodeFailed(self, node):
163         """ call this when a node fails to respond to a message, to invalidate that node """
164         try:
165             n = self.findNodes(node.num)[0]
166         except IndexError:
167             return None
168         else:
169             if n.msgFailed() >= const.MAX_FAILURES:
170                 self.invalidateNode(n)
171                         
172 class KBucket:
173     __slots__ = ('min', 'max', 'lastAccessed')
174     def __init__(self, contents, min, max):
175         self.l = contents
176         self.min = min
177         self.max = max
178         self.lastAccessed = time.time()
179         
180     def touch(self):
181         self.lastAccessed = time.time()
182     
183     def getNodeWithInt(self, num):
184         if num in self.l: return num
185         else: raise ValueError
186         
187     def __repr__(self):
188         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
189     
190     ## Comparators    
191     # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
192     # compares integer or node object with the bucket's range
193     def __lt__(self, a):
194         if isinstance(a, Node): a = a.num
195         return self.max <= a
196     def __le__(self, a):
197         if isinstance(a, Node): a = a.num
198         return self.min < a
199     def __gt__(self, a):
200         if isinstance(a, Node): a = a.num
201         return self.min > a
202     def __ge__(self, a):
203         if isinstance(a, Node): a = a.num
204         return self.max >= a
205     def __eq__(self, a):
206         if isinstance(a, Node): a = a.num
207         return self.min <= a and self.max > a
208     def __ne__(self, a):
209         if isinstance(a, Node): a = a.num
210         return self.min >= a or self.max < a
211
212
213 ### UNIT TESTS ###
214 import unittest
215
216 class TestKTable(unittest.TestCase):
217     def setUp(self):
218         self.a = Node().init(hash.newID(), 'localhost', 2002)
219         self.t = KTable(self.a)
220
221     def testAddNode(self):
222         self.b = Node().init(hash.newID(), 'localhost', 2003)
223         self.t.insertNode(self.b)
224         self.assertEqual(len(self.t.buckets[0].l), 1)
225         self.assertEqual(self.t.buckets[0].l[0], self.b)
226
227     def testRemove(self):
228         self.testAddNode()
229         self.t.invalidateNode(self.b)
230         self.assertEqual(len(self.t.buckets[0].l), 0)
231
232     def testFail(self):
233         self.testAddNode()
234         for i in range(const.MAX_FAILURES - 1):
235             self.t.nodeFailed(self.b)
236             self.assertEqual(len(self.t.buckets[0].l), 1)
237             self.assertEqual(self.t.buckets[0].l[0], self.b)
238             
239         self.t.nodeFailed(self.b)
240         self.assertEqual(len(self.t.buckets[0].l), 0)
241
242
243 if __name__ == "__main__":
244     unittest.main()