ab206d2abf3ba04a5e531325646249f08307a5af
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 from datetime import datetime
3 import os, sha, random
4
5 from twisted.internet import defer, reactor
6 from twisted.internet.abstract import isIPAddress
7 from twisted.python import log
8 from twisted.trial import unittest
9 from zope.interface import implements
10
11 from apt_dht.interfaces import IDHT
12 from khashmir import Khashmir
13 from bencode import bencode, bdecode
14
15 khashmir_dir = 'apt-dht-Khashmir'
16
17 class DHTError(Exception):
18     """Represents errors that occur in the DHT."""
19
20 class DHT:
21     
22     implements(IDHT)
23     
24     def __init__(self):
25         self.config = None
26         self.cache_dir = ''
27         self.bootstrap = []
28         self.bootstrap_node = False
29         self.joining = None
30         self.joined = False
31         self.outstandingJoins = 0
32         self.foundAddrs = []
33         self.storing = {}
34         self.retrieving = {}
35         self.retrieved = {}
36     
37     def loadConfig(self, config, section):
38         """See L{apt_dht.interfaces.IDHT}."""
39         self.config_parser = config
40         self.section = section
41         self.config = {}
42         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
43         if not os.path.exists(self.cache_dir):
44             os.makedirs(self.cache_dir)
45         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
46         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
47         for k in self.config_parser.options(section):
48             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
49                      'MAX_FAILURES', 'PORT']:
50                 self.config[k] = self.config_parser.getint(section, k)
51             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
52                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
53                 self.config[k] = self.config_parser.gettime(section, k)
54             elif k in ['SPEW']:
55                 self.config[k] = self.config_parser.getboolean(section, k)
56             else:
57                 self.config[k] = self.config_parser.get(section, k)
58     
59     def join(self):
60         """See L{apt_dht.interfaces.IDHT}."""
61         if self.config is None:
62             raise DHTError, "configuration not loaded"
63         if self.joining:
64             raise DHTError, "a join is already in progress"
65
66         self.khashmir = Khashmir(self.config, self.cache_dir)
67         
68         self.joining = defer.Deferred()
69         for node in self.bootstrap:
70             host, port = node.rsplit(':', 1)
71             port = int(port)
72             if isIPAddress(host):
73                 self._join_gotIP(host, port)
74             else:
75                 reactor.resolve(host).addCallback(self._join_gotIP, port)
76         
77         return self.joining
78
79     def _join_gotIP(self, ip, port):
80         """Called after an IP address has been found for a single bootstrap node."""
81         self.outstandingJoins += 1
82         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
83     
84     def _join_single(self, addr):
85         """Called when a single bootstrap node has been added."""
86         self.outstandingJoins -= 1
87         if addr:
88             self.foundAddrs.append(addr)
89         if addr or self.outstandingJoins <= 0:
90             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
91         log.msg('Got back from bootstrap node: %r' % (addr,))
92     
93     def _join_error(self, failure = None):
94         """Called when a single bootstrap node has failed."""
95         self.outstandingJoins -= 1
96         log.msg("bootstrap node could not be reached")
97         if self.outstandingJoins <= 0:
98             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
99
100     def _join_complete(self, result):
101         """Called when the tables have been initialized with nodes."""
102         if not self.joined and len(result) > 0:
103             self.joined = True
104         if self.joining and self.outstandingJoins <= 0:
105             df = self.joining
106             self.joining = None
107             if self.joined or self.bootstrap_node:
108                 self.joined = True
109                 df.callback(self.foundAddrs)
110             else:
111                 df.errback(DHTError('could not find any nodes to bootstrap to'))
112         
113     def getAddrs(self):
114         return self.foundAddrs
115         
116     def leave(self):
117         """See L{apt_dht.interfaces.IDHT}."""
118         if self.config is None:
119             raise DHTError, "configuration not loaded"
120         
121         if self.joined or self.joining:
122             if self.joining:
123                 self.joining.errback(DHTError('still joining when leave was called'))
124                 self.joining = None
125             self.joined = False
126             self.khashmir.shutdown()
127         
128     def getValue(self, key):
129         """See L{apt_dht.interfaces.IDHT}."""
130         if self.config is None:
131             raise DHTError, "configuration not loaded"
132         if not self.joined:
133             raise DHTError, "have not joined a network yet"
134
135         d = defer.Deferred()
136         if key not in self.retrieving:
137             self.khashmir.valueForKey(key, self._getValue)
138         self.retrieving.setdefault(key, []).append(d)
139         return d
140         
141     def _getValue(self, key, result):
142         if result:
143             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
144         else:
145             final_result = []
146             if key in self.retrieved:
147                 final_result = self.retrieved[key]
148                 del self.retrieved[key]
149             for i in range(len(self.retrieving[key])):
150                 d = self.retrieving[key].pop(0)
151                 d.callback(final_result)
152             del self.retrieving[key]
153
154     def storeValue(self, key, value):
155         """See L{apt_dht.interfaces.IDHT}."""
156         if self.config is None:
157             raise DHTError, "configuration not loaded"
158         if not self.joined:
159             raise DHTError, "have not joined a network yet"
160
161         bvalue = bencode(value)
162
163         if key in self.storing and bvalue in self.storing[key]:
164             raise DHTError, "already storing that key with the same value"
165
166         d = defer.Deferred()
167         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
168         self.storing.setdefault(key, {})[bvalue] = d
169         return d
170     
171     def _storeValue(self, key, bvalue, result):
172         if key in self.storing and bvalue in self.storing[key]:
173             if len(result) > 0:
174                 self.storing[key][bvalue].callback(result)
175             else:
176                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
177             del self.storing[key][bvalue]
178             if len(self.storing[key].keys()) == 0:
179                 del self.storing[key]
180
181 class TestSimpleDHT(unittest.TestCase):
182     """Unit tests for the DHT."""
183     
184     timeout = 2
185     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
186                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
187                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
188                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
189                     'KEY_EXPIRE': 3600, 'SPEW': False, }
190
191     def setUp(self):
192         self.a = DHT()
193         self.b = DHT()
194         self.a.config = self.DHT_DEFAULTS.copy()
195         self.a.config['PORT'] = 4044
196         self.a.bootstrap = ["127.0.0.1:4044"]
197         self.a.bootstrap_node = True
198         self.a.cache_dir = '/tmp'
199         self.b.config = self.DHT_DEFAULTS.copy()
200         self.b.config['PORT'] = 4045
201         self.b.bootstrap = ["127.0.0.1:4044"]
202         self.b.cache_dir = '/tmp'
203         
204     def test_bootstrap_join(self):
205         d = self.a.join()
206         return d
207         
208     def node_join(self, result):
209         d = self.b.join()
210         return d
211     
212     def test_join(self):
213         self.lastDefer = defer.Deferred()
214         d = self.a.join()
215         d.addCallback(self.node_join)
216         d.addCallback(self.lastDefer.callback)
217         return self.lastDefer
218
219     def value_stored(self, result, value):
220         self.stored -= 1
221         if self.stored == 0:
222             self.get_values()
223         
224     def store_values(self, result):
225         self.stored = 3
226         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
227         d.addCallback(self.value_stored, 4045)
228         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
229         d.addCallback(self.value_stored, 4044)
230         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
231         d.addCallback(self.value_stored, 4045)
232
233     def check_values(self, result, values):
234         self.checked -= 1
235         self.failUnless(len(result) == len(values))
236         for v in result:
237             self.failUnless(v in values)
238         if self.checked == 0:
239             self.lastDefer.callback(1)
240     
241     def get_values(self):
242         self.checked = 4
243         d = self.a.getValue(sha.new('4044').digest())
244         d.addCallback(self.check_values, [str(4044*2)])
245         d = self.b.getValue(sha.new('4044').digest())
246         d.addCallback(self.check_values, [str(4044*2)])
247         d = self.a.getValue(sha.new('4045').digest())
248         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
249         d = self.b.getValue(sha.new('4045').digest())
250         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
251
252     def test_store(self):
253         from twisted.internet.base import DelayedCall
254         DelayedCall.debug = True
255         self.lastDefer = defer.Deferred()
256         d = self.a.join()
257         d.addCallback(self.node_join)
258         d.addCallback(self.store_values)
259         return self.lastDefer
260
261     def tearDown(self):
262         self.a.leave()
263         try:
264             os.unlink(self.a.khashmir.store.db)
265         except:
266             pass
267         self.b.leave()
268         try:
269             os.unlink(self.b.khashmir.store.db)
270         except:
271             pass
272
273 class TestMultiDHT(unittest.TestCase):
274     
275     timeout = 60
276     num = 20
277     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
278                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
279                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
280                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
281                     'KEY_EXPIRE': 3600, 'SPEW': False, }
282
283     def setUp(self):
284         self.l = []
285         self.startport = 4081
286         for i in range(self.num):
287             self.l.append(DHT())
288             self.l[i].config = self.DHT_DEFAULTS.copy()
289             self.l[i].config['PORT'] = self.startport + i
290             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
291             self.l[i].cache_dir = '/tmp'
292         self.l[0].bootstrap_node = True
293         
294     def node_join(self, result, next_node):
295         d = self.l[next_node].join()
296         if next_node + 1 < len(self.l):
297             d.addCallback(self.node_join, next_node + 1)
298         else:
299             d.addCallback(self.lastDefer.callback)
300     
301     def test_join(self):
302         self.timeout = 2
303         self.lastDefer = defer.Deferred()
304         d = self.l[0].join()
305         d.addCallback(self.node_join, 1)
306         return self.lastDefer
307         
308     def store_values(self, result, i = 0, j = 0):
309         if j > i:
310             j -= i+1
311             i += 1
312         if i == len(self.l):
313             self.get_values()
314         else:
315             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
316             d.addCallback(self.store_values, i, j+1)
317     
318     def get_values(self, result = None, check = None, i = 0, j = 0):
319         if result is not None:
320             self.failUnless(len(result) == len(check))
321             for v in result:
322                 self.failUnless(v in check)
323         if j >= len(self.l):
324             j -= len(self.l)
325             i += 1
326         if i == len(self.l):
327             self.lastDefer.callback(1)
328         else:
329             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
330             check = []
331             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
332                 check.append(str(k))
333             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
334
335     def store_join(self, result, next_node):
336         d = self.l[next_node].join()
337         if next_node + 1 < len(self.l):
338             d.addCallback(self.store_join, next_node + 1)
339         else:
340             d.addCallback(self.store_values)
341     
342     def test_store(self):
343         from twisted.internet.base import DelayedCall
344         DelayedCall.debug = True
345         self.lastDefer = defer.Deferred()
346         d = self.l[0].join()
347         d.addCallback(self.store_join, 1)
348         return self.lastDefer
349
350     def tearDown(self):
351         for i in self.l:
352             try:
353                 i.leave()
354                 os.unlink(i.khashmir.store.db)
355             except:
356                 pass