5f08dae07498962759264b4dfb1c07169b5da059
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 from datetime import datetime
3 import os, sha, random
4
5 from twisted.internet import defer, reactor
6 from twisted.internet.abstract import isIPAddress
7 from twisted.python import log
8 from twisted.trial import unittest
9 from zope.interface import implements
10
11 from apt_dht.interfaces import IDHT
12 from khashmir import Khashmir
13 from bencode import bencode, bdecode
14
15 khashmir_dir = 'apt-dht-Khashmir'
16
17 class DHTError(Exception):
18     """Represents errors that occur in the DHT."""
19
20 class DHT:
21     
22     implements(IDHT)
23     
24     def __init__(self):
25         self.config = None
26         self.cache_dir = ''
27         self.bootstrap = []
28         self.bootstrap_node = False
29         self.joining = None
30         self.joined = False
31         self.outstandingJoins = 0
32         self.foundAddrs = []
33         self.storing = {}
34         self.retrieving = {}
35         self.retrieved = {}
36     
37     def loadConfig(self, config, section):
38         """See L{apt_dht.interfaces.IDHT}."""
39         self.config_parser = config
40         self.section = section
41         self.config = {}
42         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
43         if not os.path.exists(self.cache_dir):
44             os.makedirs(self.cache_dir)
45         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
46         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
47         for k in self.config_parser.options(section):
48             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
49                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
50                 self.config[k] = self.config_parser.getint(section, k)
51             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
52                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
53                 self.config[k] = self.config_parser.gettime(section, k)
54             elif k in ['SPEW']:
55                 self.config[k] = self.config_parser.getboolean(section, k)
56             else:
57                 self.config[k] = self.config_parser.get(section, k)
58     
59     def join(self):
60         """See L{apt_dht.interfaces.IDHT}."""
61         if self.config is None:
62             raise DHTError, "configuration not loaded"
63         if self.joining:
64             raise DHTError, "a join is already in progress"
65
66         self.khashmir = Khashmir(self.config, self.cache_dir)
67         
68         self.joining = defer.Deferred()
69         for node in self.bootstrap:
70             host, port = node.rsplit(':', 1)
71             port = int(port)
72             if isIPAddress(host):
73                 self._join_gotIP(host, port)
74             else:
75                 reactor.resolve(host).addCallback(self._join_gotIP, port)
76         
77         return self.joining
78
79     def _join_gotIP(self, ip, port):
80         """Called after an IP address has been found for a single bootstrap node."""
81         self.outstandingJoins += 1
82         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
83     
84     def _join_single(self, addr):
85         """Called when a single bootstrap node has been added."""
86         self.outstandingJoins -= 1
87         if addr:
88             self.foundAddrs.append(addr)
89         if addr or self.outstandingJoins <= 0:
90             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
91         log.msg('Got back from bootstrap node: %r' % (addr,))
92     
93     def _join_error(self, failure = None):
94         """Called when a single bootstrap node has failed."""
95         self.outstandingJoins -= 1
96         log.msg("bootstrap node could not be reached")
97         if self.outstandingJoins <= 0:
98             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
99
100     def _join_complete(self, result):
101         """Called when the tables have been initialized with nodes."""
102         if not self.joined and len(result) > 0:
103             self.joined = True
104         if self.joining and self.outstandingJoins <= 0:
105             df = self.joining
106             self.joining = None
107             if self.joined or self.bootstrap_node:
108                 self.joined = True
109                 df.callback(self.foundAddrs)
110             else:
111                 df.errback(DHTError('could not find any nodes to bootstrap to'))
112         
113     def getAddrs(self):
114         return self.foundAddrs
115         
116     def leave(self):
117         """See L{apt_dht.interfaces.IDHT}."""
118         if self.config is None:
119             raise DHTError, "configuration not loaded"
120         
121         if self.joined or self.joining:
122             if self.joining:
123                 self.joining.errback(DHTError('still joining when leave was called'))
124                 self.joining = None
125             self.joined = False
126             self.khashmir.shutdown()
127         
128     def _normKey(self, key, bits=None, bytes=None):
129         bits = self.config["HASH_LENGTH"]
130         if bits is not None:
131             bytes = (bits - 1) // 8 + 1
132         else:
133             if bytes is None:
134                 raise DHTError, "you must specify one of bits or bytes for normalization"
135         if len(key) < bytes:
136             key = key + '\000'*(bytes - len(key))
137         elif len(key) > bytes:
138             key = key[:bytes]
139         return key
140
141     def getValue(self, key):
142         """See L{apt_dht.interfaces.IDHT}."""
143         if self.config is None:
144             raise DHTError, "configuration not loaded"
145         if not self.joined:
146             raise DHTError, "have not joined a network yet"
147         
148         key = self._normKey(key)
149
150         d = defer.Deferred()
151         if key not in self.retrieving:
152             self.khashmir.valueForKey(key, self._getValue)
153         self.retrieving.setdefault(key, []).append(d)
154         return d
155         
156     def _getValue(self, key, result):
157         if result:
158             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
159         else:
160             final_result = []
161             if key in self.retrieved:
162                 final_result = self.retrieved[key]
163                 del self.retrieved[key]
164             for i in range(len(self.retrieving[key])):
165                 d = self.retrieving[key].pop(0)
166                 d.callback(final_result)
167             del self.retrieving[key]
168
169     def storeValue(self, key, value):
170         """See L{apt_dht.interfaces.IDHT}."""
171         if self.config is None:
172             raise DHTError, "configuration not loaded"
173         if not self.joined:
174             raise DHTError, "have not joined a network yet"
175
176         key = self._normKey(key)
177         bvalue = bencode(value)
178
179         if key in self.storing and bvalue in self.storing[key]:
180             raise DHTError, "already storing that key with the same value"
181
182         d = defer.Deferred()
183         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
184         self.storing.setdefault(key, {})[bvalue] = d
185         return d
186     
187     def _storeValue(self, key, bvalue, result):
188         if key in self.storing and bvalue in self.storing[key]:
189             if len(result) > 0:
190                 self.storing[key][bvalue].callback(result)
191             else:
192                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
193             del self.storing[key][bvalue]
194             if len(self.storing[key].keys()) == 0:
195                 del self.storing[key]
196
197 class TestSimpleDHT(unittest.TestCase):
198     """Unit tests for the DHT."""
199     
200     timeout = 2
201     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
202                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
203                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
204                     'MAX_FAILURES': 3,
205                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
206                     'KEY_EXPIRE': 3600, 'SPEW': False, }
207
208     def setUp(self):
209         self.a = DHT()
210         self.b = DHT()
211         self.a.config = self.DHT_DEFAULTS.copy()
212         self.a.config['PORT'] = 4044
213         self.a.bootstrap = ["127.0.0.1:4044"]
214         self.a.bootstrap_node = True
215         self.a.cache_dir = '/tmp'
216         self.b.config = self.DHT_DEFAULTS.copy()
217         self.b.config['PORT'] = 4045
218         self.b.bootstrap = ["127.0.0.1:4044"]
219         self.b.cache_dir = '/tmp'
220         
221     def test_bootstrap_join(self):
222         d = self.a.join()
223         return d
224         
225     def node_join(self, result):
226         d = self.b.join()
227         return d
228     
229     def test_join(self):
230         self.lastDefer = defer.Deferred()
231         d = self.a.join()
232         d.addCallback(self.node_join)
233         d.addCallback(self.lastDefer.callback)
234         return self.lastDefer
235
236     def test_normKey(self):
237         h = self.a._normKey('12345678901234567890')
238         self.failUnless(h == '12345678901234567890')
239         h = self.a._normKey('12345678901234567')
240         self.failUnless(h == '12345678901234567\000\000\000')
241         h = self.a._normKey('1234567890123456789012345')
242         self.failUnless(h == '12345678901234567890')
243         h = self.a._normKey('1234567890123456789')
244         self.failUnless(h == '1234567890123456789\000')
245         h = self.a._normKey('123456789012345678901')
246         self.failUnless(h == '12345678901234567890')
247
248     def value_stored(self, result, value):
249         self.stored -= 1
250         if self.stored == 0:
251             self.get_values()
252         
253     def store_values(self, result):
254         self.stored = 3
255         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
256         d.addCallback(self.value_stored, 4045)
257         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
258         d.addCallback(self.value_stored, 4044)
259         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
260         d.addCallback(self.value_stored, 4045)
261
262     def check_values(self, result, values):
263         self.checked -= 1
264         self.failUnless(len(result) == len(values))
265         for v in result:
266             self.failUnless(v in values)
267         if self.checked == 0:
268             self.lastDefer.callback(1)
269     
270     def get_values(self):
271         self.checked = 4
272         d = self.a.getValue(sha.new('4044').digest())
273         d.addCallback(self.check_values, [str(4044*2)])
274         d = self.b.getValue(sha.new('4044').digest())
275         d.addCallback(self.check_values, [str(4044*2)])
276         d = self.a.getValue(sha.new('4045').digest())
277         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
278         d = self.b.getValue(sha.new('4045').digest())
279         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
280
281     def test_store(self):
282         from twisted.internet.base import DelayedCall
283         DelayedCall.debug = True
284         self.lastDefer = defer.Deferred()
285         d = self.a.join()
286         d.addCallback(self.node_join)
287         d.addCallback(self.store_values)
288         return self.lastDefer
289
290     def tearDown(self):
291         self.a.leave()
292         try:
293             os.unlink(self.a.khashmir.store.db)
294         except:
295             pass
296         self.b.leave()
297         try:
298             os.unlink(self.b.khashmir.store.db)
299         except:
300             pass
301
302 class TestMultiDHT(unittest.TestCase):
303     
304     timeout = 60
305     num = 20
306     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
307                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
308                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
309                     'MAX_FAILURES': 3,
310                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
311                     'KEY_EXPIRE': 3600, 'SPEW': False, }
312
313     def setUp(self):
314         self.l = []
315         self.startport = 4081
316         for i in range(self.num):
317             self.l.append(DHT())
318             self.l[i].config = self.DHT_DEFAULTS.copy()
319             self.l[i].config['PORT'] = self.startport + i
320             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
321             self.l[i].cache_dir = '/tmp'
322         self.l[0].bootstrap_node = True
323         
324     def node_join(self, result, next_node):
325         d = self.l[next_node].join()
326         if next_node + 1 < len(self.l):
327             d.addCallback(self.node_join, next_node + 1)
328         else:
329             d.addCallback(self.lastDefer.callback)
330     
331     def test_join(self):
332         self.timeout = 2
333         self.lastDefer = defer.Deferred()
334         d = self.l[0].join()
335         d.addCallback(self.node_join, 1)
336         return self.lastDefer
337         
338     def store_values(self, result, i = 0, j = 0):
339         if j > i:
340             j -= i+1
341             i += 1
342         if i == len(self.l):
343             self.get_values()
344         else:
345             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
346             d.addCallback(self.store_values, i, j+1)
347     
348     def get_values(self, result = None, check = None, i = 0, j = 0):
349         if result is not None:
350             self.failUnless(len(result) == len(check))
351             for v in result:
352                 self.failUnless(v in check)
353         if j >= len(self.l):
354             j -= len(self.l)
355             i += 1
356         if i == len(self.l):
357             self.lastDefer.callback(1)
358         else:
359             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
360             check = []
361             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
362                 check.append(str(k))
363             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
364
365     def store_join(self, result, next_node):
366         d = self.l[next_node].join()
367         if next_node + 1 < len(self.l):
368             d.addCallback(self.store_join, next_node + 1)
369         else:
370             d.addCallback(self.store_values)
371     
372     def test_store(self):
373         from twisted.internet.base import DelayedCall
374         DelayedCall.debug = True
375         self.lastDefer = defer.Deferred()
376         d = self.l[0].join()
377         d.addCallback(self.store_join, 1)
378         return self.lastDefer
379
380     def tearDown(self):
381         for i in self.l:
382             try:
383                 i.leave()
384                 os.unlink(i.khashmir.store.db)
385             except:
386                 pass