]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/db.py
Various documentation fixes and additions.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
index 5c62aed9416f865ed551454c039d6a6d7b13539f..47e974cf62122bab33cd4b4a4868a92e873bf822 100644 (file)
@@ -1,4 +1,6 @@
 
+"""An sqlite database for storing nodes and key/value pairs."""
+
 from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
@@ -13,14 +15,34 @@ class DBExcept(Exception):
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     
+class dht_value(str):
+    """Dummy class to convert all DHT values to base64 for storing in the DB."""
+
+# Initialize the database to work with 'khash' objects (binary strings)
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
 
+# Initialize the database to work with DHT values (binary strings)
+sqlite.register_adapter(dht_value, b2a_base64)
+sqlite.register_converter("DHT_VALUE", a2b_base64)
+sqlite.register_converter("dht_value", a2b_base64)
+
 class DB:
-    """Database access for storing persistent data."""
+    """An sqlite database for storing persistent node info and key/value pairs.
+    
+    @type db: C{string}
+    @ivar db: the database file to use
+    @type conn: L{pysqlite2.dbapi2.Connection}
+    @ivar conn: an open connection to the sqlite database
+    """
     
     def __init__(self, db):
+        """Load or create the database file.
+        
+        @type db: C{string}
+        @param db: the database file to use
+        """
         self.db = db
         try:
             os.stat(db)
@@ -33,8 +55,10 @@ class DB:
             sqlite.register_converter("text", str)
         else:
             self.conn.text_factory = str
-        
+
+    #{ Loading the DB
     def _loadDB(self, db):
+        """Open a new connection to the existing database file"""
         try:
             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         except:
@@ -42,16 +66,23 @@ class DB:
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
     def _createNewDB(self, db):
+        """Open a connection to a new database and create the necessary tables."""
         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE kv (key KHASH, value TEXT, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
+        c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
+                                    "PRIMARY KEY (key, value))")
         c.execute("CREATE INDEX kv_key ON kv(key)")
         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
         self.conn.commit()
 
+    def close(self):
+        self.conn.close()
+
+    #{ This node's ID
     def getSelfNode(self):
+        """Retrieve this node's ID from a previous run of the program."""
         c = self.conn.cursor()
         c.execute('SELECT id FROM self WHERE num = 0')
         id = c.fetchone()
@@ -61,14 +92,14 @@ class DB:
             return None
         
     def saveSelfNode(self, id):
+        """Store this node's ID for a subsequent run of the program."""
         c = self.conn.cursor()
         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
         self.conn.commit()
         
+    #{ Routing table
     def dumpRoutingTable(self, buckets):
-        """
-            save routing table nodes to the database
-        """
+        """Save routing table nodes to the database."""
         c = self.conn.cursor()
         c.execute("DELETE FROM nodes WHERE id NOT NULL")
         for bucket in buckets:
@@ -77,14 +108,12 @@ class DB:
         self.conn.commit()
         
     def getRoutingTable(self):
-        """
-            load routing table nodes from database
-            it's usually a good idea to call refreshTable(force=1) after loading the table
-        """
+        """Load routing table nodes from database."""
         c = self.conn.cursor()
         c.execute("SELECT * FROM nodes")
         return c.fetchall()
-            
+
+    #{ Key/value pairs
     def retrieveValues(self, key):
         """Retrieve values from the database."""
         c = self.conn.cursor()
@@ -95,11 +124,21 @@ class DB:
             l.append(row[0])
         return l
 
+    def countValues(self, key):
+        """Count the number of values in the database."""
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
+        res = 0
+        row = c.fetchone()
+        if row:
+            res = row[0]
+        return res
+
     def storeValue(self, key, value):
         """Store or update a key and value."""
         c = self.conn.cursor()
         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
-                  (khash(key), value, datetime.now()))
+                  (khash(key), dht_value(value), datetime.now()))
         self.conn.commit()
 
     def expireValues(self, expireAfter):
@@ -109,26 +148,6 @@ class DB:
         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
         self.conn.commit()
         
-    def refreshValues(self, expireAfter):
-        """Find older values than expireAfter seconds to refresh.
-        
-        @return: a list of the hash keys and a list of dictionaries with
-            key of the value, value is the origination time
-        """
-        t = datetime.now() - timedelta(seconds=expireAfter)
-        c = self.conn.cursor()
-        c.execute("SELECT key, value, FROM kv WHERE last_refresh < ?", (t,))
-        keys = []
-        vals = []
-        rows = c.fetchall()
-        for row in rows:
-            keys.append(row[0])
-            vals.append({row[1]: row[2]})
-        return keys, vals
-        
-    def close(self):
-        self.conn.close()
-
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
@@ -144,32 +163,19 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(self.store.getSelfNode(), self.key)
         
     def test_Value(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
+        self.store.storeValue(self.key, self.key)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'foobar')
+        self.failUnlessEqual(val[0], self.key)
         
     def test_expireValues(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
+        self.store.storeValue(self.key, self.key)
         sleep(2)
-        self.store.storeValue(self.key, 'barfoo', datetime.now())
+        self.store.storeValue(self.key, self.key+self.key)
         self.store.expireValues(1)
         val = self.store.retrieveValues(self.key)
         self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'barfoo')
-        
-    def test_refreshValues(self):
-        self.store.storeValue(self.key, 'foobar', datetime.now())
-        sleep(2)
-        self.store.storeValue(self.key, 'barfoo', datetime.now())
-        keys, vals = self.store.refreshValues(1)
-        self.failUnlessEqual(len(keys), 1)
-        self.failUnlessEqual(keys[0], self.key)
-        self.failUnlessEqual(len(vals), 1)
-        self.failUnlessEqual(len(vals[0].keys()), 1)
-        self.failUnlessEqual(vals[0].keys()[0], 'foobar')
-        val = self.store.retrieveValues(self.key)
-        self.failUnlessEqual(len(val), 2)
+        self.failUnlessEqual(val[0], self.key+self.key)
         
     def test_RoutingTable(self):
         class dummy: