]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/db.py
Add a NEWS entry for port forwarding.
[quix0rs-apt-p2p.git] / apt_p2p / db.py
index 396f419da12b791d5f2ce0408ce6ce2e79a700e0..4bb5a754b498c32e2cac9b53ae21c1558f7f8c82 100644 (file)
@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
 from time import sleep
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
 from time import sleep
-import os
+import os, sha
 
 from twisted.python.filepath import FilePath
 from twisted.trial import unittest
 
 from twisted.python.filepath import FilePath
 from twisted.trial import unittest
@@ -49,6 +49,7 @@ class DB:
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
+    #{ DB Functions
     def _loadDB(self):
         """Open a new connection to the existing database file"""
         try:
     def _loadDB(self):
         """Open a new connection to the existing database file"""
         try:
@@ -64,15 +65,22 @@ class DB:
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
-                                      "size NUMBER, mtime NUMBER)")
+                                      "dht BOOL, size NUMBER, mtime NUMBER)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
+        c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
+    def close(self):
+        """Close the database connection."""
+        self.conn.close()
+
+    #{ Files and Hashes
     def _removeChanged(self, file, row):
         """If the file has changed or is missing, remove it from the DB.
         
     def _removeChanged(self, file, row):
         """If the file has changed or is missing, remove it from the DB.
         
@@ -98,13 +106,15 @@ class DB:
                 c.close()
         return res
         
                 c.close()
         return res
         
-    def storeFile(self, file, hash, pieces = ''):
+    def storeFile(self, file, hash, dht = True, pieces = ''):
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
+        @param dht: whether the file is added to the DHT
+            (optional, defaults to true)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
@@ -114,8 +124,7 @@ class DB:
         # Hash the pieces to get the piecehash
         piecehash = ''
         if pieces:
         # Hash the pieces to get the piecehash
         piecehash = ''
         if pieces:
-            s = sha.new().update(pieces)
-            piecehash = sha.digest()
+            piecehash = sha.new(pieces).digest()
             
         # Check the database for the hash
         c = self.conn.cursor()
             
         # Check the database for the hash
         c = self.conn.cursor()
@@ -136,8 +145,8 @@ class DB:
 
         # Add the file to the database
         file.restat()
 
         # Add the file to the database
         file.restat()
-        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
-                  (file.path, hashID, file.getsize(), file.getmtime()))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
+                  (file.path, hashID, dht, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
         self.conn.commit()
         c.close()
         
@@ -231,36 +240,46 @@ class DB:
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
-        @return: dictionary with keys the hashes, values a list of FilePaths
+        @return: a list of dictionaries of each hash needing refreshing, sorted by age
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
         row = c.fetchone()
         row = c.fetchone()
-        expired = {}
+        expired = []
         while row:
         while row:
-            res = expired.setdefault(row['hash'], {})
-            res['hashID'] = row['hashID']
+            res = {}
             res['hash'] = row['hash']
             res['hash'] = row['hash']
+            res['hashID'] = row['hashID']
             res['pieces'] = row['pieces']
             res['pieces'] = row['pieces']
+            expired.append(res)
             row = c.fetchone()
 
             row = c.fetchone()
 
-        # Make sure there are still valid files for each hash
-        for hash in expired.values():
-            valid = False
-            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
+        # Make sure there are still valid DHT files for each hash
+        for i in xrange(len(expired)-1, -1, -1):
+            dht = False
+            non_dht = False
+            hash = expired[i]
+            c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
             row = c.fetchone()
             while row:
-                res = self._removeChanged(FilePath(row['path']), row)
-                if res:
-                    valid = True
+                if row['dht']:
+                    dht = True
+                else:
+                    non_dht = True
                 row = c.fetchone()
                 row = c.fetchone()
-            if not valid:
-                # Remove hashes for which no files are still available
-                del expired[hash['hash']]
-                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+            if not dht:
+                # Remove hashes for which no DHT files are still available
+                del expired[i]
+                if not non_dht:
+                    # Remove hashes for which no files are still available
+                    c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                else:
+                    # There are still some non-DHT files available, so refresh them
+                    c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
+                              (datetime.now(), hash['hashID']))
                 
         self.conn.commit()
         c.close()
                 
         self.conn.commit()
         c.close()
@@ -296,14 +315,63 @@ class DB:
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
+            self.conn.commit()
+        
+        c.execute("SELECT path FROM files")
+        rows = c.fetchall()
+        for row in rows:
+            if not os.path.exists(row['path']):
+                # Leave hashes, they will be removed on next refresh
+                c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
+                removed.append(FilePath(row['path']))
         self.conn.commit()
 
         return removed
     
         self.conn.commit()
 
         return removed
     
-    def close(self):
-        """Close the database connection."""
-        self.conn.close()
+    #{ Statistics
+    def dbStats(self):
+        """Count the total number of files and hashes in the database.
+        
+        @rtype: (C{int}, C{int})
+        @return: the number of distinct hashes and total files in the database
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
+        hashes = 0
+        row = c.fetchone()
+        if row:
+            hashes = row[0]
+        c.execute("SELECT COUNT(path) as num_files FROM files")
+        files = 0
+        row = c.fetchone()
+        if row:
+            files = row[0]
+        return hashes, files
 
 
+    def getStats(self):
+        """Retrieve the saved statistics from the DB.
+        
+        @return: dictionary of statistics
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT param, value FROM stats")
+        row = c.fetchone()
+        stats = {}
+        while row:
+            stats[row['param']] = row['value']
+            row = c.fetchone()
+        c.close()
+        return stats
+        
+    def saveStats(self, stats):
+        """Save the statistics to the DB."""
+        c = self.conn.cursor()
+        for param in stats:
+            c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
+                      (param, stats[param]))
+            self.conn.commit()
+        c.close()
+        
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
@@ -361,14 +429,14 @@ class TestDB(unittest.TestCase):
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
         
     def build_dirs(self):
         for dir in self.dirs:
@@ -383,7 +451,7 @@ class TestDB(unittest.TestCase):
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
@@ -392,20 +460,28 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
+        file = self.dirs[0].child('test.khashmir')
+        file.setContent(file.path)
+        file.touch()
+        self.store.storeFile(file, self.hash)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
+        file.remove()
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)