2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.trial import unittest
10 assert sqlite.version_info >= (2, 1)
12 class DBExcept(Exception):
16 """Dummy class to convert all hashes to base64 for storing in the DB."""
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
21 sqlite.enable_callback_tracebacks(True)
24 """Database access for storing persistent data."""
26 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
37 def _loadDB(self, db):
39 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self, db):
45 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46 c = self.conn.cursor()
47 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
48 c.execute("CREATE INDEX files_urldir ON files(urldir)")
49 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
50 c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
51 c.execute("CREATE INDEX dirs_path ON dirs(path)")
55 def _removeChanged(self, path, row):
63 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
65 c = self.conn.cursor()
66 c.execute("DELETE FROM files WHERE path = ?", (path, ))
71 def storeFile(self, path, hash, directory):
72 """Store or update a file in the database.
74 @return: the urlpath to access the file, and whether a
75 new url top-level directory was needed
77 path = os.path.abspath(path)
78 directory = os.path.abspath(directory)
79 assert path.startswith(directory)
81 c = self.conn.cursor()
82 c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (path, ))
84 if row and directory == row['directory']:
85 c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?",
86 (khash(hash), stat.st_size, stat.st_mtime, datetime.now()))
88 urldir = row['urldir']
90 urldir, newdir = self.findDirectory(directory)
91 c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
92 (path, khash(hash), urldir, len(directory), stat.st_size, stat.st_mtime, datetime.now()))
95 return '/~' + str(urldir) + path[len(directory):], newdir
97 def getFile(self, path):
98 """Get a file from the database.
100 If it has changed or is missing, it is removed from the database.
102 @return: dictionary of info for the file, False if changed, or
103 None if not in database or missing
105 path = os.path.abspath(path)
106 c = self.conn.cursor()
107 c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (path, ))
109 res = self._removeChanged(path, row)
112 res['hash'] = row['hash']
113 res['urlpath'] = '/~' + str(row['urldir']) + path[row['dirlength']:]
117 def isUnchanged(self, path):
118 """Check if a file in the file system has changed.
120 If it has changed, it is removed from the table.
122 @return: True if unchanged, False if changed, None if not in database
124 path = os.path.abspath(path)
125 c = self.conn.cursor()
126 c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
128 return self._removeChanged(path, row)
130 def refreshFile(self, path):
131 """Refresh the publishing time of a file.
133 If it has changed or is missing, it is removed from the table.
135 @return: True if unchanged, False if changed, None if not in database
137 path = os.path.abspath(path)
138 c = self.conn.cursor()
139 c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
141 res = self._removeChanged(path, row)
143 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
146 def expiredFiles(self, expireAfter):
147 """Find files that need refreshing after expireAfter seconds.
149 Also removes any entries from the table that no longer exist.
151 @return: dictionary with keys the hashes, values a list of url paths
153 t = datetime.now() - timedelta(seconds=expireAfter)
154 c = self.conn.cursor()
155 c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
159 res = self._removeChanged(row['path'], row)
161 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
166 def removeUntrackedFiles(self, dirs):
167 """Find files that are no longer tracked and so should be removed.
169 Also removes the entries from the table.
171 @return: list of files that were removed
173 assert len(dirs) >= 1
177 newdirs.append(os.path.abspath(dir) + os.sep + '*')
178 sql += " path NOT GLOB ? AND"
181 c = self.conn.cursor()
182 c.execute("SELECT path FROM files " + sql, newdirs)
186 removed.append(row['path'])
190 c.execute("DELETE FROM files " + sql, newdirs)
194 def findDirectory(self, directory):
195 """Store or update a directory in the database.
197 @return: the index of the url directory, and whether it is new or not
199 directory = os.path.abspath(directory)
200 c = self.conn.cursor()
201 c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory, ))
205 return row['urldir'], False
207 # Not found, need to add a new one
208 c = self.conn.cursor()
209 c.execute("INSERT INTO dirs (path) VALUES (?)", (directory, ))
215 def getAllDirectories(self):
216 """Get all the current directories avaliable."""
217 c = self.conn.cursor()
218 c.execute("SELECT urldir, path FROM dirs")
222 dirs['~' + str(row['urldir'])] = row['path']
227 def reconcileDirectories(self):
228 """Remove any unneeded directories by checking which are used by files."""
229 c = self.conn.cursor()
230 c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
232 return bool(c.rowcount)
237 class TestDB(unittest.TestCase):
238 """Tests for the khashmir database."""
241 db = '/tmp/khashmir.db'
242 path = '/tmp/khashmir.test'
243 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
245 urlpath = '/~1/khashmir.test'
246 dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
249 f = open(self.path, 'w')
252 os.utime(self.path, None)
253 self.store = DB(self.db)
254 self.store.storeFile(self.path, self.hash, self.directory)
256 def test_getFile(self):
257 res = self.store.getFile(self.path)
259 self.failUnlessEqual(res['hash'], self.hash)
260 self.failUnlessEqual(res['urlpath'], self.urlpath)
262 def test_getAllDirectories(self):
263 res = self.store.getAllDirectories()
265 self.failUnlessEqual(len(res.keys()), 1)
266 self.failUnlessEqual(res.keys()[0], '~1')
267 self.failUnlessEqual(res['~1'], os.path.abspath(self.directory))
269 def test_isUnchanged(self):
270 res = self.store.isUnchanged(self.path)
273 os.utime(self.path, None)
274 res = self.store.isUnchanged(self.path)
275 self.failUnless(res == False)
277 res = self.store.isUnchanged(self.path)
278 self.failUnless(res == None)
280 def test_expiry(self):
281 res = self.store.expiredFiles(1)
282 self.failUnlessEqual(len(res.keys()), 0)
284 res = self.store.expiredFiles(1)
285 self.failUnlessEqual(len(res.keys()), 1)
286 self.failUnlessEqual(res.keys()[0], self.hash)
287 self.failUnlessEqual(len(res[self.hash]), 1)
288 self.failUnlessEqual(res[self.hash][0], self.urlpath)
289 res = self.store.refreshFile(self.path)
291 res = self.store.expiredFiles(1)
292 self.failUnlessEqual(len(res.keys()), 0)
294 def build_dirs(self):
295 for dir in self.dirs:
296 path = os.path.join(dir, self.path[1:])
297 os.makedirs(os.path.dirname(path))
302 self.store.storeFile(path, self.hash, dir)
304 def test_removeUntracked(self):
306 res = self.store.removeUntrackedFiles(self.dirs)
307 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
308 self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
309 res = self.store.removeUntrackedFiles(self.dirs)
310 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
311 res = self.store.removeUntrackedFiles(self.dirs[1:])
312 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
313 self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
314 res = self.store.removeUntrackedFiles(self.dirs[:1])
315 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
316 self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
317 self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
319 def test_reconcileDirectories(self):
321 res = self.store.getAllDirectories()
323 self.failUnlessEqual(len(res.keys()), 4)
324 res = self.store.reconcileDirectories()
325 self.failUnlessEqual(res, False)
326 res = self.store.getAllDirectories()
328 self.failUnlessEqual(len(res.keys()), 4)
329 res = self.store.removeUntrackedFiles(self.dirs)
330 res = self.store.reconcileDirectories()
331 self.failUnlessEqual(res, True)
332 res = self.store.getAllDirectories()
334 self.failUnlessEqual(len(res.keys()), 3)
335 res = self.store.removeUntrackedFiles(self.dirs[:1])
336 res = self.store.reconcileDirectories()
337 self.failUnlessEqual(res, True)
338 res = self.store.getAllDirectories()
340 self.failUnlessEqual(len(res.keys()), 1)
341 res = self.store.removeUntrackedFiles(['/what'])
342 res = self.store.reconcileDirectories()
343 self.failUnlessEqual(res, True)
344 res = self.store.getAllDirectories()
345 self.failUnlessEqual(len(res.keys()), 0)
348 for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
350 os.remove(os.path.join(root, name))
352 os.rmdir(os.path.join(root, name))