2 """An sqlite database for storing persistent files and hashes."""
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
13 assert sqlite.version_info >= (2, 1)
15 class DBExcept(Exception):
16 """An error occurred in accessing the database."""
20 """Dummy class to convert all hashes to base64 for storing in the DB."""
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
29 """An sqlite database for storing persistent files and hashes.
31 @type db: L{twisted.python.filepath.FilePath}
32 @ivar db: the database file to use
33 @type conn: L{pysqlite2.dbapi2.Connection}
34 @ivar conn: an open connection to the sqlite database
37 def __init__(self, db):
38 """Load or create the database file.
40 @type db: L{twisted.python.filepath.FilePath}
41 @param db: the database file to use
49 self.conn.text_factory = str
50 self.conn.row_factory = sqlite.Row
53 """Open a new connection to the existing database file"""
55 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
58 raise DBExcept, "Couldn't open DB", traceback.format_exc()
60 def _createNewDB(self):
61 """Open a connection to a new database and create the necessary tables."""
62 if not self.db.parent().exists():
63 self.db.parent().makedirs()
64 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
65 c = self.conn.cursor()
66 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
67 "size NUMBER, mtime NUMBER)")
68 c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
69 "hash KHASH UNIQUE, pieces KHASH, " +
70 "piecehash KHASH, refreshed TIMESTAMP)")
71 c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
72 c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
76 def _removeChanged(self, file, row):
77 """If the file has changed or is missing, remove it from the DB.
79 @type file: L{twisted.python.filepath.FilePath}
80 @param file: the file to check
81 @type row: C{dictionary}-like object
82 @param row: contains the expected 'size' and 'mtime' of the file
84 @return: True if the file is unchanged, False if it is changed,
85 and None if it is missing
91 # Compare the current with the expected file properties
92 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
94 # Remove the file from the database
95 c = self.conn.cursor()
96 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
101 def storeFile(self, file, hash, pieces = ''):
102 """Store or update a file in the database.
104 @type file: L{twisted.python.filepath.FilePath}
105 @param file: the file to check
106 @type hash: C{string}
107 @param hash: the hash of the file
108 @type pieces: C{string}
109 @param pieces: the concatenated list of the hashes of the pieces of
110 the file (optional, defaults to the empty string)
111 @return: True if the hash was not in the database before
112 (so it needs to be added to the DHT)
114 # Hash the pieces to get the piecehash
117 s = sha.new().update(pieces)
118 piecehash = sha.digest()
120 # Check the database for the hash
121 c = self.conn.cursor()
122 c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
125 assert piecehash == row['piecehash']
127 hashID = row['hashID']
129 # Add the new hash to the database
130 c = self.conn.cursor()
131 c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
132 (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
137 # Add the file to the database
139 c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
140 (file.path, hashID, file.getsize(), file.getmtime()))
146 def getFile(self, file):
147 """Get a file from the database.
149 If it has changed or is missing, it is removed from the database.
151 @type file: L{twisted.python.filepath.FilePath}
152 @param file: the file to check
153 @return: dictionary of info for the file, False if changed, or
154 None if not in database or missing
156 c = self.conn.cursor()
157 c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
161 res = self._removeChanged(file, row)
164 res['hash'] = row['hash']
165 res['size'] = row['size']
166 res['pieces'] = row['pieces']
170 def lookupHash(self, hash, filesOnly = False):
171 """Find a file by hash in the database.
173 If any found files have changed or are missing, they are removed
174 from the database. If filesOnly is False then it will also look for
175 piece string hashes if no files can be found.
177 @return: list of dictionaries of info for the found files
179 # Try to find the hash in the files table
180 c = self.conn.cursor()
181 c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
185 # Save the file to the list of found files
186 file = FilePath(row['path'])
187 res = self._removeChanged(file, row)
191 res['size'] = row['size']
192 res['refreshed'] = row['refreshed']
193 res['pieces'] = row['pieces']
197 if not filesOnly and not files:
198 # No files were found, so check the piecehashes as well
199 c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
203 res['refreshed'] = row['refreshed']
204 res['pieces'] = row['pieces']
210 def isUnchanged(self, file):
211 """Check if a file in the file system has changed.
213 If it has changed, it is removed from the database.
215 @return: True if unchanged, False if changed, None if not in database
217 c = self.conn.cursor()
218 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
220 return self._removeChanged(file, row)
222 def refreshHash(self, hash):
223 """Refresh the publishing time of a hash."""
224 c = self.conn.cursor()
225 c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
228 def expiredHashes(self, expireAfter):
229 """Find files that need refreshing after expireAfter seconds.
231 For each hash that needs refreshing, finds all the files with that hash.
232 If the file has changed or is missing, it is removed from the table.
234 @return: dictionary with keys the hashes, values a list of FilePaths
236 t = datetime.now() - timedelta(seconds=expireAfter)
238 # Find all the hashes that need refreshing
239 c = self.conn.cursor()
240 c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
244 res = expired.setdefault(row['hash'], {})
245 res['hashID'] = row['hashID']
246 res['hash'] = row['hash']
247 res['pieces'] = row['pieces']
250 # Make sure there are still valid files for each hash
251 for hash in expired.values():
253 c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
256 res = self._removeChanged(FilePath(row['path']), row)
261 # Remove hashes for which no files are still available
262 del expired[hash['hash']]
263 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
270 def removeUntrackedFiles(self, dirs):
271 """Remove files that are no longer tracked by the program.
273 @type dirs: C{list} of L{twisted.python.filepath.FilePath}
274 @param dirs: a list of the directories that we are tracking
275 @return: list of files that were removed
277 assert len(dirs) >= 1
279 # Create a list of globs and an SQL statement for the directories
283 newdirs.append(dir.child('*').path)
284 sql += " path NOT GLOB ? AND"
287 # Get a listing of all the files that will be removed
288 c = self.conn.cursor()
289 c.execute("SELECT path FROM files " + sql, newdirs)
293 removed.append(FilePath(row['path']))
296 # Delete all the removed files from the database
298 c.execute("DELETE FROM files " + sql, newdirs)
304 """Close the database connection."""
307 class TestDB(unittest.TestCase):
308 """Tests for the khashmir database."""
311 db = FilePath('/tmp/khashmir.db')
312 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
313 directory = FilePath('/tmp/apt-p2p/')
314 file = FilePath('/tmp/apt-p2p/khashmir.test')
315 testfile = 'tmp/khashmir.test'
316 dirs = [FilePath('/tmp/apt-p2p/top1'),
317 FilePath('/tmp/apt-p2p/top2/sub1'),
318 FilePath('/tmp/apt-p2p/top2/sub2/')]
321 if not self.file.parent().exists():
322 self.file.parent().makedirs()
323 self.file.setContent('fgfhds')
325 self.store = DB(self.db)
326 self.store.storeFile(self.file, self.hash)
328 def test_openExistingDB(self):
329 """Tests opening an existing database."""
333 self.store = DB(self.db)
334 res = self.store.isUnchanged(self.file)
337 def test_getFile(self):
338 """Tests retrieving a file from the database."""
339 res = self.store.getFile(self.file)
341 self.failUnlessEqual(res['hash'], self.hash)
343 def test_lookupHash(self):
344 """Tests looking up a hash in the database."""
345 res = self.store.lookupHash(self.hash)
347 self.failUnlessEqual(len(res), 1)
348 self.failUnlessEqual(res[0]['path'].path, self.file.path)
350 def test_isUnchanged(self):
351 """Tests checking if a file in the database is unchanged."""
352 res = self.store.isUnchanged(self.file)
356 res = self.store.isUnchanged(self.file)
357 self.failUnless(res == False)
358 res = self.store.isUnchanged(self.file)
359 self.failUnless(res is None)
361 def test_expiry(self):
362 """Tests retrieving the files from the database that have expired."""
363 res = self.store.expiredHashes(1)
364 self.failUnlessEqual(len(res.keys()), 0)
366 res = self.store.expiredHashes(1)
367 self.failUnlessEqual(len(res.keys()), 1)
368 self.failUnlessEqual(res.keys()[0], self.hash)
369 self.store.refreshHash(self.hash)
370 res = self.store.expiredHashes(1)
371 self.failUnlessEqual(len(res.keys()), 0)
373 def build_dirs(self):
374 for dir in self.dirs:
375 file = dir.preauthChild(self.testfile)
376 if not file.parent().exists():
377 file.parent().makedirs()
378 file.setContent(file.path)
380 self.store.storeFile(file, self.hash)
382 def test_multipleHashes(self):
383 """Tests looking up a hash with multiple files in the database."""
385 res = self.store.expiredHashes(1)
386 self.failUnlessEqual(len(res.keys()), 0)
387 res = self.store.lookupHash(self.hash)
389 self.failUnlessEqual(len(res), 4)
390 self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
391 self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
392 self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
394 res = self.store.expiredHashes(1)
395 self.failUnlessEqual(len(res.keys()), 1)
396 self.failUnlessEqual(res.keys()[0], self.hash)
397 self.store.refreshHash(self.hash)
398 res = self.store.expiredHashes(1)
399 self.failUnlessEqual(len(res.keys()), 0)
401 def test_removeUntracked(self):
402 """Tests removing untracked files from the database."""
404 res = self.store.removeUntrackedFiles(self.dirs)
405 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
406 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
407 res = self.store.removeUntrackedFiles(self.dirs)
408 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
409 res = self.store.removeUntrackedFiles(self.dirs[1:])
410 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
411 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
412 res = self.store.removeUntrackedFiles(self.dirs[:1])
413 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
414 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
415 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
418 self.directory.remove()