]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/db.py
Break up the find_value into 2 parts (with get_value).
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
index bea40dbef0db15babe95000af281d1944e98614d..7d40176a4ffb012fe0b64ec5ae663bda6350c3e7 100644 (file)
@@ -13,9 +13,15 @@ class DBExcept(Exception):
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     
+class dht_value(str):
+    """Dummy class to convert all DHT values to base64 for storing in the DB."""
+    
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
+sqlite.register_adapter(dht_value, b2a_base64)
+sqlite.register_converter("DHT_VALUE", a2b_base64)
+sqlite.register_converter("dht_value", a2b_base64)
 
 class DB:
     """Database access for storing persistent data."""
 
 class DB:
     """Database access for storing persistent data."""
@@ -28,7 +34,11 @@ class DB:
             self._createNewDB(db)
         else:
             self._loadDB(db)
             self._createNewDB(db)
         else:
             self._loadDB(db)
-        self.conn.text_factory = str
+        if sqlite.version_info < (2, 1):
+            sqlite.register_converter("TEXT", str)
+            sqlite.register_converter("text", str)
+        else:
+            self.conn.text_factory = str
         
     def _loadDB(self, db):
         try:
         
     def _loadDB(self, db):
         try:
@@ -40,9 +50,9 @@ class DB:
     def _createNewDB(self, db):
         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
     def _createNewDB(self, db):
         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE kv (key KHASH, value TEXT, time TIMESTAMP, PRIMARY KEY (key, value))")
+        c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
         c.execute("CREATE INDEX kv_key ON kv(key)")
         c.execute("CREATE INDEX kv_key ON kv(key)")
-        c.execute("CREATE INDEX kv_timestamp ON kv(time)")
+        c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
         self.conn.commit()
         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
         self.conn.commit()
@@ -82,26 +92,37 @@ class DB:
         return c.fetchall()
             
     def retrieveValues(self, key):
         return c.fetchall()
             
     def retrieveValues(self, key):
+        """Retrieve values from the database."""
         c = self.conn.cursor()
         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
         c = self.conn.cursor()
         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
-        t = c.fetchone()
         l = []
         l = []
-        while t:
-            l.append(t[0])
-            t = c.fetchone()
+        rows = c.fetchall()
+        for row in rows:
+            l.append(row[0])
         return l
 
         return l
 
+    def countValues(self, key):
+        """Count the number of values in the database."""
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
+        res = 0
+        row = c.fetchone()
+        if row:
+            res = row[0]
+        return res
+
     def storeValue(self, key, value):
         """Store or update a key and value."""
         c = self.conn.cursor()
     def storeValue(self, key, value):
         """Store or update a key and value."""
         c = self.conn.cursor()
-        c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", (khash(key), value, datetime.now()))
+        c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
+                  (khash(key), dht_value(value), datetime.now()))
         self.conn.commit()
 
     def expireValues(self, expireAfter):
         """Expire older values after expireAfter seconds."""
         t = datetime.now() - timedelta(seconds=expireAfter)
         c = self.conn.cursor()
         self.conn.commit()
 
     def expireValues(self, expireAfter):
         """Expire older values after expireAfter seconds."""
         t = datetime.now() - timedelta(seconds=expireAfter)
         c = self.conn.cursor()
-        c.execute("DELETE FROM kv WHERE time < ?", (t, ))
+        c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
         self.conn.commit()
         
     def close(self):
         self.conn.commit()
         
     def close(self):
@@ -122,19 +143,19 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(self.store.getSelfNode(), self.key)
         
     def test_Value(self):
         self.failUnlessEqual(self.store.getSelfNode(), self.key)
         
     def test_Value(self):
-        self.store.storeValue(self.key, 'foobar')
+        self.store.storeValue(self.key, self.key)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'foobar')
+        self.failUnlessEqual(val[0], self.key)
         
     def test_expireValues(self):
         
     def test_expireValues(self):
-        self.store.storeValue(self.key, 'foobar')
+        self.store.storeValue(self.key, self.key)
         sleep(2)
         sleep(2)
-        self.store.storeValue(self.key, 'barfoo')
+        self.store.storeValue(self.key, self.key+self.key)
         self.store.expireValues(1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
         self.store.expireValues(1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'barfoo')
+        self.failUnlessEqual(val[0], self.key+self.key)
         
     def test_RoutingTable(self):
         class dummy:
         
     def test_RoutingTable(self):
         class dummy: