]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
Made the get and storeValue DHT functions work.
authorCameron Dale <camrdale@gmail.com>
Thu, 3 Jan 2008 02:06:46 +0000 (18:06 -0800)
committerCameron Dale <camrdale@gmail.com>
Thu, 3 Jan 2008 02:06:46 +0000 (18:06 -0800)
Also added tests for them.
Had to modify the callback from khashmir's get and storeValue functions
to include the key, and the key and value, respectively so multiple
calls can be tracked.

apt_dht_Khashmir/DHT.py
apt_dht_Khashmir/actions.py
apt_dht_Khashmir/khashmir.py

index 80a322b33405d1c38612b0a39fc7e8ad245e1bd0..087201edb24d00c815f4498ae83c5ea584d609ec 100644 (file)
@@ -1,5 +1,5 @@
 
-import os
+import os, sha, random
 
 from twisted.internet import defer
 from twisted.trial import unittest
@@ -22,6 +22,9 @@ class DHT:
         self.bootstrap_node = False
         self.joining = None
         self.joined = False
+        self.storing = {}
+        self.retrieving = {}
+        self.retrieved = {}
     
     def loadConfig(self, config, section):
         """See L{apt_dht.interfaces.IDHT}."""
@@ -97,9 +100,24 @@ class DHT:
             raise DHTError, "have not joined a network yet"
 
         d = defer.Deferred()
-        self.khashmir.valueForKey(key, d.callback)
+        if key not in self.retrieving:
+            self.khashmir.valueForKey(key, self._getValue)
+        self.retrieving.setdefault(key, []).append(d)
         return d
         
+    def _getValue(self, key, result = -1):
+        if result:
+            self.retrieved.setdefault(key, []).extend(result)
+        else:
+            final_result = []
+            if key in self.retrieved:
+                final_result = self.retrieved[key]
+                del self.retrieved[key]
+            for i in range(len(self.retrieving[key])):
+                d = self.retrieving[key].pop(0)
+                d.callback(final_result)
+            del self.retrieving[key]
+
     def storeValue(self, key, value):
         """See L{apt_dht.interfaces.IDHT}."""
         if self.config is None:
@@ -107,7 +125,23 @@ class DHT:
         if not self.joined:
             raise DHTError, "have not joined a network yet"
 
-        self.khashmir.storeValueForKey(key, value)
+        if key in self.storing and value in self.storing[key]:
+            raise DHTError, "already storing that key with the same value"
+
+        d = defer.Deferred()
+        self.khashmir.storeValueForKey(key, value, self._storeValue)
+        self.storing.setdefault(key, {})[value] = d
+        return d
+    
+    def _storeValue(self, key, value, result):
+        if key in self.storing and value in self.storing[key]:
+            if len(result) > 0:
+                self.storing[key][value].callback(result)
+            else:
+                self.storing[key][value].errback(DHTError('could not store value %s in key %s' % (value, key)))
+            del self.storing[key][value]
+            if len(self.storing[key].keys()) == 0:
+                del self.storing[key]
 
 class TestSimpleDHT(unittest.TestCase):
     """Unit tests for the DHT."""
@@ -147,7 +181,49 @@ class TestSimpleDHT(unittest.TestCase):
         d.addCallback(self.node_join)
         d.addCallback(self.lastDefer.callback)
         return self.lastDefer
+
+    def value_stored(self, result, value):
+        self.stored -= 1
+        if self.stored == 0:
+            self.get_values()
         
+    def store_values(self, result):
+        self.stored = 3
+        d = self.a.storeValue(sha.new('4045').digest(), str(4045*3))
+        d.addCallback(self.value_stored, 4045)
+        d = self.a.storeValue(sha.new('4044').digest(), str(4044*2))
+        d.addCallback(self.value_stored, 4044)
+        d = self.b.storeValue(sha.new('4045').digest(), str(4045*2))
+        d.addCallback(self.value_stored, 4045)
+
+    def check_values(self, result, values):
+        self.checked -= 1
+        self.failUnless(len(result) == len(values))
+        for v in result:
+            self.failUnless(v in values)
+        if self.checked == 0:
+            self.lastDefer.callback(1)
+    
+    def get_values(self):
+        self.checked = 4
+        d = self.a.getValue(sha.new('4044').digest())
+        d.addCallback(self.check_values, [str(4044*2)])
+        d = self.b.getValue(sha.new('4044').digest())
+        d.addCallback(self.check_values, [str(4044*2)])
+        d = self.a.getValue(sha.new('4045').digest())
+        d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
+        d = self.b.getValue(sha.new('4045').digest())
+        d.addCallback(self.check_values, [str(4045*2), str(4045*3)])
+
+    def test_store(self):
+        from twisted.internet.base import DelayedCall
+        DelayedCall.debug = True
+        self.lastDefer = defer.Deferred()
+        d = self.a.join()
+        d.addCallback(self.node_join)
+        d.addCallback(self.store_values)
+        return self.lastDefer
+
     def tearDown(self):
         self.a.leave()
         try:
@@ -162,7 +238,7 @@ class TestSimpleDHT(unittest.TestCase):
 
 class TestMultiDHT(unittest.TestCase):
     
-    timeout = 10
+    timeout = 60
     num = 20
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
@@ -173,7 +249,7 @@ class TestMultiDHT(unittest.TestCase):
 
     def setUp(self):
         self.l = []
