]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/DHT.py
Move the normalization of key lengths from the HashObject to the DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / DHT.py
index 48ae1f4fe5843e06050d358dab1181623f0139cd..5f08dae07498962759264b4dfb1c07169b5da059 100644 (file)
@@ -125,12 +125,27 @@ class DHT:
             self.joined = False
             self.khashmir.shutdown()
         
+    def _normKey(self, key, bits=None, bytes=None):
+        bits = self.config["HASH_LENGTH"]
+        if bits is not None:
+            bytes = (bits - 1) // 8 + 1
+        else:
+            if bytes is None:
+                raise DHTError, "you must specify one of bits or bytes for normalization"
+        if len(key) < bytes:
+            key = key + '\000'*(bytes - len(key))
+        elif len(key) > bytes:
+            key = key[:bytes]
+        return key
+
     def getValue(self, key):
         """See L{apt_dht.interfaces.IDHT}."""
         if self.config is None:
             raise DHTError, "configuration not loaded"
         if not self.joined:
             raise DHTError, "have not joined a network yet"
+        
+        key = self._normKey(key)
 
         d = defer.Deferred()
         if key not in self.retrieving:
@@ -158,6 +173,7 @@ class DHT:
         if not self.joined:
             raise DHTError, "have not joined a network yet"
 
+        key = self._normKey(key)
         bvalue = bencode(value)
 
         if key in self.storing and bvalue in self.storing[key]:
@@ -217,6 +233,18 @@ class TestSimpleDHT(unittest.TestCase):
         d.addCallback(self.lastDefer.callback)
         return self.lastDefer
 
+    def test_normKey(self):
+        h = self.a._normKey('12345678901234567890')
+        self.failUnless(h == '12345678901234567890')
+        h = self.a._normKey('12345678901234567')
+        self.failUnless(h == '12345678901234567\000\000\000')
+        h = self.a._normKey('1234567890123456789012345')
+        self.failUnless(h == '12345678901234567890')
+        h = self.a._normKey('1234567890123456789')
+        self.failUnless(h == '1234567890123456789\000')
+        h = self.a._normKey('123456789012345678901')
+        self.failUnless(h == '12345678901234567890')
+
     def value_stored(self, result, value):
         self.stored -= 1
         if self.stored == 0: