2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
11 assert sqlite.version_info >= (2, 1)
13 class DBExcept(Exception):
17 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
25 """Database access for storing persistent data."""
27 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
39 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self):
45 if not self.db.parent().exists():
46 self.db.parent().makedirs()
47 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48 c = self.conn.cursor()
49 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50 c.execute("CREATE INDEX files_hash ON files(hash)")
51 c.execute("CREATE INDEX files_urldir ON files(urldir)")
52 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
53 c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
54 c.execute("CREATE INDEX dirs_path ON dirs(path)")
58 def _removeChanged(self, file, row):
63 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
65 c = self.conn.cursor()
66 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
71 def storeFile(self, file, hash, directory):
72 """Store or update a file in the database.
74 @return: the urlpath to access the file, and whether a
75 new url top-level directory was needed
78 c = self.conn.cursor()
79 c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
81 if row and directory == row['directory']:
82 c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?",
83 (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
85 urldir = row['urldir']
87 urldir, newdir = self.findDirectory(directory)
88 c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
89 (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
92 return '/~' + str(urldir) + file.path[len(directory.path):], newdir
94 def getFile(self, file):
95 """Get a file from the database.
97 If it has changed or is missing, it is removed from the database.
99 @return: dictionary of info for the file, False if changed, or
100 None if not in database or missing
102 c = self.conn.cursor()
103 c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
107 res = self._removeChanged(file, row)
110 res['hash'] = row['hash']
111 res['size'] = row['size']
112 res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
116 def lookupHash(self, hash):
117 """Find a file by hash in the database.
119 If any found files have changed or are missing, they are removed
122 @return: list of dictionaries of info for the found files
124 c = self.conn.cursor()
125 c.execute("SELECT path, urldir, dirlength, size, mtime FROM files WHERE hash = ? ORDER BY urldir", (khash(hash), ))
129 file = FilePath(row['path'])
130 res = self._removeChanged(file, row)
134 res['size'] = row['size']
135 res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
141 def isUnchanged(self, file):
142 """Check if a file in the file system has changed.
144 If it has changed, it is removed from the table.
146 @return: True if unchanged, False if changed, None if not in database
148 c = self.conn.cursor()
149 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
151 return self._removeChanged(file, row)
153 def refreshFile(self, file):
154 """Refresh the publishing time of a file.
156 If it has changed or is missing, it is removed from the table.
158 @return: True if unchanged, False if changed, None if not in database
160 c = self.conn.cursor()
161 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
165 res = self._removeChanged(file, row)
167 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
170 def expiredFiles(self, expireAfter):
171 """Find files that need refreshing after expireAfter seconds.
173 Also removes any entries from the table that no longer exist.
175 @return: dictionary with keys the hashes, values a list of url paths
177 t = datetime.now() - timedelta(seconds=expireAfter)
178 c = self.conn.cursor()
179 c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
183 res = self._removeChanged(FilePath(row['path']), row)
185 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
190 def removeUntrackedFiles(self, dirs):
191 """Find files that are no longer tracked and so should be removed.
193 Also removes the entries from the table.
195 @return: list of files that were removed
197 assert len(dirs) >= 1
201 newdirs.append(dir.child('*').path)
202 sql += " path NOT GLOB ? AND"
205 c = self.conn.cursor()
206 c.execute("SELECT path FROM files " + sql, newdirs)
210 removed.append(FilePath(row['path']))
214 c.execute("DELETE FROM files " + sql, newdirs)
218 def findDirectory(self, directory):
219 """Store or update a directory in the database.
221 @return: the index of the url directory, and whether it is new or not
223 c = self.conn.cursor()
224 c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
228 return row['urldir'], False
230 # Not found, need to add a new one
231 c = self.conn.cursor()
232 c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
238 def getAllDirectories(self):
239 """Get all the current directories avaliable."""
240 c = self.conn.cursor()
241 c.execute("SELECT urldir, path FROM dirs")
245 dirs['~' + str(row['urldir'])] = FilePath(row['path'])
250 def reconcileDirectories(self):
251 """Remove any unneeded directories by checking which are used by files."""
252 c = self.conn.cursor()
253 c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
255 return bool(c.rowcount)
260 class TestDB(unittest.TestCase):
261 """Tests for the khashmir database."""
264 db = FilePath('/tmp/khashmir.db')
265 file = FilePath('/tmp/apt-dht/khashmir.test')
266 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
267 directory = FilePath('/tmp/apt-dht/')
268 urlpath = '/~1/khashmir.test'
269 testfile = 'tmp/khashmir.test'
270 dirs = [FilePath('/tmp/apt-dht/top1'),
271 FilePath('/tmp/apt-dht/top2/sub1'),
272 FilePath('/tmp/apt-dht/top2/sub2/')]
275 if not self.file.parent().exists():
276 self.file.parent().makedirs()
277 self.file.setContent('fgfhds')
279 self.store = DB(self.db)
280 self.store.storeFile(self.file, self.hash, self.directory)
282 def test_openExistsingDB(self):
286 self.store = DB(self.db)
287 res = self.store.isUnchanged(self.file)
290 def test_getFile(self):
291 res = self.store.getFile(self.file)
293 self.failUnlessEqual(res['hash'], self.hash)
294 self.failUnlessEqual(res['urlpath'], self.urlpath)
296 def test_getAllDirectories(self):
297 res = self.store.getAllDirectories()
299 self.failUnlessEqual(len(res.keys()), 1)
300 self.failUnlessEqual(res.keys()[0], '~1')
301 self.failUnlessEqual(res['~1'], self.directory)
303 def test_isUnchanged(self):
304 res = self.store.isUnchanged(self.file)
308 res = self.store.isUnchanged(self.file)
309 self.failUnless(res == False)
311 res = self.store.isUnchanged(self.file)
312 self.failUnless(res == None)
314 def test_expiry(self):
315 res = self.store.expiredFiles(1)
316 self.failUnlessEqual(len(res.keys()), 0)
318 res = self.store.expiredFiles(1)
319 self.failUnlessEqual(len(res.keys()), 1)
320 self.failUnlessEqual(res.keys()[0], self.hash)
321 self.failUnlessEqual(len(res[self.hash]), 1)
322 self.failUnlessEqual(res[self.hash][0], self.urlpath)
323 res = self.store.refreshFile(self.file)
325 res = self.store.expiredFiles(1)
326 self.failUnlessEqual(len(res.keys()), 0)
328 def build_dirs(self):
329 for dir in self.dirs:
330 file = dir.preauthChild(self.testfile)
331 if not file.parent().exists():
332 file.parent().makedirs()
333 file.setContent(file.path)
335 self.store.storeFile(file, self.hash, dir)
337 def test_removeUntracked(self):
339 res = self.store.removeUntrackedFiles(self.dirs)
340 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
341 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
342 res = self.store.removeUntrackedFiles(self.dirs)
343 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
344 res = self.store.removeUntrackedFiles(self.dirs[1:])
345 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
346 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
347 res = self.store.removeUntrackedFiles(self.dirs[:1])
348 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
349 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
350 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
352 def test_reconcileDirectories(self):
354 res = self.store.getAllDirectories()
356 self.failUnlessEqual(len(res.keys()), 4)
357 res = self.store.reconcileDirectories()
358 self.failUnlessEqual(res, False)
359 res = self.store.getAllDirectories()
361 self.failUnlessEqual(len(res.keys()), 4)
362 res = self.store.removeUntrackedFiles(self.dirs)
363 res = self.store.reconcileDirectories()
364 self.failUnlessEqual(res, True)
365 res = self.store.getAllDirectories()
367 self.failUnlessEqual(len(res.keys()), 3)
368 res = self.store.removeUntrackedFiles(self.dirs[:1])
369 res = self.store.reconcileDirectories()
370 self.failUnlessEqual(res, True)
371 res = self.store.getAllDirectories()
373 self.failUnlessEqual(len(res.keys()), 1)
374 res = self.store.removeUntrackedFiles([FilePath('/what')])
375 res = self.store.reconcileDirectories()
376 self.failUnlessEqual(res, True)
377 res = self.store.getAllDirectories()
378 self.failUnlessEqual(len(res.keys()), 0)
381 self.directory.remove()