Make better use of defer.Fail for returning deferred errors.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19
20 try:
21     from twisted.web2 import channel, server, resource, http, http_headers
22     _web2 = True
23 except ImportError:
24     _web2 = False
25
26 khashmir_dir = 'apt-p2p-Khashmir'
27
28 class DHTError(Exception):
29     """Represents errors that occur in the DHT."""
30
31 class DHT:
32     """The main interface instance to the Khashmir DHT.
33     
34     @type config: C{dictionary}
35     @ivar config: the DHT configuration values
36     @type cache_dir: C{string}
37     @ivar cache_dir: the directory to use for storing files
38     @type bootstrap: C{list} of C{string}
39     @ivar bootstrap: the nodes to contact to bootstrap into the system
40     @type bootstrap_node: C{boolean}
41     @ivar bootstrap_node: whether this node is a bootstrap node
42     @type joining: L{twisted.internet.defer.Deferred}
43     @ivar joining: if a join is underway, the deferred that will signal it's end
44     @type joined: C{boolean}
45     @ivar joined: whether the DHT network has been successfully joined
46     @type outstandingJoins: C{int}
47     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
48     @type next_rejoin: C{int}
49     @ivar next_rejoin: the number of seconds before retrying the next join
50     @type foundAddrs: C{list} of (C{string}, C{int})
51     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
52     @type storing: C{dictionary}
53     @ivar storing: keys are keys for which store requests are active, values
54         are dictionaries with keys the values being stored and values the
55         deferred to call when complete
56     @type retrieving: C{dictionary}
57     @ivar retrieving: keys are the keys for which getValue requests are active,
58         values are lists of the deferreds waiting for the requests
59     @type retrieved: C{dictionary}
60     @ivar retrieved: keys are the keys for which getValue requests are active,
61         values are list of the values returned so far
62     @type factory: L{twisted.web2.channel.HTTPFactory}
63     @ivar factory: the factory to use to serve HTTP requests for statistics
64     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
65     @ivar config_parser: the configuration info for the main program
66     @type section: C{string}
67     @ivar section: the section of the configuration info that applies to the DHT
68     @type khashmir: L{khashmir.Khashmir}
69     @ivar khashmir: the khashmir DHT instance to use
70     """
71     
72     if _web2:
73         implements(IDHT, IDHTStats, IDHTStatsFactory)
74     else:
75         implements(IDHT, IDHTStats)
76         
77     def __init__(self):
78         """Initialize the DHT."""
79         self.config = None
80         self.cache_dir = ''
81         self.bootstrap = []
82         self.bootstrap_node = False
83         self.joining = None
84         self.khashmir = None
85         self.joined = False
86         self.outstandingJoins = 0
87         self.next_rejoin = 20
88         self.foundAddrs = []
89         self.storing = {}
90         self.retrieving = {}
91         self.retrieved = {}
92         self.factory = None
93     
94     def loadConfig(self, config, section):
95         """See L{apt_p2p.interfaces.IDHT}."""
96         self.config_parser = config
97         self.section = section
98         self.config = {}
99         
100         # Get some initial values
101         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
102         if not os.path.exists(self.cache_dir):
103             os.makedirs(self.cache_dir)
104         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
105         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
106         for k in self.config_parser.options(section):
107             # The numbers in the config file
108             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
109                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
110                 self.config[k] = self.config_parser.getint(section, k)
111             # The times in the config file
112             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
113                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
114                 self.config[k] = self.config_parser.gettime(section, k)
115             # The booleans in the config file
116             elif k in ['SPEW']:
117                 self.config[k] = self.config_parser.getboolean(section, k)
118             # Everything else is a string
119             else:
120                 self.config[k] = self.config_parser.get(section, k)
121     
122     def join(self, deferred = None):
123         """See L{apt_p2p.interfaces.IDHT}.
124         
125         @param deferred: the deferred to callback when the join is complete
126             (optional, defaults to creating a new deferred and returning it)
127         """
128         # Check for multiple simultaneous joins 
129         if self.joining:
130             if deferred:
131                 deferred.errback(DHTError("a join is already in progress"))
132                 return
133             else:
134                 raise DHTError, "a join is already in progress"
135
136         if deferred:
137             self.joining = deferred
138         else:
139             self.joining = defer.Deferred()
140
141         if self.config is None:
142             self.joining.errback(DHTError("configuration not loaded"))
143             return self.joining
144
145         # Create the new khashmir instance
146         if not self.khashmir:
147             self.khashmir = Khashmir(self.config, self.cache_dir)
148
149         self.outstandingJoins = 0
150         for node in self.bootstrap:
151             host, port = node.rsplit(':', 1)
152             port = int(port)
153             self.outstandingJoins += 1
154             
155             # Translate host names into IP addresses
156             if isIPAddress(host):
157                 self._join_gotIP(host, port)
158             else:
159                 reactor.resolve(host).addCallbacks(self._join_gotIP,
160                                                    self._join_resolveFailed,
161                                                    callbackArgs = (port, ),
162                                                    errbackArgs = (host, port))
163         
164         return self.joining
165
166     def _join_gotIP(self, ip, port):
167         """Join the DHT using a single bootstrap nodes IP address."""
168         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
169     
170     def _join_resolveFailed(self, err, host, port):
171         """Failed to lookup the IP address of the bootstrap node."""
172         log.msg('Failed to find an IP address for host: (%r, %r)' % (host, port))
173         log.err(err)
174         self.outstandingJoins -= 1
175         if self.outstandingJoins <= 0:
176             self.khashmir.findCloseNodes(self._join_complete)
177     
178     def _join_single(self, addr):
179         """Process the response from the bootstrap node.
180         
181         Finish the join by contacting close nodes.
182         """
183         self.outstandingJoins -= 1
184         if addr:
185             self.foundAddrs.append(addr)
186         if addr or self.outstandingJoins <= 0:
187             self.khashmir.findCloseNodes(self._join_complete)
188         log.msg('Got back from bootstrap node: %r' % (addr,))
189     
190     def _join_error(self, failure = None):
191         """Process an error in contacting the bootstrap node.
192         
193         If no bootstrap nodes remain, finish the process by contacting
194         close nodes.
195         """
196         self.outstandingJoins -= 1
197         log.msg("bootstrap node could not be reached")
198         if self.outstandingJoins <= 0:
199             self.khashmir.findCloseNodes(self._join_complete)
200
201     def _join_complete(self, result):
202         """End the joining process and return the addresses found for this node."""
203         if not self.joined and isinstance(result, list) and len(result) > 1:
204             self.joined = True
205         if self.joining and self.outstandingJoins <= 0:
206             df = self.joining
207             self.joining = None
208             if self.joined or self.bootstrap_node:
209                 self.joined = True
210                 df.callback(self.foundAddrs)
211             else:
212                 # Try to join later using exponential backoff delays
213                 log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
214                 reactor.callLater(self.next_rejoin, self.join, df)
215                 self.next_rejoin *= 2
216         
217     def getAddrs(self):
218         """Get the list of addresses returned by bootstrap nodes for this node."""
219         return self.foundAddrs
220         
221     def leave(self):
222         """See L{apt_p2p.interfaces.IDHT}."""
223         if self.config is None:
224             raise DHTError, "configuration not loaded"
225         
226         if self.joined or self.joining:
227             if self.joining:
228                 self.joining.errback(DHTError('still joining when leave was called'))
229                 self.joining = None
230             self.joined = False
231             self.khashmir.shutdown()
232         
233     def _normKey(self, key, bits=None, bytes=None):
234         """Normalize the length of keys used in the DHT."""
235         bits = self.config["HASH_LENGTH"]
236         if bits is not None:
237             bytes = (bits - 1) // 8 + 1
238         else:
239             if bytes is None:
240                 raise DHTError, "you must specify one of bits or bytes for normalization"
241             
242         # Extend short keys with null bytes
243         if len(key) < bytes:
244             key = key + '\000'*(bytes - len(key))
245         # Truncate long keys
246         elif len(key) > bytes:
247             key = key[:bytes]
248         return key
249
250     def getValue(self, key):
251         """See L{apt_p2p.interfaces.IDHT}."""
252         if self.config is None:
253             return defer.fail(DHTError("configuration not loaded"))
254         if not self.joined:
255             return defer.fail(DHTError("have not joined a network yet"))
256         
257         d = defer.Deferred()
258         key = self._normKey(key)
259
260         if key not in self.retrieving:
261             self.khashmir.valueForKey(key, self._getValue)
262         self.retrieving.setdefault(key, []).append(d)
263         return d
264         
265     def _getValue(self, key, result):
266         """Process a returned list of values from the DHT."""
267         # Save the list of values to return when it is complete
268         if result:
269             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
270         else:
271             # Empty list, the get is complete, return the result
272             final_result = []
273             if key in self.retrieved:
274                 final_result = self.retrieved[key]
275                 del self.retrieved[key]
276             for i in range(len(self.retrieving[key])):
277                 d = self.retrieving[key].pop(0)
278                 d.callback(final_result)
279             del self.retrieving[key]
280
281     def storeValue(self, key, value):
282         """See L{apt_p2p.interfaces.IDHT}."""
283         if self.config is None:
284             return defer.fail(DHTError("configuration not loaded"))
285         if not self.joined:
286             return defer.fail(DHTError("have not joined a network yet"))
287
288         d = defer.Deferred()
289         key = self._normKey(key)
290         bvalue = bencode(value)
291
292         if key in self.storing and bvalue in self.storing[key]:
293             raise DHTError, "already storing that key with the same value"
294
295         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
296         self.storing.setdefault(key, {})[bvalue] = d
297         return d
298     
299     def _storeValue(self, key, bvalue, result):
300         """Process the response from the DHT."""
301         if key in self.storing and bvalue in self.storing[key]:
302             # Check if the store succeeded
303             if len(result) > 0:
304                 self.storing[key][bvalue].callback(result)
305             else:
306                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
307             del self.storing[key][bvalue]
308             if len(self.storing[key].keys()) == 0:
309                 del self.storing[key]
310     
311     def getStats(self):
312         """See L{apt_p2p.interfaces.IDHTStats}."""
313         return self.khashmir.getStats()
314
315     def getStatsFactory(self):
316         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
317         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
318         if self.factory is None:
319             # Create a simple HTTP factory for stats
320             class StatsResource(resource.Resource):
321                 def __init__(self, manager):
322                     self.manager = manager
323                 def render(self, ctx):
324                     return http.Response(
325                         200,
326                         {'content-type': http_headers.MimeType('text', 'html')},
327                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
328                 def locateChild(self, request, segments):
329                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
330                     return self, ()
331             
332             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
333         return self.factory
334         
335
336 class TestSimpleDHT(unittest.TestCase):
337     """Simple 2-node unit tests for the DHT."""
338     
339     timeout = 50
340     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
341                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
342                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
343                     'MAX_FAILURES': 3,
344                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
345                     'KEY_EXPIRE': 3600, 'SPEW': False, }
346
347     def setUp(self):
348         self.a = DHT()
349         self.b = DHT()
350         self.a.config = self.DHT_DEFAULTS.copy()
351         self.a.config['PORT'] = 4044
352         self.a.bootstrap = ["127.0.0.1:4044"]
353         self.a.bootstrap_node = True
354         self.a.cache_dir = '/tmp'
355         self.b.config = self.DHT_DEFAULTS.copy()
356         self.b.config['PORT'] = 4045
357         self.b.bootstrap = ["127.0.0.1:4044"]
358         self.b.cache_dir = '/tmp'
359         
360     def test_bootstrap_join(self):
361         d = self.a.join()
362         return d
363
364     def no_krpc_errors(self, result):
365         from krpc import KrpcError
366         self.flushLoggedErrors(KrpcError)
367         return result
368
369     def test_failed_join(self):
370         d = self.b.join()
371         reactor.callLater(30, self.a.join)
372         d.addCallback(self.no_krpc_errors)
373         return d
374         
375     def node_join(self, result):
376         d = self.b.join()
377         return d
378     
379     def test_join(self):
380         d = self.a.join()
381         d.addCallback(self.node_join)
382         return d
383
384     def test_timeout_retransmit(self):
385         d = self.b.join()
386         reactor.callLater(4, self.a.join)
387         return d
388
389     def test_normKey(self):
390         h = self.a._normKey('12345678901234567890')
391         self.failUnless(h == '12345678901234567890')
392         h = self.a._normKey('12345678901234567')
393         self.failUnless(h == '12345678901234567\000\000\000')
394         h = self.a._normKey('1234567890123456789012345')
395         self.failUnless(h == '12345678901234567890')
396         h = self.a._normKey('1234567890123456789')
397         self.failUnless(h == '1234567890123456789\000')
398         h = self.a._normKey('123456789012345678901')
399         self.failUnless(h == '12345678901234567890')
400
401     def value_stored(self, result, value):
402         self.stored -= 1
403         if self.stored == 0:
404             self.get_values()
405         
406     def store_values(self, result):
407         self.stored = 3
408         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
409         d.addCallback(self.value_stored, 4045)
410         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
411         d.addCallback(self.value_stored, 4044)
412         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
413         d.addCallback(self.value_stored, 4045)
414
415     def check_values(self, result, values):
416         self.checked -= 1
417         self.failUnless(len(result) == len(values))
418         for v in result:
419             self.failUnless(v in values)
420         if self.checked == 0:
421             self.lastDefer.callback(1)
422     
423     def get_values(self):
424         self.checked = 4
425         d = self.a.getValue(sha.new('4044').digest())
426         d.addCallback(self.check_values, [str(4044*2)])
427         d = self.b.getValue(sha.new('4044').digest())
428         d.addCallback(self.check_values, [str(4044*2)])
429         d = self.a.getValue(sha.new('4045').digest())
430         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
431         d = self.b.getValue(sha.new('4045').digest())
432         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
433
434     def test_store(self):
435         from twisted.internet.base import DelayedCall
436         DelayedCall.debug = True
437         self.lastDefer = defer.Deferred()
438         d = self.a.join()
439         d.addCallback(self.node_join)
440         d.addCallback(self.store_values)
441         return self.lastDefer
442
443     def tearDown(self):
444         self.a.leave()
445         try:
446             os.unlink(self.a.khashmir.store.db)
447         except:
448             pass
449         self.b.leave()
450         try:
451             os.unlink(self.b.khashmir.store.db)
452         except:
453             pass
454
455 class TestMultiDHT(unittest.TestCase):
456     """More complicated 20-node tests for the DHT."""
457     
458     timeout = 80
459     num = 20
460     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
461                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
462                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
463                     'MAX_FAILURES': 3,
464                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
465                     'KEY_EXPIRE': 3600, 'SPEW': False, }
466
467     def setUp(self):
468         self.l = []
469         self.startport = 4081
470         for i in range(self.num):
471             self.l.append(DHT())
472             self.l[i].config = self.DHT_DEFAULTS.copy()
473             self.l[i].config['PORT'] = self.startport + i
474             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
475             self.l[i].cache_dir = '/tmp'
476         self.l[0].bootstrap_node = True
477         
478     def node_join(self, result, next_node):
479         d = self.l[next_node].join()
480         if next_node + 1 < len(self.l):
481             d.addCallback(self.node_join, next_node + 1)
482         else:
483             d.addCallback(self.lastDefer.callback)
484     
485     def test_join(self):
486         self.timeout = 2
487         self.lastDefer = defer.Deferred()
488         d = self.l[0].join()
489         d.addCallback(self.node_join, 1)
490         return self.lastDefer
491         
492     def store_values(self, result, i = 0, j = 0):
493         if j > i:
494             j -= i+1
495             i += 1
496         if i == len(self.l):
497             self.get_values()
498         else:
499             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
500             d.addCallback(self.store_values, i, j+1)
501     
502     def get_values(self, result = None, check = None, i = 0, j = 0):
503         if result is not None:
504             self.failUnless(len(result) == len(check))
505             for v in result:
506                 self.failUnless(v in check)
507         if j >= len(self.l):
508             j -= len(self.l)
509             i += 1
510         if i == len(self.l):
511             self.lastDefer.callback(1)
512         else:
513             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
514             check = []
515             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
516                 check.append(str(k))
517             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
518
519     def store_join(self, result, next_node):
520         d = self.l[next_node].join()
521         if next_node + 1 < len(self.l):
522             d.addCallback(self.store_join, next_node + 1)
523         else:
524             d.addCallback(self.store_values)
525     
526     def test_store(self):
527         from twisted.internet.base import DelayedCall
528         DelayedCall.debug = True
529         self.lastDefer = defer.Deferred()
530         d = self.l[0].join()
531         d.addCallback(self.store_join, 1)
532         return self.lastDefer
533
534     def tearDown(self):
535         for i in self.l:
536             try:
537                 i.leave()
538                 os.unlink(i.khashmir.store.db)
539             except:
540                 pass