2 """An sqlite database for storing persistent files and hashes."""
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
13 assert sqlite.version_info >= (2, 1)
15 class DBExcept(Exception):
16 """An error occurred in accessing the database."""
20 """Dummy class to convert all hashes to base64 for storing in the DB."""
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
29 """An sqlite database for storing persistent files and hashes.
31 @type db: L{twisted.python.filepath.FilePath}
32 @ivar db: the database file to use
33 @type conn: L{pysqlite2.dbapi2.Connection}
34 @ivar conn: an open connection to the sqlite database
37 def __init__(self, db):
38 """Load or create the database file.
40 @type db: L{twisted.python.filepath.FilePath}
41 @param db: the database file to use
49 self.conn.text_factory = str
50 self.conn.row_factory = sqlite.Row
54 """Open a new connection to the existing database file"""
56 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
59 raise DBExcept, "Couldn't open DB", traceback.format_exc()
61 def _createNewDB(self):
62 """Open a connection to a new database and create the necessary tables."""
63 if not self.db.parent().exists():
64 self.db.parent().makedirs()
65 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
66 c = self.conn.cursor()
67 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
68 "size NUMBER, mtime NUMBER)")
69 c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
70 "hash KHASH UNIQUE, pieces KHASH, " +
71 "piecehash KHASH, refreshed TIMESTAMP)")
72 c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
73 c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
74 c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
75 c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
80 """Close the database connection."""
84 def _removeChanged(self, file, row):
85 """If the file has changed or is missing, remove it from the DB.
87 @type file: L{twisted.python.filepath.FilePath}
88 @param file: the file to check
89 @type row: C{dictionary}-like object
90 @param row: contains the expected 'size' and 'mtime' of the file
92 @return: True if the file is unchanged, False if it is changed,
93 and None if it is missing
99 # Compare the current with the expected file properties
100 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
102 # Remove the file from the database
103 c = self.conn.cursor()
104 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
109 def storeFile(self, file, hash, pieces = ''):
110 """Store or update a file in the database.
112 @type file: L{twisted.python.filepath.FilePath}
113 @param file: the file to check
114 @type hash: C{string}
115 @param hash: the hash of the file
116 @type pieces: C{string}
117 @param pieces: the concatenated list of the hashes of the pieces of
118 the file (optional, defaults to the empty string)
119 @return: True if the hash was not in the database before
120 (so it needs to be added to the DHT)
122 # Hash the pieces to get the piecehash
125 piecehash = sha.new(pieces).digest()
127 # Check the database for the hash
128 c = self.conn.cursor()
129 c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
132 assert piecehash == row['piecehash']
134 hashID = row['hashID']
136 # Add the new hash to the database
137 c = self.conn.cursor()
138 c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
139 (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
144 # Add the file to the database
146 c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
147 (file.path, hashID, file.getsize(), file.getmtime()))
153 def getFile(self, file):
154 """Get a file from the database.
156 If it has changed or is missing, it is removed from the database.
158 @type file: L{twisted.python.filepath.FilePath}
159 @param file: the file to check
160 @return: dictionary of info for the file, False if changed, or
161 None if not in database or missing
163 c = self.conn.cursor()
164 c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
168 res = self._removeChanged(file, row)
171 res['hash'] = row['hash']
172 res['size'] = row['size']
173 res['pieces'] = row['pieces']
177 def lookupHash(self, hash, filesOnly = False):
178 """Find a file by hash in the database.
180 If any found files have changed or are missing, they are removed
181 from the database. If filesOnly is False then it will also look for
182 piece string hashes if no files can be found.
184 @return: list of dictionaries of info for the found files
186 # Try to find the hash in the files table
187 c = self.conn.cursor()
188 c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
192 # Save the file to the list of found files
193 file = FilePath(row['path'])
194 res = self._removeChanged(file, row)
198 res['size'] = row['size']
199 res['refreshed'] = row['refreshed']
200 res['pieces'] = row['pieces']
204 if not filesOnly and not files:
205 # No files were found, so check the piecehashes as well
206 c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
210 res['refreshed'] = row['refreshed']
211 res['pieces'] = row['pieces']
217 def isUnchanged(self, file):
218 """Check if a file in the file system has changed.
220 If it has changed, it is removed from the database.
222 @return: True if unchanged, False if changed, None if not in database
224 c = self.conn.cursor()
225 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
227 return self._removeChanged(file, row)
229 def refreshHash(self, hash):
230 """Refresh the publishing time of a hash."""
231 c = self.conn.cursor()
232 c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
235 def expiredHashes(self, expireAfter):
236 """Find files that need refreshing after expireAfter seconds.
238 For each hash that needs refreshing, finds all the files with that hash.
239 If the file has changed or is missing, it is removed from the table.
241 @return: dictionary with keys the hashes, values a list of FilePaths
243 t = datetime.now() - timedelta(seconds=expireAfter)
245 # Find all the hashes that need refreshing
246 c = self.conn.cursor()
247 c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
251 res = expired.setdefault(row['hash'], {})
252 res['hashID'] = row['hashID']
253 res['hash'] = row['hash']
254 res['pieces'] = row['pieces']
257 # Make sure there are still valid files for each hash
258 for hash in expired.values():
260 c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
263 res = self._removeChanged(FilePath(row['path']), row)
268 # Remove hashes for which no files are still available
269 del expired[hash['hash']]
270 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
277 def removeUntrackedFiles(self, dirs):
278 """Remove files that are no longer tracked by the program.
280 @type dirs: C{list} of L{twisted.python.filepath.FilePath}
281 @param dirs: a list of the directories that we are tracking
282 @return: list of files that were removed
284 assert len(dirs) >= 1
286 # Create a list of globs and an SQL statement for the directories
290 newdirs.append(dir.child('*').path)
291 sql += " path NOT GLOB ? AND"
294 # Get a listing of all the files that will be removed
295 c = self.conn.cursor()
296 c.execute("SELECT path FROM files " + sql, newdirs)
300 removed.append(FilePath(row['path']))
303 # Delete all the removed files from the database
305 c.execute("DELETE FROM files " + sql, newdirs)
312 """Count the total number of files and hashes in the database.
314 @rtype: (C{int}, C{int})
315 @return: the number of distinct hashes and total files in the database
317 c = self.conn.cursor()
318 c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
323 c.execute("SELECT COUNT(path) as num_files FROM files")
331 """Retrieve the saved statistics from the DB.
333 @return: dictionary of statistics
335 c = self.conn.cursor()
336 c.execute("SELECT param, value FROM stats")
340 stats[row['param']] = row['value']
345 def saveStats(self, stats):
346 """Save the statistics to the DB."""
347 c = self.conn.cursor()
349 c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
350 (param, stats[param]))
354 class TestDB(unittest.TestCase):
355 """Tests for the khashmir database."""
358 db = FilePath('/tmp/khashmir.db')
359 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
360 directory = FilePath('/tmp/apt-p2p/')
361 file = FilePath('/tmp/apt-p2p/khashmir.test')
362 testfile = 'tmp/khashmir.test'
363 dirs = [FilePath('/tmp/apt-p2p/top1'),
364 FilePath('/tmp/apt-p2p/top2/sub1'),
365 FilePath('/tmp/apt-p2p/top2/sub2/')]
368 if not self.file.parent().exists():
369 self.file.parent().makedirs()
370 self.file.setContent('fgfhds')
372 self.store = DB(self.db)
373 self.store.storeFile(self.file, self.hash)
375 def test_openExistingDB(self):
376 """Tests opening an existing database."""
380 self.store = DB(self.db)
381 res = self.store.isUnchanged(self.file)
384 def test_getFile(self):
385 """Tests retrieving a file from the database."""
386 res = self.store.getFile(self.file)
388 self.failUnlessEqual(res['hash'], self.hash)
390 def test_lookupHash(self):
391 """Tests looking up a hash in the database."""
392 res = self.store.lookupHash(self.hash)
394 self.failUnlessEqual(len(res), 1)
395 self.failUnlessEqual(res[0]['path'].path, self.file.path)
397 def test_isUnchanged(self):
398 """Tests checking if a file in the database is unchanged."""
399 res = self.store.isUnchanged(self.file)
403 res = self.store.isUnchanged(self.file)
404 self.failUnless(res == False)
405 res = self.store.isUnchanged(self.file)
406 self.failUnless(res is None)
408 def test_expiry(self):
409 """Tests retrieving the files from the database that have expired."""
410 res = self.store.expiredHashes(1)
411 self.failUnlessEqual(len(res.keys()), 0)
413 res = self.store.expiredHashes(1)
414 self.failUnlessEqual(len(res.keys()), 1)
415 self.failUnlessEqual(res.keys()[0], self.hash)
416 self.store.refreshHash(self.hash)
417 res = self.store.expiredHashes(1)
418 self.failUnlessEqual(len(res.keys()), 0)
420 def build_dirs(self):
421 for dir in self.dirs:
422 file = dir.preauthChild(self.testfile)
423 if not file.parent().exists():
424 file.parent().makedirs()
425 file.setContent(file.path)
427 self.store.storeFile(file, self.hash)
429 def test_multipleHashes(self):
430 """Tests looking up a hash with multiple files in the database."""
432 res = self.store.expiredHashes(1)
433 self.failUnlessEqual(len(res.keys()), 0)
434 res = self.store.lookupHash(self.hash)
436 self.failUnlessEqual(len(res), 4)
437 self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
438 self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
439 self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
441 res = self.store.expiredHashes(1)
442 self.failUnlessEqual(len(res.keys()), 1)
443 self.failUnlessEqual(res.keys()[0], self.hash)
444 self.store.refreshHash(self.hash)
445 res = self.store.expiredHashes(1)
446 self.failUnlessEqual(len(res.keys()), 0)
448 def test_removeUntracked(self):
449 """Tests removing untracked files from the database."""
451 res = self.store.removeUntrackedFiles(self.dirs)
452 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
453 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
454 res = self.store.removeUntrackedFiles(self.dirs)
455 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
456 res = self.store.removeUntrackedFiles(self.dirs[1:])
457 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
458 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
459 res = self.store.removeUntrackedFiles(self.dirs[:1])
460 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
461 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
462 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
465 self.directory.remove()