Fixed justSeenNode in KTable to update the bucket properly.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19
20 khashmir_dir = 'apt-p2p-Khashmir'
21
22 class DHTError(Exception):
23     """Represents errors that occur in the DHT."""
24
25 class DHT:
26     """The main interface instance to the Khashmir DHT.
27     
28     @type config: C{dictionary}
29     @ivar config: the DHT configuration values
30     @type cache_dir: C{string}
31     @ivar cache_dir: the directory to use for storing files
32     @type bootstrap: C{list} of C{string}
33     @ivar bootstrap: the nodes to contact to bootstrap into the system
34     @type bootstrap_node: C{boolean}
35     @ivar bootstrap_node: whether this node is a bootstrap node
36     @type joining: L{twisted.internet.defer.Deferred}
37     @ivar joining: if a join is underway, the deferred that will signal it's end
38     @type joined: C{boolean}
39     @ivar joined: whether the DHT network has been successfully joined
40     @type outstandingJoins: C{int}
41     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
42     @type foundAddrs: C{list} of (C{string}, C{int})
43     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
44     @type storing: C{dictionary}
45     @ivar storing: keys are keys for which store requests are active, values
46         are dictionaries with keys the values being stored and values the
47         deferred to call when complete
48     @type retrieving: C{dictionary}
49     @ivar retrieving: keys are the keys for which getValue requests are active,
50         values are lists of the deferreds waiting for the requests
51     @type retrieved: C{dictionary}
52     @ivar retrieved: keys are the keys for which getValue requests are active,
53         values are list of the values returned so far
54     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
55     @ivar config_parser: the configuration info for the main program
56     @type section: C{string}
57     @ivar section: the section of the configuration info that applies to the DHT
58     @type khashmir: L{khashmir.Khashmir}
59     @ivar khashmir: the khashmir DHT instance to use
60     """
61     
62     implements(IDHT)
63     
64     def __init__(self):
65         """Initialize the DHT."""
66         self.config = None
67         self.cache_dir = ''
68         self.bootstrap = []
69         self.bootstrap_node = False
70         self.joining = None
71         self.joined = False
72         self.outstandingJoins = 0
73         self.foundAddrs = []
74         self.storing = {}
75         self.retrieving = {}
76         self.retrieved = {}
77     
78     def loadConfig(self, config, section):
79         """See L{apt_p2p.interfaces.IDHT}."""
80         self.config_parser = config
81         self.section = section
82         self.config = {}
83         
84         # Get some initial values
85         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
86         if not os.path.exists(self.cache_dir):
87             os.makedirs(self.cache_dir)
88         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
89         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
90         for k in self.config_parser.options(section):
91             # The numbers in the config file
92             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
93                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
94                 self.config[k] = self.config_parser.getint(section, k)
95             # The times in the config file
96             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
97                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
98                 self.config[k] = self.config_parser.gettime(section, k)
99             # The booleans in the config file
100             elif k in ['SPEW']:
101                 self.config[k] = self.config_parser.getboolean(section, k)
102             # Everything else is a string
103             else:
104                 self.config[k] = self.config_parser.get(section, k)
105     
106     def join(self):
107         """See L{apt_p2p.interfaces.IDHT}."""
108         if self.config is None:
109             raise DHTError, "configuration not loaded"
110         if self.joining:
111             raise DHTError, "a join is already in progress"
112
113         # Create the new khashmir instance
114         self.khashmir = Khashmir(self.config, self.cache_dir)
115         
116         self.joining = defer.Deferred()
117         for node in self.bootstrap:
118             host, port = node.rsplit(':', 1)
119             port = int(port)
120             
121             # Translate host names into IP addresses
122             if isIPAddress(host):
123                 self._join_gotIP(host, port)
124             else:
125                 reactor.resolve(host).addCallback(self._join_gotIP, port)
126         
127         return self.joining
128
129     def _join_gotIP(self, ip, port):
130         """Join the DHT using a single bootstrap nodes IP address."""
131         self.outstandingJoins += 1
132         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
133     
134     def _join_single(self, addr):
135         """Process the response from the bootstrap node.
136         
137         Finish the join by contacting close nodes.
138         """
139         self.outstandingJoins -= 1
140         if addr:
141             self.foundAddrs.append(addr)
142         if addr or self.outstandingJoins <= 0:
143             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
144         log.msg('Got back from bootstrap node: %r' % (addr,))
145     
146     def _join_error(self, failure = None):
147         """Process an error in contacting the bootstrap node.
148         
149         If no bootstrap nodes remain, finish the process by contacting
150         close nodes.
151         """
152         self.outstandingJoins -= 1
153         log.msg("bootstrap node could not be reached")
154         if self.outstandingJoins <= 0:
155             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
156
157     def _join_complete(self, result):
158         """End the joining process and return the addresses found for this node."""
159         if not self.joined and len(result) > 0:
160             self.joined = True
161         if self.joining and self.outstandingJoins <= 0:
162             df = self.joining
163             self.joining = None
164             if self.joined or self.bootstrap_node:
165                 self.joined = True
166                 df.callback(self.foundAddrs)
167             else:
168                 df.errback(DHTError('could not find any nodes to bootstrap to'))
169         
170     def getAddrs(self):
171         """Get the list of addresses returned by bootstrap nodes for this node."""
172         return self.foundAddrs
173         
174     def leave(self):
175         """See L{apt_p2p.interfaces.IDHT}."""
176         if self.config is None:
177             raise DHTError, "configuration not loaded"
178         
179         if self.joined or self.joining:
180             if self.joining:
181                 self.joining.errback(DHTError('still joining when leave was called'))
182                 self.joining = None
183             self.joined = False
184             self.khashmir.shutdown()
185         
186     def _normKey(self, key, bits=None, bytes=None):
187         """Normalize the length of keys used in the DHT."""
188         bits = self.config["HASH_LENGTH"]
189         if bits is not None:
190             bytes = (bits - 1) // 8 + 1
191         else:
192             if bytes is None:
193                 raise DHTError, "you must specify one of bits or bytes for normalization"
194             
195         # Extend short keys with null bytes
196         if len(key) < bytes:
197             key = key + '\000'*(bytes - len(key))
198         # Truncate long keys
199         elif len(key) > bytes:
200             key = key[:bytes]
201         return key
202
203     def getValue(self, key):
204         """See L{apt_p2p.interfaces.IDHT}."""
205         if self.config is None:
206             raise DHTError, "configuration not loaded"
207         if not self.joined:
208             raise DHTError, "have not joined a network yet"
209         
210         key = self._normKey(key)
211
212         d = defer.Deferred()
213         if key not in self.retrieving:
214             self.khashmir.valueForKey(key, self._getValue)
215         self.retrieving.setdefault(key, []).append(d)
216         return d
217         
218     def _getValue(self, key, result):
219         """Process a returned list of values from the DHT."""
220         # Save the list of values to return when it is complete
221         if result:
222             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
223         else:
224             # Empty list, the get is complete, return the result
225             final_result = []
226             if key in self.retrieved:
227                 final_result = self.retrieved[key]
228                 del self.retrieved[key]
229             for i in range(len(self.retrieving[key])):
230                 d = self.retrieving[key].pop(0)
231                 d.callback(final_result)
232             del self.retrieving[key]
233
234     def storeValue(self, key, value):
235         """See L{apt_p2p.interfaces.IDHT}."""
236         if self.config is None:
237             raise DHTError, "configuration not loaded"
238         if not self.joined:
239             raise DHTError, "have not joined a network yet"
240
241         key = self._normKey(key)
242         bvalue = bencode(value)
243
244         if key in self.storing and bvalue in self.storing[key]:
245             raise DHTError, "already storing that key with the same value"
246
247         d = defer.Deferred()
248         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
249         self.storing.setdefault(key, {})[bvalue] = d
250         return d
251     
252     def _storeValue(self, key, bvalue, result):
253         """Process the response from the DHT."""
254         if key in self.storing and bvalue in self.storing[key]:
255             # Check if the store succeeded
256             if len(result) > 0:
257                 self.storing[key][bvalue].callback(result)
258             else:
259                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
260             del self.storing[key][bvalue]
261             if len(self.storing[key].keys()) == 0:
262                 del self.storing[key]
263
264 class TestSimpleDHT(unittest.TestCase):
265     """Simple 2-node unit tests for the DHT."""
266     
267     timeout = 2
268     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
269                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
270                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
271                     'MAX_FAILURES': 3,
272                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
273                     'KEY_EXPIRE': 3600, 'SPEW': False, }
274
275     def setUp(self):
276         self.a = DHT()
277         self.b = DHT()
278         self.a.config = self.DHT_DEFAULTS.copy()
279         self.a.config['PORT'] = 4044
280         self.a.bootstrap = ["127.0.0.1:4044"]
281         self.a.bootstrap_node = True
282         self.a.cache_dir = '/tmp'
283         self.b.config = self.DHT_DEFAULTS.copy()
284         self.b.config['PORT'] = 4045
285         self.b.bootstrap = ["127.0.0.1:4044"]
286         self.b.cache_dir = '/tmp'
287         
288     def test_bootstrap_join(self):
289         d = self.a.join()
290         return d
291         
292     def node_join(self, result):
293         d = self.b.join()
294         return d
295     
296     def test_join(self):
297         self.lastDefer = defer.Deferred()
298         d = self.a.join()
299         d.addCallback(self.node_join)
300         d.addCallback(self.lastDefer.callback)
301         return self.lastDefer
302
303     def test_normKey(self):
304         h = self.a._normKey('12345678901234567890')
305         self.failUnless(h == '12345678901234567890')
306         h = self.a._normKey('12345678901234567')
307         self.failUnless(h == '12345678901234567\000\000\000')
308         h = self.a._normKey('1234567890123456789012345')
309         self.failUnless(h == '12345678901234567890')
310         h = self.a._normKey('1234567890123456789')
311         self.failUnless(h == '1234567890123456789\000')
312         h = self.a._normKey('123456789012345678901')
313         self.failUnless(h == '12345678901234567890')
314
315     def value_stored(self, result, value):
316         self.stored -= 1
317         if self.stored == 0:
318             self.get_values()
319         
320     def store_values(self, result):
321         self.stored = 3
322         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
323         d.addCallback(self.value_stored, 4045)
324         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
325         d.addCallback(self.value_stored, 4044)
326         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
327         d.addCallback(self.value_stored, 4045)
328
329     def check_values(self, result, values):
330         self.checked -= 1
331         self.failUnless(len(result) == len(values))
332         for v in result:
333             self.failUnless(v in values)
334         if self.checked == 0:
335             self.lastDefer.callback(1)
336     
337     def get_values(self):
338         self.checked = 4
339         d = self.a.getValue(sha.new('4044').digest())
340         d.addCallback(self.check_values, [str(4044*2)])
341         d = self.b.getValue(sha.new('4044').digest())
342         d.addCallback(self.check_values, [str(4044*2)])
343         d = self.a.getValue(sha.new('4045').digest())
344         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
345         d = self.b.getValue(sha.new('4045').digest())
346         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
347
348     def test_store(self):
349         from twisted.internet.base import DelayedCall
350         DelayedCall.debug = True
351         self.lastDefer = defer.Deferred()
352         d = self.a.join()
353         d.addCallback(self.node_join)
354         d.addCallback(self.store_values)
355         return self.lastDefer
356
357     def tearDown(self):
358         self.a.leave()
359         try:
360             os.unlink(self.a.khashmir.store.db)
361         except:
362             pass
363         self.b.leave()
364         try:
365             os.unlink(self.b.khashmir.store.db)
366         except:
367             pass
368
369 class TestMultiDHT(unittest.TestCase):
370     """More complicated 20-node tests for the DHT."""
371     
372     timeout = 60
373     num = 20
374     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
375                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
376                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
377                     'MAX_FAILURES': 3,
378                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
379                     'KEY_EXPIRE': 3600, 'SPEW': False, }
380
381     def setUp(self):
382         self.l = []
383         self.startport = 4081
384         for i in range(self.num):
385             self.l.append(DHT())
386             self.l[i].config = self.DHT_DEFAULTS.copy()
387             self.l[i].config['PORT'] = self.startport + i
388             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
389             self.l[i].cache_dir = '/tmp'
390         self.l[0].bootstrap_node = True
391         
392     def node_join(self, result, next_node):
393         d = self.l[next_node].join()
394         if next_node + 1 < len(self.l):
395             d.addCallback(self.node_join, next_node + 1)
396         else:
397             d.addCallback(self.lastDefer.callback)
398     
399     def test_join(self):
400         self.timeout = 2
401         self.lastDefer = defer.Deferred()
402         d = self.l[0].join()
403         d.addCallback(self.node_join, 1)
404         return self.lastDefer
405         
406     def store_values(self, result, i = 0, j = 0):
407         if j > i:
408             j -= i+1
409             i += 1
410         if i == len(self.l):
411             self.get_values()
412         else:
413             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
414             d.addCallback(self.store_values, i, j+1)
415     
416     def get_values(self, result = None, check = None, i = 0, j = 0):
417         if result is not None:
418             self.failUnless(len(result) == len(check))
419             for v in result:
420                 self.failUnless(v in check)
421         if j >= len(self.l):
422             j -= len(self.l)
423             i += 1
424         if i == len(self.l):
425             self.lastDefer.callback(1)
426         else:
427             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
428             check = []
429             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
430                 check.append(str(k))
431             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
432
433     def store_join(self, result, next_node):
434         d = self.l[next_node].join()
435         if next_node + 1 < len(self.l):
436             d.addCallback(self.store_join, next_node + 1)
437         else:
438             d.addCallback(self.store_values)
439     
440     def test_store(self):
441         from twisted.internet.base import DelayedCall
442         DelayedCall.debug = True
443         self.lastDefer = defer.Deferred()
444         d = self.l[0].join()
445         d.addCallback(self.store_join, 1)
446         return self.lastDefer
447
448     def tearDown(self):
449         for i in self.l:
450             try:
451                 i.leave()
452                 os.unlink(i.khashmir.store.db)
453             except:
454                 pass