bd89ef76adb14c887fa6e9b73e7bab947c73c3c0
[quix0rs-apt-p2p.git] / ktable.py
1 ## Copyright 2002 Andrew Loewenstern, All Rights Reserved
2
3 import time
4 from bisect import *
5 from types import *
6
7 import hash
8 import const
9 from const import K, HASH_LENGTH
10 from node import Node
11
12 class KTable:
13         """local routing table for a kademlia like distributed hash table"""
14         def __init__(self, node):
15                 # this is the root node, a.k.a. US!
16                 self.node = node
17                 self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)]
18                 self.insertNode(node)
19                 
20         def _bucketIndexForInt(self, num):
21                 """the index of the bucket that should hold int"""
22                 return bisect_left(self.buckets, num)
23         
24         def findNodes(self, id):
25                 """
26                         return K nodes in our own local table closest to the ID.
27                 """
28                 
29                 if isinstance(id, str):
30                         num = hash.intify(id)
31                 elif isinstance(id, Node):
32                         num = id.num
33                 elif isinstance(id, int) or isinstance(id, long):
34                         num = id
35                 else:
36                         raise TypeError, "findNodes requires an int, string, or Node"
37                         
38                 nodes = []
39                 i = self._bucketIndexForInt(num)
40                 
41                 # if this node is already in our table then return it
42                 try:
43                         index = self.buckets[i].l.index(num)
44                 except ValueError:
45                         pass
46                 else:
47                         return [self.buckets[i].l[index]]
48                         
49                 # don't have the node, get the K closest nodes
50                 nodes = nodes + self.buckets[i].l
51                 if len(nodes) < K:
52                         # need more nodes
53                         min = i - 1
54                         max = i + 1
55                         while len(nodes) < K and (min >= 0 or max < len(self.buckets)):
56                                 #ASw: note that this requires K be even
57                                 if min >= 0:
58                                         nodes = nodes + self.buckets[min].l
59                                 if max < len(self.buckets):
60                                         nodes = nodes + self.buckets[max].l
61                                 min = min - 1
62                                 max = max + 1
63         
64                 nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
65                 return nodes[:K]
66                 
67         def _splitBucket(self, a):
68                 diff = (a.max - a.min) / 2
69                 b = KBucket([], a.max - diff, a.max)
70                 self.buckets.insert(self.buckets.index(a.min) + 1, b)
71                 a.max = a.max - diff
72                 # transfer nodes to new bucket
73                 for anode in a.l[:]:
74                         if anode.num >= a.max:
75                                 a.l.remove(anode)
76                                 b.l.append(anode)
77         
78         def replaceStaleNode(self, stale, new):
79                 """this is used by clients to replace a node returned by insertNode after
80                 it fails to respond to a Pong message"""
81                 i = self._bucketIndexForInt(stale.num)
82                 try:
83                         it = self.buckets[i].l.index(stale.num)
84                 except ValueError:
85                         return
86         
87                 del(self.buckets[i].l[it])
88                 if new:
89                         self.buckets[i].l.append(new)
90         
91         def insertNode(self, node, contacted=1):
92                 """ 
93                 this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
94                 the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
95                 contacted means that yes, we contacted THEM and we know the node is reachable
96                 """
97                 assert node.id != " "*20
98                 if node.id == self.node.id: return
99                 # get the bucket for this node
100                 i = self. _bucketIndexForInt(node.num)
101                 # check to see if node is in the bucket already
102                 try:
103                         it = self.buckets[i].l.index(node.num)
104                 except ValueError:
105                         # no
106                         pass
107                 else:
108                         if contacted:
109                                 node.updateLastSeen()
110                                 # move node to end of bucket
111                                 xnode = self.buckets[i].l[it]
112                                 del(self.buckets[i].l[it])
113                                 # note that we removed the original and replaced it with the new one
114                                 # utilizing this nodes new contact info
115                                 self.buckets[i].l.append(xnode)
116                                 self.buckets[i].touch()
117                         return
118                 
119                 # we don't have this node, check to see if the bucket is full
120                 if len(self.buckets[i].l) < K:
121                         # no, append this node and return
122                         if contacted:
123                                 node.updateLastSeen()
124                         self.buckets[i].l.append(node)
125                         self.buckets[i].touch()
126                         return
127                         
128                 # bucket is full, check to see if self.node is in the bucket
129                 if not (self.buckets[i].min <= self.node < self.buckets[i].max):
130                         return self.buckets[i].l[0]
131                 
132                 # this bucket is full and contains our node, split the bucket
133                 if len(self.buckets) >= HASH_LENGTH:
134                         # our table is FULL, this is really unlikely
135                         print "Hash Table is FULL!  Increase K!"
136                         return
137                         
138                 self._splitBucket(self.buckets[i])
139                 
140                 # now that the bucket is split and balanced, try to insert the node again
141                 return self.insertNode(node)
142         
143         def justSeenNode(self, node):
144                 """call this any time you get a message from a node
145                 it will update it in the table if it's there """
146                 try:
147                         n = self.findNodes(node.num)[0]
148                 except IndexError:
149                         return None
150                 else:
151                         tstamp = n.lastSeen
152                         n.updateLastSeen()
153                         return tstamp
154         
155         def invalidateNode(self, n):
156                 """
157                         forget about node n - use when you know that node is invalid
158                 """
159                 self.replaceStaleNode(n, None)
160         
161         def nodeFailed(self, node):
162                 """ call this when a node fails to respond to a message, to invalidate that node """
163                 try:
164                         n = self.findNodes(node.num)[0]
165                 except IndexError:
166                         return None
167                 else:
168                         if n.msgFailed() >= const.MAX_FAILURES:
169                                 self.invalidateNode(n)
170                                         
171 class KBucket:
172         __slots__ = ('min', 'max', 'lastAccessed')
173         def __init__(self, contents, min, max):
174                 self.l = contents
175                 self.min = min
176                 self.max = max
177                 self.lastAccessed = time.time()
178                 
179         def touch(self):
180                 self.lastAccessed = time.time()
181         
182         def getNodeWithInt(self, num):
183                 if num in self.l: return num
184                 else: raise ValueError
185                 
186         def __repr__(self):
187                 return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
188         
189         ## Comparators    
190         # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
191         # compares integer or node object with the bucket's range
192         def __lt__(self, a):
193                 if isinstance(a, Node): a = a.num
194                 return self.max <= a
195         def __le__(self, a):
196                 if isinstance(a, Node): a = a.num
197                 return self.min < a
198         def __gt__(self, a):
199                 if isinstance(a, Node): a = a.num
200                 return self.min > a
201         def __ge__(self, a):
202                 if isinstance(a, Node): a = a.num
203                 return self.max >= a
204         def __eq__(self, a):
205                 if isinstance(a, Node): a = a.num
206                 return self.min <= a and self.max > a
207         def __ne__(self, a):
208                 if isinstance(a, Node): a = a.num
209                 return self.min >= a or self.max < a
210
211
212 ### UNIT TESTS ###
213 import unittest
214
215 class TestKTable(unittest.TestCase):
216     def setUp(self):
217         self.a = Node().init(hash.newID(), 'localhost', 2002)
218         self.t = KTable(self.a)
219         print self.t.buckets[0].l
220
221     def test_replace_stale_node(self):
222         self.b = Node().init(hash.newID(), 'localhost', 2003)
223         self.t.replaceStaleNode(self.a, self.b)
224         assert len(self.t.buckets[0].l) == 1
225         assert self.t.buckets[0].l[0].id == self.b.id
226
227 if __name__ == "__main__":
228     unittest.main()