Standardize the number of values retrieved from the DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 from datetime import datetime
3 import os, sha, random
4
5 from twisted.internet import defer, reactor
6 from twisted.internet.abstract import isIPAddress
7 from twisted.python import log
8 from twisted.trial import unittest
9 from zope.interface import implements
10
11 from apt_dht.interfaces import IDHT
12 from khashmir import Khashmir
13 from bencode import bencode, bdecode
14
15 khashmir_dir = 'apt-dht-Khashmir'
16
17 class DHTError(Exception):
18     """Represents errors that occur in the DHT."""
19
20 class DHT:
21     
22     implements(IDHT)
23     
24     def __init__(self):
25         self.config = None
26         self.cache_dir = ''
27         self.bootstrap = []
28         self.bootstrap_node = False
29         self.joining = None
30         self.joined = False
31         self.outstandingJoins = 0
32         self.foundAddrs = []
33         self.storing = {}
34         self.retrieving = {}
35         self.retrieved = {}
36     
37     def loadConfig(self, config, section):
38         """See L{apt_dht.interfaces.IDHT}."""
39         self.config_parser = config
40         self.section = section
41         self.config = {}
42         self.cache_dir = os.path.join(self.config_parser.get(section, 'cache_dir'), khashmir_dir)
43         if not os.path.exists(self.cache_dir):
44             os.makedirs(self.cache_dir)
45         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
46         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
47         for k in self.config_parser.options(section):
48             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
49                      'RETRIEVE_VALUES', 'MAX_FAILURES', 'PORT']:
50                 self.config[k] = self.config_parser.getint(section, k)
51             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
52                        'BUCKET_STALENESS', 'KEY_EXPIRE']:
53                 self.config[k] = self.config_parser.gettime(section, k)
54             elif k in ['SPEW']:
55                 self.config[k] = self.config_parser.getboolean(section, k)
56             else:
57                 self.config[k] = self.config_parser.get(section, k)
58     
59     def join(self):
60         """See L{apt_dht.interfaces.IDHT}."""
61         if self.config is None:
62             raise DHTError, "configuration not loaded"
63         if self.joining:
64             raise DHTError, "a join is already in progress"
65
66         self.khashmir = Khashmir(self.config, self.cache_dir)
67         
68         self.joining = defer.Deferred()
69         for node in self.bootstrap:
70             host, port = node.rsplit(':', 1)
71             port = int(port)
72             if isIPAddress(host):
73                 self._join_gotIP(host, port)
74             else:
75                 reactor.resolve(host).addCallback(self._join_gotIP, port)
76         
77         return self.joining
78
79     def _join_gotIP(self, ip, port):
80         """Called after an IP address has been found for a single bootstrap node."""
81         self.outstandingJoins += 1
82         self.khashmir.addContact(ip, port, self._join_single, self._join_error)
83     
84     def _join_single(self, addr):
85         """Called when a single bootstrap node has been added."""
86         self.outstandingJoins -= 1
87         if addr:
88             self.foundAddrs.append(addr)
89         if addr or self.outstandingJoins <= 0:
90             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
91         log.msg('Got back from bootstrap node: %r' % (addr,))
92     
93     def _join_error(self, failure = None):
94         """Called when a single bootstrap node has failed."""
95         self.outstandingJoins -= 1
96         log.msg("bootstrap node could not be reached")
97         if self.outstandingJoins <= 0:
98             self.khashmir.findCloseNodes(self._join_complete, self._join_complete)
99
100     def _join_complete(self, result):
101         """Called when the tables have been initialized with nodes."""
102         if not self.joined and len(result) > 0:
103             self.joined = True
104         if self.joining and self.outstandingJoins <= 0:
105             df = self.joining
106             self.joining = None
107             if self.joined or self.bootstrap_node:
108                 self.joined = True
109                 df.callback(self.foundAddrs)
110             else:
111                 df.errback(DHTError('could not find any nodes to bootstrap to'))
112         
113     def getAddrs(self):
114         return self.foundAddrs
115         
116     def leave(self):
117         """See L{apt_dht.interfaces.IDHT}."""
118         if self.config is None:
119             raise DHTError, "configuration not loaded"
120         
121         if self.joined or self.joining:
122             if self.joining:
123                 self.joining.errback(DHTError('still joining when leave was called'))
124                 self.joining = None
125             self.joined = False
126             self.khashmir.shutdown()
127         
128     def getValue(self, key):
129         """See L{apt_dht.interfaces.IDHT}."""
130         if self.config is None:
131             raise DHTError, "configuration not loaded"
132         if not self.joined:
133             raise DHTError, "have not joined a network yet"
134
135         d = defer.Deferred()
136         if key not in self.retrieving:
137             self.khashmir.valueForKey(key, self._getValue)
138         self.retrieving.setdefault(key, []).append(d)
139         return d
140         
141     def _getValue(self, key, result):
142         if result:
143             self.retrieved.setdefault(key, []).extend([bdecode(r) for r in result])
144         else:
145             final_result = []
146             if key in self.retrieved:
147                 final_result = self.retrieved[key]
148                 del self.retrieved[key]
149             for i in range(len(self.retrieving[key])):
150                 d = self.retrieving[key].pop(0)
151                 d.callback(final_result)
152             del self.retrieving[key]
153
154     def storeValue(self, key, value):
155         """See L{apt_dht.interfaces.IDHT}."""
156         if self.config is None:
157             raise DHTError, "configuration not loaded"
158         if not self.joined:
159             raise DHTError, "have not joined a network yet"
160
161         bvalue = bencode(value)
162
163         if key in self.storing and bvalue in self.storing[key]:
164             raise DHTError, "already storing that key with the same value"
165
166         d = defer.Deferred()
167         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
168         self.storing.setdefault(key, {})[bvalue] = d
169         return d
170     
171     def _storeValue(self, key, bvalue, result):
172         if key in self.storing and bvalue in self.storing[key]:
173             if len(result) > 0:
174                 self.storing[key][bvalue].callback(result)
175             else:
176                 self.storing[key][bvalue].errback(DHTError('could not store value %s in key %s' % (bvalue, key)))
177             del self.storing[key][bvalue]
178             if len(self.storing[key].keys()) == 0:
179                 del self.storing[key]
180
181 class TestSimpleDHT(unittest.TestCase):
182     """Unit tests for the DHT."""
183     
184     timeout = 2
185     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
186                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
187                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
188                     'MAX_FAILURES': 3,
189                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
190                     'KEY_EXPIRE': 3600, 'SPEW': False, }
191
192     def setUp(self):
193         self.a = DHT()
194         self.b = DHT()
195         self.a.config = self.DHT_DEFAULTS.copy()
196         self.a.config['PORT'] = 4044
197         self.a.bootstrap = ["127.0.0.1:4044"]
198         self.a.bootstrap_node = True
199         self.a.cache_dir = '/tmp'
200         self.b.config = self.DHT_DEFAULTS.copy()
201         self.b.config['PORT'] = 4045
202         self.b.bootstrap = ["127.0.0.1:4044"]
203         self.b.cache_dir = '/tmp'
204         
205     def test_bootstrap_join(self):
206         d = self.a.join()
207         return d
208         
209     def node_join(self, result):
210         d = self.b.join()
211         return d
212     
213     def test_join(self):
214         self.lastDefer = defer.Deferred()
215         d = self.a.join()
216         d.addCallback(self.node_join)
217         d.addCallback(self.lastDefer.callback)
218         return self.lastDefer
219
220     def value_stored(self, result, value):
221         self.stored -= 1
222         if self.stored == 0:
223             self.get_values()
224         
225     def store_values(self, result):
226         self.stored = 3
227         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
228         d.addCallback(self.value_stored, 4045)
229         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
230         d.addCallback(self.value_stored, 4044)
231         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
232         d.addCallback(self.value_stored, 4045)
233
234     def check_values(self, result, values):
235         self.checked -= 1
236         self.failUnless(len(result) == len(values))
237         for v in result:
238             self.failUnless(v in values)
239         if self.checked == 0:
240             self.lastDefer.callback(1)
241     
242     def get_values(self):
243         self.checked = 4
244         d = self.a.getValue(sha.new('4044').digest())
245         d.addCallback(self.check_values, [str(4044*2)])
246         d = self.b.getValue(sha.new('4044').digest())
247         d.addCallback(self.check_values, [str(4044*2)])
248         d = self.a.getValue(sha.new('4045').digest())
249         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
250         d = self.b.getValue(sha.new('4045').digest())
251         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
252
253     def test_store(self):
254         from twisted.internet.base import DelayedCall
255         DelayedCall.debug = True
256         self.lastDefer = defer.Deferred()
257         d = self.a.join()
258         d.addCallback(self.node_join)
259         d.addCallback(self.store_values)
260         return self.lastDefer
261
262     def tearDown(self):
263         self.a.leave()
264         try:
265             os.unlink(self.a.khashmir.store.db)
266         except:
267             pass
268         self.b.leave()
269         try:
270             os.unlink(self.b.khashmir.store.db)
271         except:
272             pass
273
274 class TestMultiDHT(unittest.TestCase):
275     
276     timeout = 60
277     num = 20
278     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
279                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
280                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
281                     'MAX_FAILURES': 3,
282                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
283                     'KEY_EXPIRE': 3600, 'SPEW': False, }
284
285     def setUp(self):
286         self.l = []
287         self.startport = 4081
288         for i in range(self.num):
289             self.l.append(DHT())
290             self.l[i].config = self.DHT_DEFAULTS.copy()
291             self.l[i].config['PORT'] = self.startport + i
292             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
293             self.l[i].cache_dir = '/tmp'
294         self.l[0].bootstrap_node = True
295         
296     def node_join(self, result, next_node):
297         d = self.l[next_node].join()
298         if next_node + 1 < len(self.l):
299             d.addCallback(self.node_join, next_node + 1)
300         else:
301             d.addCallback(self.lastDefer.callback)
302     
303     def test_join(self):
304         self.timeout = 2
305         self.lastDefer = defer.Deferred()
306         d = self.l[0].join()
307         d.addCallback(self.node_join, 1)
308         return self.lastDefer
309         
310     def store_values(self, result, i = 0, j = 0):
311         if j > i:
312             j -= i+1
313             i += 1
314         if i == len(self.l):
315             self.get_values()
316         else:
317             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
318             d.addCallback(self.store_values, i, j+1)
319     
320     def get_values(self, result = None, check = None, i = 0, j = 0):
321         if result is not None:
322             self.failUnless(len(result) == len(check))
323             for v in result:
324                 self.failUnless(v in check)
325         if j >= len(self.l):
326             j -= len(self.l)
327             i += 1
328         if i == len(self.l):
329             self.lastDefer.callback(1)
330         else:
331             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
332             check = []
333             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
334                 check.append(str(k))
335             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
336
337     def store_join(self, result, next_node):
338         d = self.l[next_node].join()
339         if next_node + 1 < len(self.l):
340             d.addCallback(self.store_join, next_node + 1)
341         else:
342             d.addCallback(self.store_values)
343     
344     def test_store(self):
345         from twisted.internet.base import DelayedCall
346         DelayedCall.debug = True
347         self.lastDefer = defer.Deferred()
348         d = self.l[0].join()
349         d.addCallback(self.store_join, 1)
350         return self.lastDefer
351
352     def tearDown(self):
353         for i in self.l:
354             try:
355                 i.leave()
356                 os.unlink(i.khashmir.store.db)
357             except:
358                 pass