]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht/db.py
Fix a minor error in the DB's storeFile function.
[quix0rs-apt-p2p.git] / apt_dht / db.py
index 1ab2b37cac96df1bc8d6a7331326ef5b1d252dfe..cdd86c5e2fa89b033c2a3ededa45d02e71a1dfc2 100644 (file)
@@ -46,9 +46,13 @@ class DB:
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
-        c.execute("CREATE INDEX files_hash ON files(hash)")
-        c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
+        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
+                                      "size NUMBER, mtime NUMBER)")
+        c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
+                                       "hash KHASH UNIQUE, pieces KHASH, " +
+                                       "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
+        c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
         c.close()
         self.conn.commit()
 
@@ -65,32 +69,34 @@ class DB:
                 c.close()
         return res
         
                 c.close()
         return res
         
-    def storeFile(self, file, hash):
+    def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
-        new_hash = True
-        refreshTime = datetime.now()
+        piecehash = ''
+        if pieces:
+            s = sha.new().update(pieces)
+            piecehash = sha.digest()
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         row = c.fetchone()
-        if row and row['max_refresh']:
+        if row:
+            assert piecehash == row['piecehash']
             new_hash = False
             new_hash = False
-            refreshTime = row['max_refresh']
-        c.close()
+            hashID = row['hashID']
+        else:
+            c = self.conn.cursor()
+            c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
+                      (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
+            self.conn.commit()
+            new_hash = True
+            hashID = c.lastrowid
         
         file.restat()
         
         file.restat()
-        c = self.conn.cursor()
-        c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
-        row = c.fetchone()
-        if row:
-            c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), file.getsize(), file.getmtime(), refreshTime))
-        else:
-            c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
-                      (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
+                  (file.path, hashID, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
         self.conn.commit()
         c.close()
         
@@ -105,7 +111,7 @@ class DB:
             None if not in database or missing
         """
         c = self.conn.cursor()
             None if not in database or missing
         """
         c = self.conn.cursor()
-        c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
+        c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
         row = c.fetchone()
         res = None
         if row:
         row = c.fetchone()
         res = None
         if row:
@@ -114,19 +120,21 @@ class DB:
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
+                res['pieces'] = row['pieces']
         c.close()
         return res
         
         c.close()
         return res
         
-    def lookupHash(self, hash):
+    def lookupHash(self, hash, filesOnly = False):
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
-        from the database.
+        from the database. If filesOnly is False then it will also look for
+        piece string hashes if no files can be found.
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
-        c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
         row = c.fetchone()
         files = []
         while row:
@@ -137,8 +145,19 @@ class DB:
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
                 files.append(res)
             row = c.fetchone()
                 files.append(res)
             row = c.fetchone()
+            
+        if not filesOnly and not files:
+            c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
+            row = c.fetchone()
+            if row:
+                res = {}
+                res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
+                files.append(res)
+
         c.close()
         return files
         
         c.close()
         return files
         
@@ -156,12 +175,11 @@ class DB:
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
-        refreshTime = datetime.now()
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
+        c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
     
         c.close()
     
-    def expiredFiles(self, expireAfter):
+    def expiredHashes(self, expireAfter):
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
@@ -173,27 +191,32 @@ class DB:
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         expired = {}
         while row:
         row = c.fetchone()
         expired = {}
         while row:
-            expired.setdefault(row['hash'], [])
+            res = expired.setdefault(row['hash'], {})
+            res['hashID'] = row['hashID']
+            res['hash'] = row['hash']
+            res['pieces'] = row['pieces']
             row = c.fetchone()
             row = c.fetchone()
-        c.close()
 
 
-        # Now find the files for each hash
-        for hash in expired.keys():
-            c = self.conn.cursor()
-            c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
+        # Make sure there are still valid files for each hash
+        for hash in expired.values():
+            valid = False
+            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
-                    expired[hash].append(FilePath(row['path']))
+                    valid = True
                 row = c.fetchone()
                 row = c.fetchone()
-            if len(expired[hash]) == 0:
-                del expired[hash]
-            c.close()
+            if not valid:
+                del expired[hash['hash']]
+                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                
+        self.conn.commit()
+        c.close()
         
         return expired
         
         
         return expired
         
@@ -249,7 +272,7 @@ class TestDB(unittest.TestCase):
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
-    def test_openExistsingDB(self):
+    def test_openExistingDB(self):
         self.store.close()
         self.store = None
         sleep(1)
         self.store.close()
         self.store = None
         sleep(1)
@@ -275,20 +298,18 @@ class TestDB(unittest.TestCase):
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
-        self.file.remove()
         res = self.store.isUnchanged(self.file)
         res = self.store.isUnchanged(self.file)
-        self.failUnless(res == None)
+        self.failUnless(res is None)
         
     def test_expiry(self):
         
     def test_expiry(self):
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 1)
         self.store.refreshHash(self.hash)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
@@ -302,7 +323,7 @@ class TestDB(unittest.TestCase):
     
     def test_multipleHashes(self):
         self.build_dirs()
     
     def test_multipleHashes(self):
         self.build_dirs()
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
@@ -311,12 +332,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 4)
         self.store.refreshHash(self.hash)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
@@ -338,3 +358,4 @@ class TestDB(unittest.TestCase):
         self.directory.remove()
         self.store.close()
         self.db.remove()
         self.directory.remove()
         self.store.close()
         self.db.remove()
+