af247302f9ed6dc0b3838bbef09d2ebf681d54f8
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19 from khash import HASH_LENGTH
20
21 try:
22     from twisted.web2 import channel, server, resource, http, http_headers
23     _web2 = True
24 except ImportError:
25     _web2 = False
26
27 khashmir_dir = 'apt-p2p-Khashmir'
28
29 class DHTError(Exception):
30     """Represents errors that occur in the DHT."""
31
32 class DHT:
33     """The main interface instance to the Khashmir DHT.
34     
35     @type config: C{dictionary}
36     @ivar config: the DHT configuration values
37     @type cache_dir: C{string}
38     @ivar cache_dir: the directory to use for storing files
39     @type bootstrap: C{list} of C{string}
40     @ivar bootstrap: the nodes to contact to bootstrap into the system
41     @type bootstrap_node: C{boolean}
42     @ivar bootstrap_node: whether this node is a bootstrap node
43     @type joining: L{twisted.internet.defer.Deferred}
44     @ivar joining: if a join is underway, the deferred that will signal it's end
45     @type joined: C{boolean}
46     @ivar joined: whether the DHT network has been successfully joined
47     @type outstandingJoins: C{int}
48     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
49     @type next_rejoin: C{int}
50     @ivar next_rejoin: the number of seconds before retrying the next join
51     @type foundAddrs: C{list} of (C{string}, C{int})
52     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
53     @type storing: C{dictionary}
54     @ivar storing: keys are keys for which store requests are active, values
55         are dictionaries with keys the values being stored and values the
56         deferred to call when complete
57     @type retrieving: C{dictionary}
58     @ivar retrieving: keys are the keys for which getValue requests are active,
59         values are lists of the deferreds waiting for the requests
60     @type retrieved: C{dictionary}
61     @ivar retrieved: keys are the keys for which getValue requests are active,
62         values are list of the values returned so far
63     @type factory: L{twisted.web2.channel.HTTPFactory}
64     @ivar factory: the factory to use to serve HTTP requests for statistics
65     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
66     @ivar config_parser: the configuration info for the main program
67     @type section: C{string}
68     @ivar section: the section of the configuration info that applies to the DHT
69     @type khashmir: L{khashmir.Khashmir}
70     @ivar khashmir: the khashmir DHT instance to use
71     """
72     
73     if _web2:
74         implements(IDHT, IDHTStats, IDHTStatsFactory)
75     else:
76         implements(IDHT, IDHTStats)
77         
78     def __init__(self):
79         """Initialize the DHT."""
80         self.config = None
81         self.cache_dir = ''
82         self.bootstrap = []
83         self.bootstrap_node = False
84         self.joining = None
85         self.khashmir = None
86         self.joined = False
87         self.outstandingJoins = 0
88         self.next_rejoin = 20
89         self.foundAddrs = []
90         self.storing = {}
91         self.retrieving = {}
92         self.retrieved = {}
93         self.factory = None
94     
95     def loadConfig(self, config, section):
96         """See L{apt_p2p.interfaces.IDHT}."""
97         self.config_parser = config
98         self.section = section
99         self.config = {}
100         
101         # Get some initial values
102         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
103         if not os.path.exists(self.cache_dir):
104             os.makedirs(self.cache_dir)
105         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
106         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
107         for k in self.config_parser.options(section):
108             # The numbers in the config file
109             if k in ['CONCURRENT_REQS', 'STORE_REDUNDANCY', 
110                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
111                 self.config[k] = self.config_parser.getint(section, k)
112             # The times in the config file
113             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
114                        'BUCKET_STALENESS', 'KEY_EXPIRE',
115                        'KRPC_TIMEOUT', 'KRPC_INITIAL_DELAY']:
116                 self.config[k] = self.config_parser.gettime(section, k)
117             # The booleans in the config file
118             elif k in ['SPEW', 'LOCAL_OK']:
119                 self.config[k] = self.config_parser.getboolean(section, k)
120             # Everything else is a string
121             else:
122                 self.config[k] = self.config_parser.get(section, k)
123     
124     def join(self, deferred = None):
125         """See L{apt_p2p.interfaces.IDHT}.
126         
127         @param deferred: the deferred to callback when the join is complete
128             (optional, defaults to creating a new deferred and returning it)
129         """
130         # Check for multiple simultaneous joins 
131         if self.joining:
132             if deferred:
133                 deferred.errback(DHTError("a join is already in progress"))
134                 return
135             else:
136                 raise DHTError, "a join is already in progress"
137
138         if deferred:
139             self.joining = deferred
140         else:
141             self.joining = defer.Deferred()
142
143         if self.config is None:
144             self.joining.errback(DHTError("configuration not loaded"))
145             return self.joining
146
147         # Create the new khashmir instance
148         if not self.khashmir:
149             self.khashmir = Khashmir(self.config, self.cache_dir)
150
151         self.outstandingJoins = 0
152         for node in self.bootstrap:
153             host, port = node.rsplit(':', 1)
154             port = int(port)
155             self.outstandingJoins += 1
156             
157             # Translate host names into IP addresses
158             if isIPAddress(host):
159                 self._join_gotIP(host, port)
160             else:
161                 reactor.resolve(host).addCallbacks(self._join_gotIP,
162                                                    self._join_resolveFailed,
163                                                    callbackArgs = (port, ),
164                                                    errbackArgs = (host, port))
165         
166         return self.joining
167
168     def _join_gotIP(self, ip, port):
169         """Join the DHT using a single bootstrap nodes IP address."""
170         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
171     
172     def _join_resolveFailed(self, err, host, port):
173         """Failed to lookup the IP address of the bootstrap node."""
174         log.msg('Failed to find an IP address for host: (%r, %r)' % (host, port))
175         log.err(err)
176         self.outstandingJoins -= 1
177         if self.outstandingJoins <= 0:
178             self.khashmir.findCloseNodes(self._join_complete)
179     
180     def _join_single(self, addr):
181         """Process the response from the bootstrap node.
182         
183         Finish the join by contacting close nodes.
184         """
185         self.outstandingJoins -= 1
186         if addr:
187             self.foundAddrs.append(addr)
188         if addr or self.outstandingJoins <= 0:
189             self.khashmir.findCloseNodes(self._join_complete)
190         log.msg('Got back from bootstrap node: %r' % (addr,))
191     
192     def _join_error(self, failure = None):
193         """Process an error in contacting the bootstrap node.
194         
195         If no bootstrap nodes remain, finish the process by contacting
196         close nodes.
197         """
198         self.outstandingJoins -= 1
199         log.msg("bootstrap node could not be reached")
200         if self.outstandingJoins <= 0:
201             self.khashmir.findCloseNodes(self._join_complete)
202
203     def _join_complete(self, result):
204         """End the joining process and return the addresses found for this node."""
205         if not self.joined and isinstance(result, list) and len(result) > 1:
206             self.joined = True
207         if self.joining and self.outstandingJoins <= 0:
208             df = self.joining
209             self.joining = None
210             if self.joined or self.bootstrap_node:
211                 self.joined = True
212                 df.callback(self.foundAddrs)
213             else:
214                 # Try to join later using exponential backoff delays
215                 log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
216                 reactor.callLater(self.next_rejoin, self.join, df)
217                 self.next_rejoin *= 2
218         
219     def getAddrs(self):
220         """Get the list of addresses returned by bootstrap nodes for this node."""
221         return self.foundAddrs
222         
223     def leave(self):
224         """See L{apt_p2p.interfaces.IDHT}."""
225         if self.config is None:
226             raise DHTError, "configuration not loaded"
227         
228         if self.joined or self.joining:
229             if self.joining:
230                 self.joining.errback(DHTError('still joining when leave was called'))
231                 self.joining = None
232             self.joined = False
233             self.khashmir.shutdown()
234         
235     def _normKey(self, key):
236         """Normalize the length of keys used in the DHT."""
237         # Extend short keys with null bytes
238         if len(key) < HASH_LENGTH:
239             key = key + '\000'*(HASH_LENGTH - len(key))
240         # Truncate long keys
241         elif len(key) > HASH_LENGTH:
242             key = key[:HASH_LENGTH]
243         return key
244
245     def getValue(self, key):
246         """See L{apt_p2p.interfaces.IDHT}."""
247         if self.config is None:
248             return defer.fail(DHTError("configuration not loaded"))
249         if not self.joined:
250             return defer.fail(DHTError("have not joined a network yet"))
251         
252         d = defer.Deferred()
253         key = self._normKey(key)
254
255         if key not in self.retrieving:
256             self.khashmir.valueForKey(key, self._getValue)
257         self.retrieving.setdefault(key, []).append(d)
258         return d
259         
260     def _getValue(self, key, result):
261         """Process a returned list of values from the DHT."""
262         # Save the list of values to return when it is complete
263         if result:
264             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
265         else:
266             # Empty list, the get is complete, return the result
267             final_result = []
268             if key in self.retrieved:
269                 final_result = self.retrieved[key]
270                 del self.retrieved[key]
271             for i in range(len(self.retrieving[key])):
272                 d = self.retrieving[key].pop(0)
273                 d.callback(final_result)
274             del self.retrieving[key]
275
276     def storeValue(self, key, value):
277         """See L{apt_p2p.interfaces.IDHT}."""
278         if self.config is None:
279             return defer.fail(DHTError("configuration not loaded"))
280         if not self.joined:
281             return defer.fail(DHTError("have not joined a network yet"))
282
283         d = defer.Deferred()
284         key = self._normKey(key)
285         bvalue = bencode(value)
286
287         if key in self.storing and bvalue in self.storing[key]:
288             raise DHTError, "already storing that key with the same value"
289
290         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
291         self.storing.setdefault(key, {})[bvalue] = d
292         return d
293     
294     def _storeValue(self, key, bvalue, result):
295         """Process the response from the DHT."""
296         if key in self.storing and bvalue in self.storing[key]:
297             # Check if the store succeeded
298             if len(result) > 0:
299                 self.storing[key][bvalue].callback(result)
300             else:
301                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
302             del self.storing[key][bvalue]
303             if len(self.storing[key].keys()) == 0:
304                 del self.storing[key]
305     
306     def getStats(self):
307         """See L{apt_p2p.interfaces.IDHTStats}."""
308         return self.khashmir.getStats()
309
310     def getStatsFactory(self):
311         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
312         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
313         if self.factory is None:
314             # Create a simple HTTP factory for stats
315             class StatsResource(resource.Resource):
316                 def __init__(self, manager):
317                     self.manager = manager
318                 def render(self, ctx):
319                     return http.Response(
320                         200,
321                         {'content-type': http_headers.MimeType('text', 'html')},
322                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
323                 def locateChild(self, request, segments):
324                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
325                     return self, ()
326             
327             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
328         return self.factory
329         
330
331 class TestSimpleDHT(unittest.TestCase):
332     """Simple 2-node unit tests for the DHT."""
333     
334     timeout = 50
335     DHT_DEFAULTS = {'PORT': 9977,
336                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
337                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
338                     'MAX_FAILURES': 3, 'LOCAL_OK': True,
339                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
340                     'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
341                     'KEY_EXPIRE': 3600, 'SPEW': True, }
342
343     def setUp(self):
344         self.a = DHT()
345         self.b = DHT()
346         self.a.config = self.DHT_DEFAULTS.copy()
347         self.a.config['PORT'] = 4044
348         self.a.bootstrap = ["127.0.0.1:4044"]
349         self.a.bootstrap_node = True
350         self.a.cache_dir = '/tmp'
351         self.b.config = self.DHT_DEFAULTS.copy()
352         self.b.config['PORT'] = 4045
353         self.b.bootstrap = ["127.0.0.1:4044"]
354         self.b.cache_dir = '/tmp'
355         
356     def test_bootstrap_join(self):
357         d = self.a.join()
358         return d
359
360     def no_krpc_errors(self, result):
361         from krpc import KrpcError
362         self.flushLoggedErrors(KrpcError)
363         return result
364
365     def test_failed_join(self):
366         d = self.b.join()
367         reactor.callLater(30, self.a.join)
368         d.addCallback(self.no_krpc_errors)
369         return d
370         
371     def node_join(self, result):
372         d = self.b.join()
373         return d
374     
375     def test_join(self):
376         d = self.a.join()
377         d.addCallback(self.node_join)
378         return d
379
380     def test_timeout_retransmit(self):
381         d = self.b.join()
382         reactor.callLater(4, self.a.join)
383         return d
384
385     def test_normKey(self):
386         h = self.a._normKey('12345678901234567890')
387         self.failUnless(h == '12345678901234567890')
388         h = self.a._normKey('12345678901234567')
389         self.failUnless(h == '12345678901234567\000\000\000')
390         h = self.a._normKey('1234567890123456789012345')
391         self.failUnless(h == '12345678901234567890')
392         h = self.a._normKey('1234567890123456789')
393         self.failUnless(h == '1234567890123456789\000')
394         h = self.a._normKey('123456789012345678901')
395         self.failUnless(h == '12345678901234567890')
396
397     def value_stored(self, result, value):
398         self.stored -= 1
399         if self.stored == 0:
400             self.get_values()
401         
402     def store_values(self, result):
403         self.stored = 3
404         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
405         d.addCallback(self.value_stored, 4045)
406         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
407         d.addCallback(self.value_stored, 4044)
408         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
409         d.addCallback(self.value_stored, 4045)
410
411     def check_values(self, result, values):
412         self.checked -= 1
413         self.failUnless(len(result) == len(values))
414         for v in result:
415             self.failUnless(v in values)
416         if self.checked == 0:
417             self.lastDefer.callback(1)
418     
419     def get_values(self):
420         self.checked = 4
421         d = self.a.getValue(sha.new('4044').digest())
422         d.addCallback(self.check_values, [str(4044*2)])
423         d = self.b.getValue(sha.new('4044').digest())
424         d.addCallback(self.check_values, [str(4044*2)])
425         d = self.a.getValue(sha.new('4045').digest())
426         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
427         d = self.b.getValue(sha.new('4045').digest())
428         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
429
430     def test_store(self):
431         from twisted.internet.base import DelayedCall
432         DelayedCall.debug = True
433         self.lastDefer = defer.Deferred()
434         d = self.a.join()
435         d.addCallback(self.node_join)
436         d.addCallback(self.store_values)
437         return self.lastDefer
438
439     def tearDown(self):
440         self.a.leave()
441         try:
442             os.unlink(self.a.khashmir.store.db)
443         except:
444             pass
445         self.b.leave()
446         try:
447             os.unlink(self.b.khashmir.store.db)
448         except:
449             pass
450
451 class TestMultiDHT(unittest.TestCase):
452     """More complicated 20-node tests for the DHT."""
453     
454     timeout = 100
455     num = 20
456     DHT_DEFAULTS = {'PORT': 9977,
457                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
458                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
459                     'MAX_FAILURES': 3, 'LOCAL_OK': True,
460                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
461                     'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
462                     'KEY_EXPIRE': 3600, 'SPEW': True, }
463
464     def setUp(self):
465         self.l = []
466         self.startport = 4081
467         for i in range(self.num):
468             self.l.append(DHT())
469             self.l[i].config = self.DHT_DEFAULTS.copy()
470             self.l[i].config['PORT'] = self.startport + i
471             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
472             self.l[i].cache_dir = '/tmp'
473         self.l[0].bootstrap_node = True
474         
475     def node_join(self, result, next_node):
476         d = self.l[next_node].join()
477         if next_node + 1 < len(self.l):
478             d.addCallback(self.node_join, next_node + 1)
479         else:
480             d.addCallback(self.lastDefer.callback)
481     
482     def test_join(self):
483         self.timeout = 2
484         self.lastDefer = defer.Deferred()
485         d = self.l[0].join()
486         d.addCallback(self.node_join, 1)
487         return self.lastDefer
488         
489     def store_values(self, result, i = 0, j = 0):
490         if j > i:
491             j -= i+1
492             i += 1
493         if i == len(self.l):
494             self.get_values()
495         else:
496             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
497             d.addCallback(self.store_values, i, j+1)
498     
499     def get_values(self, result = None, check = None, i = 0, j = 0):
500         if result is not None:
501             self.failUnless(len(result) == len(check))
502             for v in result:
503                 self.failUnless(v in check)
504         if j >= len(self.l):
505             j -= len(self.l)
506             i += 1
507         if i == len(self.l):
508             self.lastDefer.callback(1)
509         else:
510             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
511             check = []
512             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
513                 check.append(str(k))
514             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
515
516     def store_join(self, result, next_node):
517         d = self.l[next_node].join()
518         if next_node + 1 < len(self.l):
519             d.addCallback(self.store_join, next_node + 1)
520         else:
521             d.addCallback(self.store_values)
522     
523     def test_store(self):
524         from twisted.internet.base import DelayedCall
525         DelayedCall.debug = True
526         self.lastDefer = defer.Deferred()
527         d = self.l[0].join()
528         d.addCallback(self.store_join, 1)
529         return self.lastDefer
530
531     def tearDown(self):
532         for i in self.l:
533             try:
534                 i.leave()
535                 os.unlink(i.khashmir.store.db)
536             except:
537                 pass