]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht_Khashmir/db.py
Fixed the storage of binary strings in the database.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 class DBExcept(Exception):
11     pass
12
13 class khash(str):
14     """Dummy class to convert all hashes to base64 for storing in the DB."""
15     
16 sqlite.register_adapter(khash, b2a_base64)
17 sqlite.register_converter("KHASH", a2b_base64)
18 sqlite.register_converter("khash", a2b_base64)
19
20 class DB:
21     """Database access for storing persistent data."""
22     
23     def __init__(self, db):
24         self.db = db
25         try:
26             os.stat(db)
27         except OSError:
28             self._createNewDB(db)
29         else:
30             self._loadDB(db)
31         self.conn.text_factory = str
32         
33     def _loadDB(self, db):
34         try:
35             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
36         except:
37             import traceback
38             raise DBExcept, "Couldn't open DB", traceback.format_exc()
39         
40     def _createNewDB(self, db):
41         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
42         c = self.conn.cursor()
43         c.execute("CREATE TABLE kv (key KHASH, value TEXT, time TIMESTAMP, PRIMARY KEY (key, value))")
44         c.execute("CREATE INDEX kv_key ON kv(key)")
45         c.execute("CREATE INDEX kv_timestamp ON kv(time)")
46         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
47         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
48         self.conn.commit()
49
50     def getSelfNode(self):
51         c = self.conn.cursor()
52         c.execute('SELECT id FROM self WHERE num = 0')
53         id = c.fetchone()
54         if id:
55             return id[0]
56         else:
57             return None
58         
59     def saveSelfNode(self, id):
60         c = self.conn.cursor()
61         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
62         self.conn.commit()
63         
64     def dumpRoutingTable(self, buckets):
65         """
66             save routing table nodes to the database
67         """
68         c = self.conn.cursor()
69         c.execute("DELETE FROM nodes WHERE id NOT NULL")
70         for bucket in buckets:
71             for node in bucket.l:
72                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
73         self.conn.commit()
74         
75     def getRoutingTable(self):
76         """
77             load routing table nodes from database
78             it's usually a good idea to call refreshTable(force=1) after loading the table
79         """
80         c = self.conn.cursor()
81         c.execute("SELECT * FROM nodes")
82         return c.fetchall()
83             
84     def retrieveValues(self, key):
85         c = self.conn.cursor()
86         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
87         t = c.fetchone()
88         l = []
89         while t:
90             l.append(t[0])
91             t = c.fetchone()
92         return l
93
94     def storeValue(self, key, value):
95         """Store or update a key and value."""
96         c = self.conn.cursor()
97         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", (khash(key), value, datetime.now()))
98         self.conn.commit()
99
100     def expireValues(self, expireAfter):
101         """Expire older values after expireAfter seconds."""
102         t = datetime.now() - timedelta(seconds=expireAfter)
103         c = self.conn.cursor()
104         c.execute("DELETE FROM kv WHERE time < ?", (t, ))
105         self.conn.commit()
106         
107     def close(self):
108         self.conn.close()
109
110 class TestDB(unittest.TestCase):
111     """Tests for the khashmir database."""
112     
113     timeout = 5
114     db = '/tmp/khashmir.db'
115     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
116
117     def setUp(self):
118         self.store = DB(self.db)
119
120     def test_selfNode(self):
121         self.store.saveSelfNode(self.key)
122         self.failUnlessEqual(self.store.getSelfNode(), self.key)
123         
124     def test_Value(self):
125         self.store.storeValue(self.key, 'foobar')
126         val = self.store.retrieveValues(self.key)
127         self.failUnlessEqual(len(val), 1)
128         self.failUnlessEqual(val[0], 'foobar')
129         
130     def test_expireValues(self):
131         self.store.storeValue(self.key, 'foobar')
132         sleep(2)
133         self.store.storeValue(self.key, 'barfoo')
134         self.store.expireValues(1)
135         val = self.store.retrieveValues(self.key)
136         self.failUnlessEqual(len(val), 1)
137         self.failUnlessEqual(val[0], 'barfoo')
138         
139     def test_RoutingTable(self):
140         class dummy:
141             id = self.key
142             host = "127.0.0.1"
143             port = 9977
144             def contents(self):
145                 return (self.id, self.host, self.port)
146         dummy2 = dummy()
147         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
148         dummy2.host = '205.23.67.124'
149         dummy2.port = 12345
150         class bl:
151             def __init__(self):
152                 self.l = []
153         bl1 = bl()
154         bl1.l.append(dummy())
155         bl2 = bl()
156         bl2.l.append(dummy2)
157         buckets = [bl1, bl2]
158         self.store.dumpRoutingTable(buckets)
159         rt = self.store.getRoutingTable()
160         self.failUnlessIn(dummy().contents(), rt)
161         self.failUnlessIn(dummy2.contents(), rt)
162         
163     def tearDown(self):
164         self.store.close()
165         os.unlink(self.db)