-from time import time
-import sqlite ## find this at http://pysqlite.sourceforge.net/
+from datetime import datetime, timedelta
+from pysqlite2 import dbapi2 as sqlite
+from binascii import a2b_base64, b2a_base64
+from time import sleep
import os
+from twisted.trial import unittest
+
class DBExcept(Exception):
pass
+class khash(str):
+ """Dummy class to convert all hashes to base64 for storing in the DB."""
+
+sqlite.register_adapter(khash, b2a_base64)
+sqlite.register_converter("KHASH", a2b_base64)
+sqlite.register_converter("khash", a2b_base64)
+
class DB:
"""Database access for storing persistent data."""
self._createNewDB(db)
else:
self._loadDB(db)
+ if sqlite.version_info < (2, 1):
+ sqlite.register_converter("TEXT", str)
+ sqlite.register_converter("text", str)
+ else:
+ self.conn.text_factory = str
def _loadDB(self, db):
try:
- self.store = sqlite.connect(db=db)
- #self.store.autocommit = 0
+ self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
except:
import traceback
raise DBExcept, "Couldn't open DB", traceback.format_exc()
def _createNewDB(self, db):
- self.store = sqlite.connect(db=db)
- s = """
- create table kv (key binary, value binary, time timestamp, primary key (key, value));
- create index kv_key on kv(key);
- create index kv_timestamp on kv(time);
-
- create table nodes (id binary primary key, host text, port number);
-
- create table self (num number primary key, id binary);
- """
- c = self.store.cursor()
- c.execute(s)
- self.store.commit()
+ self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
+ c = self.conn.cursor()
+ c.execute("CREATE TABLE kv (key KHASH, value TEXT, originated TIMESTAMP, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
+ c.execute("CREATE INDEX kv_key ON kv(key)")
+ c.execute("CREATE INDEX kv_originated ON kv(originated)")
+ c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
+ c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
+ c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
+ self.conn.commit()
def getSelfNode(self):
- c = self.store.cursor()
- c.execute('select id from self where num = 0;')
- if c.rowcount > 0:
- return c.fetchone()[0]
+ c = self.conn.cursor()
+ c.execute('SELECT id FROM self WHERE num = 0')
+ id = c.fetchone()
+ if id:
+ return id[0]
else:
return None
def saveSelfNode(self, id):
- c = self.store.cursor()
- c.execute('delete from self where num = 0;')
- c.execute("insert into self values (0, %s);", sqlite.encode(id))
- self.store.commit()
+ c = self.conn.cursor()
+ c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
+ self.conn.commit()
def dumpRoutingTable(self, buckets):
"""
save routing table nodes to the database
"""
- c = self.store.cursor()
- c.execute("delete from nodes where id not NULL;")
+ c = self.conn.cursor()
+ c.execute("DELETE FROM nodes WHERE id NOT NULL")
for bucket in buckets:
for node in bucket.l:
- c.execute("insert into nodes values (%s, %s, %s);", (sqlite.encode(node.id), node.host, node.port))
- self.store.commit()
+ c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
+ self.conn.commit()
def getRoutingTable(self):
"""
load routing table nodes from database
it's usually a good idea to call refreshTable(force=1) after loading the table
"""
- c = self.store.cursor()
- c.execute("select * from nodes;")
+ c = self.conn.cursor()
+ c.execute("SELECT * FROM nodes")
return c.fetchall()
def retrieveValues(self, key):
- c = self.store.cursor()
- c.execute("select value from kv where key = %s;", sqlite.encode(key))
- t = c.fetchone()
+ """Retrieve values from the database."""
+ c = self.conn.cursor()
+ c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
l = []
- while t:
- l.append(t['value'])
- t = c.fetchone()
+ rows = c.fetchall()
+ for row in rows:
+ l.append(row[0])
return l
- def storeValue(self, key, value):
+ def storeValue(self, key, value, originated):
"""Store or update a key and value."""
- t = "%0.6f" % time()
- c = self.store.cursor()
- try:
- c.execute("insert into kv values (%s, %s, %s);", (sqlite.encode(key), sqlite.encode(value), t))
- except sqlite.IntegrityError, reason:
- # update last insert time
- c.execute("update kv set time = %s where key = %s and value = %s;", (t, sqlite.encode(key), sqlite.encode(value)))
- self.store.commit()
+ c = self.conn.cursor()
+ c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?)",
+ (khash(key), value, originated, datetime.now()))
+ self.conn.commit()
- def expireValues(self, expireTime):
- """Expire older values than expireTime."""
- t = "%0.6f" % expireTime
- c = self.store.cursor()
- s = "delete from kv where time < '%s';" % t
- c.execute(s)
+ def expireValues(self, expireAfter):
+ """Expire older values after expireAfter seconds."""
+ t = datetime.now() - timedelta(seconds=expireAfter)
+ c = self.conn.cursor()
+ c.execute("DELETE FROM kv WHERE originated < ?", (t, ))
+ self.conn.commit()
+
+ def refreshValues(self, expireAfter):
+ """Find older values than expireAfter seconds to refresh.
+
+ @return: a list of the hash keys and a list of dictionaries with
+ key of the value, value is the origination time
+ """
+ t = datetime.now() - timedelta(seconds=expireAfter)
+ c = self.conn.cursor()
+ c.execute("SELECT key, value, originated FROM kv WHERE last_refresh < ?", (t,))
+ keys = []
+ vals = []
+ rows = c.fetchall()
+ for row in rows:
+ keys.append(row[0])
+ vals.append({row[1]: row[2]})
+ return keys, vals
def close(self):
+ self.conn.close()
+
+class TestDB(unittest.TestCase):
+ """Tests for the khashmir database."""
+
+ timeout = 5
+ db = '/tmp/khashmir.db'
+ key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+
+ def setUp(self):
+ self.store = DB(self.db)
+
+ def test_selfNode(self):
+ self.store.saveSelfNode(self.key)
+ self.failUnlessEqual(self.store.getSelfNode(), self.key)
+
+ def test_Value(self):
+ self.store.storeValue(self.key, 'foobar', datetime.now())
+ val = self.store.retrieveValues(self.key)
+ self.failUnlessEqual(len(val), 1)
+ self.failUnlessEqual(val[0], 'foobar')
+
+ def test_expireValues(self):
+ self.store.storeValue(self.key, 'foobar', datetime.now())
+ sleep(2)
+ self.store.storeValue(self.key, 'barfoo', datetime.now())
+ self.store.expireValues(1)
+ val = self.store.retrieveValues(self.key)
+ self.failUnlessEqual(len(val), 1)
+ self.failUnlessEqual(val[0], 'barfoo')
+
+ def test_refreshValues(self):
+ self.store.storeValue(self.key, 'foobar', datetime.now())
+ sleep(2)
+ self.store.storeValue(self.key, 'barfoo', datetime.now())
+ keys, vals = self.store.refreshValues(1)
+ self.failUnlessEqual(len(keys), 1)
+ self.failUnlessEqual(keys[0], self.key)
+ self.failUnlessEqual(len(vals), 1)
+ self.failUnlessEqual(len(vals[0].keys()), 1)
+ self.failUnlessEqual(vals[0].keys()[0], 'foobar')
+ val = self.store.retrieveValues(self.key)
+ self.failUnlessEqual(len(val), 2)
+
+ def test_RoutingTable(self):
+ class dummy:
+ id = self.key
+ host = "127.0.0.1"
+ port = 9977
+ def contents(self):
+ return (self.id, self.host, self.port)
+ dummy2 = dummy()
+ dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+ dummy2.host = '205.23.67.124'
+ dummy2.port = 12345
+ class bl:
+ def __init__(self):
+ self.l = []
+ bl1 = bl()
+ bl1.l.append(dummy())
+ bl2 = bl()
+ bl2.l.append(dummy2)
+ buckets = [bl1, bl2]
+ self.store.dumpRoutingTable(buckets)
+ rt = self.store.getRoutingTable()
+ self.failUnlessIn(dummy().contents(), rt)
+ self.failUnlessIn(dummy2.contents(), rt)
+
+ def tearDown(self):
self.store.close()
+ os.unlink(self.db)