]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/db.py
WIP on final version of accepted INFOCOM paper.
[quix0rs-apt-p2p.git] / apt_p2p / db.py
index 44e692b416d92e47122deee6ca9fc9788bdd697e..4bb5a754b498c32e2cac9b53ae21c1558f7f8c82 100644 (file)
@@ -65,7 +65,7 @@ class DB:
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
-                                      "size NUMBER, mtime NUMBER)")
+                                      "dht BOOL, size NUMBER, mtime NUMBER)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
@@ -106,13 +106,15 @@ class DB:
                 c.close()
         return res
         
                 c.close()
         return res
         
-    def storeFile(self, file, hash, pieces = ''):
+    def storeFile(self, file, hash, dht = True, pieces = ''):
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
+        @param dht: whether the file is added to the DHT
+            (optional, defaults to true)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
@@ -143,8 +145,8 @@ class DB:
 
         # Add the file to the database
         file.restat()
 
         # Add the file to the database
         file.restat()
-        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
-                  (file.path, hashID, file.getsize(), file.getmtime()))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
+                  (file.path, hashID, dht, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
         self.conn.commit()
         c.close()
         
@@ -238,36 +240,46 @@ class DB:
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
-        @return: dictionary with keys the hashes, values a list of FilePaths
+        @return: a list of dictionaries of each hash needing refreshing, sorted by age
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
         row = c.fetchone()
         row = c.fetchone()
-        expired = {}
+        expired = []
         while row:
         while row:
-            res = expired.setdefault(row['hash'], {})
-            res['hashID'] = row['hashID']
+            res = {}
             res['hash'] = row['hash']
             res['hash'] = row['hash']
+            res['hashID'] = row['hashID']
             res['pieces'] = row['pieces']
             res['pieces'] = row['pieces']
+            expired.append(res)
             row = c.fetchone()
 
             row = c.fetchone()
 
-        # Make sure there are still valid files for each hash
-        for hash in expired.values():
-            valid = False
-            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
+        # Make sure there are still valid DHT files for each hash
+        for i in xrange(len(expired)-1, -1, -1):
+            dht = False
+            non_dht = False
+            hash = expired[i]
+            c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
             row = c.fetchone()
             while row:
-                res = self._removeChanged(FilePath(row['path']), row)
-                if res:
-                    valid = True
+                if row['dht']:
+                    dht = True
+                else:
+                    non_dht = True
                 row = c.fetchone()
                 row = c.fetchone()
-            if not valid:
-                # Remove hashes for which no files are still available
-                del expired[hash['hash']]
-                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+            if not dht:
+                # Remove hashes for which no DHT files are still available
+                del expired[i]
+                if not non_dht:
+                    # Remove hashes for which no files are still available
+                    c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                else:
+                    # There are still some non-DHT files available, so refresh them
+                    c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
+                              (datetime.now(), hash['hashID']))
                 
         self.conn.commit()
         c.close()
                 
         self.conn.commit()
         c.close()
@@ -303,6 +315,15 @@ class DB:
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
+            self.conn.commit()
+        
+        c.execute("SELECT path FROM files")
+        rows = c.fetchall()
+        for row in rows:
+            if not os.path.exists(row['path']):
+                # Leave hashes, they will be removed on next refresh
+                c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
+                removed.append(FilePath(row['path']))
         self.conn.commit()
 
         return removed
         self.conn.commit()
 
         return removed
@@ -408,14 +429,14 @@ class TestDB(unittest.TestCase):
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
         
     def build_dirs(self):
         for dir in self.dirs:
@@ -430,7 +451,7 @@ class TestDB(unittest.TestCase):
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
@@ -439,20 +460,28 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
+        file = self.dirs[0].child('test.khashmir')
+        file.setContent(file.path)
+        file.touch()
+        self.store.storeFile(file, self.hash)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
+        file.remove()
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)