Break up the find_value into 2 parts (with get_value).
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 class DBExcept(Exception):
11     pass
12
13 class khash(str):
14     """Dummy class to convert all hashes to base64 for storing in the DB."""
15     
16 class dht_value(str):
17     """Dummy class to convert all DHT values to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.register_adapter(dht_value, b2a_base64)
23 sqlite.register_converter("DHT_VALUE", a2b_base64)
24 sqlite.register_converter("dht_value", a2b_base64)
25
26 class DB:
27     """Database access for storing persistent data."""
28     
29     def __init__(self, db):
30         self.db = db
31         try:
32             os.stat(db)
33         except OSError:
34             self._createNewDB(db)
35         else:
36             self._loadDB(db)
37         if sqlite.version_info < (2, 1):
38             sqlite.register_converter("TEXT", str)
39             sqlite.register_converter("text", str)
40         else:
41             self.conn.text_factory = str
42         
43     def _loadDB(self, db):
44         try:
45             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         except:
47             import traceback
48             raise DBExcept, "Couldn't open DB", traceback.format_exc()
49         
50     def _createNewDB(self, db):
51         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
52         c = self.conn.cursor()
53         c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, PRIMARY KEY (key, value))")
54         c.execute("CREATE INDEX kv_key ON kv(key)")
55         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
56         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
57         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
58         self.conn.commit()
59
60     def getSelfNode(self):
61         c = self.conn.cursor()
62         c.execute('SELECT id FROM self WHERE num = 0')
63         id = c.fetchone()
64         if id:
65             return id[0]
66         else:
67             return None
68         
69     def saveSelfNode(self, id):
70         c = self.conn.cursor()
71         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
72         self.conn.commit()
73         
74     def dumpRoutingTable(self, buckets):
75         """
76             save routing table nodes to the database
77         """
78         c = self.conn.cursor()
79         c.execute("DELETE FROM nodes WHERE id NOT NULL")
80         for bucket in buckets:
81             for node in bucket.l:
82                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
83         self.conn.commit()
84         
85     def getRoutingTable(self):
86         """
87             load routing table nodes from database
88             it's usually a good idea to call refreshTable(force=1) after loading the table
89         """
90         c = self.conn.cursor()
91         c.execute("SELECT * FROM nodes")
92         return c.fetchall()
93             
94     def retrieveValues(self, key):
95         """Retrieve values from the database."""
96         c = self.conn.cursor()
97         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
98         l = []
99         rows = c.fetchall()
100         for row in rows:
101             l.append(row[0])
102         return l
103
104     def countValues(self, key):
105         """Count the number of values in the database."""
106         c = self.conn.cursor()
107         c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
108         res = 0
109         row = c.fetchone()
110         if row:
111             res = row[0]
112         return res
113
114     def storeValue(self, key, value):
115         """Store or update a key and value."""
116         c = self.conn.cursor()
117         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
118                   (khash(key), dht_value(value), datetime.now()))
119         self.conn.commit()
120
121     def expireValues(self, expireAfter):
122         """Expire older values after expireAfter seconds."""
123         t = datetime.now() - timedelta(seconds=expireAfter)
124         c = self.conn.cursor()
125         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
126         self.conn.commit()
127         
128     def close(self):
129         self.conn.close()
130
131 class TestDB(unittest.TestCase):
132     """Tests for the khashmir database."""
133     
134     timeout = 5
135     db = '/tmp/khashmir.db'
136     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
137
138     def setUp(self):
139         self.store = DB(self.db)
140
141     def test_selfNode(self):
142         self.store.saveSelfNode(self.key)
143         self.failUnlessEqual(self.store.getSelfNode(), self.key)
144         
145     def test_Value(self):
146         self.store.storeValue(self.key, self.key)
147         val = self.store.retrieveValues(self.key)
148         self.failUnlessEqual(len(val), 1)
149         self.failUnlessEqual(val[0], self.key)
150         
151     def test_expireValues(self):
152         self.store.storeValue(self.key, self.key)
153         sleep(2)
154         self.store.storeValue(self.key, self.key+self.key)
155         self.store.expireValues(1)
156         val = self.store.retrieveValues(self.key)
157         self.failUnlessEqual(len(val), 1)
158         self.failUnlessEqual(val[0], self.key+self.key)
159         
160     def test_RoutingTable(self):
161         class dummy:
162             id = self.key
163             host = "127.0.0.1"
164             port = 9977
165             def contents(self):
166                 return (self.id, self.host, self.port)
167         dummy2 = dummy()
168         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
169         dummy2.host = '205.23.67.124'
170         dummy2.port = 12345
171         class bl:
172             def __init__(self):
173                 self.l = []
174         bl1 = bl()
175         bl1.l.append(dummy())
176         bl2 = bl()
177         bl2.l.append(dummy2)
178         buckets = [bl1, bl2]
179         self.store.dumpRoutingTable(buckets)
180         rt = self.store.getRoutingTable()
181         self.failUnlessIn(dummy().contents(), rt)
182         self.failUnlessIn(dummy2.contents(), rt)
183         
184     def tearDown(self):
185         self.store.close()
186         os.unlink(self.db)