]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/db.py
WIP on final version of accepted INFOCOM paper.
[quix0rs-apt-p2p.git] / apt_p2p / db.py
index 569de9a38d3c45f0cb1fc4a20d4dfbb97a3f2277..4bb5a754b498c32e2cac9b53ae21c1558f7f8c82 100644 (file)
@@ -49,6 +49,7 @@ class DB:
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
+    #{ DB Functions
     def _loadDB(self):
         """Open a new connection to the existing database file"""
         try:
     def _loadDB(self):
         """Open a new connection to the existing database file"""
         try:
@@ -64,15 +65,22 @@ class DB:
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
-                                      "size NUMBER, mtime NUMBER)")
+                                      "dht BOOL, size NUMBER, mtime NUMBER)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
+        c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
+    def close(self):
+        """Close the database connection."""
+        self.conn.close()
+
+    #{ Files and Hashes
     def _removeChanged(self, file, row):
         """If the file has changed or is missing, remove it from the DB.
         
     def _removeChanged(self, file, row):
         """If the file has changed or is missing, remove it from the DB.
         
@@ -98,13 +106,15 @@ class DB:
                 c.close()
         return res
         
                 c.close()
         return res
         
-    def storeFile(self, file, hash, pieces = ''):
+    def storeFile(self, file, hash, dht = True, pieces = ''):
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
+        @param dht: whether the file is added to the DHT
+            (optional, defaults to true)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
@@ -135,8 +145,8 @@ class DB:
 
         # Add the file to the database
         file.restat()
 
         # Add the file to the database
         file.restat()
-        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
-                  (file.path, hashID, file.getsize(), file.getmtime()))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
+                  (file.path, hashID, dht, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
         self.conn.commit()
         c.close()
         
@@ -230,36 +240,46 @@ class DB:
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
-        @return: dictionary with keys the hashes, values a list of FilePaths
+        @return: a list of dictionaries of each hash needing refreshing, sorted by age
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
         row = c.fetchone()
         row = c.fetchone()
-        expired = {}
+        expired = []
         while row:
         while row:
-            res = expired.setdefault(row['hash'], {})
-            res['hashID'] = row['hashID']
+            res = {}
             res['hash'] = row['hash']
             res['hash'] = row['hash']
+            res['hashID'] = row['hashID']
             res['pieces'] = row['pieces']
             res['pieces'] = row['pieces']
+            expired.append(res)
             row = c.fetchone()
 
             row = c.fetchone()
 
-        # Make sure there are still valid files for each hash
-        for hash in expired.values():
-            valid = False
-            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
+        # Make sure there are still valid DHT files for each hash
+        for i in xrange(len(expired)-1, -1, -1):
+            dht = False
+            non_dht = False
+            hash = expired[i]
+            c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
             row = c.fetchone()
             while row:
-                res = self._removeChanged(FilePath(row['path']), row)
-                if res:
-                    valid = True
+                if row['dht']:
+                    dht = True
+                else:
+                    non_dht = True
                 row = c.fetchone()
                 row = c.fetchone()
-            if not valid:
-                # Remove hashes for which no files are still available
-                del expired[hash['hash']]
-                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+            if not dht:
+                # Remove hashes for which no DHT files are still available
+                del expired[i]
+                if not non_dht:
+                    # Remove hashes for which no files are still available
+                    c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                else:
+                    # There are still some non-DHT files available, so refresh them
+                    c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
+                              (datetime.now(), hash['hashID']))
                 
         self.conn.commit()
         c.close()
                 
         self.conn.commit()
         c.close()
@@ -295,14 +315,63 @@ class DB:
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
+            self.conn.commit()
+        
+        c.execute("SELECT path FROM files")
+        rows = c.fetchall()
+        for row in rows:
+            if not os.path.exists(row['path']):
+                # Leave hashes, they will be removed on next refresh
+                c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
+                removed.append(FilePath(row['path']))
         self.conn.commit()
 
         return removed
     
         self.conn.commit()
 
         return removed
     
-    def close(self):
-        """Close the database connection."""
-        self.conn.close()
+    #{ Statistics
+    def dbStats(self):
+        """Count the total number of files and hashes in the database.
+        
+        @rtype: (C{int}, C{int})
+        @return: the number of distinct hashes and total files in the database
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
+        hashes = 0
+        row = c.fetchone()
+        if row:
+            hashes = row[0]
+        c.execute("SELECT COUNT(path) as num_files FROM files")
+        files = 0
+        row = c.fetchone()
+        if row:
+            files = row[0]
+        return hashes, files
 
 
+    def getStats(self):
+        """Retrieve the saved statistics from the DB.
+        
+        @return: dictionary of statistics
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT param, value FROM stats")
+        row = c.fetchone()
+        stats = {}
+        while row:
+            stats[row['param']] = row['value']
+            row = c.fetchone()
+        c.close()
+        return stats
+        
+    def saveStats(self, stats):
+        """Save the statistics to the DB."""
+        c = self.conn.cursor()
+        for param in stats:
+            c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
+                      (param, stats[param]))
+            self.conn.commit()
+        c.close()
+        
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
@@ -360,14 +429,14 @@ class TestDB(unittest.TestCase):
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
         
     def build_dirs(self):
         for dir in self.dirs:
@@ -382,7 +451,7 @@ class TestDB(unittest.TestCase):
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
@@ -391,20 +460,28 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
+        file = self.dirs[0].child('test.khashmir')
+        file.setContent(file.path)
+        file.touch()
+        self.store.storeFile(file, self.hash)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
+        file.remove()
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)