]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht/db.py
Store piece hashes in the DB.
[quix0rs-apt-p2p.git] / apt_dht / db.py
index 1ab2b37cac96df1bc8d6a7331326ef5b1d252dfe..e1d6d7b474f7f69bce218b85af50235e5d6f1dcb 100644 (file)
@@ -46,9 +46,13 @@ class DB:
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
-        c.execute("CREATE INDEX files_hash ON files(hash)")
-        c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
+        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
+                                      "size NUMBER, mtime NUMBER)")
+        c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
+                                       "hash KHASH UNIQUE, pieces KHASH, " +
+                                       "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
+        c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
@@ -65,32 +69,34 @@ class DB:
                 c.close()
         return res
         
-    def storeFile(self, file, hash):
+    def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
-        new_hash = True
-        refreshTime = datetime.now()
+        piecehash = ''
+        if pieces:
+            s = sha.new().update(pieces)
+            piecehash = sha.digest()
         c = self.conn.cursor()
-        c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
-        if row and row['max_refresh']:
+        if row:
+            assert piecehash == row['piecehash']
             new_hash = False
-            refreshTime = row['max_refresh']
-        c.close()
+            hashID = row['hashID']
+        else:
+            c = self.conn.cursor()
+            c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
+                      (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
+            self.conn.commit()
+            new_hash = True
+            hashID = c.lastrowid
         
         file.restat()
-        c = self.conn.cursor()
-        c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
-        row = c.fetchone()
-        if row:
-            c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), file.getsize(), file.getmtime(), refreshTime))
-        else:
-            c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
-                      (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
+                  (file.path, hashID, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
@@ -105,7 +111,7 @@ class DB:
             None if not in database or missing
         """
         c = self.conn.cursor()
-        c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
+        c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
         row = c.fetchone()
         res = None
         if row:
@@ -114,19 +120,21 @@ class DB:
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
+                res['pieces'] = row['pieces']
         c.close()
         return res
         
-    def lookupHash(self, hash):
+    def lookupHash(self, hash, filesOnly = False):
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
-        from the database.
+        from the database. If filesOnly is False then it will also look for
+        piece string hashes if no files can be found.
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
-        c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
@@ -137,8 +145,19 @@ class DB:
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
                 files.append(res)
             row = c.fetchone()
+            
+        if not filesOnly and not files:
+            c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
+            row = c.fetchone()
+            if row:
+                res = {}
+                res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
+                files.append(res)
+
         c.close()
         return files
         
@@ -156,12 +175,11 @@ class DB:
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
-        refreshTime = datetime.now()
         c = self.conn.cursor()
-        c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
+        c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
     
-    def expiredFiles(self, expireAfter):
+    def expiredHashes(self, expireAfter):
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
@@ -173,27 +191,32 @@ class DB:
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         expired = {}
         while row:
-            expired.setdefault(row['hash'], [])
+            res = expired.setdefault(row['hash'], {})
+            res['hashID'] = row['hashID']
+            res['hash'] = row['hash']
+            res['pieces'] = row['pieces']
             row = c.fetchone()
-        c.close()
 
-        # Now find the files for each hash
-        for hash in expired.keys():
-            c = self.conn.cursor()
-            c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
+        # Make sure there are still valid files for each hash
+        for hash in expired.values():
+            valid = False
+            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
-                    expired[hash].append(FilePath(row['path']))
+                    valid = True
                 row = c.fetchone()
-            if len(expired[hash]) == 0:
-                del expired[hash]
-            c.close()
+            if not valid:
+                del expired[hash['hash']]
+                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                
+        self.conn.commit()
+        c.close()
         
         return expired
         
@@ -249,7 +272,7 @@ class TestDB(unittest.TestCase):
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
-    def test_openExistsingDB(self):
+    def test_openExistingDB(self):
         self.store.close()
         self.store = None
         sleep(1)
@@ -275,20 +298,18 @@ class TestDB(unittest.TestCase):
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
-        self.file.remove()
         res = self.store.isUnchanged(self.file)
-        self.failUnless(res == None)
+        self.failUnless(res is None)
         
     def test_expiry(self):
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 1)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
@@ -302,7 +323,7 @@ class TestDB(unittest.TestCase):
     
     def test_multipleHashes(self):
         self.build_dirs()
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
@@ -311,12 +332,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 4)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
@@ -338,3 +358,4 @@ class TestDB(unittest.TestCase):
         self.directory.remove()
         self.store.close()
         self.db.remove()
+