2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
11 assert sqlite.version_info >= (2, 1)
13 class DBExcept(Exception):
17 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
25 """Database access for storing persistent data."""
27 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
39 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self):
45 if not self.db.parent().exists():
46 self.db.parent().makedirs()
47 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48 c = self.conn.cursor()
49 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50 c.execute("CREATE INDEX files_urldir ON files(urldir)")
51 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
52 c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
53 c.execute("CREATE INDEX dirs_path ON dirs(path)")
57 def _removeChanged(self, file, row):
62 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
64 c = self.conn.cursor()
65 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
70 def storeFile(self, file, hash, directory):
71 """Store or update a file in the database.
73 @return: the urlpath to access the file, and whether a
74 new url top-level directory was needed
77 c = self.conn.cursor()
78 c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
80 if row and directory == row['directory']:
81 c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?",
82 (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
84 urldir = row['urldir']
86 urldir, newdir = self.findDirectory(directory)
87 c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
88 (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
91 return '/~' + str(urldir) + file.path[len(directory.path):], newdir
93 def getFile(self, file):
94 """Get a file from the database.
96 If it has changed or is missing, it is removed from the database.
98 @return: dictionary of info for the file, False if changed, or
99 None if not in database or missing
101 c = self.conn.cursor()
102 c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
104 res = self._removeChanged(file, row)
107 res['hash'] = row['hash']
108 res['size'] = row['size']
109 res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
113 def isUnchanged(self, file):
114 """Check if a file in the file system has changed.
116 If it has changed, it is removed from the table.
118 @return: True if unchanged, False if changed, None if not in database
120 c = self.conn.cursor()
121 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
123 return self._removeChanged(file, row)
125 def refreshFile(self, file):
126 """Refresh the publishing time of a file.
128 If it has changed or is missing, it is removed from the table.
130 @return: True if unchanged, False if changed, None if not in database
132 c = self.conn.cursor()
133 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
135 res = self._removeChanged(file, row)
137 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
140 def expiredFiles(self, expireAfter):
141 """Find files that need refreshing after expireAfter seconds.
143 Also removes any entries from the table that no longer exist.
145 @return: dictionary with keys the hashes, values a list of url paths
147 t = datetime.now() - timedelta(seconds=expireAfter)
148 c = self.conn.cursor()
149 c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
153 res = self._removeChanged(FilePath(row['path']), row)
155 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
160 def removeUntrackedFiles(self, dirs):
161 """Find files that are no longer tracked and so should be removed.
163 Also removes the entries from the table.
165 @return: list of files that were removed
167 assert len(dirs) >= 1
171 newdirs.append(dir.child('*').path)
172 sql += " path NOT GLOB ? AND"
175 c = self.conn.cursor()
176 c.execute("SELECT path FROM files " + sql, newdirs)
180 removed.append(FilePath(row['path']))
184 c.execute("DELETE FROM files " + sql, newdirs)
188 def findDirectory(self, directory):
189 """Store or update a directory in the database.
191 @return: the index of the url directory, and whether it is new or not
193 c = self.conn.cursor()
194 c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
198 return row['urldir'], False
200 # Not found, need to add a new one
201 c = self.conn.cursor()
202 c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
208 def getAllDirectories(self):
209 """Get all the current directories avaliable."""
210 c = self.conn.cursor()
211 c.execute("SELECT urldir, path FROM dirs")
215 dirs['~' + str(row['urldir'])] = FilePath(row['path'])
220 def reconcileDirectories(self):
221 """Remove any unneeded directories by checking which are used by files."""
222 c = self.conn.cursor()
223 c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
225 return bool(c.rowcount)
230 class TestDB(unittest.TestCase):
231 """Tests for the khashmir database."""
234 db = FilePath('/tmp/khashmir.db')
235 file = FilePath('/tmp/apt-dht/khashmir.test')
236 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
237 directory = FilePath('/tmp/apt-dht/')
238 urlpath = '/~1/khashmir.test'
239 testfile = 'tmp/khashmir.test'
240 dirs = [FilePath('/tmp/apt-dht/top1'),
241 FilePath('/tmp/apt-dht/top2/sub1'),
242 FilePath('/tmp/apt-dht/top2/sub2/')]
245 if not self.file.parent().exists():
246 self.file.parent().makedirs()
247 self.file.setContent('fgfhds')
249 self.store = DB(self.db)
250 self.store.storeFile(self.file, self.hash, self.directory)
252 def test_openExistsingDB(self):
256 self.store = DB(self.db)
257 res = self.store.isUnchanged(self.file)
260 def test_getFile(self):
261 res = self.store.getFile(self.file)
263 self.failUnlessEqual(res['hash'], self.hash)
264 self.failUnlessEqual(res['urlpath'], self.urlpath)
266 def test_getAllDirectories(self):
267 res = self.store.getAllDirectories()
269 self.failUnlessEqual(len(res.keys()), 1)
270 self.failUnlessEqual(res.keys()[0], '~1')
271 self.failUnlessEqual(res['~1'], self.directory)
273 def test_isUnchanged(self):
274 res = self.store.isUnchanged(self.file)
278 res = self.store.isUnchanged(self.file)
279 self.failUnless(res == False)
281 res = self.store.isUnchanged(self.file)
282 self.failUnless(res == None)
284 def test_expiry(self):
285 res = self.store.expiredFiles(1)
286 self.failUnlessEqual(len(res.keys()), 0)
288 res = self.store.expiredFiles(1)
289 self.failUnlessEqual(len(res.keys()), 1)
290 self.failUnlessEqual(res.keys()[0], self.hash)
291 self.failUnlessEqual(len(res[self.hash]), 1)
292 self.failUnlessEqual(res[self.hash][0], self.urlpath)
293 res = self.store.refreshFile(self.file)
295 res = self.store.expiredFiles(1)
296 self.failUnlessEqual(len(res.keys()), 0)
298 def build_dirs(self):
299 for dir in self.dirs:
300 file = dir.preauthChild(self.testfile)
301 if not file.parent().exists():
302 file.parent().makedirs()
303 file.setContent(file.path)
305 self.store.storeFile(file, self.hash, dir)
307 def test_removeUntracked(self):
309 res = self.store.removeUntrackedFiles(self.dirs)
310 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
311 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
312 res = self.store.removeUntrackedFiles(self.dirs)
313 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
314 res = self.store.removeUntrackedFiles(self.dirs[1:])
315 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
316 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
317 res = self.store.removeUntrackedFiles(self.dirs[:1])
318 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
319 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
320 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
322 def test_reconcileDirectories(self):
324 res = self.store.getAllDirectories()
326 self.failUnlessEqual(len(res.keys()), 4)
327 res = self.store.reconcileDirectories()
328 self.failUnlessEqual(res, False)
329 res = self.store.getAllDirectories()
331 self.failUnlessEqual(len(res.keys()), 4)
332 res = self.store.removeUntrackedFiles(self.dirs)
333 res = self.store.reconcileDirectories()
334 self.failUnlessEqual(res, True)
335 res = self.store.getAllDirectories()
337 self.failUnlessEqual(len(res.keys()), 3)
338 res = self.store.removeUntrackedFiles(self.dirs[:1])
339 res = self.store.reconcileDirectories()
340 self.failUnlessEqual(res, True)
341 res = self.store.getAllDirectories()
343 self.failUnlessEqual(len(res.keys()), 1)
344 res = self.store.removeUntrackedFiles([FilePath('/what')])
345 res = self.store.reconcileDirectories()
346 self.failUnlessEqual(res, True)
347 res = self.store.getAllDirectories()
348 self.failUnlessEqual(len(res.keys()), 0)
351 self.directory.remove()