2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.trial import unittest
10 class DBExcept(Exception):
14 """Dummy class to convert all hashes to base64 for storing in the DB."""
17 """Dummy class to convert all DHT values to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.register_adapter(dht_value, b2a_base64)
23 sqlite.register_converter("DHT_VALUE", a2b_base64)
24 sqlite.register_converter("dht_value", a2b_base64)
27 """Database access for storing persistent data."""
29 def __init__(self, db):
37 if sqlite.version_info < (2, 1):
38 sqlite.register_converter("TEXT", str)
39 sqlite.register_converter("text", str)
41 self.conn.text_factory = str
43 def _loadDB(self, db):
45 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
48 raise DBExcept, "Couldn't open DB", traceback.format_exc()
50 def _createNewDB(self, db):
51 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
52 c = self.conn.cursor()
53 c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
54 c.execute("CREATE INDEX kv_key ON kv(key)")
55 c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
56 c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
57 c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
60 def getSelfNode(self):
61 c = self.conn.cursor()
62 c.execute('SELECT id FROM self WHERE num = 0')
69 def saveSelfNode(self, id):
70 c = self.conn.cursor()
71 c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
74 def dumpRoutingTable(self, buckets):
76 save routing table nodes to the database
78 c = self.conn.cursor()
79 c.execute("DELETE FROM nodes WHERE id NOT NULL")
80 for bucket in buckets:
82 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
85 def getRoutingTable(self):
87 load routing table nodes from database
88 it's usually a good idea to call refreshTable(force=1) after loading the table
90 c = self.conn.cursor()
91 c.execute("SELECT * FROM nodes")
94 def retrieveValues(self, key):
95 """Retrieve values from the database."""
96 c = self.conn.cursor()
97 c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
104 def countValues(self, key):
105 """Count the number of values in the database."""
106 c = self.conn.cursor()
107 c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
114 def storeValue(self, key, value):
115 """Store or update a key and value."""
116 c = self.conn.cursor()
117 c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)",
118 (khash(key), dht_value(value), datetime.now()))
121 def expireValues(self, expireAfter):
122 """Expire older values after expireAfter seconds."""
123 t = datetime.now() - timedelta(seconds=expireAfter)
124 c = self.conn.cursor()
125 c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
131 class TestDB(unittest.TestCase):
132 """Tests for the khashmir database."""
135 db = '/tmp/khashmir.db'
136 key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
139 self.store = DB(self.db)
141 def test_selfNode(self):
142 self.store.saveSelfNode(self.key)
143 self.failUnlessEqual(self.store.getSelfNode(), self.key)
145 def test_Value(self):
146 self.store.storeValue(self.key, self.key)
147 val = self.store.retrieveValues(self.key)
148 self.failUnlessEqual(len(val), 1)
149 self.failUnlessEqual(val[0], self.key)
151 def test_expireValues(self):
152 self.store.storeValue(self.key, self.key)
154 self.store.storeValue(self.key, self.key+self.key)
155 self.store.expireValues(1)
156 val = self.store.retrieveValues(self.key)
157 self.failUnlessEqual(len(val), 1)
158 self.failUnlessEqual(val[0], self.key+self.key)
160 def test_RoutingTable(self):
166 return (self.id, self.host, self.port)
168 dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
169 dummy2.host = '205.23.67.124'
175 bl1.l.append(dummy())
179 self.store.dumpRoutingTable(buckets)
180 rt = self.store.getRoutingTable()
181 self.failUnlessIn(dummy().contents(), rt)
182 self.failUnlessIn(dummy2.contents(), rt)