]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/db.py
Standardize the number of values retrieved from the DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
index 8bb499e14512e85b0218941b22485ba0dac63474..7d40176a4ffb012fe0b64ec5ae663bda6350c3e7 100644 (file)
@@ -13,9 +13,15 @@ class DBExcept(Exception):
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     
+class dht_value(str):
+    """Dummy class to convert all DHT values to base64 for storing in the DB."""
+    
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
+sqlite.register_adapter(dht_value, b2a_base64)
+sqlite.register_converter("DHT_VALUE", a2b_base64)
+sqlite.register_converter("dht_value", a2b_base64)
 
 class DB:
     """Database access for storing persistent data."""
@@ -44,9 +50,9 @@ class DB:
     def _createNewDB(self, db):
         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE kv (key KHASH, value TEXT, time TIMESTAMP, PRIMARY KEY (key, value))")
+        c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
         c.execute("CREATE INDEX kv_key ON kv(key)")
-        c.execute("CREATE INDEX kv_timestamp ON kv(time)")
+        c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
         self.conn.commit()
@@ -86,26 +92,37 @@ class DB:
         return c.fetchall()
             
     def retrieveValues(self, key):
+        """Retrieve values from the database."""
         c = self.conn.cursor()
         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
-        t = c.fetchone()
         l = []
-        while t:
-            l.append(t[0])
-            t = c.fetchone()
+        rows = c.fetchall()
+        for row in rows:
+            l.append(row[0])
         return l
 
+    def countValues(self, key):
+        """Count the number of values in the database."""
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
+        res = 0
+        row = c.fetchone()
+        if row:
+            res = row[0]
+        return res
+
     def storeValue(self, key, value):
         """Store or update a key and value."""
         c = self.conn.cursor()
-        c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", (khash(key), value, datetime.now()))
+        c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
+                  (khash(key), dht_value(value), datetime.now()))
         self.conn.commit()
 
     def expireValues(self, expireAfter):
         """Expire older values after expireAfter seconds."""
         t = datetime.now() - timedelta(seconds=expireAfter)
         c = self.conn.cursor()
-        c.execute("DELETE FROM kv WHERE time < ?", (t, ))
+        c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
         self.conn.commit()
         
     def close(self):
@@ -126,19 +143,19 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(self.store.getSelfNode(), self.key)
         
     def test_Value(self):
-        self.store.storeValue(self.key, 'foobar')
+        self.store.storeValue(self.key, self.key)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'foobar')
+        self.failUnlessEqual(val[0], self.key)
         
     def test_expireValues(self):
-        self.store.storeValue(self.key, 'foobar')
+        self.store.storeValue(self.key, self.key)
         sleep(2)
-        self.store.storeValue(self.key, 'barfoo')
+        self.store.storeValue(self.key, self.key+self.key)
         self.store.expireValues(1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'barfoo')
+        self.failUnlessEqual(val[0], self.key+self.key)
         
     def test_RoutingTable(self):
         class dummy: