From: Cameron Dale <camrdale@gmail.com>
Date: Thu, 3 Jan 2008 02:06:46 +0000 (-0800)
Subject: Made the get and storeValue DHT functions work.
X-Git-Url: https://git.mxchange.org/?a=commitdiff_plain;h=2a7aeb736af7b22f58a716a5891308de3785cc67;p=quix0rs-apt-p2p.git

Made the get and storeValue DHT functions work.

Also added tests for them.
Had to modify the callback from khashmir's get and storeValue functions
to include the key, and the key and value, respectively so multiple
calls can be tracked.
---

diff --git a/apt_dht_Khashmir/DHT.py b/apt_dht_Khashmir/DHT.py
index 80a322b..087201e 100644
--- a/apt_dht_Khashmir/DHT.py
+++ b/apt_dht_Khashmir/DHT.py
@@ -1,5 +1,5 @@
 
-import os
+import os, sha, random
 
 from twisted.internet import defer
 from twisted.trial import unittest
@@ -22,6 +22,9 @@ class DHT:
         self.bootstrap_node = False
         self.joining = None
         self.joined = False
+        self.storing = {}
+        self.retrieving = {}
+        self.retrieved = {}
     
     def loadConfig(self, config, section):
         """See L{apt_dht.interfaces.IDHT}."""
@@ -97,9 +100,24 @@ class DHT:
             raise DHTError, "have not joined a network yet"
 
         d = defer.Deferred()
-        self.khashmir.valueForKey(key, d.callback)
+        if key not in self.retrieving:
+            self.khashmir.valueForKey(key, self._getValue)
+        self.retrieving.setdefault(key, []).append(d)
         return d
         
+    def _getValue(self, key, result = -1):
+        if result:
+            self.retrieved.setdefault(key, []).extend(result)
+        else:
+            final_result = []
+            if key in self.retrieved:
+                final_result = self.retrieved[key]
+                del self.retrieved[key]
+            for i in range(len(self.retrieving[key])):
+                d = self.retrieving[key].pop(0)
+                d.callback(final_result)
+            del self.retrieving[key]
+
     def storeValue(self, key, value):
         """See L{apt_dht.interfaces.IDHT}."""
         if self.config is None:
@@ -107,7 +125,23 @@ class DHT:
         if not self.joined:
             raise DHTError, "have not joined a network yet"
 
-        self.khashmir.storeValueForKey(key, value)
+        if key in self.storing and value in self.storing[key]:
+            raise DHTError, "already storing that key with the same value"
+
+        d = defer.Deferred()
+        self.khashmir.storeValueForKey(key, value, self._storeValue)
+        self.storing.setdefault(key, {})[value] = d
+        return d
+    
+    def _storeValue(self, key, value, result):
+        if key in self.storing and value in self.storing[key]:
+            if len(result) > 0:
+                self.storing[key][value].callback(result)
+            else:
+                self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
+            del self.storing[key][value]
+            if len(self.storing[key].keys()) == 0:
+                del self.storing[key]
 
 class TestSimpleDHT(unittest.TestCase):
     """Unit tests for the DHT."""
@@ -147,7 +181,49 @@ class TestSimpleDHT(unittest.TestCase):
         d.addCallback(self.node_join)
         d.addCallback(self.lastDefer.callback)
         return self.lastDefer
+
+    def value_stored(self, result, value):
+        self.stored -= 1
+        if self.stored == 0:
+            self.get_values()
         
+    def store_values(self, result):
+        self.stored = 3
+        d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
+        d.addCallback(self.value_stored, 4045)
+        d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
+        d.addCallback(self.value_stored, 4044)
+        d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
+        d.addCallback(self.value_stored, 4045)
+
+    def check_values(self, result, values):
+        self.checked -= 1
+        self.failUnless(len(result) == len(values))
+        for v in result:
+            self.failUnless(v in values)
+        if self.checked == 0:
+            self.lastDefer.callback(1)
+    
+    def get_values(self):
+        self.checked = 4
+        d = self.a.getValue(sha.new('4044').digest())
+        d.addCallback(self.check_values, [str(4044*2)])
+        d = self.b.getValue(sha.new('4044').digest())
+        d.addCallback(self.check_values, [str(4044*2)])
+        d = self.a.getValue(sha.new('4045').digest())
+        d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
+        d = self.b.getValue(sha.new('4045').digest())
+        d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
+
+    def test_store(self):
+        from twisted.internet.base import DelayedCall
+        DelayedCall.debug = True
+        self.lastDefer = defer.Deferred()
+        d = self.a.join()
+        d.addCallback(self.node_join)
+        d.addCallback(self.store_values)
+        return self.lastDefer
+
     def tearDown(self):
         self.a.leave()
         try:
@@ -162,7 +238,7 @@ class TestSimpleDHT(unittest.TestCase):
 
 class TestMultiDHT(unittest.TestCase):
     
-    timeout = 10
+    timeout = 60
     num = 20
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
@@ -173,7 +249,7 @@ class TestMultiDHT(unittest.TestCase):
 
     def setUp(self):
         self.l = []
