Remove the originated time from the DHT value storage.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 class DBExcept(Exception):
11     pass
12
13 class khash(str):
14     """Dummy class to convert all hashes to base64 for storing in the DB."""
15     
16 sqlite.register_adapter(khash, b2a_base64)
17 sqlite.register_converter("KHASH", a2b_base64)
18 sqlite.register_converter("khash", a2b_base64)
19
20 class DB:
21     """Database access for storing persistent data."""
22     
23     def __init__(self, db):
24         self.db = db
25         try:
26             os.stat(db)
27         except OSError:
28             self._createNewDB(db)
29         else:
30             self._loadDB(db)
31         if sqlite.version_info < (2, 1):
32             sqlite.register_converter("TEXT", str)
33             sqlite.register_converter("text", str)
34         else:
35             self.conn.text_factory = str
36         
37     def _loadDB(self, db):
38         try:
39             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self, db):
45         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         c = self.conn.cursor()
47         c.execute("CREATE TABLE kv (key KHASH, value TEXT, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
48         c.execute("CREATE INDEX kv_key ON kv(key)")
49         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
50         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
51         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
52         self.conn.commit()
53
54     def getSelfNode(self):
55         c = self.conn.cursor()
56         c.execute('SELECT id FROM self WHERE num = 0')
57         id = c.fetchone()
58         if id:
59             return id[0]
60         else:
61             return None
62         
63     def saveSelfNode(self, id):
64         c = self.conn.cursor()
65         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
66         self.conn.commit()
67         
68     def dumpRoutingTable(self, buckets):
69         """
70             save routing table nodes to the database
71         """
72         c = self.conn.cursor()
73         c.execute("DELETE FROM nodes WHERE id NOT NULL")
74         for bucket in buckets:
75             for node in bucket.l:
76                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
77         self.conn.commit()
78         
79     def getRoutingTable(self):
80         """
81             load routing table nodes from database
82             it's usually a good idea to call refreshTable(force=1) after loading the table
83         """
84         c = self.conn.cursor()
85         c.execute("SELECT * FROM nodes")
86         return c.fetchall()
87             
88     def retrieveValues(self, key):
89         """Retrieve values from the database."""
90         c = self.conn.cursor()
91         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
92         l = []
93         rows = c.fetchall()
94         for row in rows:
95             l.append(row[0])
96         return l
97
98     def storeValue(self, key, value):
99         """Store or update a key and value."""
100         c = self.conn.cursor()
101         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
102                   (khash(key), value, datetime.now()))
103         self.conn.commit()
104
105     def expireValues(self, expireAfter):
106         """Expire older values after expireAfter seconds."""
107         t = datetime.now() - timedelta(seconds=expireAfter)
108         c = self.conn.cursor()
109         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
110         self.conn.commit()
111         
112     def refreshValues(self, expireAfter):
113         """Find older values than expireAfter seconds to refresh.
114         
115         @return: a list of the hash keys and a list of dictionaries with
116             key of the value, value is the origination time
117         """
118         t = datetime.now() - timedelta(seconds=expireAfter)
119         c = self.conn.cursor()
120         c.execute("SELECT key, value, FROM kv WHERE last_refresh < ?", (t,))
121         keys = []
122         vals = []
123         rows = c.fetchall()
124         for row in rows:
125             keys.append(row[0])
126             vals.append({row[1]: row[2]})
127         return keys, vals
128         
129     def close(self):
130         self.conn.close()
131
132 class TestDB(unittest.TestCase):
133     """Tests for the khashmir database."""
134     
135     timeout = 5
136     db = '/tmp/khashmir.db'
137     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
138
139     def setUp(self):
140         self.store = DB(self.db)
141
142     def test_selfNode(self):
143         self.store.saveSelfNode(self.key)
144         self.failUnlessEqual(self.store.getSelfNode(), self.key)
145         
146     def test_Value(self):
147         self.store.storeValue(self.key, 'foobar', datetime.now())
148         val = self.store.retrieveValues(self.key)
149         self.failUnlessEqual(len(val), 1)
150         self.failUnlessEqual(val[0], 'foobar')
151         
152     def test_expireValues(self):
153         self.store.storeValue(self.key, 'foobar', datetime.now())
154         sleep(2)
155         self.store.storeValue(self.key, 'barfoo', datetime.now())
156         self.store.expireValues(1)
157         val = self.store.retrieveValues(self.key)
158         self.failUnlessEqual(len(val), 1)
159         self.failUnlessEqual(val[0], 'barfoo')
160         
161     def test_refreshValues(self):
162         self.store.storeValue(self.key, 'foobar', datetime.now())
163         sleep(2)
164         self.store.storeValue(self.key, 'barfoo', datetime.now())
165         keys, vals = self.store.refreshValues(1)
166         self.failUnlessEqual(len(keys), 1)
167         self.failUnlessEqual(keys[0], self.key)
168         self.failUnlessEqual(len(vals), 1)
169         self.failUnlessEqual(len(vals[0].keys()), 1)
170         self.failUnlessEqual(vals[0].keys()[0], 'foobar')
171         val = self.store.retrieveValues(self.key)
172         self.failUnlessEqual(len(val), 2)
173         
174     def test_RoutingTable(self):
175         class dummy:
176             id = self.key
177             host = "127.0.0.1"
178             port = 9977
179             def contents(self):
180                 return (self.id, self.host, self.port)
181         dummy2 = dummy()
182         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
183         dummy2.host = '205.23.67.124'
184         dummy2.port = 12345
185         class bl:
186             def __init__(self):
187                 self.l = []
188         bl1 = bl()
189         bl1.l.append(dummy())
190         bl2 = bl()
191         bl2.l.append(dummy2)
192         buckets = [bl1, bl2]
193         self.store.dumpRoutingTable(buckets)
194         rt = self.store.getRoutingTable()
195         self.failUnlessIn(dummy().contents(), rt)
196         self.failUnlessIn(dummy2.contents(), rt)
197         
198     def tearDown(self):
199         self.store.close()
200         os.unlink(self.db)