2 """An sqlite database for storing nodes and key/value pairs."""
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
10 from twisted.trial import unittest
12 class DBExcept(Exception):
16 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 """Dummy class to convert all DHT values to base64 for storing in the DB."""
21 # Initialize the database to work with 'khash' objects (binary strings)
22 sqlite.register_adapter(khash, b2a_base64)
23 sqlite.register_converter("KHASH", a2b_base64)
24 sqlite.register_converter("khash", a2b_base64)
26 # Initialize the database to work with DHT values (binary strings)
27 sqlite.register_adapter(dht_value, b2a_base64)
28 sqlite.register_converter("DHT_VALUE", a2b_base64)
29 sqlite.register_converter("dht_value", a2b_base64)
32 """An sqlite database for storing persistent node info and key/value pairs.
35 @ivar db: the database file to use
36 @type conn: L{pysqlite2.dbapi2.Connection}
37 @ivar conn: an open connection to the sqlite database
40 def __init__(self, db):
41 """Load or create the database file.
44 @param db: the database file to use
53 if sqlite.version_info < (2, 1):
54 sqlite.register_converter("TEXT", str)
55 sqlite.register_converter("text", str)
57 self.conn.text_factory = str
60 def _loadDB(self, db):
61 """Open a new connection to the existing database file"""
63 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
66 raise DBExcept, "Couldn't open DB", traceback.format_exc()
68 def _createNewDB(self, db):
69 """Open a connection to a new database and create the necessary tables."""
70 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
71 c = self.conn.cursor()
72 c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
73 "PRIMARY KEY (key, value))")
74 c.execute("CREATE INDEX kv_key ON kv(key)")
75 c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
76 c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
77 c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
84 def getSelfNode(self):
85 """Retrieve this node's ID from a previous run of the program."""
86 c = self.conn.cursor()
87 c.execute('SELECT id FROM self WHERE num = 0')
94 def saveSelfNode(self, id):
95 """Store this node's ID for a subsequent run of the program."""
96 c = self.conn.cursor()
97 c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
101 def dumpRoutingTable(self, buckets):
102 """Save routing table nodes to the database."""
103 c = self.conn.cursor()
104 c.execute("DELETE FROM nodes WHERE id NOT NULL")
105 for bucket in buckets:
106 for node in bucket.l:
107 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
110 def getRoutingTable(self):
111 """Load routing table nodes from database."""
112 c = self.conn.cursor()
113 c.execute("SELECT * FROM nodes")
117 def retrieveValues(self, key):
118 """Retrieve values from the database."""
119 c = self.conn.cursor()
120 c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
127 def countValues(self, key):
128 """Count the number of values in the database."""
129 c = self.conn.cursor()
130 c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
137 def storeValue(self, key, value):
138 """Store or update a key and value."""
139 c = self.conn.cursor()
140 c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)",
141 (khash(key), dht_value(value), datetime.now()))
144 def expireValues(self, expireAfter):
145 """Expire older values after expireAfter seconds."""
146 t = datetime.now() - timedelta(seconds=expireAfter)
147 c = self.conn.cursor()
148 c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
152 """Count the total number of keys and values in the database.
153 @rtype: (C{int), C{int})
154 @return: the number of distinct keys and total values in the database
156 c = self.conn.cursor()
157 c.execute("SELECT COUNT(DISTINCT key) as num_keys, COUNT(value) as num_values FROM kv")
161 keys, values = row[0], row[1]
164 class TestDB(unittest.TestCase):
165 """Tests for the khashmir database."""
168 db = '/tmp/khashmir.db'
169 key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
172 self.store = DB(self.db)
174 def test_selfNode(self):
175 self.store.saveSelfNode(self.key)
176 self.failUnlessEqual(self.store.getSelfNode(), self.key)
178 def test_Value(self):
179 self.store.storeValue(self.key, self.key)
180 val = self.store.retrieveValues(self.key)
181 self.failUnlessEqual(len(val), 1)
182 self.failUnlessEqual(val[0], self.key)
184 def test_expireValues(self):
185 self.store.storeValue(self.key, self.key)
187 self.store.storeValue(self.key, self.key+self.key)
188 self.store.expireValues(1)
189 val = self.store.retrieveValues(self.key)
190 self.failUnlessEqual(len(val), 1)
191 self.failUnlessEqual(val[0], self.key+self.key)
193 def test_RoutingTable(self):
199 return (self.id, self.host, self.port)
201 dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
202 dummy2.host = '205.23.67.124'
208 bl1.l.append(dummy())
212 self.store.dumpRoutingTable(buckets)
213 rt = self.store.getRoutingTable()
214 self.failUnlessIn(dummy().contents(), rt)
215 self.failUnlessIn(dummy2.contents(), rt)