400b10de1d7098d692e80805c73dba9e8959cb7c
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19
20 try:
21     from twisted.web2 import channel, server, resource, http, http_headers
22     _web2 = True
23 except ImportError:
24     _web2 = False
25
26 khashmir_dir = 'apt-p2p-Khashmir'
27
28 class DHTError(Exception):
29     """Represents errors that occur in the DHT."""
30
31 class DHT:
32     """The main interface instance to the Khashmir DHT.
33     
34     @type config: C{dictionary}
35     @ivar config: the DHT configuration values
36     @type cache_dir: C{string}
37     @ivar cache_dir: the directory to use for storing files
38     @type bootstrap: C{list} of C{string}
39     @ivar bootstrap: the nodes to contact to bootstrap into the system
40     @type bootstrap_node: C{boolean}
41     @ivar bootstrap_node: whether this node is a bootstrap node
42     @type joining: L{twisted.internet.defer.Deferred}
43     @ivar joining: if a join is underway, the deferred that will signal it's end
44     @type joined: C{boolean}
45     @ivar joined: whether the DHT network has been successfully joined
46     @type outstandingJoins: C{int}
47     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
48     @type next_rejoin: C{int}
49     @ivar next_rejoin: the number of seconds before retrying the next join
50     @type foundAddrs: C{list} of (C{string}, C{int})
51     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
52     @type storing: C{dictionary}
53     @ivar storing: keys are keys for which store requests are active, values
54         are dictionaries with keys the values being stored and values the
55         deferred to call when complete
56     @type retrieving: C{dictionary}
57     @ivar retrieving: keys are the keys for which getValue requests are active,
58         values are lists of the deferreds waiting for the requests
59     @type retrieved: C{dictionary}
60     @ivar retrieved: keys are the keys for which getValue requests are active,
61         values are list of the values returned so far
62     @type factory: L{twisted.web2.channel.HTTPFactory}
63     @ivar factory: the factory to use to serve HTTP requests for statistics
64     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
65     @ivar config_parser: the configuration info for the main program
66     @type section: C{string}
67     @ivar section: the section of the configuration info that applies to the DHT
68     @type khashmir: L{khashmir.Khashmir}
69     @ivar khashmir: the khashmir DHT instance to use
70     """
71     
72     if _web2:
73         implements(IDHT, IDHTStats, IDHTStatsFactory)
74     else:
75         implements(IDHT, IDHTStats)
76         
77     def __init__(self):
78         """Initialize the DHT."""
79         self.config = None
80         self.cache_dir = ''
81         self.bootstrap = []
82         self.bootstrap_node = False
83         self.joining = None
84         self.khashmir = None
85         self.joined = False
86         self.outstandingJoins = 0
87         self.next_rejoin = 20
88         self.foundAddrs = []
89         self.storing = {}
90         self.retrieving = {}
91         self.retrieved = {}
92         self.factory = None
93     
94     def loadConfig(self, config, section):
95         """See L{apt_p2p.interfaces.IDHT}."""
96         self.config_parser = config
97         self.section = section
98         self.config = {}
99         
100         # Get some initial values
101         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
102         if not os.path.exists(self.cache_dir):
103             os.makedirs(self.cache_dir)
104         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
105         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
106         for k in self.config_parser.options(section):
107             # The numbers in the config file
108             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
109                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
110                 self.config[k] = self.config_parser.getint(section, k)
111             # The times in the config file
112             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
113                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
114                 self.config[k] = self.config_parser.gettime(section, k)
115             # The booleans in the config file
116             elif k in ['SPEW']:
117                 self.config[k] = self.config_parser.getboolean(section, k)
118             # Everything else is a string
119             else:
120                 self.config[k] = self.config_parser.get(section, k)
121     
122     def join(self, deferred = None):
123         """See L{apt_p2p.interfaces.IDHT}.
124         
125         @param deferred: the deferred to callback when the join is complete
126             (optional, defaults to creating a new deferred and returning it)
127         """
128         # Check for multiple simultaneous joins 
129         if self.joining:
130             if deferred:
131                 deferred.errback(DHTError("a join is already in progress"))
132                 return
133             else:
134                 raise DHTError, "a join is already in progress"
135
136         if deferred:
137             self.joining = deferred
138         else:
139             self.joining = defer.Deferred()
140
141         if self.config is None:
142             self.joining.errback(DHTError("configuration not loaded"))
143             return self.joining
144
145         # Create the new khashmir instance
146         if not self.khashmir:
147             self.khashmir = Khashmir(self.config, self.cache_dir)
148
149         self.outstandingJoins = 0
150         for node in self.bootstrap:
151             host, port = node.rsplit(':', 1)
152             port = int(port)
153             self.outstandingJoins += 1
154             
155             # Translate host names into IP addresses
156             if isIPAddress(host):
157                 self._join_gotIP(host, port)
158             else:
159                 reactor.resolve(host).addCallbacks(self._join_gotIP,
160                                                    self._join_resolveFailed,
161                                                    callbackArgs = (port, ),
162                                                    errbackArgs = (host, port))
163         
164         return self.joining
165
166     def _join_gotIP(self, ip, port):
167         """Join the DHT using a single bootstrap nodes IP address."""
168         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
169     
170     def _join_resolveFailed(self, err, host, port):
171         """Failed to lookup the IP address of the bootstrap node."""
172         log.msg('Failed to find an IP address for host: (%r, %r)' % (host, port))
173         log.err(err)
174         self.outstandingJoins -= 1
175         if self.outstandingJoins <= 0:
176             self.khashmir.findCloseNodes(self._join_complete)
177     
178     def _join_single(self, addr):
179         """Process the response from the bootstrap node.
180         
181         Finish the join by contacting close nodes.
182         """
183         self.outstandingJoins -= 1
184         if addr:
185             self.foundAddrs.append(addr)
186         if addr or self.outstandingJoins <= 0:
187             self.khashmir.findCloseNodes(self._join_complete)
188         log.msg('Got back from bootstrap node: %r' % (addr,))
189     
190     def _join_error(self, failure = None):
191         """Process an error in contacting the bootstrap node.
192         
193         If no bootstrap nodes remain, finish the process by contacting
194         close nodes.
195         """
196         self.outstandingJoins -= 1
197         log.msg("bootstrap node could not be reached")
198         if self.outstandingJoins <= 0:
199             self.khashmir.findCloseNodes(self._join_complete)
200
201     def _join_complete(self, result):
202         """End the joining process and return the addresses found for this node."""
203         if not self.joined and isinstance(result, list) and len(result) > 1:
204             self.joined = True
205         if self.joining and self.outstandingJoins <= 0:
206             df = self.joining
207             self.joining = None
208             if self.joined or self.bootstrap_node:
209                 self.joined = True
210                 df.callback(self.foundAddrs)
211             else:
212                 # Try to join later using exponential backoff delays
213                 log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
214                 reactor.callLater(self.next_rejoin, self.join, df)
215                 self.next_rejoin *= 2
216         
217     def getAddrs(self):
218         """Get the list of addresses returned by bootstrap nodes for this node."""
219         return self.foundAddrs
220         
221     def leave(self):
222         """See L{apt_p2p.interfaces.IDHT}."""
223         if self.config is None:
224             raise DHTError, "configuration not loaded"
225         
226         if self.joined or self.joining:
227             if self.joining:
228                 self.joining.errback(DHTError('still joining when leave was called'))
229                 self.joining = None
230             self.joined = False
231             self.khashmir.shutdown()
232         
233     def _normKey(self, key, bits=None, bytes=None):
234         """Normalize the length of keys used in the DHT."""
235         bits = self.config["HASH_LENGTH"]
236         if bits is not None:
237             bytes = (bits - 1) // 8 + 1
238         else:
239             if bytes is None:
240                 raise DHTError, "you must specify one of bits or bytes for normalization"
241             
242         # Extend short keys with null bytes
243         if len(key) < bytes:
244             key = key + '\000'*(bytes - len(key))
245         # Truncate long keys
246         elif len(key) > bytes:
247             key = key[:bytes]
248         return key
249
250     def getValue(self, key):
251         """See L{apt_p2p.interfaces.IDHT}."""
252         d = defer.Deferred()
253
254         if self.config is None:
255             d.errback(DHTError("configuration not loaded"))
256             return d
257         if not self.joined:
258             d.errback(DHTError("have not joined a network yet"))
259             return d
260         
261         key = self._normKey(key)
262
263         if key not in self.retrieving:
264             self.khashmir.valueForKey(key, self._getValue)
265         self.retrieving.setdefault(key, []).append(d)
266         return d
267         
268     def _getValue(self, key, result):
269         """Process a returned list of values from the DHT."""
270         # Save the list of values to return when it is complete
271         if result:
272             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
273         else:
274             # Empty list, the get is complete, return the result
275             final_result = []
276             if key in self.retrieved:
277                 final_result = self.retrieved[key]
278                 del self.retrieved[key]
279             for i in range(len(self.retrieving[key])):
280                 d = self.retrieving[key].pop(0)
281                 d.callback(final_result)
282             del self.retrieving[key]
283
284     def storeValue(self, key, value):
285         """See L{apt_p2p.interfaces.IDHT}."""
286         d = defer.Deferred()
287
288         if self.config is None:
289             d.errback(DHTError("configuration not loaded"))
290             return d
291         if not self.joined:
292             d.errback(DHTError("have not joined a network yet"))
293             return d
294
295         key = self._normKey(key)
296         bvalue = bencode(value)
297
298         if key in self.storing and bvalue in self.storing[key]:
299             raise DHTError, "already storing that key with the same value"
300
301         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
302         self.storing.setdefault(key, {})[bvalue] = d
303         return d
304     
305     def _storeValue(self, key, bvalue, result):
306         """Process the response from the DHT."""
307         if key in self.storing and bvalue in self.storing[key]:
308             # Check if the store succeeded
309             if len(result) > 0:
310                 self.storing[key][bvalue].callback(result)
311             else:
312                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
313             del self.storing[key][bvalue]
314             if len(self.storing[key].keys()) == 0:
315                 del self.storing[key]
316     
317     def getStats(self):
318         """See L{apt_p2p.interfaces.IDHTStats}."""
319         return self.khashmir.getStats()
320
321     def getStatsFactory(self):
322         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
323         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
324         if self.factory is None:
325             # Create a simple HTTP factory for stats
326             class StatsResource(resource.Resource):
327                 def __init__(self, manager):
328                     self.manager = manager
329                 def render(self, ctx):
330                     return http.Response(
331                         200,
332                         {'content-type': http_headers.MimeType('text', 'html')},
333                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
334                 def locateChild(self, request, segments):
335                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
336                     return self, ()
337             
338             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
339         return self.factory
340         
341
342 class TestSimpleDHT(unittest.TestCase):
343     """Simple 2-node unit tests for the DHT."""
344     
345     timeout = 50
346     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
347                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
348                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
349                     'MAX_FAILURES': 3,
350                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
351                     'KEY_EXPIRE': 3600, 'SPEW': False, }
352
353     def setUp(self):
354         self.a = DHT()
355         self.b = DHT()
356         self.a.config = self.DHT_DEFAULTS.copy()
357         self.a.config['PORT'] = 4044
358         self.a.bootstrap = ["127.0.0.1:4044"]
359         self.a.bootstrap_node = True
360         self.a.cache_dir = '/tmp'
361         self.b.config = self.DHT_DEFAULTS.copy()
362         self.b.config['PORT'] = 4045
363         self.b.bootstrap = ["127.0.0.1:4044"]
364         self.b.cache_dir = '/tmp'
365         
366     def test_bootstrap_join(self):
367         d = self.a.join()
368         return d
369
370     def no_krpc_errors(self, result):
371         from krpc import KrpcError
372         self.flushLoggedErrors(KrpcError)
373         return result
374
375     def test_failed_join(self):
376         d = self.b.join()
377         reactor.callLater(30, self.a.join)
378         d.addCallback(self.no_krpc_errors)
379         return d
380         
381     def node_join(self, result):
382         d = self.b.join()
383         return d
384     
385     def test_join(self):
386         d = self.a.join()
387         d.addCallback(self.node_join)
388         return d
389
390     def test_timeout_retransmit(self):
391         d = self.b.join()
392         reactor.callLater(4, self.a.join)
393         return d
394
395     def test_normKey(self):
396         h = self.a._normKey('12345678901234567890')
397         self.failUnless(h == '12345678901234567890')
398         h = self.a._normKey('12345678901234567')
399         self.failUnless(h == '12345678901234567\000\000\000')
400         h = self.a._normKey('1234567890123456789012345')
401         self.failUnless(h == '12345678901234567890')
402         h = self.a._normKey('1234567890123456789')
403         self.failUnless(h == '1234567890123456789\000')
404         h = self.a._normKey('123456789012345678901')
405         self.failUnless(h == '12345678901234567890')
406
407     def value_stored(self, result, value):
408         self.stored -= 1
409         if self.stored == 0:
410             self.get_values()
411         
412     def store_values(self, result):
413         self.stored = 3
414         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
415         d.addCallback(self.value_stored, 4045)
416         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
417         d.addCallback(self.value_stored, 4044)
418         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
419         d.addCallback(self.value_stored, 4045)
420
421     def check_values(self, result, values):
422         self.checked -= 1
423         self.failUnless(len(result) == len(values))
424         for v in result:
425             self.failUnless(v in values)
426         if self.checked == 0:
427             self.lastDefer.callback(1)
428     
429     def get_values(self):
430         self.checked = 4
431         d = self.a.getValue(sha.new('4044').digest())
432         d.addCallback(self.check_values, [str(4044*2)])
433         d = self.b.getValue(sha.new('4044').digest())
434         d.addCallback(self.check_values, [str(4044*2)])
435         d = self.a.getValue(sha.new('4045').digest())
436         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
437         d = self.b.getValue(sha.new('4045').digest())
438         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
439
440     def test_store(self):
441         from twisted.internet.base import DelayedCall
442         DelayedCall.debug = True
443         self.lastDefer = defer.Deferred()
444         d = self.a.join()
445         d.addCallback(self.node_join)
446         d.addCallback(self.store_values)
447         return self.lastDefer
448
449     def tearDown(self):
450         self.a.leave()
451         try:
452             os.unlink(self.a.khashmir.store.db)
453         except:
454             pass
455         self.b.leave()
456         try:
457             os.unlink(self.b.khashmir.store.db)
458         except:
459             pass
460
461 class TestMultiDHT(unittest.TestCase):
462     """More complicated 20-node tests for the DHT."""
463     
464     timeout = 80
465     num = 20
466     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
467                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
468                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
469                     'MAX_FAILURES': 3,
470                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
471                     'KEY_EXPIRE': 3600, 'SPEW': False, }
472
473     def setUp(self):
474         self.l = []
475         self.startport = 4081
476         for i in range(self.num):
477             self.l.append(DHT())
478             self.l[i].config = self.DHT_DEFAULTS.copy()
479             self.l[i].config['PORT'] = self.startport + i
480             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
481             self.l[i].cache_dir = '/tmp'
482         self.l[0].bootstrap_node = True
483         
484     def node_join(self, result, next_node):
485         d = self.l[next_node].join()
486         if next_node + 1 < len(self.l):
487             d.addCallback(self.node_join, next_node + 1)
488         else:
489             d.addCallback(self.lastDefer.callback)
490     
491     def test_join(self):
492         self.timeout = 2
493         self.lastDefer = defer.Deferred()
494         d = self.l[0].join()
495         d.addCallback(self.node_join, 1)
496         return self.lastDefer
497         
498     def store_values(self, result, i = 0, j = 0):
499         if j > i:
500             j -= i+1
501             i += 1
502         if i == len(self.l):
503             self.get_values()
504         else:
505             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
506             d.addCallback(self.store_values, i, j+1)
507     
508     def get_values(self, result = None, check = None, i = 0, j = 0):
509         if result is not None:
510             self.failUnless(len(result) == len(check))
511             for v in result:
512                 self.failUnless(v in check)
513         if j >= len(self.l):
514             j -= len(self.l)
515             i += 1
516         if i == len(self.l):
517             self.lastDefer.callback(1)
518         else:
519             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
520             check = []
521             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
522                 check.append(str(k))
523             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
524
525     def store_join(self, result, next_node):
526         d = self.l[next_node].join()
527         if next_node + 1 < len(self.l):
528             d.addCallback(self.store_join, next_node + 1)
529         else:
530             d.addCallback(self.store_values)
531     
532     def test_store(self):
533         from twisted.internet.base import DelayedCall
534         DelayedCall.debug = True
535         self.lastDefer = defer.Deferred()
536         d = self.l[0].join()
537         d.addCallback(self.store_join, 1)
538         return self.lastDefer
539
540     def tearDown(self):
541         for i in self.l:
542             try:
543                 i.leave()
544                 os.unlink(i.khashmir.store.db)
545             except:
546                 pass