2f1cf314598ab50a104946d97e60df0aa147ebf4
[quix0rs-apt-p2p.git] / ktable.py
1 ## Copyright 2002 Andrew Loewenstern, All Rights Reserved
2
3 import time
4 from bisect import *
5 from types import *
6
7 import hash
8 import const
9 from const import K, HASH_LENGTH
10 from node import Node
11
12 class KTable:
13     """local routing table for a kademlia like distributed hash table"""
14     def __init__(self, node):
15         # this is the root node, a.k.a. US!
16         self.node = node
17         self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
18         self.insertNode(node)
19         
20     def _bucketIndexForInt(self, int):
21         """the index of the bucket that should hold int"""
22         return bisect_left(self.buckets, int)
23     
24     def findNodes(self, id):
25         """k nodes in our own local table closest to the ID.
26         
27         NOTE: response may actually include ourself, it's your responsibilty 
28         to not send messages to yourself if it matters."""
29         
30         if isinstance(id, str):
31             int = hash.intify(id)
32         elif isinstance(id, Node):
33             int = id.int
34         elif isinstance(id, int) or isinstance(id, long):
35             int = id
36         else:
37             raise TypeError, "findNodes requires an int, string, or Node"
38             
39         nodes = []
40         i = self._bucketIndexForInt(int)
41         
42         # if this node is already in our table then return it
43         try:
44             index = self.buckets[i].l.index(int)
45         except ValueError:
46             pass
47         else:
48             return [self.buckets[i].l[index]]            
49         nodes = nodes + self.buckets[i].l
50         if len(nodes) < K:
51             # need more nodes
52             min = i - 1
53             max = i + 1
54             while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
55                 #ASw: note that this requires K be even
56                 if min >= 0:
57                     nodes = nodes + self.buckets[min].l
58                 if max < len(self.buckets):
59                     nodes = nodes + self.buckets[max].l
60                 min = min - 1
61                 max = max + 1
62
63         nodes.sort(lambda a, b, int=int: cmp(int ^ a.int, int ^ b.int))
64         return nodes[:K]
65         
66     def _splitBucket(self, a):
67         diff = (a.max - a.min) / 2
68         b = KBucket([], a.max - diff, a.max)
69         self.buckets.insert(self.buckets.index(a.min) + 1, b)
70         a.max = a.max - diff
71         # transfer nodes to new bucket
72         for anode in a.l[:]:
73             if anode.int >= a.max:
74                 a.l.remove(anode)
75                 b.l.append(anode)
76
77     def replaceStaleNode(self, stale, new):
78         """this is used by clients to replace a node returned by insertNode after
79         it fails to respond to a Pong message"""
80         i = self._bucketIndexForInt(stale.int)
81         try:
82             it = self.buckets[i].l.index(stale.int)
83         except ValueError:
84             return
85
86         del(self.buckets[i].l[it])
87         if new:
88             self.buckets[i].l.append(new)
89
90     def insertNode(self, node, contacted=1):
91         """ 
92         this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
93         the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
94         contacted means that yes, we contacted THEM and we know the node is available
95         """
96         assert node.id != " "*20
97         if node.id == self.node.id: return
98         # get the bucket for this node
99         i = self. _bucketIndexForInt(node.int)
100         # check to see if node is in the bucket already
101         try:
102             it = self.buckets[i].l.index(node.int)
103         except ValueError:
104             # no
105             pass
106         else:
107             if contacted:
108                 node.updateLastSeen()
109                 # move node to end of bucket
110                 xnode = self.buckets[i].l[it]
111                 del(self.buckets[i].l[it])
112                 # note that we removed the original and replaced it with the new one
113                 # utilizing this nodes new contact info
114                 self.buckets[i].l.append(xnode)
115                 self.buckets[i].touch()
116             return
117         
118         # we don't have this node, check to see if the bucket is full
119         if len(self.buckets[i].l) < K:
120             # no, append this node and return
121             if contacted:
122                 node.updateLastSeen()
123             self.buckets[i].l.append(node)
124             self.buckets[i].touch()
125             return
126             
127         # bucket is full, check to see if self.node is in the bucket
128         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
129             return self.buckets[i].l[0]
130         
131         # this bucket is full and contains our node, split the bucket
132         if len(self.buckets) >= HASH_LENGTH:
133             # our table is FULL, this is really unlikely
134             print "Hash Table is FULL!  Increase K!"
135             return
136             
137         self._splitBucket(self.buckets[i])
138         
139         # now that the bucket is split and balanced, try to insert the node again
140         return self.insertNode(node)
141
142     def justSeenNode(self, node):
143         """call this any time you get a message from a node
144         it will update it in the table if it's there """
145         try:
146             n = self.findNodes(node.int)[0]
147         except IndexError:
148             return None
149         else:
150             tstamp = n.lastSeen
151             n.updateLastSeen()
152             return tstamp
153
154     def nodeFailed(self, node):
155         """ call this when a node fails to respond to a message, to invalidate that node """
156         try:
157             n = self.findNodes(node.int)[0]
158         except IndexError:
159             return None
160         else:
161             if n.msgFailed() >= const.MAX_FAILURES:
162                 self.replaceStaleNode(n, None)
163         
164 class KBucket:
165     __slots = ['min', 'max', 'lastAccessed']
166     def __init__(self, contents, min, max):
167         self.l = contents
168         self.min = min
169         self.max = max
170         self.lastAccessed = time.time()
171         
172     def touch(self):
173         self.lastAccessed = time.time()
174
175     def getNodeWithInt(self, int):
176         if int in self.l: return int
177         else: raise ValueError
178         
179     def __repr__(self):
180         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
181     
182     ## Comparators    
183     # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
184     # compares integer or node object with the bucket's range
185     def __lt__(self, a):
186         if isinstance(a, Node): a = a.int
187         return self.max <= a
188     def __le__(self, a):
189         if isinstance(a, Node): a = a.int
190         return self.min < a
191     def __gt__(self, a):
192         if isinstance(a, Node): a = a.int
193         return self.min > a
194     def __ge__(self, a):
195         if isinstance(a, Node): a = a.int
196         return self.max >= a
197     def __eq__(self, a):
198         if isinstance(a, Node): a = a.int
199         return self.min <= a and self.max > a
200     def __ne__(self, a):
201         if isinstance(a, Node): a = a.int
202         return self.min >= a or self.max < a
203
204
205 ### UNIT TESTS ###
206 import unittest
207
208 class TestKTable(unittest.TestCase):
209     def setUp(self):
210         self.a = Node().init(hash.newID(), 'localhost', 2002)
211         self.t = KTable(self.a)
212         print self.t.buckets[0].l
213
214     def test_replace_stale_node(self):
215         self.b = Node().init(hash.newID(), 'localhost', 2003)
216         self.t.replaceStaleNode(self.a, self.b)
217         assert len(self.t.buckets[0].l) == 1
218         assert self.t.buckets[0].l[0].id == self.b.id
219
220 if __name__ == "__main__":
221     unittest.main()