dc97352f8644beb07d9b20bd39edb031571e562f
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 """An sqlite database for storing nodes and key/value pairs."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os
9
10 from twisted.trial import unittest
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 class dht_value(str):
19     """Dummy class to convert all DHT values to base64 for storing in the DB."""
20
21 # Initialize the database to work with 'khash' objects (binary strings)
22 sqlite.register_adapter(khash, b2a_base64)
23 sqlite.register_converter("KHASH", a2b_base64)
24 sqlite.register_converter("khash", a2b_base64)
25
26 # Initialize the database to work with DHT values (binary strings)
27 sqlite.register_adapter(dht_value, b2a_base64)
28 sqlite.register_converter("DHT_VALUE", a2b_base64)
29 sqlite.register_converter("dht_value", a2b_base64)
30
31 class DB:
32     """An sqlite database for storing persistent node info and key/value pairs.
33     
34     @type db: C{string}
35     @ivar db: the database file to use
36     @type conn: L{pysqlite2.dbapi2.Connection}
37     @ivar conn: an open connection to the sqlite database
38     """
39     
40     def __init__(self, db):
41         """Load or create the database file.
42         
43         @type db: C{string}
44         @param db: the database file to use
45         """
46         self.db = db
47         try:
48             os.stat(db)
49         except OSError:
50             self._createNewDB(db)
51         else:
52             self._loadDB(db)
53         if sqlite.version_info < (2, 1):
54             sqlite.register_converter("TEXT", str)
55             sqlite.register_converter("text", str)
56         else:
57             self.conn.text_factory = str
58
59     #{ Loading the DB
60     def _loadDB(self, db):
61         """Open a new connection to the existing database file"""
62         try:
63             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
64         except:
65             import traceback
66             raise DBExcept, "Couldn't open DB", traceback.format_exc()
67         
68     def _createNewDB(self, db):
69         """Open a connection to a new database and create the necessary tables."""
70         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
71         c = self.conn.cursor()
72         c.execute("CREATE TABLE kv (key KHASH, value DHT_VALUE, last_refresh TIMESTAMP, "+
73                                     "PRIMARY KEY (key, value))")
74         c.execute("CREATE INDEX kv_key ON kv(key)")
75         c.execute("CREATE INDEX kv_last_refresh ON kv(last_refresh)")
76         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
77         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
78         self.conn.commit()
79
80     def close(self):
81         self.conn.close()
82
83     #{ This node's ID
84     def getSelfNode(self):
85         """Retrieve this node's ID from a previous run of the program."""
86         c = self.conn.cursor()
87         c.execute('SELECT id FROM self WHERE num = 0')
88         id = c.fetchone()
89         if id:
90             return id[0]
91         else:
92             return None
93         
94     def saveSelfNode(self, id):
95         """Store this node's ID for a subsequent run of the program."""
96         c = self.conn.cursor()
97         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
98         self.conn.commit()
99         
100     #{ Routing table
101     def dumpRoutingTable(self, buckets):
102         """Save routing table nodes to the database."""
103         c = self.conn.cursor()
104         c.execute("DELETE FROM nodes WHERE id NOT NULL")
105         for bucket in buckets:
106             for node in bucket.l:
107                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
108         self.conn.commit()
109         
110     def getRoutingTable(self):
111         """Load routing table nodes from database.
112         
113         It's usually a good idea to call refreshTable(force=1) after loading the table.
114         """
115         c = self.conn.cursor()
116         c.execute("SELECT * FROM nodes")
117         return c.fetchall()
118
119     #{ Key/value pairs
120     def retrieveValues(self, key):
121         """Retrieve values from the database."""
122         c = self.conn.cursor()
123         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
124         l = []
125         rows = c.fetchall()
126         for row in rows:
127             l.append(row[0])
128         return l
129
130     def countValues(self, key):
131         """Count the number of values in the database."""
132         c = self.conn.cursor()
133         c.execute("SELECT COUNT(value) as num_values FROM kv WHERE key = ?", (khash(key),))
134         res = 0
135         row = c.fetchone()
136         if row:
137             res = row[0]
138         return res
139
140     def storeValue(self, key, value):
141         """Store or update a key and value."""
142         c = self.conn.cursor()
143         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", 
144                   (khash(key), dht_value(value), datetime.now()))
145         self.conn.commit()
146
147     def expireValues(self, expireAfter):
148         """Expire older values after expireAfter seconds."""
149         t = datetime.now() - timedelta(seconds=expireAfter)
150         c = self.conn.cursor()
151         c.execute("DELETE FROM kv WHERE last_refresh < ?", (t, ))
152         self.conn.commit()
153         
154 class TestDB(unittest.TestCase):
155     """Tests for the khashmir database."""
156     
157     timeout = 5
158     db = '/tmp/khashmir.db'
159     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
160
161     def setUp(self):
162         self.store = DB(self.db)
163
164     def test_selfNode(self):
165         self.store.saveSelfNode(self.key)
166         self.failUnlessEqual(self.store.getSelfNode(), self.key)
167         
168     def test_Value(self):
169         self.store.storeValue(self.key, self.key)
170         val = self.store.retrieveValues(self.key)
171         self.failUnlessEqual(len(val), 1)
172         self.failUnlessEqual(val[0], self.key)
173         
174     def test_expireValues(self):
175         self.store.storeValue(self.key, self.key)
176         sleep(2)
177         self.store.storeValue(self.key, self.key+self.key)
178         self.store.expireValues(1)
179         val = self.store.retrieveValues(self.key)
180         self.failUnlessEqual(len(val), 1)
181         self.failUnlessEqual(val[0], self.key+self.key)
182         
183     def test_RoutingTable(self):
184         class dummy:
185             id = self.key
186             host = "127.0.0.1"
187             port = 9977
188             def contents(self):
189                 return (self.id, self.host, self.port)
190         dummy2 = dummy()
191         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
192         dummy2.host = '205.23.67.124'
193         dummy2.port = 12345
194         class bl:
195             def __init__(self):
196                 self.l = []
197         bl1 = bl()
198         bl1.l.append(dummy())
199         bl2 = bl()
200         bl2.l.append(dummy2)
201         buckets = [bl1, bl2]
202         self.store.dumpRoutingTable(buckets)
203         rt = self.store.getRoutingTable()
204         self.failUnlessIn(dummy().contents(), rt)
205         self.failUnlessIn(dummy2.contents(), rt)
206         
207     def tearDown(self):
208         self.store.close()
209         os.unlink(self.db)