3b206dd86597b4c53d0c417d11a8adbfaa9a3f07
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19 from khash import HASH_LENGTH
20
21 try:
22     from twisted.web2 import channel, server, resource, http, http_headers
23     _web2 = True
24 except ImportError:
25     _web2 = False
26
27 khashmir_dir = 'apt-p2p-Khashmir'
28
29 class DHTError(Exception):
30     """Represents errors that occur in the DHT."""
31
32 class DHT:
33     """The main interface instance to the Khashmir DHT.
34     
35     @type config: C{dictionary}
36     @ivar config: the DHT configuration values
37     @type cache_dir: C{string}
38     @ivar cache_dir: the directory to use for storing files
39     @type bootstrap: C{list} of C{string}
40     @ivar bootstrap: the nodes to contact to bootstrap into the system
41     @type bootstrap_node: C{boolean}
42     @ivar bootstrap_node: whether this node is a bootstrap node
43     @type joining: L{twisted.internet.defer.Deferred}
44     @ivar joining: if a join is underway, the deferred that will signal it's end
45     @type joined: C{boolean}
46     @ivar joined: whether the DHT network has been successfully joined
47     @type outstandingJoins: C{int}
48     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
49     @type next_rejoin: C{int}
50     @ivar next_rejoin: the number of seconds before retrying the next join
51     @type foundAddrs: C{list} of (C{string}, C{int})
52     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
53     @type storing: C{dictionary}
54     @ivar storing: keys are keys for which store requests are active, values
55         are dictionaries with keys the values being stored and values the
56         deferred to call when complete
57     @type retrieving: C{dictionary}
58     @ivar retrieving: keys are the keys for which getValue requests are active,
59         values are lists of the deferreds waiting for the requests
60     @type retrieved: C{dictionary}
61     @ivar retrieved: keys are the keys for which getValue requests are active,
62         values are list of the values returned so far
63     @type factory: L{twisted.web2.channel.HTTPFactory}
64     @ivar factory: the factory to use to serve HTTP requests for statistics
65     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
66     @ivar config_parser: the configuration info for the main program
67     @type section: C{string}
68     @ivar section: the section of the configuration info that applies to the DHT
69     @type khashmir: L{khashmir.Khashmir}
70     @ivar khashmir: the khashmir DHT instance to use
71     """
72     
73     if _web2:
74         implements(IDHT, IDHTStats, IDHTStatsFactory)
75     else:
76         implements(IDHT, IDHTStats)
77         
78     def __init__(self):
79         """Initialize the DHT."""
80         self.config = None
81         self.cache_dir = ''
82         self.bootstrap = []
83         self.bootstrap_node = False
84         self.joining = None
85         self.khashmir = None
86         self.joined = False
87         self.outstandingJoins = 0
88         self.next_rejoin = 20
89         self.foundAddrs = []
90         self.storing = {}
91         self.retrieving = {}
92         self.retrieved = {}
93         self.factory = None
94     
95     def loadConfig(self, config, section):
96         """See L{apt_p2p.interfaces.IDHT}."""
97         self.config_parser = config
98         self.section = section
99         self.config = {}
100         
101         # Get some initial values
102         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
103         if not os.path.exists(self.cache_dir):
104             os.makedirs(self.cache_dir)
105         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
106         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
107         for k in self.config_parser.options(section):
108             # The numbers in the config file
109             if k in ['CONCURRENT_REQS', 'STORE_REDUNDANCY', 
110                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
111                 self.config[k] = self.config_parser.getint(section, k)
112             # The times in the config file
113             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
114                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
115                 self.config[k] = self.config_parser.gettime(section, k)
116             # The booleans in the config file
117             elif k in ['SPEW']:
118                 self.config[k] = self.config_parser.getboolean(section, k)
119             # Everything else is a string
120             else:
121                 self.config[k] = self.config_parser.get(section, k)
122     
123     def join(self, deferred = None):
124         """See L{apt_p2p.interfaces.IDHT}.
125         
126         @param deferred: the deferred to callback when the join is complete
127             (optional, defaults to creating a new deferred and returning it)
128         """
129         # Check for multiple simultaneous joins 
130         if self.joining:
131             if deferred:
132                 deferred.errback(DHTError("a join is already in progress"))
133                 return
134             else:
135                 raise DHTError, "a join is already in progress"
136
137         if deferred:
138             self.joining = deferred
139         else:
140             self.joining = defer.Deferred()
141
142         if self.config is None:
143             self.joining.errback(DHTError("configuration not loaded"))
144             return self.joining
145
146         # Create the new khashmir instance
147         if not self.khashmir:
148             self.khashmir = Khashmir(self.config, self.cache_dir)
149
150         self.outstandingJoins = 0
151         for node in self.bootstrap:
152             host, port = node.rsplit(':', 1)
153             port = int(port)
154             self.outstandingJoins += 1
155             
156             # Translate host names into IP addresses
157             if isIPAddress(host):
158                 self._join_gotIP(host, port)
159             else:
160                 reactor.resolve(host).addCallbacks(self._join_gotIP,
161                                                    self._join_resolveFailed,
162                                                    callbackArgs = (port, ),
163                                                    errbackArgs = (host, port))
164         
165         return self.joining
166
167     def _join_gotIP(self, ip, port):
168         """Join the DHT using a single bootstrap nodes IP address."""
169         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
170     
171     def _join_resolveFailed(self, err, host, port):
172         """Failed to lookup the IP address of the bootstrap node."""
173         log.msg('Failed to find an IP address for host: (%r, %r)' % (host, port))
174         log.err(err)
175         self.outstandingJoins -= 1
176         if self.outstandingJoins <= 0:
177             self.khashmir.findCloseNodes(self._join_complete)
178     
179     def _join_single(self, addr):
180         """Process the response from the bootstrap node.
181         
182         Finish the join by contacting close nodes.
183         """
184         self.outstandingJoins -= 1
185         if addr:
186             self.foundAddrs.append(addr)
187         if addr or self.outstandingJoins <= 0:
188             self.khashmir.findCloseNodes(self._join_complete)
189         log.msg('Got back from bootstrap node: %r' % (addr,))
190     
191     def _join_error(self, failure = None):
192         """Process an error in contacting the bootstrap node.
193         
194         If no bootstrap nodes remain, finish the process by contacting
195         close nodes.
196         """
197         self.outstandingJoins -= 1
198         log.msg("bootstrap node could not be reached")
199         if self.outstandingJoins <= 0:
200             self.khashmir.findCloseNodes(self._join_complete)
201
202     def _join_complete(self, result):
203         """End the joining process and return the addresses found for this node."""
204         if not self.joined and isinstance(result, list) and len(result) > 1:
205             self.joined = True
206         if self.joining and self.outstandingJoins <= 0:
207             df = self.joining
208             self.joining = None
209             if self.joined or self.bootstrap_node:
210                 self.joined = True
211                 df.callback(self.foundAddrs)
212             else:
213                 # Try to join later using exponential backoff delays
214                 log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
215                 reactor.callLater(self.next_rejoin, self.join, df)
216                 self.next_rejoin *= 2
217         
218     def getAddrs(self):
219         """Get the list of addresses returned by bootstrap nodes for this node."""
220         return self.foundAddrs
221         
222     def leave(self):
223         """See L{apt_p2p.interfaces.IDHT}."""
224         if self.config is None:
225             raise DHTError, "configuration not loaded"
226         
227         if self.joined or self.joining:
228             if self.joining:
229                 self.joining.errback(DHTError('still joining when leave was called'))
230                 self.joining = None
231             self.joined = False
232             self.khashmir.shutdown()
233         
234     def _normKey(self, key):
235         """Normalize the length of keys used in the DHT."""
236         # Extend short keys with null bytes
237         if len(key) < HASH_LENGTH:
238             key = key + '\000'*(HASH_LENGTH - len(key))
239         # Truncate long keys
240         elif len(key) > HASH_LENGTH:
241             key = key[:HASH_LENGTH]
242         return key
243
244     def getValue(self, key):
245         """See L{apt_p2p.interfaces.IDHT}."""
246         if self.config is None:
247             return defer.fail(DHTError("configuration not loaded"))
248         if not self.joined:
249             return defer.fail(DHTError("have not joined a network yet"))
250         
251         d = defer.Deferred()
252         key = self._normKey(key)
253
254         if key not in self.retrieving:
255             self.khashmir.valueForKey(key, self._getValue)
256         self.retrieving.setdefault(key, []).append(d)
257         return d
258         
259     def _getValue(self, key, result):
260         """Process a returned list of values from the DHT."""
261         # Save the list of values to return when it is complete
262         if result:
263             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
264         else:
265             # Empty list, the get is complete, return the result
266             final_result = []
267             if key in self.retrieved:
268                 final_result = self.retrieved[key]
269                 del self.retrieved[key]
270             for i in range(len(self.retrieving[key])):
271                 d = self.retrieving[key].pop(0)
272                 d.callback(final_result)
273             del self.retrieving[key]
274
275     def storeValue(self, key, value):
276         """See L{apt_p2p.interfaces.IDHT}."""
277         if self.config is None:
278             return defer.fail(DHTError("configuration not loaded"))
279         if not self.joined:
280             return defer.fail(DHTError("have not joined a network yet"))
281
282         d = defer.Deferred()
283         key = self._normKey(key)
284         bvalue = bencode(value)
285
286         if key in self.storing and bvalue in self.storing[key]:
287             raise DHTError, "already storing that key with the same value"
288
289         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
290         self.storing.setdefault(key, {})[bvalue] = d
291         return d
292     
293     def _storeValue(self, key, bvalue, result):
294         """Process the response from the DHT."""
295         if key in self.storing and bvalue in self.storing[key]:
296             # Check if the store succeeded
297             if len(result) > 0:
298                 self.storing[key][bvalue].callback(result)
299             else:
300                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
301             del self.storing[key][bvalue]
302             if len(self.storing[key].keys()) == 0:
303                 del self.storing[key]
304     
305     def getStats(self):
306         """See L{apt_p2p.interfaces.IDHTStats}."""
307         return self.khashmir.getStats()
308
309     def getStatsFactory(self):
310         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
311         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
312         if self.factory is None:
313             # Create a simple HTTP factory for stats
314             class StatsResource(resource.Resource):
315                 def __init__(self, manager):
316                     self.manager = manager
317                 def render(self, ctx):
318                     return http.Response(
319                         200,
320                         {'content-type': http_headers.MimeType('text', 'html')},
321                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
322                 def locateChild(self, request, segments):
323                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
324                     return self, ()
325             
326             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
327         return self.factory
328         
329
330 class TestSimpleDHT(unittest.TestCase):
331     """Simple 2-node unit tests for the DHT."""
332     
333     timeout = 50
334     DHT_DEFAULTS = {'PORT': 9977,
335                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
336                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
337                     'MAX_FAILURES': 3,
338                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
339                     'KRPC_TIMEOUT': 14, 'KRPC_INITIAL_DELAY': 2,
340                     'KEY_EXPIRE': 3600, 'SPEW': False, }
341
342     def setUp(self):
343         self.a = DHT()
344         self.b = DHT()
345         self.a.config = self.DHT_DEFAULTS.copy()
346         self.a.config['PORT'] = 4044
347         self.a.bootstrap = ["127.0.0.1:4044"]
348         self.a.bootstrap_node = True
349         self.a.cache_dir = '/tmp'
350         self.b.config = self.DHT_DEFAULTS.copy()
351         self.b.config['PORT'] = 4045
352         self.b.bootstrap = ["127.0.0.1:4044"]
353         self.b.cache_dir = '/tmp'
354         
355     def test_bootstrap_join(self):
356         d = self.a.join()
357         return d
358
359     def no_krpc_errors(self, result):
360         from krpc import KrpcError
361         self.flushLoggedErrors(KrpcError)
362         return result
363
364     def test_failed_join(self):
365         d = self.b.join()
366         reactor.callLater(30, self.a.join)
367         d.addCallback(self.no_krpc_errors)
368         return d
369         
370     def node_join(self, result):
371         d = self.b.join()
372         return d
373     
374     def test_join(self):
375         d = self.a.join()
376         d.addCallback(self.node_join)
377         return d
378
379     def test_timeout_retransmit(self):
380         d = self.b.join()
381         reactor.callLater(4, self.a.join)
382         return d
383
384     def test_normKey(self):
385         h = self.a._normKey('12345678901234567890')
386         self.failUnless(h == '12345678901234567890')
387         h = self.a._normKey('12345678901234567')
388         self.failUnless(h == '12345678901234567\000\000\000')
389         h = self.a._normKey('1234567890123456789012345')
390         self.failUnless(h == '12345678901234567890')
391         h = self.a._normKey('1234567890123456789')
392         self.failUnless(h == '1234567890123456789\000')
393         h = self.a._normKey('123456789012345678901')
394         self.failUnless(h == '12345678901234567890')
395
396     def value_stored(self, result, value):
397         self.stored -= 1
398         if self.stored == 0:
399             self.get_values()
400         
401     def store_values(self, result):
402         self.stored = 3
403         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
404         d.addCallback(self.value_stored, 4045)
405         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
406         d.addCallback(self.value_stored, 4044)
407         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
408         d.addCallback(self.value_stored, 4045)
409
410     def check_values(self, result, values):
411         self.checked -= 1
412         self.failUnless(len(result) == len(values))
413         for v in result:
414             self.failUnless(v in values)
415         if self.checked == 0:
416             self.lastDefer.callback(1)
417     
418     def get_values(self):
419         self.checked = 4
420         d = self.a.getValue(sha.new('4044').digest())
421         d.addCallback(self.check_values, [str(4044*2)])
422         d = self.b.getValue(sha.new('4044').digest())
423         d.addCallback(self.check_values, [str(4044*2)])
424         d = self.a.getValue(sha.new('4045').digest())
425         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
426         d = self.b.getValue(sha.new('4045').digest())
427         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
428
429     def test_store(self):
430         from twisted.internet.base import DelayedCall
431         DelayedCall.debug = True
432         self.lastDefer = defer.Deferred()
433         d = self.a.join()
434         d.addCallback(self.node_join)
435         d.addCallback(self.store_values)
436         return self.lastDefer
437
438     def tearDown(self):
439         self.a.leave()
440         try:
441             os.unlink(self.a.khashmir.store.db)
442         except:
443             pass
444         self.b.leave()
445         try:
446             os.unlink(self.b.khashmir.store.db)
447         except:
448             pass
449
450 class TestMultiDHT(unittest.TestCase):
451     """More complicated 20-node tests for the DHT."""
452     
453     timeout = 80
454     num = 20
455     DHT_DEFAULTS = {'PORT': 9977,
456                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
457                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
458                     'MAX_FAILURES': 3,
459                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
460                     'KRPC_TIMEOUT': 14, 'KRPC_INITIAL_DELAY': 2,
461                     'KEY_EXPIRE': 3600, 'SPEW': False, }
462
463     def setUp(self):
464         self.l = []
465         self.startport = 4081
466         for i in range(self.num):
467             self.l.append(DHT())
468             self.l[i].config = self.DHT_DEFAULTS.copy()
469             self.l[i].config['PORT'] = self.startport + i
470             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
471             self.l[i].cache_dir = '/tmp'
472         self.l[0].bootstrap_node = True
473         
474     def node_join(self, result, next_node):
475         d = self.l[next_node].join()
476         if next_node + 1 < len(self.l):
477             d.addCallback(self.node_join, next_node + 1)
478         else:
479             d.addCallback(self.lastDefer.callback)
480     
481     def test_join(self):
482         self.timeout = 2
483         self.lastDefer = defer.Deferred()
484         d = self.l[0].join()
485         d.addCallback(self.node_join, 1)
486         return self.lastDefer
487         
488     def store_values(self, result, i = 0, j = 0):
489         if j > i:
490             j -= i+1
491             i += 1
492         if i == len(self.l):
493             self.get_values()
494         else:
495             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
496             d.addCallback(self.store_values, i, j+1)
497     
498     def get_values(self, result = None, check = None, i = 0, j = 0):
499         if result is not None:
500             self.failUnless(len(result) == len(check))
501             for v in result:
502                 self.failUnless(v in check)
503         if j >= len(self.l):
504             j -= len(self.l)
505             i += 1
506         if i == len(self.l):
507             self.lastDefer.callback(1)
508         else:
509             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
510             check = []
511             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
512                 check.append(str(k))
513             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
514
515     def store_join(self, result, next_node):
516         d = self.l[next_node].join()
517         if next_node + 1 < len(self.l):
518             d.addCallback(self.store_join, next_node + 1)
519         else:
520             d.addCallback(self.store_values)
521     
522     def test_store(self):
523         from twisted.internet.base import DelayedCall
524         DelayedCall.debug = True
525         self.lastDefer = defer.Deferred()
526         d = self.l[0].join()
527         d.addCallback(self.store_join, 1)
528         return self.lastDefer
529
530     def tearDown(self):
531         for i in self.l:
532             try:
533                 i.leave()
534                 os.unlink(i.khashmir.store.db)
535             except:
536                 pass