X-Git-Url: https://git.mxchange.org/?a=blobdiff_plain;f=apt_dht%2Fdb.py;h=1ab2b37cac96df1bc8d6a7331326ef5b1d252dfe;hb=f67d1f47283729722ac1b3f528a78063b4b36a14;hp=d6a5d6801dfaef950df40725a6b8de1b4a1f039c;hpb=e4eab3f8f3bd287101cf588a77a49bc693f7d201;p=quix0rs-apt-p2p.git diff --git a/apt_dht/db.py b/apt_dht/db.py index d6a5d68..1ab2b37 100644 --- a/apt_dht/db.py +++ b/apt_dht/db.py @@ -5,6 +5,7 @@ from binascii import a2b_base64, b2a_base64 from time import sleep import os +from twisted.python.filepath import FilePath from twisted.trial import unittest assert sqlite.version_info >= (2, 1) @@ -18,91 +19,182 @@ class khash(str): sqlite.register_adapter(khash, b2a_base64) sqlite.register_converter("KHASH", a2b_base64) sqlite.register_converter("khash", a2b_base64) +sqlite.enable_callback_tracebacks(True) class DB: """Database access for storing persistent data.""" def __init__(self, db): self.db = db - try: - os.stat(db) - except OSError: - self._createNewDB(db) + self.db.restat(False) + if self.db.exists(): + self._loadDB() else: - self._loadDB(db) + self._createNewDB() self.conn.text_factory = str self.conn.row_factory = sqlite.Row - def _loadDB(self, db): + def _loadDB(self): try: - self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES) + self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES) except: import traceback raise DBExcept, "Couldn't open DB", traceback.format_exc() - def _createNewDB(self, db): - self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES) + def _createNewDB(self): + if not self.db.parent().exists(): + self.db.parent().makedirs() + self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES) c = self.conn.cursor() - c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)") -# c.execute("CREATE INDEX files_hash ON files(hash)") + c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)") + c.execute("CREATE INDEX files_hash ON files(hash)") c.execute("CREATE INDEX files_refreshed ON files(refreshed)") - c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)") c.close() self.conn.commit() - def storeFile(self, path, hash, urlpath, refreshed): - """Store or update a file in the database.""" - path = os.path.abspath(path) - stat = os.stat(path) + def _removeChanged(self, file, row): + res = None + if row: + file.restat(False) + if file.exists(): + res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime()) + if not res: + c = self.conn.cursor() + c.execute("DELETE FROM files WHERE path = ?", (file.path, )) + self.conn.commit() + c.close() + return res + + def storeFile(self, file, hash): + """Store or update a file in the database. + + @return: True if the hash was not in the database before + (so it needs to be added to the DHT) + """ + new_hash = True + refreshTime = datetime.now() + c = self.conn.cursor() + c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), )) + row = c.fetchone() + if row and row['max_refresh']: + new_hash = False + refreshTime = row['max_refresh'] + c.close() + + file.restat() c = self.conn.cursor() - c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?, ?, ?)", - (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now())) + c.execute("SELECT path FROM files WHERE path = ?", (file.path, )) + row = c.fetchone() + if row: + c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", + (khash(hash), file.getsize(), file.getmtime(), refreshTime)) + else: + c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)", + (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime)) self.conn.commit() c.close() - def isUnchanged(self, path): - """Check if a file in the file system has changed. + return new_hash - If it has changed, it is removed from the table. + def getFile(self, file): + """Get a file from the database. - @return: True if unchanged, False if changed, None if not in database + If it has changed or is missing, it is removed from the database. + + @return: dictionary of info for the file, False if changed, or + None if not in database or missing """ - path = os.path.abspath(path) - stat = os.stat(path) c = self.conn.cursor() - c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, )) + c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, )) row = c.fetchone() res = None if row: - res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime) - if not res: - c.execute("DELETE FROM files WHERE path = ?", path) - self.conn.commit() + res = self._removeChanged(file, row) + if res: + res = {} + res['hash'] = row['hash'] + res['size'] = row['size'] c.close() return res + + def lookupHash(self, hash): + """Find a file by hash in the database. + + If any found files have changed or are missing, they are removed + from the database. + + @return: list of dictionaries of info for the found files + """ + c = self.conn.cursor() + c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), )) + row = c.fetchone() + files = [] + while row: + file = FilePath(row['path']) + res = self._removeChanged(file, row) + if res: + res = {} + res['path'] = file + res['size'] = row['size'] + res['refreshed'] = row['refreshed'] + files.append(res) + row = c.fetchone() + c.close() + return files + + def isUnchanged(self, file): + """Check if a file in the file system has changed. + + If it has changed, it is removed from the table. + + @return: True if unchanged, False if changed, None if not in database + """ + c = self.conn.cursor() + c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, )) + row = c.fetchone() + return self._removeChanged(file, row) + def refreshHash(self, hash): + """Refresh the publishing time all files with a hash.""" + refreshTime = datetime.now() + c = self.conn.cursor() + c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash))) + c.close() + def expiredFiles(self, expireAfter): """Find files that need refreshing after expireAfter seconds. - Also removes any entries from the table that no longer exist. + For each hash that needs refreshing, finds all the files with that hash. + If the file has changed or is missing, it is removed from the table. - @return: dictionary with keys the hashes, values a list of url paths + @return: dictionary with keys the hashes, values a list of FilePaths """ t = datetime.now() - timedelta(seconds=expireAfter) + + # First find the hashes that need refreshing c = self.conn.cursor() - c.execute("SELECT path, hash, urlpath FROM files WHERE refreshed < ?", (t, )) + c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, )) row = c.fetchone() expired = {} - missing = [] while row: - if os.path.exists(row['path']): - expired.setdefault(row['hash'], []).append(row['urlpath']) - else: - missing.append((row['path'],)) + expired.setdefault(row['hash'], []) row = c.fetchone() - if missing: - c.executemany("DELETE FROM files WHERE path = ?", missing) - self.conn.commit() + c.close() + + # Now find the files for each hash + for hash in expired.keys(): + c = self.conn.cursor() + c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), )) + row = c.fetchone() + while row: + res = self._removeChanged(FilePath(row['path']), row) + if res: + expired[hash].append(FilePath(row['path'])) + row = c.fetchone() + if len(expired[hash]) == 0: + del expired[hash] + c.close() + return expired def removeUntrackedFiles(self, dirs): @@ -113,26 +205,26 @@ class DB: @return: list of files that were removed """ assert len(dirs) >= 1 - dirs = dirs.copy() + newdirs = [] sql = "WHERE" - for i in xrange(len(dirs)): - dirs[i] = os.path.abspath(dirs[i]) - sql += " path NOT GLOB ?/* AND" + for dir in dirs: + newdirs.append(dir.child('*').path) + sql += " path NOT GLOB ? AND" sql = sql[:-4] c = self.conn.cursor() - c.execute("SELECT path FROM files " + sql, dirs) + c.execute("SELECT path FROM files " + sql, newdirs) row = c.fetchone() removed = [] while row: - removed.append(row['path']) + removed.append(FilePath(row['path'])) row = c.fetchone() if removed: - c.execute("DELETE FROM files " + sql, dirs) + c.execute("DELETE FROM files " + sql, newdirs) self.conn.commit() return removed - + def close(self): self.conn.close() @@ -140,55 +232,109 @@ class TestDB(unittest.TestCase): """Tests for the khashmir database.""" timeout = 5 - db = '/tmp/khashmir.db' - key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!' + db = FilePath('/tmp/khashmir.db') + hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!' + directory = FilePath('/tmp/apt-dht/') + file = FilePath('/tmp/apt-dht/khashmir.test') + testfile = 'tmp/khashmir.test' + dirs = [FilePath('/tmp/apt-dht/top1'), + FilePath('/tmp/apt-dht/top2/sub1'), + FilePath('/tmp/apt-dht/top2/sub2/')] def setUp(self): + if not self.file.parent().exists(): + self.file.parent().makedirs() + self.file.setContent('fgfhds') + self.file.touch() self.store = DB(self.db) + self.store.storeFile(self.file, self.hash) - def test_selfNode(self): - self.store.saveSelfNode(self.key) - self.failUnlessEqual(self.store.getSelfNode(), self.key) - - def test_Value(self): - self.store.storeValue(self.key, 'foobar') - val = self.store.retrieveValues(self.key) - self.failUnlessEqual(len(val), 1) - self.failUnlessEqual(val[0], 'foobar') - - def test_expireValues(self): - self.store.storeValue(self.key, 'foobar') + def test_openExistsingDB(self): + self.store.close() + self.store = None + sleep(1) + self.store = DB(self.db) + res = self.store.isUnchanged(self.file) + self.failUnless(res) + + def test_getFile(self): + res = self.store.getFile(self.file) + self.failUnless(res) + self.failUnlessEqual(res['hash'], self.hash) + + def test_lookupHash(self): + res = self.store.lookupHash(self.hash) + self.failUnless(res) + self.failUnlessEqual(len(res), 1) + self.failUnlessEqual(res[0]['path'].path, self.file.path) + + def test_isUnchanged(self): + res = self.store.isUnchanged(self.file) + self.failUnless(res) + sleep(2) + self.file.touch() + res = self.store.isUnchanged(self.file) + self.failUnless(res == False) + self.file.remove() + res = self.store.isUnchanged(self.file) + self.failUnless(res == None) + + def test_expiry(self): + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 0) sleep(2) - self.store.storeValue(self.key, 'barfoo') - self.store.expireValues(1) - val = self.store.retrieveValues(self.key) - self.failUnlessEqual(len(val), 1) - self.failUnlessEqual(val[0], 'barfoo') - - def test_RoutingTable(self): - class dummy: - id = self.key - host = "127.0.0.1" - port = 9977 - def contents(self): - return (self.id, self.host, self.port) - dummy2 = dummy() - dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!' - dummy2.host = '205.23.67.124' - dummy2.port = 12345 - class bl: - def __init__(self): - self.l = [] - bl1 = bl() - bl1.l.append(dummy()) - bl2 = bl() - bl2.l.append(dummy2) - buckets = [bl1, bl2] - self.store.dumpRoutingTable(buckets) - rt = self.store.getRoutingTable() - self.failUnlessIn(dummy().contents(), rt) - self.failUnlessIn(dummy2.contents(), rt) + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 1) + self.failUnlessEqual(res.keys()[0], self.hash) + self.failUnlessEqual(len(res[self.hash]), 1) + self.store.refreshHash(self.hash) + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 0) + + def build_dirs(self): + for dir in self.dirs: + file = dir.preauthChild(self.testfile) + if not file.parent().exists(): + file.parent().makedirs() + file.setContent(file.path) + file.touch() + self.store.storeFile(file, self.hash) + + def test_multipleHashes(self): + self.build_dirs() + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 0) + res = self.store.lookupHash(self.hash) + self.failUnless(res) + self.failUnlessEqual(len(res), 4) + self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed']) + self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed']) + self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed']) + sleep(2) + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 1) + self.failUnlessEqual(res.keys()[0], self.hash) + self.failUnlessEqual(len(res[self.hash]), 4) + self.store.refreshHash(self.hash) + res = self.store.expiredFiles(1) + self.failUnlessEqual(len(res.keys()), 0) + + def test_removeUntracked(self): + self.build_dirs() + res = self.store.removeUntrackedFiles(self.dirs) + self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res) + self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res) + res = self.store.removeUntrackedFiles(self.dirs) + self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res) + res = self.store.removeUntrackedFiles(self.dirs[1:]) + self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res) + self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res) + res = self.store.removeUntrackedFiles(self.dirs[:1]) + self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res) + self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res) + self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res) def tearDown(self): + self.directory.remove() self.store.close() - os.unlink(self.db) + self.db.remove()