2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.trial import unittest
10 assert sqlite.version_info >= (2, 1)
12 class DBExcept(Exception):
16 """Dummy class to convert all hashes to base64 for storing in the DB."""
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
23 """Database access for storing persistent data."""
25 def __init__(self, db):
33 self.conn.text_factory = str
34 self.conn.row_factory = sqlite.Row
36 def _loadDB(self, db):
38 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
41 raise DBExcept, "Couldn't open DB", traceback.format_exc()
43 def _createNewDB(self, db):
44 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
45 c = self.conn.cursor()
46 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
47 # c.execute("CREATE INDEX files_hash ON files(hash)")
48 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
49 c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)")
53 def storeFile(self, path, hash, urlpath, refreshed):
54 """Store or update a file in the database."""
55 path = os.path.abspath(path)
57 c = self.conn.cursor()
58 c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?, ?, ?)",
59 (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now()))
63 def isUnchanged(self, path):
64 """Check if a file in the file system has changed.
66 If it has changed, it is removed from the table.
68 @return: True if unchanged, False if changed, None if not in database
70 path = os.path.abspath(path)
72 c = self.conn.cursor()
73 c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
77 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
79 c.execute("DELETE FROM files WHERE path = ?", path)
84 def expiredFiles(self, expireAfter):
85 """Find files that need refreshing after expireAfter seconds.
87 Also removes any entries from the table that no longer exist.
89 @return: dictionary with keys the hashes, values a list of url paths
91 t = datetime.now() - timedelta(seconds=expireAfter)
92 c = self.conn.cursor()
93 c.execute("SELECT path, hash, urlpath FROM files WHERE refreshed < ?", (t, ))
98 if os.path.exists(row['path']):
99 expired.setdefault(row['hash'], []).append(row['urlpath'])
101 missing.append((row['path'],))
104 c.executemany("DELETE FROM files WHERE path = ?", missing)
108 def removeUntrackedFiles(self, dirs):
109 """Find files that are no longer tracked and so should be removed.
111 Also removes the entries from the table.
113 @return: list of files that were removed
115 assert len(dirs) >= 1
118 for i in xrange(len(dirs)):
119 dirs[i] = os.path.abspath(dirs[i])
120 sql += " path NOT GLOB ?/* AND"
123 c = self.conn.cursor()
124 c.execute("SELECT path FROM files " + sql, dirs)
128 removed.append(row['path'])
132 c.execute("DELETE FROM files " + sql, dirs)
139 class TestDB(unittest.TestCase):
140 """Tests for the khashmir database."""
143 db = '/tmp/khashmir.db'
144 key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
147 self.store = DB(self.db)
149 def test_selfNode(self):
150 self.store.saveSelfNode(self.key)
151 self.failUnlessEqual(self.store.getSelfNode(), self.key)
153 def test_Value(self):
154 self.store.storeValue(self.key, 'foobar')
155 val = self.store.retrieveValues(self.key)
156 self.failUnlessEqual(len(val), 1)
157 self.failUnlessEqual(val[0], 'foobar')
159 def test_expireValues(self):
160 self.store.storeValue(self.key, 'foobar')
162 self.store.storeValue(self.key, 'barfoo')
163 self.store.expireValues(1)
164 val = self.store.retrieveValues(self.key)
165 self.failUnlessEqual(len(val), 1)
166 self.failUnlessEqual(val[0], 'barfoo')
168 def test_RoutingTable(self):
174 return (self.id, self.host, self.port)
176 dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
177 dummy2.host = '205.23.67.124'
183 bl1.l.append(dummy())
187 self.store.dumpRoutingTable(buckets)
188 rt = self.store.getRoutingTable()
189 self.failUnlessIn(dummy().contents(), rt)
190 self.failUnlessIn(dummy2.contents(), rt)