]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht_Khashmir/DHT.py
New DHT method 'join' like 'ping' but returns our IP and port.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
1
2 import os, sha, random
3
4 from twisted.internet import defer, reactor
5 from twisted.internet.abstract import isIPAddress
6 from twisted.python import log
7 from twisted.trial import unittest
8 from zope.interface import implements
9
10 from apt_dht.interfaces import IDHT
11 from khashmir import Khashmir
12
13 class DHTError(Exception):
14     """Represents errors that occur in the DHT."""
15
16 class DHT:
17     
18     implements(IDHT)
19     
20     def __init__(self):
21         self.config = None
22         self.cache_dir = ''
23         self.bootstrap = []
24         self.bootstrap_node = False
25         self.joining = None
26         self.joined = False
27         self.foundAddrs = []
28         self.storing = {}
29         self.retrieving = {}
30         self.retrieved = {}
31     
32     def loadConfig(self, config, section):
33         """See L{apt_dht.interfaces.IDHT}."""
34         self.config_parser = config
35         self.section = section
36         self.config = {}
37         self.cache_dir = self.config_parser.get(section, 'cache_dir')
38         self.bootstrap = self.config_parser.getstringlist(section, 'BOOTSTRAP')
39         self.bootstrap_node = self.config_parser.getboolean(section, 'BOOTSTRAP_NODE')
40         for k in self.config_parser.options(section):
41             if k in ['K', 'HASH_LENGTH', 'CONCURRENT_REQS', 'STORE_REDUNDANCY', 
42                      'MAX_FAILURES', 'PORT']:
43                 self.config[k] = self.config_parser.getint(section, k)
44             elif k in ['CHECKPOINT_INTERVAL', 'MIN_PING_INTERVAL', 
45                        'BUCKET_STALENESS', 'KEINITIAL_DELAY', 'KE_DELAY', 'KE_AGE']:
46                 self.config[k] = self.config_parser.gettime(section, k)
47             elif k in ['SPEW']:
48                 self.config[k] = self.config_parser.getboolean(section, k)
49             else:
50                 self.config[k] = self.config_parser.get(section, k)
51     
52     def join(self):
53         """See L{apt_dht.interfaces.IDHT}."""
54         if self.config is None:
55             raise DHTError, "configuration not loaded"
56         if self.joining:
57             raise DHTError, "a join is already in progress"
58
59         self.khashmir = Khashmir(self.config, self.cache_dir)
60         
61         self.joining = defer.Deferred()
62         for node in self.bootstrap:
63             host, port = node.rsplit(':', 1)
64             port = int(port)
65             if isIPAddress(host):
66                 self._join_gotIP(host, port)
67             else:
68                 reactor.resolve(host).addCallback(self._join_gotIP, port)
69         
70         return self.joining
71
72     def _join_gotIP(self, ip, port):
73         """Called after an IP address has been found for a single bootstrap node."""
74         self.khashmir.addContact(ip, port, self._join_single)
75     
76     def _join_single(self, addr):
77         """Called when a single bootstrap node has been added."""
78         if addr:
79             self.foundAddrs.append(addr)
80         log.msg('Got back from bootstrap node: %r' % (addr,))
81         self.khashmir.findCloseNodes(self._join_complete)
82     
83     def _join_complete(self, result):
84         """Called when the tables have been initialized with nodes."""
85         if not self.joined:
86             self.joined = True
87             df = self.joining
88             self.joining = None
89             if len(result) > 0 or self.bootstrap_node:
90                 df.callback(result)
91             else:
92                 df.errback(DHTError('could not find any nodes to bootstrap to'))
93         
94     def getAddrs(self):
95         return self.foundAddrs
96         
97     def leave(self):
98         """See L{apt_dht.interfaces.IDHT}."""
99         if self.config is None:
100             raise DHTError, "configuration not loaded"
101         
102         if self.joined or self.joining:
103             if self.joining:
104                 self.joining.errback(DHTError('still joining when leave was called'))
105                 self.joining = None
106             self.joined = False
107             self.khashmir.shutdown()
108         
109     def getValue(self, key):
110         """See L{apt_dht.interfaces.IDHT}."""
111         if self.config is None:
112             raise DHTError, "configuration not loaded"
113         if not self.joined:
114             raise DHTError, "have not joined a network yet"
115
116         d = defer.Deferred()
117         if key not in self.retrieving:
118             self.khashmir.valueForKey(key, self._getValue)
119         self.retrieving.setdefault(key, []).append(d)
120         return d
121         
122     def _getValue(self, key, result):
123         if result:
124             self.retrieved.setdefault(key, []).extend(result)
125         else:
126             final_result = []
127             if key in self.retrieved:
128                 final_result = self.retrieved[key]
129                 del self.retrieved[key]
130             for i in range(len(self.retrieving[key])):
131                 d = self.retrieving[key].pop(0)
132                 d.callback(final_result)
133             del self.retrieving[key]
134
135     def storeValue(self, key, value):
136         """See L{apt_dht.interfaces.IDHT}."""
137         if self.config is None:
138             raise DHTError, "configuration not loaded"
139         if not self.joined:
140             raise DHTError, "have not joined a network yet"
141
142         if key in self.storing and value in self.storing[key]:
143             raise DHTError, "already storing that key with the same value"
144
145         d = defer.Deferred()
146         self.khashmir.storeValueForKey(key, value, self._storeValue)
147         self.storing.setdefault(key, {})[value] = d
148         return d
149     
150     def _storeValue(self, key, value, result):
151         if key in self.storing and value in self.storing[key]:
152             if len(result) > 0:
153                 self.storing[key][value].callback(result)
154             else:
155                 self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
156             del self.storing[key][value]
157             if len(self.storing[key].keys()) == 0:
158                 del self.storing[key]
159
160 class TestSimpleDHT(unittest.TestCase):
161     """Unit tests for the DHT."""
162     
163     timeout = 2
164     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
165                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
166                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
167                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
168                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
169                     'KE_AGE': 3600, 'SPEW': False, }
170
171     def setUp(self):
172         self.a = DHT()
173         self.b = DHT()
174         self.a.config = self.DHT_DEFAULTS.copy()
175         self.a.config['PORT'] = 4044
176         self.a.bootstrap = ["127.0.0.1:4044"]
177         self.a.bootstrap_node = True
178         self.a.cache_dir = '/tmp'
179         self.b.config = self.DHT_DEFAULTS.copy()
180         self.b.config['PORT'] = 4045
181         self.b.bootstrap = ["127.0.0.1:4044"]
182         self.b.cache_dir = '/tmp'
183         
184     def test_bootstrap_join(self):
185         d = self.a.join()
186         return d
187         
188     def node_join(self, result):
189         d = self.b.join()
190         return d
191     
192     def test_join(self):
193         self.lastDefer = defer.Deferred()
194         d = self.a.join()
195         d.addCallback(self.node_join)
196         d.addCallback(self.lastDefer.callback)
197         return self.lastDefer
198
199     def value_stored(self, result, value):
200         self.stored -= 1
201         if self.stored == 0:
202             self.get_values()
203         
204     def store_values(self, result):
205         self.stored = 3
206         d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
207         d.addCallback(self.value_stored, 4045)
208         d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
209         d.addCallback(self.value_stored, 4044)
210         d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
211         d.addCallback(self.value_stored, 4045)
212
213     def check_values(self, result, values):
214         self.checked -= 1
215         self.failUnless(len(result) == len(values))
216         for v in result:
217             self.failUnless(v in values)
218         if self.checked == 0:
219             self.lastDefer.callback(1)
220     
221     def get_values(self):
222         self.checked = 4
223         d = self.a.getValue(sha.new('4044').digest())
224         d.addCallback(self.check_values, [str(4044*2)])
225         d = self.b.getValue(sha.new('4044').digest())
226         d.addCallback(self.check_values, [str(4044*2)])
227         d = self.a.getValue(sha.new('4045').digest())
228         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
229         d = self.b.getValue(sha.new('4045').digest())
230         d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
231
232     def test_store(self):
233         from twisted.internet.base import DelayedCall
234         DelayedCall.debug = True
235         self.lastDefer = defer.Deferred()
236         d = self.a.join()
237         d.addCallback(self.node_join)
238         d.addCallback(self.store_values)
239         return self.lastDefer
240
241     def tearDown(self):
242         self.a.leave()
243         try:
244             os.unlink(self.a.khashmir.store.db)
245         except:
246             pass
247         self.b.leave()
248         try:
249             os.unlink(self.b.khashmir.store.db)
250         except:
251             pass
252
253 class TestMultiDHT(unittest.TestCase):
254     
255     timeout = 60
256     num = 20
257     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
258                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
259                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
260                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
261                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
262                     'KE_AGE': 3600, 'SPEW': False, }
263
264     def setUp(self):
265         self.l = []
266         self.startport = 4081
267         for i in range(self.num):
268             self.l.append(DHT())
269             self.l[i].config = self.DHT_DEFAULTS.copy()
270             self.l[i].config['PORT'] = self.startport + i
271             self.l[i].bootstrap = ["127.0.0.1:" + str(self.startport)]
272             self.l[i].cache_dir = '/tmp'
273         self.l[0].bootstrap_node = True
274         
275     def node_join(self, result, next_node):
276         d = self.l[next_node].join()
277         if next_node + 1 < len(self.l):
278             d.addCallback(self.node_join, next_node + 1)
279         else:
280             d.addCallback(self.lastDefer.callback)
281     
282     def test_join(self):
283         self.timeout = 2
284         self.lastDefer = defer.Deferred()
285         d = self.l[0].join()
286         d.addCallback(self.node_join, 1)
287         return self.lastDefer
288         
289     def store_values(self, result, i = 0, j = 0):
290         if j > i:
291             j -= i+1
292             i += 1
293         if i == len(self.l):
294             self.get_values()
295         else:
296             d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
297             d.addCallback(self.store_values, i, j+1)
298     
299     def get_values(self, result = None, check = None, i = 0, j = 0):
300         if result is not None:
301             self.failUnless(len(result) == len(check))
302             for v in result:
303                 self.failUnless(v in check)
304         if j >= len(self.l):
305             j -= len(self.l)
306             i += 1
307         if i == len(self.l):
308             self.lastDefer.callback(1)
309         else:
310             d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
311             check = []
312             for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
313                 check.append(str(k))
314             d.addCallback(self.get_values, check, i, j + random.randrange(1, min(len(self.l), 10)))
315
316     def store_join(self, result, next_node):
317         d = self.l[next_node].join()
318         if next_node + 1 < len(self.l):
319             d.addCallback(self.store_join, next_node + 1)
320         else:
321             d.addCallback(self.store_values)
322     
323     def test_store(self):
324         from twisted.internet.base import DelayedCall
325         DelayedCall.debug = True
326         self.lastDefer = defer.Deferred()
327         d = self.l[0].join()
328         d.addCallback(self.store_join, 1)
329         return self.lastDefer
330
331     def tearDown(self):
332         for i in self.l:
333             try:
334                 i.leave()
335                 os.unlink(i.khashmir.store.db)
336             except:
337                 pass