-        self.startport = 4088
+        self.startport = 4081
         for i in range(self.num):
             self.l.append(DHT())
             self.l[i].config = self.DHT_DEFAULTS.copy()
@@ -190,11 +266,59 @@ class TestMultiDHT(unittest.TestCase):
             d.addCallback(self.lastDefer.callback)
     
     def test_join(self):
+        self.timeout = 2
         self.lastDefer = defer.Deferred()
         d = self.l[0].join()
         d.addCallback(self.node_join, 1)
         return self.lastDefer
         
+    def value_stored(self, result, value):
+        self.stored -= 1
+        if self.stored == 0:
+            self.get_values()
+        
+    def store_values(self, result):
+        self.stored = 0
+        for i in range(len(self.l)):
+            for j in range(0, i+1):
+                self.stored += 1
+                d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
+                d.addCallback(self.value_stored, self.startport+i)
+    
+    def check_values(self, result, values):
+        self.checked -= 1
+        self.failUnless(len(result) == len(values))
+        for v in result:
+            self.failUnless(v in values)
+        if self.checked == 0:
+            self.lastDefer.callback(1)
+    
+    def get_values(self):
+        self.checked = 0
+        for i in range(len(self.l)):
+            for j in random.sample(xrange(len(self.l)), 4):
+                self.checked += 1
+                d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
+                check = []
+                for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
+                    check.append(str(k))
+                d.addCallback(self.check_values, check)
+
+    def store_join(self, result, next_node):
+        d = self.l[next_node].join()
+        if next_node + 1 < len(self.l):
+            d.addCallback(self.store_join, next_node + 1)
+        else:
+            d.addCallback(self.store_values)
+    
+    def test_store(self):
+        from twisted.internet.base import DelayedCall
+        DelayedCall.debug = True
+        self.lastDefer = defer.Deferred()
+        d = self.l[0].join()
+        d.addCallback(self.store_join, 1)
+        return self.lastDefer
+
     def tearDown(self):
         for i in self.l:
             try:
diff --git a/apt_dht_Khashmir/actions.py b/apt_dht_Khashmir/actions.py
index 088bc0f..7f8f911 100644
--- a/apt_dht_Khashmir/actions.py
+++ b/apt_dht_Khashmir/actions.py
@@ -149,7 +149,7 @@ class GetValue(FindNode):
             z = len(dict['values'])
             v = filter(None, map(x, dict['values']))
             if(len(v)):
-                reactor.callLater(0, self.callback, v)
+                reactor.callLater(0, self.callback, self.target, v)
         self.schedule()
         
     ## get value
@@ -178,7 +178,7 @@ class GetValue(FindNode):
         if self.outstanding == 0:
             ## all done, didn't find it!!
             self.finished=1
-            reactor.callLater(0, self.callback,[])
+            reactor.callLater(0, self.callback, self.target, [])
 
     ## get value
     def goWithNodes(self, nodes, found=None):
@@ -210,7 +210,7 @@ class StoreValue(ActionBase):
         self.stored.append(t)
         if len(self.stored) >= self.config['STORE_REDUNDANCY']:
             self.finished=1
-            self.callback(self.stored)
+            self.callback(self.target, self.value, self.stored)
         else:
             if not len(self.stored) + self.outstanding >= self.config['STORE_REDUNDANCY']:
                 self.schedule()
@@ -237,7 +237,7 @@ class StoreValue(ActionBase):
             except IndexError:
                 if self.outstanding == 0:
                     self.finished = 1
-                    self.callback(self.stored)
+                    self.callback(self.target, self.value, self.stored)
             else:
                 if not node.id == self.table.node.id:
                     self.outstanding += 1
diff --git a/apt_dht_Khashmir/khashmir.py b/apt_dht_Khashmir/khashmir.py
index 1baffa1..d184329 100644
--- a/apt_dht_Khashmir/khashmir.py
+++ b/apt_dht_Khashmir/khashmir.py
@@ -299,7 +299,7 @@ class KhashmirRead(KhashmirBase):
         if searchlocal:
             l = self.retrieveValues(key)
             if len(l) > 0:
-                reactor.callLater(0, callback, (l))
+                reactor.callLater(0, callback, key, l)
         else:
             l = []
         
@@ -336,7 +336,7 @@ class KhashmirWrite(KhashmirRead):
         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
             if not response:
                 # default callback
-                def _storedValueHandler(sender):
+                def _storedValueHandler(key, value, sender):
                     pass
                 response=_storedValueHandler
             action = StoreValue(self.table, key, value, response, self.config)
@@ -432,7 +432,7 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
         reactor.iterate()
 
-    def _cb(self, val):
+    def _cb(self, key, val):
         if not val:
             self.assertEqual(self.got, 1)
         elif 'foobar' in val:
@@ -496,14 +496,14 @@ class MultiTest(unittest.TestCase):
             
             for a in range(3):
                 self.done = 0
-                def _scb(val):
+                def _scb(key, value, result):
                     self.done = 1
                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
                 while not self.done:
                     reactor.iterate()
 
 
-                def _rcb(val):
+                def _rcb(key, val):
                     if not val:
                         self.done = 1
                         self.assertEqual(self.got, 1)