X-Git-Url: https://git.mxchange.org/?a=blobdiff_plain;ds=sidebyside;f=apt_dht%2Fdb.py;h=fd265f6f36503ca31cf9d8a0dda11b8c78bc9f89;hb=b36fa9c7f5ad4827c8de990fbfca1a379bb60e39;hp=c451874c62eb8db8c0fb4a47c1c9aa2f11e7a8b6;hpb=dd40276e78eb835e6e85458028284d8ca40046d6;p=quix0rs-apt-p2p.git diff --git a/apt_dht/db.py b/apt_dht/db.py index c451874..fd265f6 100644 --- a/apt_dht/db.py +++ b/apt_dht/db.py @@ -1,258 +1,421 @@ +"""An sqlite database for storing persistent files and hashes.""" + from datetime import datetime, timedelta from pysqlite2 import dbapi2 as sqlite from binascii import a2b_base64, b2a_base64 from time import sleep import os +from twisted.python.filepath import FilePath from twisted.trial import unittest assert sqlite.version_info >= (2, 1) class DBExcept(Exception): + """An error occurred in accessing the database.""" pass class khash(str): """Dummy class to convert all hashes to base64 for storing in the DB.""" - + +# Initialize the database to work with 'khash' objects (binary strings) sqlite.register_adapter(khash, b2a_base64) sqlite.register_converter("KHASH", a2b_base64) sqlite.register_converter("khash", a2b_base64) +sqlite.enable_callback_tracebacks(True) class DB: - """Database access for storing persistent data.""" + """An sqlite database for storing persistent files and hashes. + + @type db: L{twisted.python.filepath.FilePath} + @ivar db: the database file to use + @type conn: L{pysqlite2.dbapi2.Connection} + @ivar conn: an open connection to the sqlite database + """ def __init__(self, db): + """Load or create the database file. + + @type db: L{twisted.python.filepath.FilePath} + @param db: the database file to use + """ self.db = db - try: - os.stat(db) - except OSError: - self._createNewDB(db) + self.db.restat(False) + if self.db.exists(): + self._loadDB() else: - self._loadDB(db) + self._createNewDB() self.conn.text_factory = str self.conn.row_factory = sqlite.Row - def _loadDB(self, db): + def _loadDB(self): + """Open a new connection to the existing database file""" try: - self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES) + self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES) except: import traceback raise DBExcept, "Couldn't open DB", traceback.format_exc() - def _createNewDB(self, db): - self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES) + def _createNewDB(self): + """Open a connection to a new database and create the necessary tables.""" + if not self.db.parent().exists(): + self.db.parent().makedirs() + self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES) c = self.conn.cursor() - c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)") -# c.execute("CREATE INDEX files_hash ON files(hash)") - c.execute("CREATE INDEX files_refreshed ON files(refreshed)") - c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)") + c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " + + "size NUMBER, mtime NUMBER)") + c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " + + "hash KHASH UNIQUE, pieces KHASH, " + + "piecehash KHASH, refreshed TIMESTAMP)") + c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)") + c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)") c.close() self.conn.commit() - def _removeChanged(self, path, row): + def _removeChanged(self, file, row): + """If the file has changed or is missing, remove it from the DB. + + @type file: L{twisted.python.filepath.FilePath} + @param file: the file to check + @type row: C{dictionary}-like object + @param row: contains the expected 'size' and 'mtime' of the file + @rtype: C{boolean} + @return: True if the file is unchanged, False if it is changed, + and None if it is missing + """ res = None if row: - try: - stat = os.stat(path) - except: - stat = None - if stat: - res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime) + file.restat(False) + if file.exists(): + # Compare the current with the expected file properties + res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime()) if not res: + # Remove the file from the database c = self.conn.cursor() - c.execute("DELETE FROM files WHERE path = ?", (path, )) + c.execute("DELETE FROM files WHERE path = ?", (file.path, )) self.conn.commit() c.close() return res - def storeFile(self, path, hash, urlpath): - """Store or update a file in the database.""" - path = os.path.abspath(path) - stat = os.stat(path) + def storeFile(self, file, hash, pieces = ''): + """Store or update a file in the database. + + @type file: L{twisted.python.filepath.FilePath} + @param file: the file to check + @type hash: C{string} + @param hash: the hash of the file + @type pieces: C{string} + @param pieces: the concatenated list of the hashes of the pieces of + the file (optional, defaults to the empty string) + @return: True if the hash was not in the database before + (so it needs to be added to the DHT) + """ + # Hash the pieces to get the piecehash + piecehash = '' + if pieces: + s = sha.new().update(pieces) + piecehash = sha.digest() + + # Check the database for the hash c = self.conn.cursor() - c.execute("INSERT OR REPLACE INTO files VALUES (?, ?, ?, ?, ?, ?)", - (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now())) + c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), )) + row = c.fetchone() + if row: + assert piecehash == row['piecehash'] + new_hash = False + hashID = row['hashID'] + else: + # Add the new hash to the database + c = self.conn.cursor() + c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)", + (khash(hash), khash(pieces), khash(piecehash), datetime.now())) + self.conn.commit() + new_hash = True + hashID = c.lastrowid + + # Add the file to the database + file.restat() + c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)", + (file.path, hashID, file.getsize(), file.getmtime())) self.conn.commit() c.close() - def getFile(self, path): + return new_hash + + def getFile(self, file): """Get a file from the database. If it has changed or is missing, it is removed from the database. + @type file: L{twisted.python.filepath.FilePath} + @param file: the file to check @return: dictionary of info for the file, False if changed, or None if not in database or missing """ - path = os.path.abspath(path) c = self.conn.cursor() - c.execute("SELECT hash, urlpath, size, mtime FROM files WHERE path = ?", (path, )) + c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, )) row = c.fetchone() - res = self._removeChanged(path, row) - if res: - res = {} - res['hash'] = row['hash'] - res['urlpath'] = row['urlpath'] + res = None + if row: + res = self._removeChanged(file, row) + if res: + res = {} + res['hash'] = row['hash'] + res['size'] = row['size'] + res['pieces'] = row['pieces'] c.close() return res - def isUnchanged(self, path): - """Check if a file in the file system has changed. + def lookupHash(self, hash, filesOnly = False): + """Find a file by hash in the database. - If it has changed, it is removed from the table. + If any found files have changed or are missing, they are removed + from the database. If filesOnly is False then it will also look for + piece string hashes if no files can be found. - @return: True if unchanged, False if changed, None if not in database + @return: list of dictionaries of info for the found files """ - path = os.path.abspath(path) + # Try to find the hash in the files table c = self.conn.cursor() - c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, )) + c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), )) row = c.fetchone() - return self._removeChanged(path, row) + files = [] + while row: + # Save the file to the list of found files + file = FilePath(row['path']) + res = self._removeChanged(file, row) + if res: + res = {} + res['path'] = file + res['size'] = row['size'] + res['refreshed'] = row['refreshed'] + res['pieces'] = row['pieces'] + files.append(res) + row = c.fetchone() + + if not filesOnly and not files: + # No files were found, so check the piecehashes as well + c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), )) + row = c.fetchone() + if row: + res = {} + res['refreshed'] = row['refreshed'] + res['pieces'] = row['pieces'] + files.append(res) - def refreshFile(self, path): - """Refresh the publishing time of a file. + c.close() + return files + + def isUnchanged(self, file): + """Check if a file in the file system has changed. - If it has changed or is missing, it is removed from the table. + If it has changed, it is removed from the database. @return: True if unchanged, False if changed, None if not in database """ - path = os.path.abspath(path) c = self.conn.cursor() - c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, )) + c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, )) row = c.fetchone() - res = self._removeChanged(path, row) - if res: - c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path)) - return res + return self._removeChanged(file, row) + + def refreshHash(self, hash): + """Refresh the publishing time of a hash.""" + c = self.conn.cursor() + c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash))) + c.close() - def expiredFiles(self, expireAfter): + def expiredHashes(self, expireAfter): """Find files that need refreshing after expireAfter seconds. - Also removes any entries from the table that no longer exist. + For each hash that needs refreshing, finds all the files with that hash. + If the file has changed or is missing, it is removed from the table. - @return: dictionary with keys the hashes, values a list of url paths + @return: dictionary with keys the hashes, values a list of FilePaths """ t = datetime.now() - timedelta(seconds=expireAfter) + + # Find all the hashes that need refreshing c = self.conn.cursor() - c.execute("SELECT path, hash, urlpath, size, mtime FROM files WHERE refreshed < ?", (t, )) + c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, )) row = c.fetchone() expired = {} while row: - res = self._removeChanged(row['path'], row) - if res: - expired.setdefault(row['hash'], []).append(row['urlpath']) + res = expired.setdefault(row['hash'], {}) + res['hashID'] = row['hashID'] + res['hash'] = row['hash'] + res['pieces'] = row['pieces'] row = c.fetchone() + + # Make sure there are still valid files for each hash + for hash in expired.values(): + valid = False + c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], )) + row = c.fetchone() + while row: + res = self._removeChanged(FilePath(row['path']), row) + if res: + valid = True + row = c.fetchone() + if not valid: + # Remove hashes for which no files are still available + del expired[hash['hash']] + c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], )) + + self.conn.commit() c.close() + return expired def removeUntrackedFiles(self, dirs): - """Find files that are no longer tracked and so should be removed. - - Also removes the entries from the table. + """Remove files that are no longer tracked by the program. + @type dirs: C{list} of L{twisted.python.filepath.FilePath} + @param dirs: a list of the directories that we are tracking @return: list of files that were removed """ assert len(dirs) >= 1 + + # Create a list of globs and an SQL statement for the directories newdirs = [] sql = "WHERE" for dir in dirs: - newdirs.append(os.path.abspath(dir) + os.sep + '*') + newdirs.append(dir.child('*').path) sql += " path NOT GLOB ? AND" sql = sql[:-4] + # Get a listing of all the files that will be removed c = self.conn.cursor() c.execute("SELECT path FROM files " + sql, newdirs) row = c.fetchone() removed = [] while row: - removed.append(row['path']) + removed.append(FilePath(row['path'])) row = c.fetchone() + # Delete all the removed files from the database if removed: c.execute("DELETE FROM files " + sql, newdirs) self.conn.commit() + return removed - + def close(self): + """Close the database connection.""" self.conn.close() class TestDB(unittest.TestCase): """Tests for the khashmir database.""" timeout = 5 - db = '/tmp/khashmir.db' - path = '/tmp/khashmir.test' + db = FilePath('/tmp/khashmir.db') hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!' - urlpath = '/~1/what/ever/khashmir.test' - dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/'] + directory = FilePath('/tmp/apt-dht/') + file = FilePath('/tmp/apt-dht/khashmir.test') + testfile = 'tmp/khashmir.test' + dirs = [FilePath('/tmp/apt-dht/top1'), + FilePath('/tmp/apt-dht/top2/sub1'), + FilePath('/tmp/apt-dht/top2/sub2/')] def setUp(self): - f = open(self.path, 'w') - f.write('fgfhds') - f.close() - os.utime(self.path, None) + if not self.file.parent().exists(): + self.file.parent().makedirs() + self.file.setContent('fgfhds') + self.file.touch() + self.store = DB(self.db) + self.store.storeFile(self.file, self.hash) + + def test_openExistingDB(self): + """Tests opening an existing database.""" + self.store.close() + self.store = None + sleep(1) self.store = DB(self.db) - self.store.storeFile(self.path, self.hash, self.urlpath) + res = self.store.isUnchanged(self.file) + self.failUnless(res) def test_getFile(self): - res = self.store.getFile(self.path) + """Tests retrieving a file from the database.""" + res = self.store.getFile(self.file) self.failUnless(res) self.failUnlessEqual(res['hash'], self.hash) - self.failUnlessEqual(res['urlpath'], self.urlpath) + + def test_lookupHash(self): + """Tests looking up a hash in the database.""" + res = self.store.lookupHash(self.hash) + self.failUnless(res) + self.failUnlessEqual(len(res), 1) + self.failUnlessEqual(res[0]['path'].path, self.file.path) def test_isUnchanged(self): - res = self.store.isUnchanged(self.path) + """Tests checking if a file in the database is unchanged.""" + res = self.store.isUnchanged(self.file) self.failUnless(res) sleep(2) - os.utime(self.path, None) - res = self.store.isUnchanged(self.path) + self.file.touch() + res = self.store.isUnchanged(self.file) self.failUnless(res == False) - os.unlink(self.path) - res = self.store.isUnchanged(self.path) - self.failUnless(res == None) + res = self.store.isUnchanged(self.file) + self.failUnless(res is None) def test_expiry(self): - res = self.store.expiredFiles(1) + """Tests retrieving the files from the database that have expired.""" + res = self.store.expiredHashes(1) self.failUnlessEqual(len(res.keys()), 0) sleep(2) - res = self.store.expiredFiles(1) + res = self.store.expiredHashes(1) self.failUnlessEqual(len(res.keys()), 1) self.failUnlessEqual(res.keys()[0], self.hash) - self.failUnlessEqual(len(res[self.hash]), 1) - self.failUnlessEqual(res[self.hash][0], self.urlpath) - res = self.store.refreshFile(self.path) - self.failUnless(res) - res = self.store.expiredFiles(1) + self.store.refreshHash(self.hash) + res = self.store.expiredHashes(1) self.failUnlessEqual(len(res.keys()), 0) - def test_removeUntracked(self): + def build_dirs(self): for dir in self.dirs: - path = os.path.join(dir, self.path[1:]) - os.makedirs(os.path.dirname(path)) - f = open(path, 'w') - f.write(path) - f.close() - os.utime(path, None) - self.store.storeFile(path, self.hash, self.urlpath) - + file = dir.preauthChild(self.testfile) + if not file.parent().exists(): + file.parent().makedirs() + file.setContent(file.path) + file.touch() + self.store.storeFile(file, self.hash) + + def test_multipleHashes(self): + """Tests looking up a hash with multiple files in the database.""" + self.build_dirs() + res = self.store.expiredHashes(1) + self.failUnlessEqual(len(res.keys()), 0) + res = self.store.lookupHash(self.hash) + self.failUnless(res) + self.failUnlessEqual(len(res), 4) + self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed']) + self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed']) + self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed']) + sleep(2) + res = self.store.expiredHashes(1) + self.failUnlessEqual(len(res.keys()), 1) + self.failUnlessEqual(res.keys()[0], self.hash) + self.store.refreshHash(self.hash) + res = self.store.expiredHashes(1) + self.failUnlessEqual(len(res.keys()), 0) + + def test_removeUntracked(self): + """Tests removing untracked files from the database.""" + self.build_dirs() res = self.store.removeUntrackedFiles(self.dirs) self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res) - self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res) + self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res) res = self.store.removeUntrackedFiles(self.dirs) self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res) res = self.store.removeUntrackedFiles(self.dirs[1:]) self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res) - self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res) + self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res) res = self.store.removeUntrackedFiles(self.dirs[:1]) self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res) - self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res) - self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res) + self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res) + self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res) def tearDown(self): - for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False): - for name in files: - os.remove(os.path.join(root, name)) - for name in dirs: - os.rmdir(os.path.join(root, name)) + self.directory.remove() self.store.close() - os.unlink(self.db) + self.db.remove() +