]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht_Khashmir/DHT.py
Pass the new HashObjects around everywhere (untested).
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 import os, sha, random
3
4 from twisted.internet import defer, reactor
5 from twisted.internet.abstract import isIPAddress
6 from twisted.trial import unittest
7 from zope.interface import implements
8
9 from apt_dht.interfaces import IDHT
10 from khashmir import Khashmir
11
12 class DHTError(Exception):
13     """Represents errors that occur in the DHT."""
14
15 class DHT:
16     
17     implements(IDHT)
18     
19     def __init__(self):
20         self.config = None
21         self.cache_dir = ''
22         self.bootstrap = []
23         self.bootstrap_node = False
24         self.joining = None
25         self.joined = False
26         self.storing = {}
27         self.retrieving = {}
28         self.retrieved = {}
29     
30     def loadConfig(self, config, section):
31         """See L{apt_dht.interfaces.IDHT}."""
32         self.config_parser = config
33         self.section = section
34         self.config = {}
35         self.cache_dir = self.config_parser.get(section, 'cache_dir')
36         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
37         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
38         for k in self.config_parser.options(section):
39             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
40                      'MAX_FAILURES', 'PORT']:
41                 self.config[k] = self.config_parser.getint(section, k)
42             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
43                        'BUCKET_STALENESS', 'KEINITIAL_DELAY', 'KE_DELAY', 'KE_AGE']:
44                 self.config[k] = self.config_parser.gettime(section, k)
45             else:
46                 self.config[k] = self.config_parser.get(section, k)
47     
48     def join(self):
49         """See L{apt_dht.interfaces.IDHT}."""
50         if self.config is None:
51             raise DHTError, "configuration not loaded"
52         if self.joining:
53             raise DHTError, "a join is already in progress"
54
55         self.khashmir = Khashmir(self.config, self.cache_dir)
56         
57         self.joining = defer.Deferred()
58         for node in self.bootstrap:
59             host, port = node.rsplit(':', 1)
60             port = int(port)
61             if isIPAddress(host):
62                 self._join_gotIP(host, port)
63             else:
64                 reactor.resolve(host).addCallback(self._join_gotIP, port)
65         
66         return self.joining
67
68     def _join_gotIP(self, ip, port):
69         """Called after an IP address has been found for a single bootstrap node."""
70         self.khashmir.addContact(ip, port, self._join_single)
71     
72     def _join_single(self):
73         """Called when a single bootstrap node has been added."""
74         self.khashmir.findCloseNodes(self._join_complete)
75     
76     def _join_complete(self, result):
77         """Called when the tables have been initialized with nodes."""
78         if not self.joined:
79             self.joined = True
80             if len(result) > 0 or self.bootstrap_node:
81                 df = self.joining
82                 self.joining = None
83                 df.callback(result)
84             else:
85                 df = self.joining
86                 self.joining = None
87                 df.errback(DHTError('could not find any nodes to bootstrap to'))
88         
89     def leave(self):
90         """See L{apt_dht.interfaces.IDHT}."""
91         if self.config is None:
92             raise DHTError, "configuration not loaded"
93         
94         if self.joined or self.joining:
95             if self.joining:
96                 self.joining.errback(DHTError('still joining when leave was called'))
97                 self.joining = None
98             self.joined = False
99             self.khashmir.shutdown()
100         
101     def getValue(self, key):
102         """See L{apt_dht.interfaces.IDHT}."""
103         if self.config is None:
104             raise DHTError, "configuration not loaded"
105         if not self.joined:
106             raise DHTError, "have not joined a network yet"
107
108         d = defer.Deferred()
109         if key not in self.retrieving:
110             self.khashmir.valueForKey(key, self._getValue)
111         self.retrieving.setdefault(key, []).append(d)
112         return d
113         
114     def _getValue(self, key, result):
115         if result:
116             self.retrieved.setdefault(key, []).extend(result)
117         else:
118             final_result = []
119             if key in self.retrieved:
120                 final_result = self.retrieved[key]
121                 del self.retrieved[key]
122             for i in range(len(self.retrieving[key])):
123                 d = self.retrieving[key].pop(0)
124                 d.callback(final_result)
125             del self.retrieving[key]
126
127     def storeValue(self, key, value):
128         """See L{apt_dht.interfaces.IDHT}."""
129         if self.config is None:
130             raise DHTError, "configuration not loaded"
131         if not self.joined:
132             raise DHTError, "have not joined a network yet"
133
134         if key in self.storing and value in self.storing[key]:
135             raise DHTError, "already storing that key with the same value"
136
137         d = defer.Deferred()
138         self.khashmir.storeValueForKey(key, value, self._storeValue)
139         self.storing.setdefault(key, {})[value] = d
140         return d
141     
142     def _storeValue(self, key, value, result):
143         if key in self.storing and value in self.storing[key]:
144             if len(result) > 0:
145                 self.storing[key][value].callback(result)
146             else:
147                 self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
148             del self.storing[key][value]
149             if len(self.storing[key].keys()) == 0:
150                 del self.storing[key]
151
152 class TestSimpleDHT(unittest.TestCase):
153     """Unit tests for the DHT."""
154     
155     timeout = 2
156     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
157                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
158                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
159                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
160                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
161                     'KE_AGE': 3600, }
162
163     def setUp(self):
164         self.a = DHT()
165         self.b = DHT()
166         self.a.config = self.DHT_DEFAULTS.copy()
167         self.a.config['PORT'] = 4044
168         self.a.bootstrap = ["127.0.0.1:4044"]
169         self.a.bootstrap_node = True
170         self.a.cache_dir = '/tmp'
171         self.b.config = self.DHT_DEFAULTS.copy()
172         self.b.config['PORT'] = 4045
173         self.b.bootstrap = ["127.0.0.1:4044"]
174         self.b.cache_dir = '/tmp'
175         
176     def test_bootstrap_join(self):
177         d = self.a.join()
178         return d
179         
180     def node_join(self, result):
181         d = self.b.join()
182         return d
183     
184     def test_join(self):
185         self.lastDefer = defer.Deferred()
186         d = self.a.join()
187         d.addCallback(self.node_join)
188         d.addCallback(self.lastDefer.callback)
189         return self.lastDefer
190
191     def value_stored(self, result, value):
192         self.stored -= 1
193         if self.stored == 0:
194             self.get_values()
195         
196     def store_values(self, result):
197         self.stored = 3
198         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
199         d.addCallback(self.value_stored, 4045)
200         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
201         d.addCallback(self.value_stored, 4044)
202         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
203         d.addCallback(self.value_stored, 4045)
204
205     def check_values(self, result, values):
206         self.checked -= 1
207         self.failUnless(len(result) == len(values))
208         for v in result:
209             self.failUnless(v in values)
210         if self.checked == 0:
211             self.lastDefer.callback(1)
212     
213     def get_values(self):
214         self.checked = 4
215         d = self.a.getValue(sha.new('4044').digest())
216         d.addCallback(self.check_values, [str(4044*2)])
217         d = self.b.getValue(sha.new('4044').digest())
218         d.addCallback(self.check_values, [str(4044*2)])
219         d = self.a.getValue(sha.new('4045').digest())
220         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
221         d = self.b.getValue(sha.new('4045').digest())
222         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
223
224     def test_store(self):
225         from twisted.internet.base import DelayedCall
226         DelayedCall.debug = True
227         self.lastDefer = defer.Deferred()
228         d = self.a.join()
229         d.addCallback(self.node_join)
230         d.addCallback(self.store_values)
231         return self.lastDefer
232
233     def tearDown(self):
234         self.a.leave()
235         try:
236             os.unlink(self.a.khashmir.store.db)
237         except:
238             pass
239         self.b.leave()
240         try:
241             os.unlink(self.b.khashmir.store.db)
242         except:
243             pass
244
245 class TestMultiDHT(unittest.TestCase):
246     
247     timeout = 60
248     num = 20
249     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
250                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
251                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
252                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
253                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
254                     'KE_AGE': 3600, }
255
256     def setUp(self):
257         self.l = []
258         self.startport = 4081
259         for i in range(self.num):
260             self.l.append(DHT())
261             self.l[i].config = self.DHT_DEFAULTS.copy()
262             self.l[i].config['PORT'] = self.startport + i
263             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
264             self.l[i].cache_dir = '/tmp'
265         self.l[0].bootstrap_node = True
266         
267     def node_join(self, result, next_node):
268         d = self.l[next_node].join()
269         if next_node + 1 < len(self.l):
270             d.addCallback(self.node_join, next_node + 1)
271         else:
272             d.addCallback(self.lastDefer.callback)
273     
274     def test_join(self):
275         self.timeout = 2
276         self.lastDefer = defer.Deferred()
277         d = self.l[0].join()
278         d.addCallback(self.node_join, 1)
279         return self.lastDefer
280         
281     def store_values(self, result, i = 0, j = 0):
282         if j > i:
283             j -= i+1
284             i += 1
285         if i == len(self.l):
286             self.get_values()
287         else:
288             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
289             d.addCallback(self.store_values, i, j+1)
290     
291     def get_values(self, result = None, check = None, i = 0, j = 0):
292         if result is not None:
293             self.failUnless(len(result) == len(check))
294             for v in result:
295                 self.failUnless(v in check)
296         if j >= len(self.l):
297             j -= len(self.l)
298             i += 1
299         if i == len(self.l):
300             self.lastDefer.callback(1)
301         else:
302             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
303             check = []
304             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
305                 check.append(str(k))
306             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
307
308     def store_join(self, result, next_node):
309         d = self.l[next_node].join()
310         if next_node + 1 < len(self.l):
311             d.addCallback(self.store_join, next_node + 1)
312         else:
313             d.addCallback(self.store_values)
314     
315     def test_store(self):
316         from twisted.internet.base import DelayedCall
317         DelayedCall.debug = True
318         self.lastDefer = defer.Deferred()
319         d = self.l[0].join()
320         d.addCallback(self.store_join, 1)
321         return self.lastDefer
322
323     def tearDown(self):
324         for i in self.l:
325             try:
326                 i.leave()
327                 os.unlink(i.khashmir.store.db)
328             except:
329                 pass