]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p/db.py
Better handling and logging for intermittent HTTP client submission errors.
[quix0rs-apt-p2p.git] / apt_p2p / db.py
index 63cc7e72cbc8f783030afe593c8464b708433f14..26409875f3339c37bb74fb086d7cc07fc8ae972f 100644 (file)
@@ -5,7 +5,7 @@ from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
 from time import sleep
-import os
+import os, sha
 
 from twisted.python.filepath import FilePath
 from twisted.trial import unittest
@@ -49,6 +49,7 @@ class DB:
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
+    #{ DB Functions
     def _loadDB(self):
         """Open a new connection to the existing database file"""
         try:
@@ -64,15 +65,22 @@ class DB:
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
-                                      "size NUMBER, mtime NUMBER)")
+                                      "dht BOOL, size NUMBER, mtime NUMBER)")
         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
                                        "hash KHASH UNIQUE, pieces KHASH, " +
                                        "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
+        c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
+    def close(self):
+        """Close the database connection."""
+        self.conn.close()
+
+    #{ Files and Hashes
     def _removeChanged(self, file, row):
         """If the file has changed or is missing, remove it from the DB.
         
@@ -98,13 +106,15 @@ class DB:
                 c.close()
         return res
         
-    def storeFile(self, file, hash, pieces = ''):
+    def storeFile(self, file, hash, dht = True, pieces = ''):
         """Store or update a file in the database.
         
         @type file: L{twisted.python.filepath.FilePath}
         @param file: the file to check
         @type hash: C{string}
         @param hash: the hash of the file
+        @param dht: whether the file is added to the DHT
+            (optional, defaults to true)
         @type pieces: C{string}
         @param pieces: the concatenated list of the hashes of the pieces of
             the file (optional, defaults to the empty string)
@@ -135,8 +145,8 @@ class DB:
 
         # Add the file to the database
         file.restat()
-        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
-                  (file.path, hashID, file.getsize(), file.getmtime()))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
+                  (file.path, hashID, dht, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
@@ -246,20 +256,28 @@ class DB:
             res['pieces'] = row['pieces']
             row = c.fetchone()
 
-        # Make sure there are still valid files for each hash
+        # Make sure there are still valid DHT files for each hash
         for hash in expired.values():
-            valid = False
-            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
+            dht = False
+            non_dht = False
+            c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
-                res = self._removeChanged(FilePath(row['path']), row)
-                if res:
-                    valid = True
+                if row['dht']:
+                    dht = True
+                else:
+                    non_dht = True
                 row = c.fetchone()
-            if not valid:
-                # Remove hashes for which no files are still available
+            if not dht:
+                # Remove hashes for which no DHT files are still available
                 del expired[hash['hash']]
-                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                if not non_dht:
+                    # Remove hashes for which no files are still available
+                    c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                else:
+                    # There are still some non-DHT files available, so refresh them
+                    c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
+                              (datetime.now(), hash['hashID']))
                 
         self.conn.commit()
         c.close()
@@ -295,14 +313,63 @@ class DB:
         # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
+            self.conn.commit()
+        
+        c.execute("SELECT path FROM files")
+        rows = c.fetchall()
+        for row in rows:
+            if not os.path.exists(row['path']):
+                # Leave hashes, they will be removed on next refresh
+                c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
+                removed.append(FilePath(row['path']))
         self.conn.commit()
 
         return removed
     
-    def close(self):
-        """Close the database connection."""
-        self.conn.close()
+    #{ Statistics
+    def dbStats(self):
+        """Count the total number of files and hashes in the database.
+        
+        @rtype: (C{int}, C{int})
+        @return: the number of distinct hashes and total files in the database
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
+        hashes = 0
+        row = c.fetchone()
+        if row:
+            hashes = row[0]
+        c.execute("SELECT COUNT(path) as num_files FROM files")
+        files = 0
+        row = c.fetchone()
+        if row:
+            files = row[0]
+        return hashes, files
 
+    def getStats(self):
+        """Retrieve the saved statistics from the DB.
+        
+        @return: dictionary of statistics
+        """
+        c = self.conn.cursor()
+        c.execute("SELECT param, value FROM stats")
+        row = c.fetchone()
+        stats = {}
+        while row:
+            stats[row['param']] = row['value']
+            row = c.fetchone()
+        c.close()
+        return stats
+        
+    def saveStats(self, stats):
+        """Save the statistics to the DB."""
+        c = self.conn.cursor()
+        for param in stats:
+            c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
+                      (param, stats[param]))
+            self.conn.commit()
+        c.close()
+        
 class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
@@ -400,11 +467,19 @@ class TestDB(unittest.TestCase):
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
         self.build_dirs()
+        file = self.dirs[0].child('test.khashmir')
+        file.setContent(file.path)
+        file.touch()
+        self.store.storeFile(file, self.hash)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
+        file.remove()
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)