2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.trial import unittest
10 class DBExcept(Exception):
14 """Dummy class to convert all hashes to base64 for storing in the DB."""
16 sqlite.register_adapter(khash, b2a_base64)
17 sqlite.register_converter("KHASH", a2b_base64)
18 sqlite.register_converter("khash", a2b_base64)
21 """Database access for storing persistent data."""
23 def __init__(self, db):
31 if sqlite.version_info < (2, 1):
32 sqlite.register_converter("TEXT", str)
33 sqlite.register_converter("text", str)
35 self.conn.text_factory = str
37 def _loadDB(self, db):
39 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self, db):
45 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46 c = self.conn.cursor()
47 c.execute("CREATE TABLE kv (key KHASH, value TEXT, originated TIMESTAMP, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
48 c.execute("CREATE INDEX kv_key ON kv(key)")
49 c.execute("CREATE INDEX kv_originated ON kv(originated)")
50 c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
51 c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
52 c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
55 def getSelfNode(self):
56 c = self.conn.cursor()
57 c.execute('SELECT id FROM self WHERE num = 0')
64 def saveSelfNode(self, id):
65 c = self.conn.cursor()
66 c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
69 def dumpRoutingTable(self, buckets):
71 save routing table nodes to the database
73 c = self.conn.cursor()
74 c.execute("DELETE FROM nodes WHERE id NOT NULL")
75 for bucket in buckets:
77 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
80 def getRoutingTable(self):
82 load routing table nodes from database
83 it's usually a good idea to call refreshTable(force=1) after loading the table
85 c = self.conn.cursor()
86 c.execute("SELECT * FROM nodes")
89 def retrieveValues(self, key):
90 """Retrieve values from the database."""
91 c = self.conn.cursor()
92 c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
99 def storeValue(self, key, value, originated):
100 """Store or update a key and value."""
101 c = self.conn.cursor()
102 c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?)",
103 (khash(key), value, originated, datetime.now()))
106 def expireValues(self, expireAfter):
107 """Expire older values after expireAfter seconds."""
108 t = datetime.now() - timedelta(seconds=expireAfter)
109 c = self.conn.cursor()
110 c.execute("DELETE FROM kv WHERE originated < ?", (t, ))
113 def refreshValues(self, expireAfter):
114 """Find older values than expireAfter seconds to refresh.
116 @return: a list of the hash keys and a list of dictionaries with
117 key of the value, value is the origination time
119 t = datetime.now() - timedelta(seconds=expireAfter)
120 c = self.conn.cursor()
121 c.execute("SELECT key, value, originated FROM kv WHERE last_refresh < ?", (t,))
127 vals.append({row[1]: row[2]})
133 class TestDB(unittest.TestCase):
134 """Tests for the khashmir database."""
137 db = '/tmp/khashmir.db'
138 key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
141 self.store = DB(self.db)
143 def test_selfNode(self):
144 self.store.saveSelfNode(self.key)
145 self.failUnlessEqual(self.store.getSelfNode(), self.key)
147 def test_Value(self):
148 self.store.storeValue(self.key, 'foobar', datetime.now())
149 val = self.store.retrieveValues(self.key)
150 self.failUnlessEqual(len(val), 1)
151 self.failUnlessEqual(val[0], 'foobar')
153 def test_expireValues(self):
154 self.store.storeValue(self.key, 'foobar', datetime.now())
156 self.store.storeValue(self.key, 'barfoo', datetime.now())
157 self.store.expireValues(1)
158 val = self.store.retrieveValues(self.key)
159 self.failUnlessEqual(len(val), 1)
160 self.failUnlessEqual(val[0], 'barfoo')
162 def test_refreshValues(self):
163 self.store.storeValue(self.key, 'foobar', datetime.now())
165 self.store.storeValue(self.key, 'barfoo', datetime.now())
166 keys, vals = self.store.refreshValues(1)
167 self.failUnlessEqual(len(keys), 1)
168 self.failUnlessEqual(keys[0], self.key)
169 self.failUnlessEqual(len(vals), 1)
170 self.failUnlessEqual(len(vals[0].keys()), 1)
171 self.failUnlessEqual(vals[0].keys()[0], 'foobar')
172 val = self.store.retrieveValues(self.key)
173 self.failUnlessEqual(len(val), 2)
175 def test_RoutingTable(self):
181 return (self.id, self.host, self.port)
183 dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
184 dummy2.host = '205.23.67.124'
190 bl1.l.append(dummy())
194 self.store.dumpRoutingTable(buckets)
195 rt = self.store.getRoutingTable()
196 self.failUnlessIn(dummy().contents(), rt)
197 self.failUnlessIn(dummy2.contents(), rt)