2 """An sqlite database for storing persistent files and hashes."""
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
13 assert sqlite.version_info >= (2, 1)
15 class DBExcept(Exception):
16 """An error occurred in accessing the database."""
20 """Dummy class to convert all hashes to base64 for storing in the DB."""
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
29 """An sqlite database for storing persistent files and hashes.
31 @type db: L{twisted.python.filepath.FilePath}
32 @ivar db: the database file to use
33 @type conn: L{pysqlite2.dbapi2.Connection}
34 @ivar conn: an open connection to the sqlite database
37 def __init__(self, db):
38 """Load or create the database file.
40 @type db: L{twisted.python.filepath.FilePath}
41 @param db: the database file to use
49 self.conn.text_factory = str
50 self.conn.row_factory = sqlite.Row
54 """Open a new connection to the existing database file"""
56 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
59 raise DBExcept, "Couldn't open DB", traceback.format_exc()
61 def _createNewDB(self):
62 """Open a connection to a new database and create the necessary tables."""
63 if not self.db.parent().exists():
64 self.db.parent().makedirs()
65 self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
66 c = self.conn.cursor()
67 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
68 "dht BOOL, size NUMBER, mtime NUMBER)")
69 c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
70 "hash KHASH UNIQUE, pieces KHASH, " +
71 "piecehash KHASH, refreshed TIMESTAMP)")
72 c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
73 c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
74 c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
75 c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
80 """Close the database connection."""
84 def _removeChanged(self, file, row):
85 """If the file has changed or is missing, remove it from the DB.
87 @type file: L{twisted.python.filepath.FilePath}
88 @param file: the file to check
89 @type row: C{dictionary}-like object
90 @param row: contains the expected 'size' and 'mtime' of the file
92 @return: True if the file is unchanged, False if it is changed,
93 and None if it is missing
99 # Compare the current with the expected file properties
100 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
102 # Remove the file from the database
103 c = self.conn.cursor()
104 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
109 def storeFile(self, file, hash, dht = True, pieces = ''):
110 """Store or update a file in the database.
112 @type file: L{twisted.python.filepath.FilePath}
113 @param file: the file to check
114 @type hash: C{string}
115 @param hash: the hash of the file
116 @param dht: whether the file is added to the DHT
117 (optional, defaults to true)
118 @type pieces: C{string}
119 @param pieces: the concatenated list of the hashes of the pieces of
120 the file (optional, defaults to the empty string)
121 @return: True if the hash was not in the database before
122 (so it needs to be added to the DHT)
124 # Hash the pieces to get the piecehash
127 piecehash = sha.new(pieces).digest()
129 # Check the database for the hash
130 c = self.conn.cursor()
131 c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
134 assert piecehash == row['piecehash']
136 hashID = row['hashID']
138 # Add the new hash to the database
139 c = self.conn.cursor()
140 c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
141 (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
146 # Add the file to the database
148 c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
149 (file.path, hashID, dht, file.getsize(), file.getmtime()))
155 def getFile(self, file):
156 """Get a file from the database.
158 If it has changed or is missing, it is removed from the database.
160 @type file: L{twisted.python.filepath.FilePath}
161 @param file: the file to check
162 @return: dictionary of info for the file, False if changed, or
163 None if not in database or missing
165 c = self.conn.cursor()
166 c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
170 res = self._removeChanged(file, row)
173 res['hash'] = row['hash']
174 res['size'] = row['size']
175 res['pieces'] = row['pieces']
179 def lookupHash(self, hash, filesOnly = False):
180 """Find a file by hash in the database.
182 If any found files have changed or are missing, they are removed
183 from the database. If filesOnly is False then it will also look for
184 piece string hashes if no files can be found.
186 @return: list of dictionaries of info for the found files
188 # Try to find the hash in the files table
189 c = self.conn.cursor()
190 c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
194 # Save the file to the list of found files
195 file = FilePath(row['path'])
196 res = self._removeChanged(file, row)
200 res['size'] = row['size']
201 res['refreshed'] = row['refreshed']
202 res['pieces'] = row['pieces']
206 if not filesOnly and not files:
207 # No files were found, so check the piecehashes as well
208 c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
212 res['refreshed'] = row['refreshed']
213 res['pieces'] = row['pieces']
219 def isUnchanged(self, file):
220 """Check if a file in the file system has changed.
222 If it has changed, it is removed from the database.
224 @return: True if unchanged, False if changed, None if not in database
226 c = self.conn.cursor()
227 c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
229 return self._removeChanged(file, row)
231 def refreshHash(self, hash):
232 """Refresh the publishing time of a hash."""
233 c = self.conn.cursor()
234 c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
237 def expiredHashes(self, expireAfter):
238 """Find files that need refreshing after expireAfter seconds.
240 For each hash that needs refreshing, finds all the files with that hash.
241 If the file has changed or is missing, it is removed from the table.
243 @return: a list of dictionaries of each hash needing refreshing, sorted by age
245 t = datetime.now() - timedelta(seconds=expireAfter)
247 # Find all the hashes that need refreshing
248 c = self.conn.cursor()
249 c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
254 res['hash'] = row['hash']
255 res['hashID'] = row['hashID']
256 res['pieces'] = row['pieces']
260 # Make sure there are still valid DHT files for each hash
261 for i in xrange(len(expired)-1, -1, -1):
265 c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
274 # Remove hashes for which no DHT files are still available
277 # Remove hashes for which no files are still available
278 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
280 # There are still some non-DHT files available, so refresh them
281 c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
282 (datetime.now(), hash['hashID']))
289 def removeUntrackedFiles(self, dirs):
290 """Remove files that are no longer tracked by the program.
292 @type dirs: C{list} of L{twisted.python.filepath.FilePath}
293 @param dirs: a list of the directories that we are tracking
294 @return: list of files that were removed
296 assert len(dirs) >= 1
298 # Create a list of globs and an SQL statement for the directories
302 newdirs.append(dir.child('*').path)
303 sql += " path NOT GLOB ? AND"
306 # Get a listing of all the files that will be removed
307 c = self.conn.cursor()
308 c.execute("SELECT path FROM files " + sql, newdirs)
312 removed.append(FilePath(row['path']))
315 # Delete all the removed files from the database
317 c.execute("DELETE FROM files " + sql, newdirs)
320 c.execute("SELECT path FROM files")
323 if not os.path.exists(row['path']):
324 # Leave hashes, they will be removed on next refresh
325 c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
326 removed.append(FilePath(row['path']))
333 """Count the total number of files and hashes in the database.
335 @rtype: (C{int}, C{int})
336 @return: the number of distinct hashes and total files in the database
338 c = self.conn.cursor()
339 c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
344 c.execute("SELECT COUNT(path) as num_files FROM files")
352 """Retrieve the saved statistics from the DB.
354 @return: dictionary of statistics
356 c = self.conn.cursor()
357 c.execute("SELECT param, value FROM stats")
361 stats[row['param']] = row['value']
366 def saveStats(self, stats):
367 """Save the statistics to the DB."""
368 c = self.conn.cursor()
370 c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
371 (param, stats[param]))
375 class TestDB(unittest.TestCase):
376 """Tests for the khashmir database."""
379 db = FilePath('/tmp/khashmir.db')
380 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
381 directory = FilePath('/tmp/apt-p2p/')
382 file = FilePath('/tmp/apt-p2p/khashmir.test')
383 testfile = 'tmp/khashmir.test'
384 dirs = [FilePath('/tmp/apt-p2p/top1'),
385 FilePath('/tmp/apt-p2p/top2/sub1'),
386 FilePath('/tmp/apt-p2p/top2/sub2/')]
389 if not self.file.parent().exists():
390 self.file.parent().makedirs()
391 self.file.setContent('fgfhds')
393 self.store = DB(self.db)
394 self.store.storeFile(self.file, self.hash)
396 def test_openExistingDB(self):
397 """Tests opening an existing database."""
401 self.store = DB(self.db)
402 res = self.store.isUnchanged(self.file)
405 def test_getFile(self):
406 """Tests retrieving a file from the database."""
407 res = self.store.getFile(self.file)
409 self.failUnlessEqual(res['hash'], self.hash)
411 def test_lookupHash(self):
412 """Tests looking up a hash in the database."""
413 res = self.store.lookupHash(self.hash)
415 self.failUnlessEqual(len(res), 1)
416 self.failUnlessEqual(res[0]['path'].path, self.file.path)
418 def test_isUnchanged(self):
419 """Tests checking if a file in the database is unchanged."""
420 res = self.store.isUnchanged(self.file)
424 res = self.store.isUnchanged(self.file)
425 self.failUnless(res == False)
426 res = self.store.isUnchanged(self.file)
427 self.failUnless(res is None)
429 def test_expiry(self):
430 """Tests retrieving the files from the database that have expired."""
431 res = self.store.expiredHashes(1)
432 self.failUnlessEqual(len(res), 0)
434 res = self.store.expiredHashes(1)
435 self.failUnlessEqual(len(res), 1)
436 self.failUnlessEqual(res[0]['hash'], self.hash)
437 self.store.refreshHash(self.hash)
438 res = self.store.expiredHashes(1)
439 self.failUnlessEqual(len(res), 0)
441 def build_dirs(self):
442 for dir in self.dirs:
443 file = dir.preauthChild(self.testfile)
444 if not file.parent().exists():
445 file.parent().makedirs()
446 file.setContent(file.path)
448 self.store.storeFile(file, self.hash)
450 def test_multipleHashes(self):
451 """Tests looking up a hash with multiple files in the database."""
453 res = self.store.expiredHashes(1)
454 self.failUnlessEqual(len(res), 0)
455 res = self.store.lookupHash(self.hash)
457 self.failUnlessEqual(len(res), 4)
458 self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
459 self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
460 self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
462 res = self.store.expiredHashes(1)
463 self.failUnlessEqual(len(res), 1)
464 self.failUnlessEqual(res[0]['hash'], self.hash)
465 self.store.refreshHash(self.hash)
466 res = self.store.expiredHashes(1)
467 self.failUnlessEqual(len(res), 0)
469 def test_removeUntracked(self):
470 """Tests removing untracked files from the database."""
472 file = self.dirs[0].child('test.khashmir')
473 file.setContent(file.path)
475 self.store.storeFile(file, self.hash)
476 res = self.store.removeUntrackedFiles(self.dirs)
477 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
478 self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
479 res = self.store.removeUntrackedFiles(self.dirs)
480 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
482 res = self.store.removeUntrackedFiles(self.dirs)
483 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
484 self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
485 res = self.store.removeUntrackedFiles(self.dirs[1:])
486 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
487 self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
488 res = self.store.removeUntrackedFiles(self.dirs[:1])
489 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
490 self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
491 self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
494 self.directory.remove()