e9baaa8e45bace9cec49d5d02defb87f99e88158
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 from StringIO import StringIO
9 import os, sha, random
10
11 from twisted.internet import defer, reactor
12 from twisted.internet.abstract import isIPAddress
13 from twisted.python import log
14 from twisted.trial import unittest
15 from zope.interface import implements
16
17 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
18 from khashmir import Khashmir
19 from bencode import bencode, bdecode
20
21 try:
22     from twisted.web2 import channel, server, resource, http, http_headers
23     _web2 = True
24 except ImportError:
25     _web2 = False
26
27 khashmir_dir = 'apt-p2p-Khashmir'
28
29 class DHTError(Exception):
30     """Represents errors that occur in the DHT."""
31
32 class DHT:
33     """The main interface instance to the Khashmir DHT.
34     
35     @type config: C{dictionary}
36     @ivar config: the DHT configuration values
37     @type cache_dir: C{string}
38     @ivar cache_dir: the directory to use for storing files
39     @type bootstrap: C{list} of C{string}
40     @ivar bootstrap: the nodes to contact to bootstrap into the system
41     @type bootstrap_node: C{boolean}
42     @ivar bootstrap_node: whether this node is a bootstrap node
43     @type joining: L{twisted.internet.defer.Deferred}
44     @ivar joining: if a join is underway, the deferred that will signal it's end
45     @type joined: C{boolean}
46     @ivar joined: whether the DHT network has been successfully joined
47     @type outstandingJoins: C{int}
48     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
49     @type foundAddrs: C{list} of (C{string}, C{int})
50     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
51     @type storing: C{dictionary}
52     @ivar storing: keys are keys for which store requests are active, values
53         are dictionaries with keys the values being stored and values the
54         deferred to call when complete
55     @type retrieving: C{dictionary}
56     @ivar retrieving: keys are the keys for which getValue requests are active,
57         values are lists of the deferreds waiting for the requests
58     @type retrieved: C{dictionary}
59     @ivar retrieved: keys are the keys for which getValue requests are active,
60         values are list of the values returned so far
61     @type factory: L{twisted.web2.channel.HTTPFactory}
62     @ivar factory: the factory to use to serve HTTP requests for statistics
63     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
64     @ivar config_parser: the configuration info for the main program
65     @type section: C{string}
66     @ivar section: the section of the configuration info that applies to the DHT
67     @type khashmir: L{khashmir.Khashmir}
68     @ivar khashmir: the khashmir DHT instance to use
69     """
70     
71     if _web2:
72         implements(IDHT, IDHTStats, IDHTStatsFactory)
73     else:
74         implements(IDHT, IDHTStats)
75         
76     def __init__(self):
77         """Initialize the DHT."""
78         self.config = None
79         self.cache_dir = ''
80         self.bootstrap = []
81         self.bootstrap_node = False
82         self.joining = None
83         self.joined = False
84         self.outstandingJoins = 0
85         self.foundAddrs = []
86         self.storing = {}
87         self.retrieving = {}
88         self.retrieved = {}
89         self.factory = None
90     
91     def loadConfig(self, config, section):
92         """See L{apt_p2p.interfaces.IDHT}."""
93         self.config_parser = config
94         self.section = section
95         self.config = {}
96         
97         # Get some initial values
98         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
99         if not os.path.exists(self.cache_dir):
100             os.makedirs(self.cache_dir)
101         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
102         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
103         for k in self.config_parser.options(section):
104             # The numbers in the config file
105             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
106                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
107                 self.config[k] = self.config_parser.getint(section, k)
108             # The times in the config file
109             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
110                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
111                 self.config[k] = self.config_parser.gettime(section, k)
112             # The booleans in the config file
113             elif k in ['SPEW']:
114                 self.config[k] = self.config_parser.getboolean(section, k)
115             # Everything else is a string
116             else:
117                 self.config[k] = self.config_parser.get(section, k)
118     
119     def join(self):
120         """See L{apt_p2p.interfaces.IDHT}."""
121         if self.config is None:
122             raise DHTError, "configuration not loaded"
123         if self.joining:
124             raise DHTError, "a join is already in progress"
125
126         # Create the new khashmir instance
127         self.khashmir = Khashmir(self.config, self.cache_dir)
128         
129         self.joining = defer.Deferred()
130         for node in self.bootstrap:
131             host, port = node.rsplit(':', 1)
132             port = int(port)
133             
134             # Translate host names into IP addresses
135             if isIPAddress(host):
136                 self._join_gotIP(host, port)
137             else:
138                 reactor.resolve(host).addCallback(self._join_gotIP, port)
139         
140         return self.joining
141
142     def _join_gotIP(self, ip, port):
143         """Join the DHT using a single bootstrap nodes IP address."""
144         self.outstandingJoins += 1
145         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
146     
147     def _join_single(self, addr):
148         """Process the response from the bootstrap node.
149         
150         Finish the join by contacting close nodes.
151         """
152         self.outstandingJoins -= 1
153         if addr:
154             self.foundAddrs.append(addr)
155         if addr or self.outstandingJoins <= 0:
156             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
157         log.msg('Got back from bootstrap node: %r' % (addr,))
158     
159     def _join_error(self, failure = None):
160         """Process an error in contacting the bootstrap node.
161         
162         If no bootstrap nodes remain, finish the process by contacting
163         close nodes.
164         """
165         self.outstandingJoins -= 1
166         log.msg("bootstrap node could not be reached")
167         if self.outstandingJoins <= 0:
168             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
169
170     def _join_complete(self, result):
171         """End the joining process and return the addresses found for this node."""
172         if not self.joined and len(result) > 0:
173             self.joined = True
174         if self.joining and self.outstandingJoins <= 0:
175             df = self.joining
176             self.joining = None
177             if self.joined or self.bootstrap_node:
178                 self.joined = True
179                 df.callback(self.foundAddrs)
180             else:
181                 df.errback(DHTError('could not find any nodes to bootstrap to'))
182         
183     def getAddrs(self):
184         """Get the list of addresses returned by bootstrap nodes for this node."""
185         return self.foundAddrs
186         
187     def leave(self):
188         """See L{apt_p2p.interfaces.IDHT}."""
189         if self.config is None:
190             raise DHTError, "configuration not loaded"
191         
192         if self.joined or self.joining:
193             if self.joining:
194                 self.joining.errback(DHTError('still joining when leave was called'))
195                 self.joining = None
196             self.joined = False
197             self.khashmir.shutdown()
198         
199     def _normKey(self, key, bits=None, bytes=None):
200         """Normalize the length of keys used in the DHT."""
201         bits = self.config["HASH_LENGTH"]
202         if bits is not None:
203             bytes = (bits - 1) // 8 + 1
204         else:
205             if bytes is None:
206                 raise DHTError, "you must specify one of bits or bytes for normalization"
207             
208         # Extend short keys with null bytes
209         if len(key) < bytes:
210             key = key + '\000'*(bytes - len(key))
211         # Truncate long keys
212         elif len(key) > bytes:
213             key = key[:bytes]
214         return key
215
216     def getValue(self, key):
217         """See L{apt_p2p.interfaces.IDHT}."""
218         if self.config is None:
219             raise DHTError, "configuration not loaded"
220         if not self.joined:
221             raise DHTError, "have not joined a network yet"
222         
223         key = self._normKey(key)
224
225         d = defer.Deferred()
226         if key not in self.retrieving:
227             self.khashmir.valueForKey(key, self._getValue)
228         self.retrieving.setdefault(key, []).append(d)
229         return d
230         
231     def _getValue(self, key, result):
232         """Process a returned list of values from the DHT."""
233         # Save the list of values to return when it is complete
234         if result:
235             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
236         else:
237             # Empty list, the get is complete, return the result
238             final_result = []
239             if key in self.retrieved:
240                 final_result = self.retrieved[key]
241                 del self.retrieved[key]
242             for i in range(len(self.retrieving[key])):
243                 d = self.retrieving[key].pop(0)
244                 d.callback(final_result)
245             del self.retrieving[key]
246
247     def storeValue(self, key, value):
248         """See L{apt_p2p.interfaces.IDHT}."""
249         if self.config is None:
250             raise DHTError, "configuration not loaded"
251         if not self.joined:
252             raise DHTError, "have not joined a network yet"
253
254         key = self._normKey(key)
255         bvalue = bencode(value)
256
257         if key in self.storing and bvalue in self.storing[key]:
258             raise DHTError, "already storing that key with the same value"
259
260         d = defer.Deferred()
261         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
262         self.storing.setdefault(key, {})[bvalue] = d
263         return d
264     
265     def _storeValue(self, key, bvalue, result):
266         """Process the response from the DHT."""
267         if key in self.storing and bvalue in self.storing[key]:
268             # Check if the store succeeded
269             if len(result) > 0:
270                 self.storing[key][bvalue].callback(result)
271             else:
272                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
273             del self.storing[key][bvalue]
274             if len(self.storing[key].keys()) == 0:
275                 del self.storing[key]
276     
277     def getStats(self):
278         """See L{apt_p2p.interfaces.IDHTStats}."""
279         stats = self.khashmir.getStats()
280         out = StringIO()
281         out.write('<h2>DHT Statistics</h2>\n')
282         old_group = None
283         for stat in stats:
284             if stat['group'] != old_group:
285                 if old_group is not None:
286                     out.write('</table>\n')
287                 out.write('\n<h3>' + stat['group'] + '</h3>\n')
288                 out.write("<table border='1'>\n")
289                 if stat['group'] != 'Actions':
290                     out.write("<tr><th>Statistic</th><th>Value</th></tr>\n")
291                 else:
292                     out.write("<tr><th>Action</th><th>Started</th><th>Sent</th><th>OK</th><th>Failed</th><th>Received</th><th>Error</th></tr>\n")
293                 old_group = stat['group']
294             if stat['group'] != 'Actions':
295                 out.write("<tr title='" + stat['tip'] + "'><td>" + stat['desc'] + '</td><td>' + str(stat['value']) + '</td></tr>\n')
296             else:
297                 actions = stat['value'].keys()
298                 actions.sort()
299                 for action in actions:
300                     out.write("<tr><td>" + action + "</td>")
301                     for i in xrange(6):
302                         out.write("<td>" + str(stat['value'][action][i]) + "</td>")
303                     out.write('</tr>\n')
304                     
305         return out.getvalue()
306
307     def getStatsFactory(self):
308         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
309         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
310         if self.factory is None:
311             # Create a simple HTTP factory for stats
312             class StatsResource(resource.Resource):
313                 def __init__(self, manager):
314                     self.manager = manager
315                 def render(self, ctx):
316                     return http.Response(
317                         200,
318                         {'content-type': http_headers.MimeType('text', 'html')},
319                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
320                 def locateChild(self, request, segments):
321                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
322                     return self, ()
323             
324             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
325         return self.factory
326         
327
328 class TestSimpleDHT(unittest.TestCase):
329     """Simple 2-node unit tests for the DHT."""
330     
331     timeout = 2
332     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
333                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
334                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
335                     'MAX_FAILURES': 3,
336                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
337                     'KEY_EXPIRE': 3600, 'SPEW': False, }
338
339     def setUp(self):
340         self.a = DHT()
341         self.b = DHT()
342         self.a.config = self.DHT_DEFAULTS.copy()
343         self.a.config['PORT'] = 4044
344         self.a.bootstrap = ["127.0.0.1:4044"]
345         self.a.bootstrap_node = True
346         self.a.cache_dir = '/tmp'
347         self.b.config = self.DHT_DEFAULTS.copy()
348         self.b.config['PORT'] = 4045
349         self.b.bootstrap = ["127.0.0.1:4044"]
350         self.b.cache_dir = '/tmp'
351         
352     def test_bootstrap_join(self):
353         d = self.a.join()
354         return d
355         
356     def node_join(self, result):
357         d = self.b.join()
358         return d
359     
360     def test_join(self):
361         self.lastDefer = defer.Deferred()
362         d = self.a.join()
363         d.addCallback(self.node_join)
364         d.addCallback(self.lastDefer.callback)
365         return self.lastDefer
366
367     def test_normKey(self):
368         h = self.a._normKey('12345678901234567890')
369         self.failUnless(h == '12345678901234567890')
370         h = self.a._normKey('12345678901234567')
371         self.failUnless(h == '12345678901234567\000\000\000')
372         h = self.a._normKey('1234567890123456789012345')
373         self.failUnless(h == '12345678901234567890')
374         h = self.a._normKey('1234567890123456789')
375         self.failUnless(h == '1234567890123456789\000')
376         h = self.a._normKey('123456789012345678901')
377         self.failUnless(h == '12345678901234567890')
378
379     def value_stored(self, result, value):
380         self.stored -= 1
381         if self.stored == 0:
382             self.get_values()
383         
384     def store_values(self, result):
385         self.stored = 3
386         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
387         d.addCallback(self.value_stored, 4045)
388         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
389         d.addCallback(self.value_stored, 4044)
390         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
391         d.addCallback(self.value_stored, 4045)
392
393     def check_values(self, result, values):
394         self.checked -= 1
395         self.failUnless(len(result) == len(values))
396         for v in result:
397             self.failUnless(v in values)
398         if self.checked == 0:
399             self.lastDefer.callback(1)
400     
401     def get_values(self):
402         self.checked = 4
403         d = self.a.getValue(sha.new('4044').digest())
404         d.addCallback(self.check_values, [str(4044*2)])
405         d = self.b.getValue(sha.new('4044').digest())
406         d.addCallback(self.check_values, [str(4044*2)])
407         d = self.a.getValue(sha.new('4045').digest())
408         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
409         d = self.b.getValue(sha.new('4045').digest())
410         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
411
412     def test_store(self):
413         from twisted.internet.base import DelayedCall
414         DelayedCall.debug = True
415         self.lastDefer = defer.Deferred()
416         d = self.a.join()
417         d.addCallback(self.node_join)
418         d.addCallback(self.store_values)
419         return self.lastDefer
420
421     def tearDown(self):
422         self.a.leave()
423         try:
424             os.unlink(self.a.khashmir.store.db)
425         except:
426             pass
427         self.b.leave()
428         try:
429             os.unlink(self.b.khashmir.store.db)
430         except:
431             pass
432
433 class TestMultiDHT(unittest.TestCase):
434     """More complicated 20-node tests for the DHT."""
435     
436     timeout = 60
437     num = 20
438     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
439                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
440                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
441                     'MAX_FAILURES': 3,
442                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
443                     'KEY_EXPIRE': 3600, 'SPEW': False, }
444
445     def setUp(self):
446         self.l = []
447         self.startport = 4081
448         for i in range(self.num):
449             self.l.append(DHT())
450             self.l[i].config = self.DHT_DEFAULTS.copy()
451             self.l[i].config['PORT'] = self.startport + i
452             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
453             self.l[i].cache_dir = '/tmp'
454         self.l[0].bootstrap_node = True
455         
456     def node_join(self, result, next_node):
457         d = self.l[next_node].join()
458         if next_node + 1 < len(self.l):
459             d.addCallback(self.node_join, next_node + 1)
460         else:
461             d.addCallback(self.lastDefer.callback)
462     
463     def test_join(self):
464         self.timeout = 2
465         self.lastDefer = defer.Deferred()
466         d = self.l[0].join()
467         d.addCallback(self.node_join, 1)
468         return self.lastDefer
469         
470     def store_values(self, result, i = 0, j = 0):
471         if j > i:
472             j -= i+1
473             i += 1
474         if i == len(self.l):
475             self.get_values()
476         else:
477             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
478             d.addCallback(self.store_values, i, j+1)
479     
480     def get_values(self, result = None, check = None, i = 0, j = 0):
481         if result is not None:
482             self.failUnless(len(result) == len(check))
483             for v in result:
484                 self.failUnless(v in check)
485         if j >= len(self.l):
486             j -= len(self.l)
487             i += 1
488         if i == len(self.l):
489             self.lastDefer.callback(1)
490         else:
491             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
492             check = []
493             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
494                 check.append(str(k))
495             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
496
497     def store_join(self, result, next_node):
498         d = self.l[next_node].join()
499         if next_node + 1 < len(self.l):
500             d.addCallback(self.store_join, next_node + 1)
501         else:
502             d.addCallback(self.store_values)
503     
504     def test_store(self):
505         from twisted.internet.base import DelayedCall
506         DelayedCall.debug = True
507         self.lastDefer = defer.Deferred()
508         d = self.l[0].join()
509         d.addCallback(self.store_join, 1)
510         return self.lastDefer
511
512     def tearDown(self):
513         for i in self.l:
514             try:
515                 i.leave()
516                 os.unlink(i.khashmir.store.db)
517             except:
518                 pass