Allow arbitrary strings to be stored in the DHT database.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
index 5c62aed9416f865ed551454c039d6a6d7b13539f..f1a63a8100edfc3887745f6774e2ff76b5c60af6 100644 (file)
@@ -13,9 +13,15 @@ class DBExcept(Exception):
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     
+class dht_value(str):
+    """Dummy class to convert all DHT values to base64 for storing in the DB."""
+    
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
+sqlite.register_adapter(dht_value, b2a_base64)
+sqlite.register_converter("DHT_VALUE", a2b_base64)
+sqlite.register_converter("dht_value", a2b_base64)
 
 class DB:
     """Database access for storing persistent data."""
@@ -44,7 +50,7 @@ class DB:
     def _createNewDB(self, db):
         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE kv (key KHASH, value TEXT, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
+        c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
         c.execute("CREATE INDEX kv_key ON kv(key)")
         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
@@ -99,7 +105,7 @@ class DB:
         """Store or update a key and value."""
         c = self.conn.cursor()
         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
-                  (khash(key), value, datetime.now()))
+                  (khash(key), dht_value(value), datetime.now()))
         self.conn.commit()
 
     def expireValues(self, expireAfter):
@@ -109,23 +115,6 @@ class DB:
         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
         self.conn.commit()
         
-    def refreshValues(self, expireAfter):
-        """Find older values than expireAfter seconds to refresh.
-        
-        @return: a list of the hash keys and a list of dictionaries with
-            key of the value, value is the origination time
-        """
-        t = datetime.now() - timedelta(seconds=expireAfter)
-        c = self.conn.cursor()
-        c.execute("SELECT key, value, FROM kv WHERE last_refresh < ?", (t,))
-        keys = []
-        vals = []
-        rows = c.fetchall()
-        for row in rows:
-            keys.append(row[0])
-            vals.append({row[1]: row[2]})
-        return keys, vals
-        
     def close(self):
         self.conn.close()
 
@@ -144,32 +133,19 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(self.store.getSelfNode(), self.key)
         
     def test_Value(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
+        self.store.storeValue(self.key, self.key)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'foobar')
+        self.failUnlessEqual(val[0], self.key)
         
     def test_expireValues(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
+        self.store.storeValue(self.key, self.key)
         sleep(2)
-        self.store.storeValue(self.key, 'barfoo', datetime.now())
+        self.store.storeValue(self.key, self.key+self.key)
         self.store.expireValues(1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'barfoo')
-        
-    def test_refreshValues(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
-        sleep(2)
-        self.store.storeValue(self.key, 'barfoo', datetime.now())
-        keys, vals = self.store.refreshValues(1)
-        self.failUnlessEqual(len(keys), 1)
-        self.failUnlessEqual(keys[0], self.key)
-        self.failUnlessEqual(len(vals), 1)
-        self.failUnlessEqual(len(vals[0].keys()), 1)
-        self.failUnlessEqual(vals[0].keys()[0], 'foobar')
-        val = self.store.retrieveValues(self.key)
-        self.failUnlessEqual(len(val), 2)
+        self.failUnlessEqual(val[0], self.key+self.key)
         
     def test_RoutingTable(self):
         class dummy: