]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht_Khashmir/db.py
8bb499e14512e85b0218941b22485ba0dac63474
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 class DBExcept(Exception):
11     pass
12
13 class khash(str):
14     """Dummy class to convert all hashes to base64 for storing in the DB."""
15     
16 sqlite.register_adapter(khash, b2a_base64)
17 sqlite.register_converter("KHASH", a2b_base64)
18 sqlite.register_converter("khash", a2b_base64)
19
20 class DB:
21     """Database access for storing persistent data."""
22     
23     def __init__(self, db):
24         self.db = db
25         try:
26             os.stat(db)
27         except OSError:
28             self._createNewDB(db)
29         else:
30             self._loadDB(db)
31         if sqlite.version_info < (2, 1):
32             sqlite.register_converter("TEXT", str)
33             sqlite.register_converter("text", str)
34         else:
35             self.conn.text_factory = str
36         
37     def _loadDB(self, db):
38         try:
39             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self, db):
45         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         c = self.conn.cursor()
47         c.execute("CREATE TABLE kv (key KHASH, value TEXT, time TIMESTAMP, PRIMARY KEY (key, value))")
48         c.execute("CREATE INDEX kv_key ON kv(key)")
49         c.execute("CREATE INDEX kv_timestamp ON kv(time)")
50         c.execute("CREATE TABLE nodes (id KHASH PRIMARY KEY, host TEXT, port NUMBER)")
51         c.execute("CREATE TABLE self (num NUMBER PRIMARY KEY, id KHASH)")
52         self.conn.commit()
53
54     def getSelfNode(self):
55         c = self.conn.cursor()
56         c.execute('SELECT id FROM self WHERE num = 0')
57         id = c.fetchone()
58         if id:
59             return id[0]
60         else:
61             return None
62         
63     def saveSelfNode(self, id):
64         c = self.conn.cursor()
65         c.execute("INSERT OR REPLACE INTO self VALUES (0, ?)", (khash(id),))
66         self.conn.commit()
67         
68     def dumpRoutingTable(self, buckets):
69         """
70             save routing table nodes to the database
71         """
72         c = self.conn.cursor()
73         c.execute("DELETE FROM nodes WHERE id NOT NULL")
74         for bucket in buckets:
75             for node in bucket.l:
76                 c.execute("INSERT INTO nodes VALUES (?, ?, ?)", (khash(node.id), node.host, node.port))
77         self.conn.commit()
78         
79     def getRoutingTable(self):
80         """
81             load routing table nodes from database
82             it's usually a good idea to call refreshTable(force=1) after loading the table
83         """
84         c = self.conn.cursor()
85         c.execute("SELECT * FROM nodes")
86         return c.fetchall()
87             
88     def retrieveValues(self, key):
89         c = self.conn.cursor()
90         c.execute("SELECT value FROM kv WHERE key = ?", (khash(key),))
91         t = c.fetchone()
92         l = []
93         while t:
94             l.append(t[0])
95             t = c.fetchone()
96         return l
97
98     def storeValue(self, key, value):
99         """Store or update a key and value."""
100         c = self.conn.cursor()
101         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?)", (khash(key), value, datetime.now()))
102         self.conn.commit()
103
104     def expireValues(self, expireAfter):
105         """Expire older values after expireAfter seconds."""
106         t = datetime.now() - timedelta(seconds=expireAfter)
107         c = self.conn.cursor()
108         c.execute("DELETE FROM kv WHERE time < ?", (t, ))
109         self.conn.commit()
110         
111     def close(self):
112         self.conn.close()
113
114 class TestDB(unittest.TestCase):
115     """Tests for the khashmir database."""
116     
117     timeout = 5
118     db = '/tmp/khashmir.db'
119     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
120
121     def setUp(self):
122         self.store = DB(self.db)
123
124     def test_selfNode(self):
125         self.store.saveSelfNode(self.key)
126         self.failUnlessEqual(self.store.getSelfNode(), self.key)
127         
128     def test_Value(self):
129         self.store.storeValue(self.key, 'foobar')
130         val = self.store.retrieveValues(self.key)
131         self.failUnlessEqual(len(val), 1)
132         self.failUnlessEqual(val[0], 'foobar')
133         
134     def test_expireValues(self):
135         self.store.storeValue(self.key, 'foobar')
136         sleep(2)
137         self.store.storeValue(self.key, 'barfoo')
138         self.store.expireValues(1)
139         val = self.store.retrieveValues(self.key)
140         self.failUnlessEqual(len(val), 1)
141         self.failUnlessEqual(val[0], 'barfoo')
142         
143     def test_RoutingTable(self):
144         class dummy:
145             id = self.key
146             host = "127.0.0.1"
147             port = 9977
148             def contents(self):
149                 return (self.id, self.host, self.port)
150         dummy2 = dummy()
151         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
152         dummy2.host = '205.23.67.124'
153         dummy2.port = 12345
154         class bl:
155             def __init__(self):
156                 self.l = []
157         bl1 = bl()
158         bl1.l.append(dummy())
159         bl2 = bl()
160         bl2.l.append(dummy2)
161         buckets = [bl1, bl2]
162         self.store.dumpRoutingTable(buckets)
163         rt = self.store.getRoutingTable()
164         self.failUnlessIn(dummy().contents(), rt)
165         self.failUnlessIn(dummy2.contents(), rt)
166         
167     def tearDown(self):
168         self.store.close()
169         os.unlink(self.db)