Use compact encoding of node contact info in the DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / ktable.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from datetime import datetime
5 from bisect import bisect_left
6
7 from twisted.python import log
8 from twisted.trial import unittest
9
10 import khash
11 from node import Node, NULL_ID
12
13 class KTable:
14     """local routing table for a kademlia like distributed hash table"""
15     def __init__(self, node, config):
16         # this is the root node, a.k.a. US!
17         assert node.id != NULL_ID
18         self.node = node
19         self.config = config
20         self.buckets = [KBucket([], 0L, 2L**self.config['HASH_LENGTH'])]
21         
22     def _bucketIndexForInt(self, num):
23         """the index of the bucket that should hold int"""
24         return bisect_left(self.buckets, num)
25     
26     def findNodes(self, id):
27         """
28             return K nodes in our own local table closest to the ID.
29         """
30         
31         if isinstance(id, str):
32             num = khash.intify(id)
33         elif isinstance(id, Node):
34             num = id.num
35         elif isinstance(id, int) or isinstance(id, long):
36             num = id
37         else:
38             raise TypeError, "findNodes requires an int, string, or Node"
39             
40         nodes = []
41         i = self._bucketIndexForInt(num)
42         
43         # if this node is already in our table then return it
44         try:
45             index = self.buckets[i].l.index(num)
46         except ValueError:
47             pass
48         else:
49             return [self.buckets[i].l[index]]
50             
51         # don't have the node, get the K closest nodes
52         nodes = nodes + self.buckets[i].l
53         if len(nodes) < self.config['K']:
54             # need more nodes
55             min = i - 1
56             max = i + 1
57             while len(nodes) < self.config['K'] and (min >= 0 or max < len(self.buckets)):
58                 #ASw: note that this requires K be even
59                 if min >= 0:
60                     nodes = nodes + self.buckets[min].l
61                 if max < len(self.buckets):
62                     nodes = nodes + self.buckets[max].l
63                 min = min - 1
64                 max = max + 1
65     
66         nodes.sort(lambda a, b, num=num: cmp(num ^ a.num, num ^ b.num))
67         return nodes[:self.config['K']]
68         
69     def _splitBucket(self, a):
70         diff = (a.max - a.min) / 2
71         b = KBucket([], a.max - diff, a.max)
72         self.buckets.insert(self.buckets.index(a.min) + 1, b)
73         a.max = a.max - diff
74         # transfer nodes to new bucket
75         for anode in a.l[:]:
76             if anode.num >= a.max:
77                 a.l.remove(anode)
78                 b.l.append(anode)
79     
80     def replaceStaleNode(self, stale, new):
81         """this is used by clients to replace a node returned by insertNode after
82         it fails to respond to a Pong message"""
83         i = self._bucketIndexForInt(stale.num)
84         try:
85             it = self.buckets[i].l.index(stale.num)
86         except ValueError:
87             return
88     
89         del(self.buckets[i].l[it])
90         if new:
91             self.buckets[i].l.append(new)
92     
93     def insertNode(self, node, contacted=1):
94         """ 
95         this insert the node, returning None if successful, returns the oldest node in the bucket if it's full
96         the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!!
97         contacted means that yes, we contacted THEM and we know the node is reachable
98         """
99         assert node.id != NULL_ID
100         if node.id == self.node.id: return
101         # get the bucket for this node
102         i = self. _bucketIndexForInt(node.num)
103         # check to see if node is in the bucket already
104         try:
105             it = self.buckets[i].l.index(node.num)
106         except ValueError:
107             # no
108             pass
109         else:
110             if contacted:
111                 node.updateLastSeen()
112                 # move node to end of bucket
113                 xnode = self.buckets[i].l[it]
114                 del(self.buckets[i].l[it])
115                 # note that we removed the original and replaced it with the new one
116                 # utilizing this nodes new contact info
117                 self.buckets[i].l.append(xnode)
118                 self.buckets[i].touch()
119             return
120         
121         # we don't have this node, check to see if the bucket is full
122         if len(self.buckets[i].l) < self.config['K']:
123             # no, append this node and return
124             if contacted:
125                 node.updateLastSeen()
126             self.buckets[i].l.append(node)
127             self.buckets[i].touch()
128             return
129             
130         # bucket is full, check to see if self.node is in the bucket
131         if not (self.buckets[i].min <= self.node < self.buckets[i].max):
132             return self.buckets[i].l[0]
133         
134         # this bucket is full and contains our node, split the bucket
135         if len(self.buckets) >= self.config['HASH_LENGTH']:
136             # our table is FULL, this is really unlikely
137             log.err("Hash Table is FULL!  Increase K!")
138             return
139             
140         self._splitBucket(self.buckets[i])
141         
142         # now that the bucket is split and balanced, try to insert the node again
143         return self.insertNode(node)
144     
145     def justSeenNode(self, id):
146         """call this any time you get a message from a node
147         it will update it in the table if it's there """
148         try:
149             n = self.findNodes(id)[0]
150         except IndexError:
151             return None
152         else:
153             tstamp = n.lastSeen
154             n.updateLastSeen()
155             return tstamp
156     
157     def invalidateNode(self, n):
158         """
159             forget about node n - use when you know that node is invalid
160         """
161         self.replaceStaleNode(n, None)
162     
163     def nodeFailed(self, node):
164         """ call this when a node fails to respond to a message, to invalidate that node """
165         try:
166             n = self.findNodes(node.num)[0]
167         except IndexError:
168             return None
169         else:
170             if n.msgFailed() >= self.config['MAX_FAILURES']:
171                 self.invalidateNode(n)
172                         
173 class KBucket:
174     def __init__(self, contents, min, max):
175         self.l = contents
176         self.min = min
177         self.max = max
178         self.lastAccessed = datetime.now()
179         
180     def touch(self):
181         self.lastAccessed = datetime.now()
182     
183     def getNodeWithInt(self, num):
184         if num in self.l: return num
185         else: raise ValueError
186         
187     def __repr__(self):
188         return "<KBucket %d items (%d to %d)>" % (len(self.l), self.min, self.max)
189     
190     ## Comparators    
191     # necessary for bisecting list of buckets with a hash expressed as an integer or a distance
192     # compares integer or node object with the bucket's range
193     def __lt__(self, a):
194         if isinstance(a, Node): a = a.num
195         return self.max <= a
196     def __le__(self, a):
197         if isinstance(a, Node): a = a.num
198         return self.min < a
199     def __gt__(self, a):
200         if isinstance(a, Node): a = a.num
201         return self.min > a
202     def __ge__(self, a):
203         if isinstance(a, Node): a = a.num
204         return self.max >= a
205     def __eq__(self, a):
206         if isinstance(a, Node): a = a.num
207         return self.min <= a and self.max > a
208     def __ne__(self, a):
209         if isinstance(a, Node): a = a.num
210         return self.min >= a or self.max < a
211
212 class TestKTable(unittest.TestCase):
213     def setUp(self):
214         self.a = Node(khash.newID(), '127.0.0.1', 2002)
215         self.t = KTable(self.a, {'HASH_LENGTH': 160, 'K': 8, 'MAX_FAILURES': 3})
216
217     def testAddNode(self):
218         self.b = Node(khash.newID(), '127.0.0.1', 2003)
219         self.t.insertNode(self.b)
220         self.failUnlessEqual(len(self.t.buckets[0].l), 1)
221         self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
222
223     def testRemove(self):
224         self.testAddNode()
225         self.t.invalidateNode(self.b)
226         self.failUnlessEqual(len(self.t.buckets[0].l), 0)
227
228     def testFail(self):
229         self.testAddNode()
230         for i in range(self.t.config['MAX_FAILURES'] - 1):
231             self.t.nodeFailed(self.b)
232             self.failUnlessEqual(len(self.t.buckets[0].l), 1)
233             self.failUnlessEqual(self.t.buckets[0].l[0], self.b)
234             
235         self.t.nodeFailed(self.b)
236         self.failUnlessEqual(len(self.t.buckets[0].l), 0)