edb626d7b338116bad8493bb9a267d6ec44eb7b1
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / DHT.py
1
2 """The main interface to the Khashmir DHT.
3
4 @var khashmir_dir: the name of the directory to use for DHT files
5 """
6
7 from datetime import datetime
8 import os, sha, random
9
10 from twisted.internet import defer, reactor
11 from twisted.internet.abstract import isIPAddress
12 from twisted.python import log
13 from twisted.trial import unittest
14 from zope.interface import implements
15
16 from apt_p2p.interfaces import IDHT, IDHTStats, IDHTStatsFactory
17 from khashmir import Khashmir
18 from bencode import bencode, bdecode
19
20 try:
21     from twisted.web2 import channel, server, resource, http, http_headers
22     _web2 = True
23 except ImportError:
24     _web2 = False
25
26 khashmir_dir = 'apt-p2p-Khashmir'
27
28 class DHTError(Exception):
29     """Represents errors that occur in the DHT."""
30
31 class DHT:
32     """The main interface instance to the Khashmir DHT.
33     
34     @type config: C{dictionary}
35     @ivar config: the DHT configuration values
36     @type cache_dir: C{string}
37     @ivar cache_dir: the directory to use for storing files
38     @type bootstrap: C{list} of C{string}
39     @ivar bootstrap: the nodes to contact to bootstrap into the system
40     @type bootstrap_node: C{boolean}
41     @ivar bootstrap_node: whether this node is a bootstrap node
42     @type joining: L{twisted.internet.defer.Deferred}
43     @ivar joining: if a join is underway, the deferred that will signal it's end
44     @type joined: C{boolean}
45     @ivar joined: whether the DHT network has been successfully joined
46     @type outstandingJoins: C{int}
47     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
48     @type foundAddrs: C{list} of (C{string}, C{int})
49     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
50     @type storing: C{dictionary}
51     @ivar storing: keys are keys for which store requests are active, values
52         are dictionaries with keys the values being stored and values the
53         deferred to call when complete
54     @type retrieving: C{dictionary}
55     @ivar retrieving: keys are the keys for which getValue requests are active,
56         values are lists of the deferreds waiting for the requests
57     @type retrieved: C{dictionary}
58     @ivar retrieved: keys are the keys for which getValue requests are active,
59         values are list of the values returned so far
60     @type factory: L{twisted.web2.channel.HTTPFactory}
61     @ivar factory: the factory to use to serve HTTP requests for statistics
62     @type config_parser: L{apt_p2p.apt_p2p_conf.AptP2PConfigParser}
63     @ivar config_parser: the configuration info for the main program
64     @type section: C{string}
65     @ivar section: the section of the configuration info that applies to the DHT
66     @type khashmir: L{khashmir.Khashmir}
67     @ivar khashmir: the khashmir DHT instance to use
68     """
69     
70     if _web2:
71         implements(IDHT, IDHTStats, IDHTStatsFactory)
72     else:
73         implements(IDHT, IDHTStats)
74         
75     def __init__(self):
76         """Initialize the DHT."""
77         self.config = None
78         self.cache_dir = ''
79         self.bootstrap = []
80         self.bootstrap_node = False
81         self.joining = None
82         self.joined = False
83         self.outstandingJoins = 0
84         self.foundAddrs = []
85         self.storing = {}
86         self.retrieving = {}
87         self.retrieved = {}
88         self.factory = None
89     
90     def loadConfig(self, config, section):
91         """See L{apt_p2p.interfaces.IDHT}."""
92         self.config_parser = config
93         self.section = section
94         self.config = {}
95         
96         # Get some initial values
97         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
98         if not os.path.exists(self.cache_dir):
99             os.makedirs(self.cache_dir)
100         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
101         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
102         for k in self.config_parser.options(section):
103             # The numbers in the config file
104             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
105                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
106                 self.config[k] = self.config_parser.getint(section, k)
107             # The times in the config file
108             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
109                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
110                 self.config[k] = self.config_parser.gettime(section, k)
111             # The booleans in the config file
112             elif k in ['SPEW']:
113                 self.config[k] = self.config_parser.getboolean(section, k)
114             # Everything else is a string
115             else:
116                 self.config[k] = self.config_parser.get(section, k)
117     
118     def join(self):
119         """See L{apt_p2p.interfaces.IDHT}."""
120         if self.config is None:
121             raise DHTError, "configuration not loaded"
122         if self.joining:
123             raise DHTError, "a join is already in progress"
124
125         # Create the new khashmir instance
126         self.khashmir = Khashmir(self.config, self.cache_dir)
127         
128         self.joining = defer.Deferred()
129         for node in self.bootstrap:
130             host, port = node.rsplit(':', 1)
131             port = int(port)
132             
133             # Translate host names into IP addresses
134             if isIPAddress(host):
135                 self._join_gotIP(host, port)
136             else:
137                 reactor.resolve(host).addCallback(self._join_gotIP, port)
138         
139         return self.joining
140
141     def _join_gotIP(self, ip, port):
142         """Join the DHT using a single bootstrap nodes IP address."""
143         self.outstandingJoins += 1
144         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
145     
146     def _join_single(self, addr):
147         """Process the response from the bootstrap node.
148         
149         Finish the join by contacting close nodes.
150         """
151         self.outstandingJoins -= 1
152         if addr:
153             self.foundAddrs.append(addr)
154         if addr or self.outstandingJoins <= 0:
155             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
156         log.msg('Got back from bootstrap node: %r' % (addr,))
157     
158     def _join_error(self, failure = None):
159         """Process an error in contacting the bootstrap node.
160         
161         If no bootstrap nodes remain, finish the process by contacting
162         close nodes.
163         """
164         self.outstandingJoins -= 1
165         log.msg("bootstrap node could not be reached")
166         if self.outstandingJoins <= 0:
167             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
168
169     def _join_complete(self, result):
170         """End the joining process and return the addresses found for this node."""
171         if not self.joined and len(result) > 0:
172             self.joined = True
173         if self.joining and self.outstandingJoins <= 0:
174             df = self.joining
175             self.joining = None
176             if self.joined or self.bootstrap_node:
177                 self.joined = True
178                 df.callback(self.foundAddrs)
179             else:
180                 df.errback(DHTError('could not find any nodes to bootstrap to'))
181         
182     def getAddrs(self):
183         """Get the list of addresses returned by bootstrap nodes for this node."""
184         return self.foundAddrs
185         
186     def leave(self):
187         """See L{apt_p2p.interfaces.IDHT}."""
188         if self.config is None:
189             raise DHTError, "configuration not loaded"
190         
191         if self.joined or self.joining:
192             if self.joining:
193                 self.joining.errback(DHTError('still joining when leave was called'))
194                 self.joining = None
195             self.joined = False
196             self.khashmir.shutdown()
197         
198     def _normKey(self, key, bits=None, bytes=None):
199         """Normalize the length of keys used in the DHT."""
200         bits = self.config["HASH_LENGTH"]
201         if bits is not None:
202             bytes = (bits - 1) // 8 + 1
203         else:
204             if bytes is None:
205                 raise DHTError, "you must specify one of bits or bytes for normalization"
206             
207         # Extend short keys with null bytes
208         if len(key) < bytes:
209             key = key + '\000'*(bytes - len(key))
210         # Truncate long keys
211         elif len(key) > bytes:
212             key = key[:bytes]
213         return key
214
215     def getValue(self, key):
216         """See L{apt_p2p.interfaces.IDHT}."""
217         if self.config is None:
218             raise DHTError, "configuration not loaded"
219         if not self.joined:
220             raise DHTError, "have not joined a network yet"
221         
222         key = self._normKey(key)
223
224         d = defer.Deferred()
225         if key not in self.retrieving:
226             self.khashmir.valueForKey(key, self._getValue)
227         self.retrieving.setdefault(key, []).append(d)
228         return d
229         
230     def _getValue(self, key, result):
231         """Process a returned list of values from the DHT."""
232         # Save the list of values to return when it is complete
233         if result:
234             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
235         else:
236             # Empty list, the get is complete, return the result
237             final_result = []
238             if key in self.retrieved:
239                 final_result = self.retrieved[key]
240                 del self.retrieved[key]
241             for i in range(len(self.retrieving[key])):
242                 d = self.retrieving[key].pop(0)
243                 d.callback(final_result)
244             del self.retrieving[key]
245
246     def storeValue(self, key, value):
247         """See L{apt_p2p.interfaces.IDHT}."""
248         if self.config is None:
249             raise DHTError, "configuration not loaded"
250         if not self.joined:
251             raise DHTError, "have not joined a network yet"
252
253         key = self._normKey(key)
254         bvalue = bencode(value)
255
256         if key in self.storing and bvalue in self.storing[key]:
257             raise DHTError, "already storing that key with the same value"
258
259         d = defer.Deferred()
260         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
261         self.storing.setdefault(key, {})[bvalue] = d
262         return d
263     
264     def _storeValue(self, key, bvalue, result):
265         """Process the response from the DHT."""
266         if key in self.storing and bvalue in self.storing[key]:
267             # Check if the store succeeded
268             if len(result) > 0:
269                 self.storing[key][bvalue].callback(result)
270             else:
271                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
272             del self.storing[key][bvalue]
273             if len(self.storing[key].keys()) == 0:
274                 del self.storing[key]
275     
276     def getStats(self):
277         """See L{apt_p2p.interfaces.IDHTStats}."""
278         return self.khashmir.getStats()
279
280     def getStatsFactory(self):
281         """See L{apt_p2p.interfaces.IDHTStatsFactory}."""
282         assert _web2, "NOT IMPLEMENTED: twisted.web2 must be installed to use the stats factory."
283         if self.factory is None:
284             # Create a simple HTTP factory for stats
285             class StatsResource(resource.Resource):
286                 def __init__(self, manager):
287                     self.manager = manager
288                 def render(self, ctx):
289                     return http.Response(
290                         200,
291                         {'content-type': http_headers.MimeType('text', 'html')},
292                         '<html><body>\n\n' + self.manager.getStats() + '\n</body></html>\n')
293                 def locateChild(self, request, segments):
294                     log.msg('Got HTTP stats request from %s' % (request.remoteAddr, ))
295                     return self, ()
296             
297             self.factory = channel.HTTPFactory(server.Site(StatsResource(self)))
298         return self.factory
299         
300
301 class TestSimpleDHT(unittest.TestCase):
302     """Simple 2-node unit tests for the DHT."""
303     
304     timeout = 2
305     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
306                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
307                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
308                     'MAX_FAILURES': 3,
309                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
310                     'KEY_EXPIRE': 3600, 'SPEW': False, }
311
312     def setUp(self):
313         self.a = DHT()
314         self.b = DHT()
315         self.a.config = self.DHT_DEFAULTS.copy()
316         self.a.config['PORT'] = 4044
317         self.a.bootstrap = ["127.0.0.1:4044"]
318         self.a.bootstrap_node = True
319         self.a.cache_dir = '/tmp'
320         self.b.config = self.DHT_DEFAULTS.copy()
321         self.b.config['PORT'] = 4045
322         self.b.bootstrap = ["127.0.0.1:4044"]
323         self.b.cache_dir = '/tmp'
324         
325     def test_bootstrap_join(self):
326         d = self.a.join()
327         return d
328         
329     def node_join(self, result):
330         d = self.b.join()
331         return d
332     
333     def test_join(self):
334         self.lastDefer = defer.Deferred()
335         d = self.a.join()
336         d.addCallback(self.node_join)
337         d.addCallback(self.lastDefer.callback)
338         return self.lastDefer
339
340     def test_normKey(self):
341         h = self.a._normKey('12345678901234567890')
342         self.failUnless(h == '12345678901234567890')
343         h = self.a._normKey('12345678901234567')
344         self.failUnless(h == '12345678901234567\000\000\000')
345         h = self.a._normKey('1234567890123456789012345')
346         self.failUnless(h == '12345678901234567890')
347         h = self.a._normKey('1234567890123456789')
348         self.failUnless(h == '1234567890123456789\000')
349         h = self.a._normKey('123456789012345678901')
350         self.failUnless(h == '12345678901234567890')
351
352     def value_stored(self, result, value):
353         self.stored -= 1
354         if self.stored == 0:
355             self.get_values()
356         
357     def store_values(self, result):
358         self.stored = 3
359         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
360         d.addCallback(self.value_stored, 4045)
361         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
362         d.addCallback(self.value_stored, 4044)
363         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
364         d.addCallback(self.value_stored, 4045)
365
366     def check_values(self, result, values):
367         self.checked -= 1
368         self.failUnless(len(result) == len(values))
369         for v in result:
370             self.failUnless(v in values)
371         if self.checked == 0:
372             self.lastDefer.callback(1)
373     
374     def get_values(self):
375         self.checked = 4
376         d = self.a.getValue(sha.new('4044').digest())
377         d.addCallback(self.check_values, [str(4044*2)])
378         d = self.b.getValue(sha.new('4044').digest())
379         d.addCallback(self.check_values, [str(4044*2)])
380         d = self.a.getValue(sha.new('4045').digest())
381         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
382         d = self.b.getValue(sha.new('4045').digest())
383         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
384
385     def test_store(self):
386         from twisted.internet.base import DelayedCall
387         DelayedCall.debug = True
388         self.lastDefer = defer.Deferred()
389         d = self.a.join()
390         d.addCallback(self.node_join)
391         d.addCallback(self.store_values)
392         return self.lastDefer
393
394     def tearDown(self):
395         self.a.leave()
396         try:
397             os.unlink(self.a.khashmir.store.db)
398         except:
399             pass
400         self.b.leave()
401         try:
402             os.unlink(self.b.khashmir.store.db)
403         except:
404             pass
405
406 class TestMultiDHT(unittest.TestCase):
407     """More complicated 20-node tests for the DHT."""
408     
409     timeout = 60
410     num = 20
411     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
412                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
413                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
414                     'MAX_FAILURES': 3,
415                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
416                     'KEY_EXPIRE': 3600, 'SPEW': False, }
417
418     def setUp(self):
419         self.l = []
420         self.startport = 4081
421         for i in range(self.num):
422             self.l.append(DHT())
423             self.l[i].config = self.DHT_DEFAULTS.copy()
424             self.l[i].config['PORT'] = self.startport + i
425             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
426             self.l[i].cache_dir = '/tmp'
427         self.l[0].bootstrap_node = True
428         
429     def node_join(self, result, next_node):
430         d = self.l[next_node].join()
431         if next_node + 1 < len(self.l):
432             d.addCallback(self.node_join, next_node + 1)
433         else:
434             d.addCallback(self.lastDefer.callback)
435     
436     def test_join(self):
437         self.timeout = 2
438         self.lastDefer = defer.Deferred()
439         d = self.l[0].join()
440         d.addCallback(self.node_join, 1)
441         return self.lastDefer
442         
443     def store_values(self, result, i = 0, j = 0):
444         if j > i:
445             j -= i+1
446             i += 1
447         if i == len(self.l):
448             self.get_values()
449         else:
450             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
451             d.addCallback(self.store_values, i, j+1)
452     
453     def get_values(self, result = None, check = None, i = 0, j = 0):
454         if result is not None:
455             self.failUnless(len(result) == len(check))
456             for v in result:
457                 self.failUnless(v in check)
458         if j >= len(self.l):
459             j -= len(self.l)
460             i += 1
461         if i == len(self.l):
462             self.lastDefer.callback(1)
463         else:
464             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
465             check = []
466             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
467                 check.append(str(k))
468             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
469
470     def store_join(self, result, next_node):
471         d = self.l[next_node].join()
472         if next_node + 1 < len(self.l):
473             d.addCallback(self.store_join, next_node + 1)
474         else:
475             d.addCallback(self.store_values)
476     
477     def test_store(self):
478         from twisted.internet.base import DelayedCall
479         DelayedCall.debug = True
480         self.lastDefer = defer.Deferred()
481         d = self.l[0].join()
482         d.addCallback(self.store_join, 1)
483         return self.lastDefer
484
485     def tearDown(self):
486         for i in self.l:
487             try:
488                 i.leave()
489                 os.unlink(i.khashmir.store.db)
490             except:
491                 pass