khashmir's store value takes the origination date.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 class DBExcept(Exception):
11     pass
12
13 class khash(str):
14     """Dummy class to convert all hashes to base64 for storing in the DB."""
15     
16 sqlite.register_adapter(khash, b2a_base64)
17 sqlite.register_converter("KHASH", a2b_base64)
18 sqlite.register_converter("khash", a2b_base64)
19
20 class DB:
21     """Database access for storing persistent data."""
22     
23     def __init__(self, db):
24         self.db = db
25         try:
26             os.stat(db)
27         except OSError:
28             self._createNewDB(db)
29         else:
30             self._loadDB(db)
31         if sqlite.version_info < (2, 1):
32             sqlite.register_converter("TEXT", str)
33             sqlite.register_converter("text", str)
34         else:
35             self.conn.text_factory = str
36         
37     def _loadDB(self, db):
38         try:
39             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self, db):
45         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         c = self.conn.cursor()
47         c.execute("CREATE TABLE kv (key KHASH, value TEXT, originated TIMESTAMP, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
48         c.execute("CREATE INDEX kv_key ON kv(key)")
49         c.execute("CREATE INDEX kv_originated ON kv(originated)")
50         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
51         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
52         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
53         self.conn.commit()
54
55     def getSelfNode(self):
56         c = self.conn.cursor()
57         c.execute('SELECT id FROM self WHERE num = 0')
58         id = c.fetchone()
59         if id:
60             return id[0]
61         else:
62             return None
63         
64     def saveSelfNode(self, id):
65         c = self.conn.cursor()
66         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
67         self.conn.commit()
68         
69     def dumpRoutingTable(self, buckets):
70         """
71             save routing table nodes to the database
72         """
73         c = self.conn.cursor()
74         c.execute("DELETE FROM nodes WHERE id NOT NULL")
75         for bucket in buckets:
76             for node in bucket.l:
77                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
78         self.conn.commit()
79         
80     def getRoutingTable(self):
81         """
82             load routing table nodes from database
83             it's usually a good idea to call refreshTable(force=1) after loading the table
84         """
85         c = self.conn.cursor()
86         c.execute("SELECT * FROM nodes")
87         return c.fetchall()
88             
89     def retrieveValues(self, key):
90         """Retrieve values from the database."""
91         c = self.conn.cursor()
92         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
93         l = []
94         rows = c.fetchall()
95         for row in rows:
96             l.append(row[0])
97         return l
98
99     def storeValue(self, key, value, originated):
100         """Store or update a key and value."""
101         c = self.conn.cursor()
102         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?)", 
103                   (khash(key), value, originated, datetime.now()))
104         self.conn.commit()
105
106     def expireValues(self, expireAfter):
107         """Expire older values after expireAfter seconds."""
108         t = datetime.now() - timedelta(seconds=expireAfter)
109         c = self.conn.cursor()
110         c.execute("DELETE FROM kv WHERE originated < ?", (t, ))
111         self.conn.commit()
112         
113     def refreshValues(self, expireAfter):
114         """Find older values than expireAfter seconds to refresh.
115         
116         @return: a list of the hash keys and a list of dictionaries with
117             key of the value, value is the origination time
118         """
119         t = datetime.now() - timedelta(seconds=expireAfter)
120         c = self.conn.cursor()
121         c.execute("SELECT key, value, originated FROM kv WHERE last_refresh < ?", (t,))
122         keys = []
123         vals = []
124         rows = c.fetchall()
125         for row in rows:
126             keys.append(row[0])
127             vals.append({row[1]: row[2]})
128         return keys, vals
129         
130     def close(self):
131         self.conn.close()
132
133 class TestDB(unittest.TestCase):
134     """Tests for the khashmir database."""
135     
136     timeout = 5
137     db = '/tmp/khashmir.db'
138     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
139
140     def setUp(self):
141         self.store = DB(self.db)
142
143     def test_selfNode(self):
144         self.store.saveSelfNode(self.key)
145         self.failUnlessEqual(self.store.getSelfNode(), self.key)
146         
147     def test_Value(self):
148         self.store.storeValue(self.key, 'foobar', datetime.now())
149         val = self.store.retrieveValues(self.key)
150         self.failUnlessEqual(len(val), 1)
151         self.failUnlessEqual(val[0], 'foobar')
152         
153     def test_expireValues(self):
154         self.store.storeValue(self.key, 'foobar', datetime.now())
155         sleep(2)
156         self.store.storeValue(self.key, 'barfoo', datetime.now())
157         self.store.expireValues(1)
158         val = self.store.retrieveValues(self.key)
159         self.failUnlessEqual(len(val), 1)
160         self.failUnlessEqual(val[0], 'barfoo')
161         
162     def test_refreshValues(self):
163         self.store.storeValue(self.key, 'foobar', datetime.now())
164         sleep(2)
165         self.store.storeValue(self.key, 'barfoo', datetime.now())
166         keys, vals = self.store.refreshValues(1)
167         self.failUnlessEqual(len(keys), 1)
168         self.failUnlessEqual(keys[0], self.key)
169         self.failUnlessEqual(len(vals), 1)
170         self.failUnlessEqual(len(vals[0].keys()), 1)
171         self.failUnlessEqual(vals[0].keys()[0], 'foobar')
172         val = self.store.retrieveValues(self.key)
173         self.failUnlessEqual(len(val), 2)
174         
175     def test_RoutingTable(self):
176         class dummy:
177             id = self.key
178             host = "127.0.0.1"
179             port = 9977
180             def contents(self):
181                 return (self.id, self.host, self.port)
182         dummy2 = dummy()
183         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
184         dummy2.host = '205.23.67.124'
185         dummy2.port = 12345
186         class bl:
187             def __init__(self):
188                 self.l = []
189         bl1 = bl()
190         bl1.l.append(dummy())
191         bl2 = bl()
192         bl2.l.append(dummy2)
193         buckets = [bl1, bl2]
194         self.store.dumpRoutingTable(buckets)
195         rt = self.store.getRoutingTable()
196         self.failUnlessIn(dummy().contents(), rt)
197         self.failUnlessIn(dummy2.contents(), rt)
198         
199     def tearDown(self):
200         self.store.close()
201         os.unlink(self.db)