Use function for sending krpc responses, and add spew parameter.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from datetime import datetime
5 from bisect import bisect_left
6
7 from twisted.trial import unittest
8
9 import khash
10 from node import Node, NULL_ID
11
12 class KTable:
13     """local routing table for a kademlia like distributed hash table"""
14     def __init__(self, node, config):
15         # this is the root node, a.k.a. US!
16         assert node.id != NULL_ID
17         self.node = node
18         self.config = config
19         self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
20         
21     def _bucketIndexForInt(self, num):
22         """the index of the bucket that should hold int"""
23         return bisect_left(self.buckets, num)
24     
25     def findNodes(self, id):
26         """
27             return K nodes in our own local table closest to the ID.
28         """
29         
30         if isinstance(id, str):
31             num = khash.intify(id)
32         elif isinstance(id, Node):
33             num = id.num
34         elif isinstance(id, int) or isinstance(id, long):
35             num = id
36         else:
37             raise TypeError, "findNodes requires an int, string, or Node"
38             
39         nodes = []
40         i = self._bucketIndexForInt(num)
41         
42         # if this node is already in our table then return it
43         try:
44             index = self.buckets[i].l.index(num)
45         except ValueError:
46             pass
47         else:
48             return [self.buckets[i].l[index]]
49             
50         # don't have the node, get the K closest nodes
51         nodes = nodes + self.buckets[i].l
52         if len(nodes) < self.config['K']:
53             # need more nodes
54             min = i - 1
55             max = i + 1
56             while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
57                 #ASw: note that this requires K be even
58                 if min >= 0:
59                     nodes = nodes + self.buckets[min].l
60                 if max < len(self.buckets):
61                     nodes = nodes + self.buckets[max].l
62                 min = min - 1
63                 max = max + 1
64     
65         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
66         return nodes[:self.config['K']]
67         
68     def _splitBucket(self, a):
69         diff = (a.max - a.min) / 2
70         b = KBucket([], a.max - diff, a.max)
71         self.buckets.insert(self.buckets.index(a.min) + 1, b)
72         a.max = a.max - diff
73         # transfer nodes to new bucket
74         for anode in a.l[:]:
75             if anode.num >= a.max:
76                 a.l.remove(anode)
77                 b.l.append(anode)
78     
79     def replaceStaleNode(self, stale, new):
80         """this is used by clients to replace a node returned by insertNode after
81         it fails to respond to a Pong message"""
82         i = self._bucketIndexForInt(stale.num)
83         try:
84             it = self.buckets[i].l.index(stale.num)
85         except ValueError:
86             return
87     
88         del(self.buckets[i].l[it])
89         if new:
90             self.buckets[i].l.append(new)
91     
92     def insertNode(self, node, contacted=1):
93         """ 
94         this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
95         the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
96         contacted means that yes, we contacted THEM and we know the node is reachable
97         """
98         assert node.id != NULL_ID
99         if node.id == self.node.id: return
100         # get the bucket for this node
101         i = self. _bucketIndexForInt(node.num)
102         # check to see if node is in the bucket already
103         try:
104             it = self.buckets[i].l.index(node.num)
105         except ValueError:
106             # no
107             pass
108         else:
109             if contacted:
110                 node.updateLastSeen()
111                 # move node to end of bucket
112                 xnode = self.buckets[i].l[it]
113                 del(self.buckets[i].l[it])
114                 # note that we removed the original and replaced it with the new one
115                 # utilizing this nodes new contact info
116                 self.buckets[i].l.append(xnode)
117                 self.buckets[i].touch()
118             return
119         
120         # we don't have this node, check to see if the bucket is full
121         if len(self.buckets[i].l) < self.config['K']:
122             # no, append this node and return
123             if contacted:
124                 node.updateLastSeen()
125             self.buckets[i].l.append(node)
126             self.buckets[i].touch()
127             return
128             
129         # bucket is full, check to see if self.node is in the bucket
130         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
131             return self.buckets[i].l[0]
132         
133         # this bucket is full and contains our node, split the bucket
134         if len(self.buckets) >= self.config['HASH_LENGTH']:
135             # our table is FULL, this is really unlikely
136             print "Hash Table is FULL!  Increase K!"
137             return
138             
139         self._splitBucket(self.buckets[i])
140         
141         # now that the bucket is split and balanced, try to insert the node again
142         return self.insertNode(node)
143     
144     def justSeenNode(self, id):
145         """call this any time you get a message from a node
146         it will update it in the table if it's there """
147         try:
148             n = self.findNodes(id)[0]
149         except IndexError:
150             return None
151         else:
152             tstamp = n.lastSeen
153             n.updateLastSeen()
154             return tstamp
155     
156     def invalidateNode(self, n):
157         """
158             forget about node n - use when you know that node is invalid
159         """
160         self.replaceStaleNode(n, None)
161     
162     def nodeFailed(self, node):
163         """ call this when a node fails to respond to a message, to invalidate that node """
164         try:
165             n = self.findNodes(node.num)[0]
166         except IndexError:
167             return None
168         else:
169             if n.msgFailed() >= self.config['MAX_FAILURES']:
170                 self.invalidateNode(n)
171                         
172 class KBucket:
173     def __init__(self, contents, min, max):
174         self.l = contents
175         self.min = min
176         self.max = max
177         self.lastAccessed = datetime.now()
178         
179     def touch(self):
180         self.lastAccessed = datetime.now()
181     
182     def getNodeWithInt(self, num):
183         if num in self.l: return num
184         else: raise ValueError
185         
186     def __repr__(self):
187         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
188     
189     ## Comparators    
190     # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
191     # compares integer or node object with the bucket's range
192     def __lt__(self, a):
193         if isinstance(a, Node): a = a.num
194         return self.max <= a
195     def __le__(self, a):
196         if isinstance(a, Node): a = a.num
197         return self.min < a
198     def __gt__(self, a):
199         if isinstance(a, Node): a = a.num
200         return self.min > a
201     def __ge__(self, a):
202         if isinstance(a, Node): a = a.num
203         return self.max >= a
204     def __eq__(self, a):
205         if isinstance(a, Node): a = a.num
206         return self.min <= a and self.max > a
207     def __ne__(self, a):
208         if isinstance(a, Node): a = a.num
209         return self.min >= a or self.max < a
210
211 class TestKTable(unittest.TestCase):
212     def setUp(self):
213         self.a = Node(khash.newID(), 'localhost', 2002)
214         self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
215
216     def testAddNode(self):
217         self.b = Node(khash.newID(), 'localhost', 2003)
218         self.t.insertNode(self.b)
219         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
220         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
221
222     def testRemove(self):
223         self.testAddNode()
224         self.t.invalidateNode(self.b)
225         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
226
227     def testFail(self):
228         self.testAddNode()
229         for i in range(self.t.config['MAX_FAILURES'] - 1):
230             self.t.nodeFailed(self.b)
231             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
232             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
233             
234         self.t.nodeFailed(self.b)
235         self.failUnlessEqual(len(self.t.buckets[0].l), 0)