]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht_Khashmir/DHT.py
Fixed some minor bugs.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 import os, sha, random
3
4 from twisted.internet import defer
5 from twisted.trial import unittest
6 from zope.interface import implements
7
8 from apt_dht.interfaces import IDHT
9 from khashmir import Khashmir
10
11 class DHTError(Exception):
12     """Represents errors that occur in the DHT."""
13
14 class DHT:
15     
16     implements(IDHT)
17     
18     def __init__(self):
19         self.config = None
20         self.cache_dir = ''
21         self.bootstrap = []
22         self.bootstrap_node = False
23         self.joining = None
24         self.joined = False
25         self.storing = {}
26         self.retrieving = {}
27         self.retrieved = {}
28     
29     def loadConfig(self, config, section):
30         """See L{apt_dht.interfaces.IDHT}."""
31         self.config_parser = config
32         self.section = section
33         self.config = {}
34         self.cache_dir = self.config_parser.get('DEFAULT', 'cache_dir')
35         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
36         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
37         for k in self.config_parser.options(section):
38             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
39                      'MAX_FAILURES', 'PORT']:
40                 self.config[k] = self.config_parser.getint(section, k)
41             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
42                        'BUCKET_STALENESS', 'KEINITIAL_DELAY', 'KE_DELAY', 'KE_AGE']:
43                 self.config[k] = self.config_parser.gettime(section, k)
44             else:
45                 self.config[k] = self.config_parser.get(section, k)
46         if 'PORT' not in self.config:
47             self.config['PORT'] = self.config_parser.getint('DEFAULT', 'PORT')
48     
49     def join(self):
50         """See L{apt_dht.interfaces.IDHT}."""
51         if self.config is None:
52             raise DHTError, "configuration not loaded"
53         if self.joining:
54             raise DHTError, "a join is already in progress"
55
56         self.khashmir = Khashmir(self.config, self.cache_dir)
57         
58         self.joining = defer.Deferred()
59         for node in self.bootstrap:
60             host, port = node.rsplit(':', 1)
61             port = int(port)
62             self.khashmir.addContact(host, port, self._join_single)
63         
64         return self.joining
65     
66     def _join_single(self):
67         """Called when a single bootstrap node has been added."""
68         self.khashmir.findCloseNodes(self._join_complete)
69     
70     def _join_complete(self, result):
71         """Called when the tables have been initialized with nodes."""
72         if not self.joined:
73             self.joined = True
74             if len(result) > 0 or self.bootstrap_node:
75                 df = self.joining
76                 self.joining = None
77                 df.callback(result)
78             else:
79                 df = self.joining
80                 self.joining = None
81                 df.errback(DHTError('could not find any nodes to bootstrap to'))
82         
83     def leave(self):
84         """See L{apt_dht.interfaces.IDHT}."""
85         if self.config is None:
86             raise DHTError, "configuration not loaded"
87         
88         if self.joined or self.joining:
89             if self.joining:
90                 self.joining.errback(DHTError('still joining when leave was called'))
91                 self.joining = None
92             self.joined = False
93             self.khashmir.shutdown()
94         
95     def getValue(self, key):
96         """See L{apt_dht.interfaces.IDHT}."""
97         if self.config is None:
98             raise DHTError, "configuration not loaded"
99         if not self.joined:
100             raise DHTError, "have not joined a network yet"
101
102         d = defer.Deferred()
103         if key not in self.retrieving:
104             self.khashmir.valueForKey(key, self._getValue)
105         self.retrieving.setdefault(key, []).append(d)
106         return d
107         
108     def _getValue(self, key, result):
109         if result:
110             self.retrieved.setdefault(key, []).extend(result)
111         else:
112             final_result = []
113             if key in self.retrieved:
114                 final_result = self.retrieved[key]
115                 del self.retrieved[key]
116             for i in range(len(self.retrieving[key])):
117                 d = self.retrieving[key].pop(0)
118                 d.callback(final_result)
119             del self.retrieving[key]
120
121     def storeValue(self, key, value):
122         """See L{apt_dht.interfaces.IDHT}."""
123         if self.config is None:
124             raise DHTError, "configuration not loaded"
125         if not self.joined:
126             raise DHTError, "have not joined a network yet"
127
128         if key in self.storing and value in self.storing[key]:
129             raise DHTError, "already storing that key with the same value"
130
131         d = defer.Deferred()
132         self.khashmir.storeValueForKey(key, value, self._storeValue)
133         self.storing.setdefault(key, {})[value] = d
134         return d
135     
136     def _storeValue(self, key, value, result):
137         if key in self.storing and value in self.storing[key]:
138             if len(result) > 0:
139                 self.storing[key][value].callback(result)
140             else:
141                 self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
142             del self.storing[key][value]
143             if len(self.storing[key].keys()) == 0:
144                 del self.storing[key]
145
146 class TestSimpleDHT(unittest.TestCase):
147     """Unit tests for the DHT."""
148     
149     timeout = 2
150     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
151                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
152                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
153                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
154                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
155                     'KE_AGE': 3600, }
156
157     def setUp(self):
158         self.a = DHT()
159         self.b = DHT()
160         self.a.config = self.DHT_DEFAULTS.copy()
161         self.a.config['PORT'] = 4044
162         self.a.bootstrap = ["127.0.0.1:4044"]
163         self.a.bootstrap_node = True
164         self.a.cache_dir = '/tmp'
165         self.b.config = self.DHT_DEFAULTS.copy()
166         self.b.config['PORT'] = 4045
167         self.b.bootstrap = ["127.0.0.1:4044"]
168         self.b.cache_dir = '/tmp'
169         
170     def test_bootstrap_join(self):
171         d = self.a.join()
172         return d
173         
174     def node_join(self, result):
175         d = self.b.join()
176         return d
177     
178     def test_join(self):
179         self.lastDefer = defer.Deferred()
180         d = self.a.join()
181         d.addCallback(self.node_join)
182         d.addCallback(self.lastDefer.callback)
183         return self.lastDefer
184
185     def value_stored(self, result, value):
186         self.stored -= 1
187         if self.stored == 0:
188             self.get_values()
189         
190     def store_values(self, result):
191         self.stored = 3
192         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
193         d.addCallback(self.value_stored, 4045)
194         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
195         d.addCallback(self.value_stored, 4044)
196         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
197         d.addCallback(self.value_stored, 4045)
198
199     def check_values(self, result, values):
200         self.checked -= 1
201         self.failUnless(len(result) == len(values))
202         for v in result:
203             self.failUnless(v in values)
204         if self.checked == 0:
205             self.lastDefer.callback(1)
206     
207     def get_values(self):
208         self.checked = 4
209         d = self.a.getValue(sha.new('4044').digest())
210         d.addCallback(self.check_values, [str(4044*2)])
211         d = self.b.getValue(sha.new('4044').digest())
212         d.addCallback(self.check_values, [str(4044*2)])
213         d = self.a.getValue(sha.new('4045').digest())
214         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
215         d = self.b.getValue(sha.new('4045').digest())
216         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
217
218     def test_store(self):
219         from twisted.internet.base import DelayedCall
220         DelayedCall.debug = True
221         self.lastDefer = defer.Deferred()
222         d = self.a.join()
223         d.addCallback(self.node_join)
224         d.addCallback(self.store_values)
225         return self.lastDefer
226
227     def tearDown(self):
228         self.a.leave()
229         try:
230             os.unlink(self.a.khashmir.db)
231         except:
232             pass
233         self.b.leave()
234         try:
235             os.unlink(self.b.khashmir.db)
236         except:
237             pass
238
239 class TestMultiDHT(unittest.TestCase):
240     
241     timeout = 60
242     num = 20
243     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
244                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
245                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
246                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
247                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
248                     'KE_AGE': 3600, }
249
250     def setUp(self):
251         self.l = []
252         self.startport = 4081
253         for i in range(self.num):
254             self.l.append(DHT())
255             self.l[i].config = self.DHT_DEFAULTS.copy()
256             self.l[i].config['PORT'] = self.startport + i
257             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
258             self.l[i].cache_dir = '/tmp'
259         self.l[0].bootstrap_node = True
260         
261     def node_join(self, result, next_node):
262         d = self.l[next_node].join()
263         if next_node + 1 < len(self.l):
264             d.addCallback(self.node_join, next_node + 1)
265         else:
266             d.addCallback(self.lastDefer.callback)
267     
268     def test_join(self):
269         self.timeout = 2
270         self.lastDefer = defer.Deferred()
271         d = self.l[0].join()
272         d.addCallback(self.node_join, 1)
273         return self.lastDefer
274         
275     def value_stored(self, result, value):
276         self.stored -= 1
277         if self.stored == 0:
278             self.get_values()
279         
280     def store_values(self, result):
281         self.stored = 0
282         for i in range(len(self.l)):
283             for j in range(0, i+1):
284                 self.stored += 1
285                 d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
286                 d.addCallback(self.value_stored, self.startport+i)
287     
288     def check_values(self, result, values):
289         self.checked -= 1
290         self.failUnless(len(result) == len(values))
291         for v in result:
292             self.failUnless(v in values)
293         if self.checked == 0:
294             self.lastDefer.callback(1)
295     
296     def get_values(self):
297         self.checked = 0
298         for i in range(len(self.l)):
299             for j in random.sample(xrange(len(self.l)), 4):
300                 self.checked += 1
301                 d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
302                 check = []
303                 for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
304                     check.append(str(k))
305                 d.addCallback(self.check_values, check)
306
307     def store_join(self, result, next_node):
308         d = self.l[next_node].join()
309         if next_node + 1 < len(self.l):
310             d.addCallback(self.store_join, next_node + 1)
311         else:
312             d.addCallback(self.store_values)
313     
314     def test_store(self):
315         from twisted.internet.base import DelayedCall
316         DelayedCall.debug = True
317         self.lastDefer = defer.Deferred()
318         d = self.l[0].join()
319         d.addCallback(self.store_join, 1)
320         return self.lastDefer
321
322     def tearDown(self):
323         for i in self.l:
324             try:
325                 i.leave()
326                 os.unlink(i.khashmir.db)
327             except:
328                 pass