]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht/db.py
Added new database module for the main code.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 assert sqlite.version_info >= (2, 1)
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
21
22 class DB:
23     """Database access for storing persistent data."""
24     
25     def __init__(self, db):
26         self.db = db
27         try:
28             os.stat(db)
29         except OSError:
30             self._createNewDB(db)
31         else:
32             self._loadDB(db)
33         self.conn.text_factory = str
34         self.conn.row_factory = sqlite.Row
35         
36     def _loadDB(self, db):
37         try:
38             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
39         except:
40             import traceback
41             raise DBExcept, "Couldn't open DB", traceback.format_exc()
42         
43     def _createNewDB(self, db):
44         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
45         c = self.conn.cursor()
46         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
47 #        c.execute("CREATE INDEX files_hash ON files(hash)")
48         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
49         c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)")
50         c.close()
51         self.conn.commit()
52
53     def storeFile(self, path, hash, urlpath, refreshed):
54         """Store or update a file in the database."""
55         path = os.path.abspath(path)
56         stat = os.stat(path)
57         c = self.conn.cursor()
58         c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?, ?, ?)", 
59                   (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now()))
60         self.conn.commit()
61         c.close()
62         
63     def isUnchanged(self, path):
64         """Check if a file in the file system has changed.
65         
66         If it has changed, it is removed from the table.
67         
68         @return: True if unchanged, False if changed, None if not in database
69         """
70         path = os.path.abspath(path)
71         stat = os.stat(path)
72         c = self.conn.cursor()
73         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
74         row = c.fetchone()
75         res = None
76         if row:
77             res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
78             if not res:
79                 c.execute("DELETE FROM files WHERE path = ?", path)
80                 self.conn.commit()
81         c.close()
82         return res
83
84     def expiredFiles(self, expireAfter):
85         """Find files that need refreshing after expireAfter seconds.
86         
87         Also removes any entries from the table that no longer exist.
88         
89         @return: dictionary with keys the hashes, values a list of url paths
90         """
91         t = datetime.now() - timedelta(seconds=expireAfter)
92         c = self.conn.cursor()
93         c.execute("SELECT path, hash, urlpath FROM files WHERE refreshed < ?", (t, ))
94         row = c.fetchone()
95         expired = {}
96         missing = []
97         while row:
98             if os.path.exists(row['path']):
99                 expired.setdefault(row['hash'], []).append(row['urlpath'])
100             else:
101                 missing.append((row['path'],))
102             row = c.fetchone()
103         if missing:
104             c.executemany("DELETE FROM files WHERE path = ?", missing)
105         self.conn.commit()
106         return expired
107         
108     def removeUntrackedFiles(self, dirs):
109         """Find files that are no longer tracked and so should be removed.
110         
111         Also removes the entries from the table.
112         
113         @return: list of files that were removed
114         """
115         assert len(dirs) >= 1
116         dirs = dirs.copy()
117         sql = "WHERE"
118         for i in xrange(len(dirs)):
119             dirs[i] = os.path.abspath(dirs[i])
120             sql += " path NOT GLOB ?/* AND"
121         sql = sql[:-4]
122
123         c = self.conn.cursor()
124         c.execute("SELECT path FROM files " + sql, dirs)
125         row = c.fetchone()
126         removed = []
127         while row:
128             removed.append(row['path'])
129             row = c.fetchone()
130
131         if removed:
132             c.execute("DELETE FROM files " + sql, dirs)
133         self.conn.commit()
134         return removed
135         
136     def close(self):
137         self.conn.close()
138
139 class TestDB(unittest.TestCase):
140     """Tests for the khashmir database."""
141     
142     timeout = 5
143     db = '/tmp/khashmir.db'
144     key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
145
146     def setUp(self):
147         self.store = DB(self.db)
148
149     def test_selfNode(self):
150         self.store.saveSelfNode(self.key)
151         self.failUnlessEqual(self.store.getSelfNode(), self.key)
152         
153     def test_Value(self):
154         self.store.storeValue(self.key, 'foobar')
155         val = self.store.retrieveValues(self.key)
156         self.failUnlessEqual(len(val), 1)
157         self.failUnlessEqual(val[0], 'foobar')
158         
159     def test_expireValues(self):
160         self.store.storeValue(self.key, 'foobar')
161         sleep(2)
162         self.store.storeValue(self.key, 'barfoo')
163         self.store.expireValues(1)
164         val = self.store.retrieveValues(self.key)
165         self.failUnlessEqual(len(val), 1)
166         self.failUnlessEqual(val[0], 'barfoo')
167         
168     def test_RoutingTable(self):
169         class dummy:
170             id = self.key
171             host = "127.0.0.1"
172             port = 9977
173             def contents(self):
174                 return (self.id, self.host, self.port)
175         dummy2 = dummy()
176         dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
177         dummy2.host = '205.23.67.124'
178         dummy2.port = 12345
179         class bl:
180             def __init__(self):
181                 self.l = []
182         bl1 = bl()
183         bl1.l.append(dummy())
184         bl2 = bl()
185         bl2.l.append(dummy2)
186         buckets = [bl1, bl2]
187         self.store.dumpRoutingTable(buckets)
188         rt = self.store.getRoutingTable()
189         self.failUnlessIn(dummy().contents(), rt)
190         self.failUnlessIn(dummy2.contents(), rt)
191         
192     def tearDown(self):
193         self.store.close()
194         os.unlink(self.db)