2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
11 assert sqlite.version_info >= (2, 1)
13 class DBExcept(Exception):
17 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
25 """Database access for storing persistent data."""
27 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
39 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self):
45 if not self.db.parent().exists():
46 self.db.parent().makedirs()
47 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48 c = self.conn.cursor()
49 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50 c.execute("CREATE INDEX files_hash ON files(hash)")
51 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
55 def _removeChanged(self, file, row):
60 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
62 c = self.conn.cursor()
63 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
68 def storeFile(self, file, hash):
69 """Store or update a file in the database.
71 @return: True if the hash was not in the database before
72 (so it needs to be added to the DHT)
75 refreshTime = datetime.now()
76 c = self.conn.cursor()
77 c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
79 if row and row['max_refresh']:
81 refreshTime = row['max_refresh']
85 c = self.conn.cursor()
86 c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
89 c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?",
90 (khash(hash), file.getsize(), file.getmtime(), refreshTime))
92 c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
93 (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
99 def getFile(self, file):
100 """Get a file from the database.
102 If it has changed or is missing, it is removed from the database.
104 @return: dictionary of info for the file, False if changed, or
105 None if not in database or missing
107 c = self.conn.cursor()
108 c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
112 res = self._removeChanged(file, row)
115 res['hash'] = row['hash']
116 res['size'] = row['size']
120 def lookupHash(self, hash):
121 """Find a file by hash in the database.
123 If any found files have changed or are missing, they are removed
126 @return: list of dictionaries of info for the found files
128 c = self.conn.cursor()
129 c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
133 file = FilePath(row['path'])
134 res = self._removeChanged(file, row)
138 res['size'] = row['size']
139 res['refreshed'] = row['refreshed']
145 def isUnchanged(self, file):
146 """Check if a file in the file system has changed.
148 If it has changed, it is removed from the table.
150 @return: True if unchanged, False if changed, None if not in database
152 c = self.conn.cursor()
153 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
155 return self._removeChanged(file, row)
157 def refreshHash(self, hash):
158 """Refresh the publishing time all files with a hash."""
159 refreshTime = datetime.now()
160 c = self.conn.cursor()
161 c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
164 def expiredFiles(self, expireAfter):
165 """Find files that need refreshing after expireAfter seconds.
167 For each hash that needs refreshing, finds all the files with that hash.
168 If the file has changed or is missing, it is removed from the table.
170 @return: dictionary with keys the hashes, values a list of FilePaths
172 t = datetime.now() - timedelta(seconds=expireAfter)
174 # First find the hashes that need refreshing
175 c = self.conn.cursor()
176 c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
180 expired.setdefault(row['hash'], [])
184 # Now find the files for each hash
185 for hash in expired.keys():
186 c = self.conn.cursor()
187 c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
190 res = self._removeChanged(FilePath(row['path']), row)
192 expired[hash].append(FilePath(row['path']))
194 if len(expired[hash]) == 0:
200 def removeUntrackedFiles(self, dirs):
201 """Find files that are no longer tracked and so should be removed.
203 Also removes the entries from the table.
205 @return: list of files that were removed
207 assert len(dirs) >= 1
211 newdirs.append(dir.child('*').path)
212 sql += " path NOT GLOB ? AND"
215 c = self.conn.cursor()
216 c.execute("SELECT path FROM files " + sql, newdirs)
220 removed.append(FilePath(row['path']))
224 c.execute("DELETE FROM files " + sql, newdirs)
231 class TestDB(unittest.TestCase):
232 """Tests for the khashmir database."""
235 db = FilePath('/tmp/khashmir.db')
236 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
237 directory = FilePath('/tmp/apt-dht/')
238 file = FilePath('/tmp/apt-dht/khashmir.test')
239 testfile = 'tmp/khashmir.test'
240 dirs = [FilePath('/tmp/apt-dht/top1'),
241 FilePath('/tmp/apt-dht/top2/sub1'),
242 FilePath('/tmp/apt-dht/top2/sub2/')]
245 if not self.file.parent().exists():
246 self.file.parent().makedirs()
247 self.file.setContent('fgfhds')
249 self.store = DB(self.db)
250 self.store.storeFile(self.file, self.hash)
252 def test_openExistsingDB(self):
256 self.store = DB(self.db)
257 res = self.store.isUnchanged(self.file)
260 def test_getFile(self):
261 res = self.store.getFile(self.file)
263 self.failUnlessEqual(res['hash'], self.hash)
265 def test_lookupHash(self):
266 res = self.store.lookupHash(self.hash)
268 self.failUnlessEqual(len(res), 1)
269 self.failUnlessEqual(res[0]['path'].path, self.file.path)
271 def test_isUnchanged(self):
272 res = self.store.isUnchanged(self.file)
276 res = self.store.isUnchanged(self.file)
277 self.failUnless(res == False)
279 res = self.store.isUnchanged(self.file)
280 self.failUnless(res == None)
282 def test_expiry(self):
283 res = self.store.expiredFiles(1)
284 self.failUnlessEqual(len(res.keys()), 0)
286 res = self.store.expiredFiles(1)
287 self.failUnlessEqual(len(res.keys()), 1)
288 self.failUnlessEqual(res.keys()[0], self.hash)
289 self.failUnlessEqual(len(res[self.hash]), 1)
290 self.store.refreshHash(self.hash)
291 res = self.store.expiredFiles(1)
292 self.failUnlessEqual(len(res.keys()), 0)
294 def build_dirs(self):
295 for dir in self.dirs:
296 file = dir.preauthChild(self.testfile)
297 if not file.parent().exists():
298 file.parent().makedirs()
299 file.setContent(file.path)
301 self.store.storeFile(file, self.hash)
303 def test_multipleHashes(self):
305 res = self.store.expiredFiles(1)
306 self.failUnlessEqual(len(res.keys()), 0)
307 res = self.store.lookupHash(self.hash)
309 self.failUnlessEqual(len(res), 4)
310 self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
311 self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
312 self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
314 res = self.store.expiredFiles(1)
315 self.failUnlessEqual(len(res.keys()), 1)
316 self.failUnlessEqual(res.keys()[0], self.hash)
317 self.failUnlessEqual(len(res[self.hash]), 4)
318 self.store.refreshHash(self.hash)
319 res = self.store.expiredFiles(1)
320 self.failUnlessEqual(len(res.keys()), 0)
322 def test_removeUntracked(self):
324 res = self.store.removeUntrackedFiles(self.dirs)
325 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
326 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
327 res = self.store.removeUntrackedFiles(self.dirs)
328 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
329 res = self.store.removeUntrackedFiles(self.dirs[1:])
330 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
331 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
332 res = self.store.removeUntrackedFiles(self.dirs[:1])
333 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
334 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
335 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
338 self.directory.remove()