2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
11 assert sqlite.version_info >= (2, 1)
13 class DBExcept(Exception):
17 """Dummy class to convert all hashes to base64 for storing in the DB."""
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
25 """Database access for storing persistent data."""
27 def __init__(self, db):
34 self.conn.text_factory = str
35 self.conn.row_factory = sqlite.Row
39 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
42 raise DBExcept, "Couldn't open DB", traceback.format_exc()
44 def _createNewDB(self):
45 if not self.db.parent().exists():
46 self.db.parent().makedirs()
47 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48 c = self.conn.cursor()
49 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
50 "size NUMBER, mtime NUMBER)")
51 c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
52 "hash KHASH UNIQUE, pieces KHASH, " +
53 "piecehash KHASH, refreshed TIMESTAMP)")
54 c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
55 c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
59 def _removeChanged(self, file, row):
64 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
66 c = self.conn.cursor()
67 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
72 def storeFile(self, file, hash, pieces = ''):
73 """Store or update a file in the database.
75 @return: True if the hash was not in the database before
76 (so it needs to be added to the DHT)
80 s = sha.new().update(pieces)
81 piecehash = sha.digest()
82 c = self.conn.cursor()
83 c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
86 assert piecehash == row['piecehash']
88 hashID = row['hashID']
90 c = self.conn.cursor()
91 c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
92 (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
98 c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
99 (file.path, hashID, file.getsize(), file.getmtime()))
105 def getFile(self, file):
106 """Get a file from the database.
108 If it has changed or is missing, it is removed from the database.
110 @return: dictionary of info for the file, False if changed, or
111 None if not in database or missing
113 c = self.conn.cursor()
114 c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
118 res = self._removeChanged(file, row)
121 res['hash'] = row['hash']
122 res['size'] = row['size']
123 res['pieces'] = row['pieces']
127 def lookupHash(self, hash, filesOnly = False):
128 """Find a file by hash in the database.
130 If any found files have changed or are missing, they are removed
131 from the database. If filesOnly is False then it will also look for
132 piece string hashes if no files can be found.
134 @return: list of dictionaries of info for the found files
136 c = self.conn.cursor()
137 c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
141 file = FilePath(row['path'])
142 res = self._removeChanged(file, row)
146 res['size'] = row['size']
147 res['refreshed'] = row['refreshed']
148 res['pieces'] = row['pieces']
152 if not filesOnly and not files:
153 c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
157 res['refreshed'] = row['refreshed']
158 res['pieces'] = row['pieces']
164 def isUnchanged(self, file):
165 """Check if a file in the file system has changed.
167 If it has changed, it is removed from the table.
169 @return: True if unchanged, False if changed, None if not in database
171 c = self.conn.cursor()
172 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
174 return self._removeChanged(file, row)
176 def refreshHash(self, hash):
177 """Refresh the publishing time all files with a hash."""
178 c = self.conn.cursor()
179 c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
182 def expiredHashes(self, expireAfter):
183 """Find files that need refreshing after expireAfter seconds.
185 For each hash that needs refreshing, finds all the files with that hash.
186 If the file has changed or is missing, it is removed from the table.
188 @return: dictionary with keys the hashes, values a list of FilePaths
190 t = datetime.now() - timedelta(seconds=expireAfter)
192 # First find the hashes that need refreshing
193 c = self.conn.cursor()
194 c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
198 res = expired.setdefault(row['hash'], {})
199 res['hashID'] = row['hashID']
200 res['hash'] = row['hash']
201 res['pieces'] = row['pieces']
204 # Make sure there are still valid files for each hash
205 for hash in expired.values():
207 c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
210 res = self._removeChanged(FilePath(row['path']), row)
215 del expired[hash['hash']]
216 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
223 def removeUntrackedFiles(self, dirs):
224 """Find files that are no longer tracked and so should be removed.
226 Also removes the entries from the table.
228 @return: list of files that were removed
230 assert len(dirs) >= 1
234 newdirs.append(dir.child('*').path)
235 sql += " path NOT GLOB ? AND"
238 c = self.conn.cursor()
239 c.execute("SELECT path FROM files " + sql, newdirs)
243 removed.append(FilePath(row['path']))
247 c.execute("DELETE FROM files " + sql, newdirs)
254 class TestDB(unittest.TestCase):
255 """Tests for the khashmir database."""
258 db = FilePath('/tmp/khashmir.db')
259 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
260 directory = FilePath('/tmp/apt-dht/')
261 file = FilePath('/tmp/apt-dht/khashmir.test')
262 testfile = 'tmp/khashmir.test'
263 dirs = [FilePath('/tmp/apt-dht/top1'),
264 FilePath('/tmp/apt-dht/top2/sub1'),
265 FilePath('/tmp/apt-dht/top2/sub2/')]
268 if not self.file.parent().exists():
269 self.file.parent().makedirs()
270 self.file.setContent('fgfhds')
272 self.store = DB(self.db)
273 self.store.storeFile(self.file, self.hash)
275 def test_openExistingDB(self):
279 self.store = DB(self.db)
280 res = self.store.isUnchanged(self.file)
283 def test_getFile(self):
284 res = self.store.getFile(self.file)
286 self.failUnlessEqual(res['hash'], self.hash)
288 def test_lookupHash(self):
289 res = self.store.lookupHash(self.hash)
291 self.failUnlessEqual(len(res), 1)
292 self.failUnlessEqual(res[0]['path'].path, self.file.path)
294 def test_isUnchanged(self):
295 res = self.store.isUnchanged(self.file)
299 res = self.store.isUnchanged(self.file)
300 self.failUnless(res == False)
301 res = self.store.isUnchanged(self.file)
302 self.failUnless(res is None)
304 def test_expiry(self):
305 res = self.store.expiredHashes(1)
306 self.failUnlessEqual(len(res.keys()), 0)
308 res = self.store.expiredHashes(1)
309 self.failUnlessEqual(len(res.keys()), 1)
310 self.failUnlessEqual(res.keys()[0], self.hash)
311 self.store.refreshHash(self.hash)
312 res = self.store.expiredHashes(1)
313 self.failUnlessEqual(len(res.keys()), 0)
315 def build_dirs(self):
316 for dir in self.dirs:
317 file = dir.preauthChild(self.testfile)
318 if not file.parent().exists():
319 file.parent().makedirs()
320 file.setContent(file.path)
322 self.store.storeFile(file, self.hash)
324 def test_multipleHashes(self):
326 res = self.store.expiredHashes(1)
327 self.failUnlessEqual(len(res.keys()), 0)
328 res = self.store.lookupHash(self.hash)
330 self.failUnlessEqual(len(res), 4)
331 self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
332 self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
333 self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
335 res = self.store.expiredHashes(1)
336 self.failUnlessEqual(len(res.keys()), 1)
337 self.failUnlessEqual(res.keys()[0], self.hash)
338 self.store.refreshHash(self.hash)
339 res = self.store.expiredHashes(1)
340 self.failUnlessEqual(len(res.keys()), 0)
342 def test_removeUntracked(self):
344 res = self.store.removeUntrackedFiles(self.dirs)
345 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
346 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
347 res = self.store.removeUntrackedFiles(self.dirs)
348 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
349 res = self.store.removeUntrackedFiles(self.dirs[1:])
350 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
351 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
352 res = self.store.removeUntrackedFiles(self.dirs[:1])
353 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
354 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
355 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
358 self.directory.remove()