Workaround old sqlite not having 'select count(distinct key)'.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / db.py
1
2 """An sqlite database for storing nodes and key/value pairs."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os
9
10 from twisted.trial import unittest
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 class dht_value(str):
19     """Dummy class to convert all DHT values to base64 for storing in the DB."""
20
21 # Initialize the database to work with 'khash' objects (binary strings)
22 sqlite.register_adapter(khash, b2a_base64)
23 sqlite.register_converter("KHASH", a2b_base64)
24 sqlite.register_converter("khash", a2b_base64)
25
26 # Initialize the database to work with DHT values (binary strings)
27 sqlite.register_adapter(dht_value, b2a_base64)
28 sqlite.register_converter("DHT_VALUE", a2b_base64)
29 sqlite.register_converter("dht_value", a2b_base64)
30
31 class DB:
32     """An sqlite database for storing persistent node info and key/value pairs.
33     
34     @type db: C{string}
35     @ivar db: the database file to use
36     @type conn: L{pysqlite2.dbapi2.Connection}
37     @ivar conn: an open connection to the sqlite database
38     """
39     
40     def __init__(self, db):
41         """Load or create the database file.
42         
43         @type db: C{string}
44         @param db: the database file to use
45         """
46         self.db = db
47         try:
48             os.stat(db)
49         except OSError:
50             self._createNewDB(db)
51         else:
52             self._loadDB(db)
53         if sqlite.version_info < (2, 1):
54             sqlite.register_converter("TEXT", str)
55             sqlite.register_converter("text", str)
56         else:
57             self.conn.text_factory = str
58
59     #{ Loading the DB
60     def _loadDB(self, db):
61         """Open a new connection to the existing database file"""
62         try:
63             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
64         except:
65             import traceback
66             raise DBExcept, "Couldn't open DB", traceback.format_exc()
67         
68     def _createNewDB(self, db):
69         """Open a connection to a new database and create the necessary tables."""
70         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
71         c = self.conn.cursor()
72         c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
73                                     "PRIMARY KEY (key, value))")
74         c.execute("CREATE INDEX kv_key ON kv(key)")
75         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
76         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
77         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
78         self.conn.commit()
79
80     def close(self):
81         self.conn.close()
82
83     #{ This node's ID
84     def getSelfNode(self):
85         """Retrieve this node's ID from a previous run of the program."""
86         c = self.conn.cursor()
87         c.execute('SELECT id FROM self WHERE num = 0')
88         id = c.fetchone()
89         if id:
90             return id[0]
91         else:
92             return None
93         
94     def saveSelfNode(self, id):
95         """Store this node's ID for a subsequent run of the program."""
96         c = self.conn.cursor()
97         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
98         self.conn.commit()
99         
100     #{ Routing table
101     def dumpRoutingTable(self, buckets):
102         """Save routing table nodes to the database."""
103         c = self.conn.cursor()
104         c.execute("DELETE FROM nodes WHERE id NOT NULL")
105         for bucket in buckets:
106             for node in bucket.l:
107                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
108         self.conn.commit()
109         
110     def getRoutingTable(self):
111         """Load routing table nodes from database."""
112         c = self.conn.cursor()
113         c.execute("SELECT * FROM nodes")
114         return c.fetchall()
115
116     #{ Key/value pairs
117     def retrieveValues(self, key):
118         """Retrieve values from the database."""
119         c = self.conn.cursor()
120         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
121         l = []
122         rows = c.fetchall()
123         for row in rows:
124             l.append(row[0])
125         return l
126
127     def countValues(self, key):
128         """Count the number of values in the database."""
129         c = self.conn.cursor()
130         c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
131         res = 0
132         row = c.fetchone()
133         if row:
134             res = row[0]
135         return res
136
137     def storeValue(self, key, value):
138         """Store or update a key and value."""
139         c = self.conn.cursor()
140         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
141                   (khash(key), dht_value(value), datetime.now()))
142         self.conn.commit()
143
144     def expireValues(self, expireAfter):
145         """Expire older values after expireAfter seconds."""
146         t = datetime.now() - timedelta(seconds=expireAfter)
147         c = self.conn.cursor()
148         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
149         self.conn.commit()
150         
151     def keyStats(self):
152         """Count the total number of keys and values in the database.
153         @rtype: (C{int), C{int})
154         @return: the number of distinct keys and total values in the database
155         """
156         c = self.conn.cursor()
157         c.execute("SELECT COUNT(value) as num_values FROM kv")
158         values = 0
159         row = c.fetchone()
160         if row:
161             values = row[0]
162         c.execute("SELECT COUNT(key) as num_keys FROM (SELECT DISTINCT key FROM kv)")
163         keys = 0
164         row = c.fetchone()
165         if row:
166             keys = row[0]
167         return keys, values
168
169 class TestDB(unittest.TestCase):
170     """Tests for the khashmir database."""
171     
172     timeout = 5
173     db = '/tmp/khashmir.db'
174     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
175
176     def setUp(self):
177         self.store = DB(self.db)
178
179     def test_selfNode(self):
180         self.store.saveSelfNode(self.key)
181         self.failUnlessEqual(self.store.getSelfNode(), self.key)
182         
183     def test_Value(self):
184         self.store.storeValue(self.key, self.key)
185         val = self.store.retrieveValues(self.key)
186         self.failUnlessEqual(len(val), 1)
187         self.failUnlessEqual(val[0], self.key)
188         
189     def test_expireValues(self):
190         self.store.storeValue(self.key, self.key)
191         sleep(2)
192         self.store.storeValue(self.key, self.key+self.key)
193         self.store.expireValues(1)
194         val = self.store.retrieveValues(self.key)
195         self.failUnlessEqual(len(val), 1)
196         self.failUnlessEqual(val[0], self.key+self.key)
197         
198     def test_RoutingTable(self):
199         class dummy:
200             id = self.key
201             host = "127.0.0.1"
202             port = 9977
203             def contents(self):
204                 return (self.id, self.host, self.port)
205         dummy2 = dummy()
206         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
207         dummy2.host = '205.23.67.124'
208         dummy2.port = 12345
209         class bl:
210             def __init__(self):
211                 self.l = []
212         bl1 = bl()
213         bl1.l.append(dummy())
214         bl2 = bl()
215         bl2.l.append(dummy2)
216         buckets = [bl1, bl2]
217         self.store.dumpRoutingTable(buckets)
218         rt = self.store.getRoutingTable()
219         self.failUnlessIn(dummy().contents(), rt)
220         self.failUnlessIn(dummy2.contents(), rt)
221         
222     def tearDown(self):
223         self.store.close()
224         os.unlink(self.db)