-        self.startport = 4088
+        self.startport = 4081
         for i in range(self.num):
             self.l.append(DHT())
             self.l[i].config = self.DHT_DEFAULTS.copy()
@@ -190,11 +266,59 @@ class TestMultiDHT(unittest.TestCase):
             d.addCallback(self.lastDefer.callback)
     
     def test_join(self):
+        self.timeout = 2
         self.lastDefer = defer.Deferred()
         d = self.l[0].join()
         d.addCallback(self.node_join, 1)
         return self.lastDefer
         
+    def value_stored(self, result, value):
+        self.stored -= 1
+        if self.stored == 0:
+            self.get_values()
+        
+    def store_values(self, result):
+        self.stored = 0
+        for i in range(len(self.l)):
+            for j in range(0, i+1):
+                self.stored += 1
+                d = self.l[j].storeValue(sha.new(str(self.startport+i)).digest(), str((self.startport+i)*(j+1)))
+                d.addCallback(self.value_stored, self.startport+i)
+    
+    def check_values(self, result, values):
+        self.checked -= 1
+        self.failUnless(len(result) == len(values))
+        for v in result:
+            self.failUnless(v in values)
+        if self.checked == 0:
+            self.lastDefer.callback(1)
+    
+    def get_values(self):
+        self.checked = 0
+        for i in range(len(self.l)):
+            for j in random.sample(xrange(len(self.l)), 4):
+                self.checked += 1
+                d = self.l[i].getValue(sha.new(str(self.startport+j)).digest())
+                check = []
+                for k in range(self.startport+j, (self.startport+j)*(j+1)+1, self.startport+j):
+                    check.append(str(k))
+                d.addCallback(self.check_values, check)
+
+    def store_join(self, result, next_node):
+        d = self.l[next_node].join()
+        if next_node + 1 < len(self.l):
+            d.addCallback(self.store_join, next_node + 1)
+        else:
+            d.addCallback(self.store_values)
+    
+    def test_store(self):
+        from twisted.internet.base import DelayedCall
+        DelayedCall.debug = True
+        self.lastDefer = defer.Deferred()
+        d = self.l[0].join()
+        d.addCallback(self.store_join, 1)
+        return self.lastDefer
+
     def tearDown(self):
         for i in self.l:
             try:
index 088bc0f2a6476b1a59643c681fbc3304f87d7940..7f8f91128b3e7360ce0cb6811a5a7f06f3b675c2 100644 (file)
@@ -149,7 +149,7 @@ class GetValue(FindNode):
             z = len(dict['values'])
             v = filter(None, map(x, dict['values']))
             if(len(v)):
-                reactor.callLater(0, self.callback, v)
+                reactor.callLater(0, self.callback, self.target, v)
         self.schedule()
         
     ## get value
@@ -178,7 +178,7 @@ class GetValue(FindNode):
         if self.outstanding == 0:
             ## all done, didn't find it!!
             self.finished=1
-            reactor.callLater(0, self.callback,[])
+            reactor.callLater(0, self.callback, self.target, [])
 
     ## get value
     def goWithNodes(self, nodes, found=None):
@@ -210,7 +210,7 @@ class StoreValue(ActionBase):
         self.stored.append(t)
         if len(self.stored) >= self.config['STORE_REDUNDANCY']:
             self.finished=1
-            self.callback(self.stored)
+            self.callback(self.target, self.value, self.stored)
         else:
             if not len(self.stored) + self.outstanding >= self.config['STORE_REDUNDANCY']:
                 self.schedule()
@@ -237,7 +237,7 @@ class StoreValue(ActionBase):
             except IndexError:
                 if self.outstanding == 0:
                     self.finished = 1
-                    self.callback(self.stored)
+                    self.callback(self.target, self.value, self.stored)
             else:
                 if not node.id == self.table.node.id:
                     self.outstanding += 1
index 1baffa17dc25072e47198f903de8e8a9fe26783d..d1843292904962dad43de250d313b8fc395ec030 100644 (file)
@@ -299,7 +299,7 @@ class KhashmirRead(KhashmirBase):
         if searchlocal:
             l = self.retrieveValues(key)
             if len(l) > 0:
-                reactor.callLater(0, callback, (l))
+                reactor.callLater(0, callback, key, l)
         else:
             l = []
         
@@ -336,7 +336,7 @@ class KhashmirWrite(KhashmirRead):
         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
             if not response:
                 # default callback
-                def _storedValueHandler(sender):
+                def _storedValueHandler(key, value, sender):
                     pass
                 response=_storedValueHandler
             action = StoreValue(self.table, key, value, response, self.config)
@@ -432,7 +432,7 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
         reactor.iterate()
 
-    def _cb(self, val):
+    def _cb(self, key, val):
         if not val:
             self.assertEqual(self.got, 1)
         elif 'foobar' in val:
@@ -496,14 +496,14 @@ class MultiTest(unittest.TestCase):
             
             for a in range(3):
                 self.done = 0
-                def _scb(val):
+                def _scb(key, value, result):
                     self.done = 1
                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
                 while not self.done:
                     reactor.iterate()
 
 
-                def _rcb(val):
+                def _rcb(key, val):
                     if not val:
                         self.done = 1
                         self.assertEqual(self.got, 1)