Try to rejoin DHT periodically after failures using exponential backoff.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19
20 try:
21     from twisted.web2 import channel, server, resource, http, http_headers
22     _web2 = True
23 except ImportError:
24     _web2 = False
25
26 khashmir_dir = 'apt-p2p-Khashmir'
27
28 class DHTError(Exception):
29     """Represents errors that occur in the DHT."""
30
31 class DHT:
32     """The main interface instance to the Khashmir DHT.
33     
34     @type config: C{dictionary}
35     @ivar config: the DHT configuration values
36     @type cache_dir: C{string}
37     @ivar cache_dir: the directory to use for storing files
38     @type bootstrap: C{list} of C{string}
39     @ivar bootstrap: the nodes to contact to bootstrap into the system
40     @type bootstrap_node: C{boolean}
41     @ivar bootstrap_node: whether this node is a bootstrap node
42     @type joining: L{twisted.internet.defer.Deferred}
43     @ivar joining: if a join is underway, the deferred that will signal it's end
44     @type joined: C{boolean}
45     @ivar joined: whether the DHT network has been successfully joined
46     @type outstandingJoins: C{int}
47     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
48     @type next_rejoin: C{int}
49     @ivar next_rejoin: the number of seconds before retrying the next join
50     @type foundAddrs: C{list} of (C{string}, C{int})
51     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
52     @type storing: C{dictionary}
53     @ivar storing: keys are keys for which store requests are active, values
54         are dictionaries with keys the values being stored and values the
55         deferred to call when complete
56     @type retrieving: C{dictionary}
57     @ivar retrieving: keys are the keys for which getValue requests are active,
58         values are lists of the deferreds waiting for the requests
59     @type retrieved: C{dictionary}
60     @ivar retrieved: keys are the keys for which getValue requests are active,
61         values are list of the values returned so far
62     @type factory: L{twisted.web2.channel.HTTPFactory}
63     @ivar factory: the factory to use to serve HTTP requests for statistics
64     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
65     @ivar config_parser: the configuration info for the main program
66     @type section: C{string}
67     @ivar section: the section of the configuration info that applies to the DHT
68     @type khashmir: L{khashmir.Khashmir}
69     @ivar khashmir: the khashmir DHT instance to use
70     """
71     
72     if _web2:
73         implements(IDHT, IDHTStats, IDHTStatsFactory)
74     else:
75         implements(IDHT, IDHTStats)
76         
77     def __init__(self):
78         """Initialize the DHT."""
79         self.config = None
80         self.cache_dir = ''
81         self.bootstrap = []
82         self.bootstrap_node = False
83         self.joining = None
84         self.khashmir = None
85         self.joined = False
86         self.outstandingJoins = 0
87         self.next_rejoin = 20
88         self.foundAddrs = []
89         self.storing = {}
90         self.retrieving = {}
91         self.retrieved = {}
92         self.factory = None
93     
94     def loadConfig(self, config, section):
95         """See L{apt_p2p.interfaces.IDHT}."""
96         self.config_parser = config
97         self.section = section
98         self.config = {}
99         
100         # Get some initial values
101         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
102         if not os.path.exists(self.cache_dir):
103             os.makedirs(self.cache_dir)
104         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
105         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
106         for k in self.config_parser.options(section):
107             # The numbers in the config file
108             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
109                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
110                 self.config[k] = self.config_parser.getint(section, k)
111             # The times in the config file
112             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
113                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
114                 self.config[k] = self.config_parser.gettime(section, k)
115             # The booleans in the config file
116             elif k in ['SPEW']:
117                 self.config[k] = self.config_parser.getboolean(section, k)
118             # Everything else is a string
119             else:
120                 self.config[k] = self.config_parser.get(section, k)
121     
122     def join(self, deferred = None):
123         """See L{apt_p2p.interfaces.IDHT}.
124         
125         @param deferred: the deferred to callback when the join is complete
126             (optional, defaults to creating a new deferred and returning it)
127         """
128         # Check for multiple simultaneous joins 
129         if self.joining:
130             if deferred:
131                 deferred.errback(DHTError("a join is already in progress"))
132                 return
133             else:
134                 raise DHTError, "a join is already in progress"
135
136         if deferred:
137             self.joining = deferred
138         else:
139             self.joining = defer.Deferred()
140
141         if self.config is None:
142             self.joining.errback(DHTError("configuration not loaded"))
143             return self.joining
144
145         # Create the new khashmir instance
146         if not self.khashmir:
147             self.khashmir = Khashmir(self.config, self.cache_dir)
148
149         for node in self.bootstrap:
150             host, port = node.rsplit(':', 1)
151             port = int(port)
152             
153             # Translate host names into IP addresses
154             if isIPAddress(host):
155                 self._join_gotIP(host, port)
156             else:
157                 reactor.resolve(host).addCallback(self._join_gotIP, port)
158         
159         return self.joining
160
161     def _join_gotIP(self, ip, port):
162         """Join the DHT using a single bootstrap nodes IP address."""
163         self.outstandingJoins += 1
164         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
165     
166     def _join_single(self, addr):
167         """Process the response from the bootstrap node.
168         
169         Finish the join by contacting close nodes.
170         """
171         self.outstandingJoins -= 1
172         if addr:
173             self.foundAddrs.append(addr)
174         if addr or self.outstandingJoins <= 0:
175             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
176         log.msg('Got back from bootstrap node: %r' % (addr,))
177     
178     def _join_error(self, failure = None):
179         """Process an error in contacting the bootstrap node.
180         
181         If no bootstrap nodes remain, finish the process by contacting
182         close nodes.
183         """
184         self.outstandingJoins -= 1
185         log.msg("bootstrap node could not be reached")
186         if self.outstandingJoins <= 0:
187             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
188
189     def _join_complete(self, result):
190         """End the joining process and return the addresses found for this node."""
191         if not self.joined and len(result) > 1:
192             self.joined = True
193         if self.joining and self.outstandingJoins <= 0:
194             df = self.joining
195             self.joining = None
196             if self.joined or self.bootstrap_node:
197                 self.joined = True
198                 df.callback(self.foundAddrs)
199             else:
200                 # Try to join later using exponential backoff delays
201                 log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
202                 reactor.callLater(self.next_rejoin, self.join, df)
203                 self.next_rejoin *= 2
204         
205     def getAddrs(self):
206         """Get the list of addresses returned by bootstrap nodes for this node."""
207         return self.foundAddrs
208         
209     def leave(self):
210         """See L{apt_p2p.interfaces.IDHT}."""
211         if self.config is None:
212             raise DHTError, "configuration not loaded"
213         
214         if self.joined or self.joining:
215             if self.joining:
216                 self.joining.errback(DHTError('still joining when leave was called'))
217                 self.joining = None
218             self.joined = False
219             self.khashmir.shutdown()
220         
221     def _normKey(self, key, bits=None, bytes=None):
222         """Normalize the length of keys used in the DHT."""
223         bits = self.config["HASH_LENGTH"]
224         if bits is not None:
225             bytes = (bits - 1) // 8 + 1
226         else:
227             if bytes is None:
228                 raise DHTError, "you must specify one of bits or bytes for normalization"
229             
230         # Extend short keys with null bytes
231         if len(key) < bytes:
232             key = key + '\000'*(bytes - len(key))
233         # Truncate long keys
234         elif len(key) > bytes:
235             key = key[:bytes]
236         return key
237
238     def getValue(self, key):
239         """See L{apt_p2p.interfaces.IDHT}."""
240         d = defer.Deferred()
241
242         if self.config is None:
243             d.errback(DHTError("configuration not loaded"))
244             return d
245         if not self.joined:
246             d.errback(DHTError("have not joined a network yet"))
247             return d
248         
249         key = self._normKey(key)
250
251         if key not in self.retrieving:
252             self.khashmir.valueForKey(key, self._getValue)
253         self.retrieving.setdefault(key, []).append(d)
254         return d
255         
256     def _getValue(self, key, result):
257         """Process a returned list of values from the DHT."""
258         # Save the list of values to return when it is complete
259         if result:
260             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
261         else:
262             # Empty list, the get is complete, return the result
263             final_result = []
264             if key in self.retrieved:
265                 final_result = self.retrieved[key]
266                 del self.retrieved[key]
267             for i in range(len(self.retrieving[key])):
268                 d = self.retrieving[key].pop(0)
269                 d.callback(final_result)
270             del self.retrieving[key]
271
272     def storeValue(self, key, value):
273         """See L{apt_p2p.interfaces.IDHT}."""
274         d = defer.Deferred()
275
276         if self.config is None:
277             d.errback(DHTError("configuration not loaded"))
278             return d
279         if not self.joined:
280             d.errback(DHTError("have not joined a network yet"))
281             return d
282
283         key = self._normKey(key)
284         bvalue = bencode(value)
285
286         if key in self.storing and bvalue in self.storing[key]:
287             raise DHTError, "already storing that key with the same value"
288
289         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
290         self.storing.setdefault(key, {})[bvalue] = d
291         return d
292     
293     def _storeValue(self, key, bvalue, result):
294         """Process the response from the DHT."""
295         if key in self.storing and bvalue in self.storing[key]:
296             # Check if the store succeeded
297             if len(result) > 0:
298                 self.storing[key][bvalue].callback(result)
299             else:
300                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
301             del self.storing[key][bvalue]
302             if len(self.storing[key].keys()) == 0:
303                 del self.storing[key]
304     
305     def getStats(self):
306         """See L{apt_p2p.interfaces.IDHTStats}."""
307         return self.khashmir.getStats()
308
309     def getStatsFactory(self):
310         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
311         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
312         if self.factory is None:
313             # Create a simple HTTP factory for stats
314             class StatsResource(resource.Resource):
315                 def __init__(self, manager):
316                     self.manager = manager
317                 def render(self, ctx):
318                     return http.Response(
319                         200,
320                         {'content-type': http_headers.MimeType('text', 'html')},
321                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
322                 def locateChild(self, request, segments):
323                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
324                     return self, ()
325             
326             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
327         return self.factory
328         
329
330 class TestSimpleDHT(unittest.TestCase):
331     """Simple 2-node unit tests for the DHT."""
332     
333     timeout = 50
334     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
335                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
336                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
337                     'MAX_FAILURES': 3,
338                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
339                     'KEY_EXPIRE': 3600, 'SPEW': False, }
340
341     def setUp(self):
342         self.a = DHT()
343         self.b = DHT()
344         self.a.config = self.DHT_DEFAULTS.copy()
345         self.a.config['PORT'] = 4044
346         self.a.bootstrap = ["127.0.0.1:4044"]
347         self.a.bootstrap_node = True
348         self.a.cache_dir = '/tmp'
349         self.b.config = self.DHT_DEFAULTS.copy()
350         self.b.config['PORT'] = 4045
351         self.b.bootstrap = ["127.0.0.1:4044"]
352         self.b.cache_dir = '/tmp'
353         
354     def test_bootstrap_join(self):
355         d = self.a.join()
356         return d
357
358     def test_failed_join(self):
359         from krpc import KrpcError
360         d = self.b.join()
361         reactor.callLater(30, self.a.join)
362         def no_errors(result, self = self):
363             self.flushLoggedErrors(KrpcError)
364             return result
365         d.addCallback(no_errors)
366         return d
367         
368     def node_join(self, result):
369         d = self.b.join()
370         return d
371     
372     def test_join(self):
373         self.lastDefer = defer.Deferred()
374         d = self.a.join()
375         d.addCallback(self.node_join)
376         d.addCallback(self.lastDefer.callback)
377         return self.lastDefer
378
379     def test_normKey(self):
380         h = self.a._normKey('12345678901234567890')
381         self.failUnless(h == '12345678901234567890')
382         h = self.a._normKey('12345678901234567')
383         self.failUnless(h == '12345678901234567\000\000\000')
384         h = self.a._normKey('1234567890123456789012345')
385         self.failUnless(h == '12345678901234567890')
386         h = self.a._normKey('1234567890123456789')
387         self.failUnless(h == '1234567890123456789\000')
388         h = self.a._normKey('123456789012345678901')
389         self.failUnless(h == '12345678901234567890')
390
391     def value_stored(self, result, value):
392         self.stored -= 1
393         if self.stored == 0:
394             self.get_values()
395         
396     def store_values(self, result):
397         self.stored = 3
398         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
399         d.addCallback(self.value_stored, 4045)
400         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
401         d.addCallback(self.value_stored, 4044)
402         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
403         d.addCallback(self.value_stored, 4045)
404
405     def check_values(self, result, values):
406         self.checked -= 1
407         self.failUnless(len(result) == len(values))
408         for v in result:
409             self.failUnless(v in values)
410         if self.checked == 0:
411             self.lastDefer.callback(1)
412     
413     def get_values(self):
414         self.checked = 4
415         d = self.a.getValue(sha.new('4044').digest())
416         d.addCallback(self.check_values, [str(4044*2)])
417         d = self.b.getValue(sha.new('4044').digest())
418         d.addCallback(self.check_values, [str(4044*2)])
419         d = self.a.getValue(sha.new('4045').digest())
420         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
421         d = self.b.getValue(sha.new('4045').digest())
422         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
423
424     def test_store(self):
425         from twisted.internet.base import DelayedCall
426         DelayedCall.debug = True
427         self.lastDefer = defer.Deferred()
428         d = self.a.join()
429         d.addCallback(self.node_join)
430         d.addCallback(self.store_values)
431         return self.lastDefer
432
433     def tearDown(self):
434         self.a.leave()
435         try:
436             os.unlink(self.a.khashmir.store.db)
437         except:
438             pass
439         self.b.leave()
440         try:
441             os.unlink(self.b.khashmir.store.db)
442         except:
443             pass
444
445 class TestMultiDHT(unittest.TestCase):
446     """More complicated 20-node tests for the DHT."""
447     
448     timeout = 80
449     num = 20
450     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
451                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
452                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
453                     'MAX_FAILURES': 3,
454                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
455                     'KEY_EXPIRE': 3600, 'SPEW': False, }
456
457     def setUp(self):
458         self.l = []
459         self.startport = 4081
460         for i in range(self.num):
461             self.l.append(DHT())
462             self.l[i].config = self.DHT_DEFAULTS.copy()
463             self.l[i].config['PORT'] = self.startport + i
464             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
465             self.l[i].cache_dir = '/tmp'
466         self.l[0].bootstrap_node = True
467         
468     def node_join(self, result, next_node):
469         d = self.l[next_node].join()
470         if next_node + 1 < len(self.l):
471             d.addCallback(self.node_join, next_node + 1)
472         else:
473             d.addCallback(self.lastDefer.callback)
474     
475     def test_join(self):
476         self.timeout = 2
477         self.lastDefer = defer.Deferred()
478         d = self.l[0].join()
479         d.addCallback(self.node_join, 1)
480         return self.lastDefer
481         
482     def store_values(self, result, i = 0, j = 0):
483         if j > i:
484             j -= i+1
485             i += 1
486         if i == len(self.l):
487             self.get_values()
488         else:
489             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
490             d.addCallback(self.store_values, i, j+1)
491     
492     def get_values(self, result = None, check = None, i = 0, j = 0):
493         if result is not None:
494             self.failUnless(len(result) == len(check))
495             for v in result:
496                 self.failUnless(v in check)
497         if j >= len(self.l):
498             j -= len(self.l)
499             i += 1
500         if i == len(self.l):
501             self.lastDefer.callback(1)
502         else:
503             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
504             check = []
505             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
506                 check.append(str(k))
507             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
508
509     def store_join(self, result, next_node):
510         d = self.l[next_node].join()
511         if next_node + 1 < len(self.l):
512             d.addCallback(self.store_join, next_node + 1)
513         else:
514             d.addCallback(self.store_values)
515     
516     def test_store(self):
517         from twisted.internet.base import DelayedCall
518         DelayedCall.debug = True
519         self.lastDefer = defer.Deferred()
520         d = self.l[0].join()
521         d.addCallback(self.store_join, 1)
522         return self.lastDefer
523
524     def tearDown(self):
525         for i in self.l:
526             try:
527                 i.leave()
528                 os.unlink(i.khashmir.store.db)
529             except:
530                 pass