2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
11 assert sqlite.version_info >= (2, 1)
13 class DBExcept(Exception):
17 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
25 """Database access for storing persistent data."""
27 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
39 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self):
45 if not self.db.parent().exists():
46 self.db.parent().makedirs()
47 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48 c = self.conn.cursor()
49 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50 c.execute("CREATE INDEX files_hash ON files(hash)")
51 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
55 def _removeChanged(self, file, row):
60 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
62 c = self.conn.cursor()
63 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
68 def storeFile(self, file, hash):
69 """Store or update a file in the database."""
71 c = self.conn.cursor()
72 c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
75 c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?",
76 (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
78 c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
79 (file.path, khash(hash), file.getsize(), file.getmtime(), datetime.now()))
83 def getFile(self, file):
84 """Get a file from the database.
86 If it has changed or is missing, it is removed from the database.
88 @return: dictionary of info for the file, False if changed, or
89 None if not in database or missing
91 c = self.conn.cursor()
92 c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
96 res = self._removeChanged(file, row)
99 res['hash'] = row['hash']
100 res['size'] = row['size']
104 def lookupHash(self, hash):
105 """Find a file by hash in the database.
107 If any found files have changed or are missing, they are removed
110 @return: list of dictionaries of info for the found files
112 c = self.conn.cursor()
113 c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
117 file = FilePath(row['path'])
118 res = self._removeChanged(file, row)
122 res['size'] = row['size']
128 def isUnchanged(self, file):
129 """Check if a file in the file system has changed.
131 If it has changed, it is removed from the table.
133 @return: True if unchanged, False if changed, None if not in database
135 c = self.conn.cursor()
136 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
138 return self._removeChanged(file, row)
140 def refreshFile(self, file):
141 """Refresh the publishing time of a file.
143 If it has changed or is missing, it is removed from the table.
145 @return: True if unchanged, False if changed, None if not in database
147 c = self.conn.cursor()
148 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
152 res = self._removeChanged(file, row)
154 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
157 def expiredFiles(self, expireAfter):
158 """Find files that need refreshing after expireAfter seconds.
160 Also removes any entries from the table that no longer exist.
162 @return: dictionary with keys the hashes, values a list of FilePaths
164 t = datetime.now() - timedelta(seconds=expireAfter)
165 c = self.conn.cursor()
166 c.execute("SELECT path, hash, size, mtime FROM files WHERE refreshed < ?", (t, ))
170 res = self._removeChanged(FilePath(row['path']), row)
172 expired.setdefault(row['hash'], []).append(FilePath(row['path']))
177 def removeUntrackedFiles(self, dirs):
178 """Find files that are no longer tracked and so should be removed.
180 Also removes the entries from the table.
182 @return: list of files that were removed
184 assert len(dirs) >= 1
188 newdirs.append(dir.child('*').path)
189 sql += " path NOT GLOB ? AND"
192 c = self.conn.cursor()
193 c.execute("SELECT path FROM files " + sql, newdirs)
197 removed.append(FilePath(row['path']))
201 c.execute("DELETE FROM files " + sql, newdirs)
208 class TestDB(unittest.TestCase):
209 """Tests for the khashmir database."""
212 db = FilePath('/tmp/khashmir.db')
213 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
214 directory = FilePath('/tmp/apt-dht/')
215 file = FilePath('/tmp/apt-dht/khashmir.test')
216 testfile = 'tmp/khashmir.test'
217 dirs = [FilePath('/tmp/apt-dht/top1'),
218 FilePath('/tmp/apt-dht/top2/sub1'),
219 FilePath('/tmp/apt-dht/top2/sub2/')]
222 if not self.file.parent().exists():
223 self.file.parent().makedirs()
224 self.file.setContent('fgfhds')
226 self.store = DB(self.db)
227 self.store.storeFile(self.file, self.hash)
229 def test_openExistsingDB(self):
233 self.store = DB(self.db)
234 res = self.store.isUnchanged(self.file)
237 def test_getFile(self):
238 res = self.store.getFile(self.file)
240 self.failUnlessEqual(res['hash'], self.hash)
242 def test_isUnchanged(self):
243 res = self.store.isUnchanged(self.file)
247 res = self.store.isUnchanged(self.file)
248 self.failUnless(res == False)
250 res = self.store.isUnchanged(self.file)
251 self.failUnless(res == None)
253 def test_expiry(self):
254 res = self.store.expiredFiles(1)
255 self.failUnlessEqual(len(res.keys()), 0)
257 res = self.store.expiredFiles(1)
258 self.failUnlessEqual(len(res.keys()), 1)
259 self.failUnlessEqual(res.keys()[0], self.hash)
260 self.failUnlessEqual(len(res[self.hash]), 1)
261 res = self.store.refreshFile(self.file)
263 res = self.store.expiredFiles(1)
264 self.failUnlessEqual(len(res.keys()), 0)
266 def build_dirs(self):
267 for dir in self.dirs:
268 file = dir.preauthChild(self.testfile)
269 if not file.parent().exists():
270 file.parent().makedirs()
271 file.setContent(file.path)
273 self.store.storeFile(file, self.hash)
275 def test_removeUntracked(self):
277 res = self.store.removeUntrackedFiles(self.dirs)
278 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
279 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
280 res = self.store.removeUntrackedFiles(self.dirs)
281 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
282 res = self.store.removeUntrackedFiles(self.dirs[1:])
283 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
284 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
285 res = self.store.removeUntrackedFiles(self.dirs[:1])
286 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
287 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
288 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
291 self.directory.remove()