From 2a7aeb736af7b22f58a716a5891308de3785cc67 Mon Sep 17 00:00:00 2001 From: Cameron Dale Date: Wed, 2 Jan 2008 18:06:46 -0800 Subject: [PATCH] Made the get and storeValue DHT functions work. Also added tests for them. Had to modify the callback from khashmir's get and storeValue functions to include the key, and the key and value, respectively so multiple calls can be tracked. --- apt_dht_Khashmir/DHT.py | 134 +++++++++++++++++++++++++++++++++-- apt_dht_Khashmir/actions.py | 8 +-- apt_dht_Khashmir/khashmir.py | 10 +-- 3 files changed, 138 insertions(+), 14 deletions(-) diff --git a/apt_dht_Khashmir/DHT.py b/apt_dht_Khashmir/DHT.py index 80a322b..087201e 100644 --- a/apt_dht_Khashmir/DHT.py +++ b/apt_dht_Khashmir/DHT.py @@ -1,5 +1,5 @@ -import os +import os, sha, random from twisted.internet import defer from twisted.trial import unittest @@ -22,6 +22,9 @@ class DHT: self.bootstrap_node = False self.joining = None self.joined = False + self.storing = {} + self.retrieving = {} + self.retrieved = {} def loadConfig(self, config, section): """See L{apt_dht.interfaces.IDHT}.""" @@ -97,9 +100,24 @@ class DHT: raise DHTError, "have not joined a network yet" d = defer.Deferred() - self.khashmir.valueForKey(key, d.callback) + if key not in self.retrieving: + self.khashmir.valueForKey(key, self._getValue) + self.retrieving.setdefault(key, []).append(d) return d + def _getValue(self, key, result = -1): + if result: + self.retrieved.setdefault(key, []).extend(result) + else: + final_result = [] + if key in self.retrieved: + final_result = self.retrieved[key] + del self.retrieved[key] + for i in range(len(self.retrieving[key])): + d = self.retrieving[key].pop(0) + d.callback(final_result) + del self.retrieving[key] + def storeValue(self, key, value): """See L{apt_dht.interfaces.IDHT}.""" if self.config is None: @@ -107,7 +125,23 @@ class DHT: if not self.joined: raise DHTError, "have not joined a network yet" - self.khashmir.storeValueForKey(key, value) + if key in self.storing and value in self.storing[key]: + raise DHTError, "already storing that key with the same value" + + d = defer.Deferred() + self.khashmir.storeValueForKey(key, value, self._storeValue) + self.storing.setdefault(key, {})[value] = d + return d + + def _storeValue(self, key, value, result): + if key in self.storing and value in self.storing[key]: + if len(result) > 0: + self.storing[key][value].callback(result) + else: + self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key))) + del self.storing[key][value] + if len(self.storing[key].keys()) == 0: + del self.storing[key] class TestSimpleDHT(unittest.TestCase): """Unit tests for the DHT.""" @@ -147,7 +181,49 @@ class TestSimpleDHT(unittest.TestCase): d.addCallback(self.node_join) d.addCallback(self.lastDefer.callback) return self.lastDefer + + def value_stored(self, result, value): + self.stored -= 1 + if self.stored == 0: + self.get_values() + def store_values(self, result): + self.stored = 3 + d = self.a.storeValue(sha.new('4045').digest(), str(4045*3)) + d.addCallback(self.value_stored, 4045) + d = self.a.storeValue(sha.new('4044').digest(), str(4044*2)) + d.addCallback(self.value_stored, 4044) + d = self.b.storeValue(sha.new('4045').digest(), str(4045*2)) + d.addCallback(self.value_stored, 4045) + + def check_values(self, result, values): + self.checked -= 1 + self.failUnless(len(result) == len(values)) + for v in result: + self.failUnless(v in values) + if self.checked == 0: + self.lastDefer.callback(1) + + def get_values(self): + self.checked = 4 + d = self.a.getValue(sha.new('4044').digest()) + d.addCallback(self.check_values, [str(4044*2)]) + d = self.b.getValue(sha.new('4044').digest()) + d.addCallback(self.check_values, [str(4044*2)]) + d = self.a.getValue(sha.new('4045').digest()) + d.addCallback(self.check_values, [str(4045*2), str(4045*3)]) + d = self.b.getValue(sha.new('4045').digest()) + d.addCallback(self.check_values, [str(4045*2), str(4045*3)]) + + def test_store(self): + from twisted.internet.base import DelayedCall + DelayedCall.debug = True + self.lastDefer = defer.Deferred() + d = self.a.join() + d.addCallback(self.node_join) + d.addCallback(self.store_values) + return self.lastDefer + def tearDown(self): self.a.leave() try: @@ -162,7 +238,7 @@ class TestSimpleDHT(unittest.TestCase): class TestMultiDHT(unittest.TestCase): - timeout = 10 + timeout = 60 num = 20 DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160, 'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4, @@ -173,7 +249,7 @@ class TestMultiDHT(unittest.TestCase): def setUp(self): self.l = [] - self.startport = 4088 + self.startport = 4081 for i in range(self.num): self.l.append(DHT()) self.l[i].config = self.DHT_DEFAULTS.copy() @@ -190,11 +266,59 @@ class TestMultiDHT(unittest.TestCase): d.addCallback(self.lastDefer.callback) def test_join(self): + self.timeout = 2 self.lastDefer = defer.Deferred() d = self.l[0].join() d.addCallback(self.node_join, 1) return self.lastDefer + def value_stored(self, result, value): + self.stored -= 1 + if self.stored == 0: + self.get_values() + + def store_values(self, result): + self.stored = 0 + for i in range(len(self.l)): + for j in range(0, i+1): + self.stored += 1 + d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1))) + d.addCallback(self.value_stored, self.startport+i) + + def check_values(self, result, values): + self.checked -= 1 + self.failUnless(len(result) == len(values)) + for v in result: + self.failUnless(v in values) + if self.checked == 0: + self.lastDefer.callback(1) + + def get_values(self): + self.checked = 0 + for i in range(len(self.l)): + for j in random.sample(xrange(len(self.l)), 4): + self.checked += 1 + d = self.l[i].getValue(sha.new(str(self.startport+j)).digest()) + check = [] + for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j): + check.append(str(k)) + d.addCallback(self.check_values, check) + + def store_join(self, result, next_node): + d = self.l[next_node].join() + if next_node + 1 < len(self.l): + d.addCallback(self.store_join, next_node + 1) + else: + d.addCallback(self.store_values) + + def test_store(self): + from twisted.internet.base import DelayedCall + DelayedCall.debug = True + self.lastDefer = defer.Deferred() + d = self.l[0].join() + d.addCallback(self.store_join, 1) + return self.lastDefer + def tearDown(self): for i in self.l: try: diff --git a/apt_dht_Khashmir/actions.py b/apt_dht_Khashmir/actions.py index 088bc0f..7f8f911 100644 --- a/apt_dht_Khashmir/actions.py +++ b/apt_dht_Khashmir/actions.py @@ -149,7 +149,7 @@ class GetValue(FindNode): z = len(dict['values']) v = filter(None, map(x, dict['values'])) if(len(v)): - reactor.callLater(0, self.callback, v) + reactor.callLater(0, self.callback, self.target, v) self.schedule() ## get value @@ -178,7 +178,7 @@ class GetValue(FindNode): if self.outstanding == 0: ## all done, didn't find it!! self.finished=1 - reactor.callLater(0, self.callback,[]) + reactor.callLater(0, self.callback, self.target, []) ## get value def goWithNodes(self, nodes, found=None): @@ -210,7 +210,7 @@ class StoreValue(ActionBase): self.stored.append(t) if len(self.stored) >= self.config['STORE_REDUNDANCY']: self.finished=1 - self.callback(self.stored) + self.callback(self.target, self.value, self.stored) else: if not len(self.stored) + self.outstanding >= self.config['STORE_REDUNDANCY']: self.schedule() @@ -237,7 +237,7 @@ class StoreValue(ActionBase): except IndexError: if self.outstanding == 0: self.finished = 1 - self.callback(self.stored) + self.callback(self.target, self.value, self.stored) else: if not node.id == self.table.node.id: self.outstanding += 1 diff --git a/apt_dht_Khashmir/khashmir.py b/apt_dht_Khashmir/khashmir.py index 1baffa1..d184329 100644 --- a/apt_dht_Khashmir/khashmir.py +++ b/apt_dht_Khashmir/khashmir.py @@ -299,7 +299,7 @@ class KhashmirRead(KhashmirBase): if searchlocal: l = self.retrieveValues(key) if len(l) > 0: - reactor.callLater(0, callback, (l)) + reactor.callLater(0, callback, key, l) else: l = [] @@ -336,7 +336,7 @@ class KhashmirWrite(KhashmirRead): def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table): if not response: # default callback - def _storedValueHandler(sender): + def _storedValueHandler(key, value, sender): pass response=_storedValueHandler action = StoreValue(self.table, key, value, response, self.config) @@ -432,7 +432,7 @@ class SimpleTests(unittest.TestCase): reactor.iterate() reactor.iterate() - def _cb(self, val): + def _cb(self, key, val): if not val: self.assertEqual(self.got, 1) elif 'foobar' in val: @@ -496,14 +496,14 @@ class MultiTest(unittest.TestCase): for a in range(3): self.done = 0 - def _scb(val): + def _scb(key, value, result): self.done = 1 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb) while not self.done: reactor.iterate() - def _rcb(val): + def _rcb(key, val): if not val: self.done = 1 self.assertEqual(self.got, 1) -- 2.39.5