]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht/db.py
Documented the db module.
[quix0rs-apt-p2p.git] / apt_dht / db.py
index e1d6d7b474f7f69bce218b85af50235e5d6f1dcb..f72b104d8c5965950ad744a432af86d3431e103d 100644 (file)
@@ -1,4 +1,6 @@
 
 
+"""An sqlite database for storing persistent files and hashes."""
+
 from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
 from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
@@ -11,20 +13,33 @@ from twisted.trial import unittest
 assert sqlite.version_info >= (2, 1)
 
 class DBExcept(Exception):
 assert sqlite.version_info >= (2, 1)
 
 class DBExcept(Exception):
+    """An error occurred in accessing the database."""
     pass
 
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
     pass
 
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
-    
+
+# Initialize the database to work with 'khash' objects (binary strings)
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
 sqlite.enable_callback_tracebacks(True)
 
 class DB:
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
 sqlite.enable_callback_tracebacks(True)
 
 class DB:
-    """Database access for storing persistent data."""
+    """An sqlite database for storing persistent files and hashes.
+    
+    @type db: L{twisted.python.filepath.FilePath}
+    @ivar db: the database file to use
+    @type conn: L{pysqlite2.dbapi2.Connection}
+    @ivar conn: an open connection to the sqlite database
+    """
     
     def __init__(self, db):
     
     def __init__(self, db):
+        """Load or create the database file.
+        
+        @type db: L{twisted.python.filepath.FilePath}
+        @param db: the database file to use
+        """
         self.db = db
         self.db.restat(False)
         if self.db.exists():
         self.db = db
         self.db.restat(False)
         if self.db.exists():
@@ -35,6 +50,7 @@ class DB:
         self.conn.row_factory = sqlite.Row
         
     def _loadDB(self):
         self.conn.row_factory = sqlite.Row
         
     def _loadDB(self):
+        """Open a new connection to the existing database file"""
         try:
             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         except:
         try:
             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         except:
@@ -42,6 +58,7 @@ class DB:
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
     def _createNewDB(self):
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
     def _createNewDB(self):
+        """Open a connection to a new database and create the necessary tables."""
         if not self.db.parent().exists():
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         if not self.db.parent().exists():
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
@@ -57,12 +74,24 @@ class DB:
         self.conn.commit()
 
     def _removeChanged(self, file, row):
         self.conn.commit()
 
     def _removeChanged(self, file, row):
+        """If the file has changed or is missing, remove it from the DB.
+        
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
+        @type row: C{dictionary}-like object
+        @param row: contains the expected 'size' and 'mtime' of the file
+        @rtype: C{boolean}
+        @return: True if the file is unchanged, False if it is changed,
+            and None if it is missing
+        """
         res = None
         if row:
             file.restat(False)
             if file.exists():
         res = None
         if row:
             file.restat(False)
             if file.exists():
+                # Compare the current with the expected file properties
                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
             if not res:
                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
             if not res:
+                # Remove the file from the database
                 c = self.conn.cursor()
                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
                 self.conn.commit()
                 c = self.conn.cursor()
                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
                 self.conn.commit()
@@ -72,13 +101,23 @@ class DB:
     def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
     def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
+        @type hash: C{string}
+        @param hash: the hash of the file
+        @type pieces: C{string}
+        @param pieces: the concatenated list of the hashes of the pieces of
+            the file (optional, defaults to the empty string)
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
+        # Hash the pieces to get the piecehash
         piecehash = ''
         if pieces:
             s = sha.new().update(pieces)
             piecehash = sha.digest()
         piecehash = ''
         if pieces:
             s = sha.new().update(pieces)
             piecehash = sha.digest()
+            
+        # Check the database for the hash
         c = self.conn.cursor()
         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         c = self.conn.cursor()
         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
@@ -87,13 +126,15 @@ class DB:
             new_hash = False
             hashID = row['hashID']
         else:
             new_hash = False
             hashID = row['hashID']
         else:
+            # Add the new hash to the database
             c = self.conn.cursor()
             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
             self.conn.commit()
             new_hash = True
             hashID = c.lastrowid
             c = self.conn.cursor()
             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
             self.conn.commit()
             new_hash = True
             hashID = c.lastrowid
-        
+
+        # Add the file to the database
         file.restat()
         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
                   (file.path, hashID, file.getsize(), file.getmtime()))
         file.restat()
         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
                   (file.path, hashID, file.getsize(), file.getmtime()))
@@ -107,6 +148,8 @@ class DB:
         
         If it has changed or is missing, it is removed from the database.
         
         
         If it has changed or is missing, it is removed from the database.
         
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
@@ -133,11 +176,13 @@ class DB:
         
         @return: list of dictionaries of info for the found files
         """
         
         @return: list of dictionaries of info for the found files
         """
+        # Try to find the hash in the files table
         c = self.conn.cursor()
         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
         c = self.conn.cursor()
         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
+            # Save the file to the list of found files
             file = FilePath(row['path'])
             res = self._removeChanged(file, row)
             if res:
             file = FilePath(row['path'])
             res = self._removeChanged(file, row)
             if res:
@@ -150,6 +195,7 @@ class DB:
             row = c.fetchone()
             
         if not filesOnly and not files:
             row = c.fetchone()
             
         if not filesOnly and not files:
+            # No files were found, so check the piecehashes as well
             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
             row = c.fetchone()
             if row:
             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
             row = c.fetchone()
             if row:
@@ -164,7 +210,7 @@ class DB:
     def isUnchanged(self, file):
         """Check if a file in the file system has changed.
         
     def isUnchanged(self, file):
         """Check if a file in the file system has changed.
         
-        If it has changed, it is removed from the table.
+        If it has changed, it is removed from the database.
         
         @return: True if unchanged, False if changed, None if not in database
         """
         
         @return: True if unchanged, False if changed, None if not in database
         """
@@ -174,7 +220,7 @@ class DB:
         return self._removeChanged(file, row)
 
     def refreshHash(self, hash):
         return self._removeChanged(file, row)
 
     def refreshHash(self, hash):
-        """Refresh the publishing time all files with a hash."""
+        """Refresh the publishing time of a hash."""
         c = self.conn.cursor()
         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
         c = self.conn.cursor()
         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
@@ -189,7 +235,7 @@ class DB:
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
-        # First find the hashes that need refreshing
+        # Find all the hashes that need refreshing
         c = self.conn.cursor()
         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         c = self.conn.cursor()
         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
@@ -212,6 +258,7 @@ class DB:
                     valid = True
                 row = c.fetchone()
             if not valid:
                     valid = True
                 row = c.fetchone()
             if not valid:
+                # Remove hashes for which no files are still available
                 del expired[hash['hash']]
                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
                 
                 del expired[hash['hash']]
                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
                 
@@ -221,13 +268,15 @@ class DB:
         return expired
         
     def removeUntrackedFiles(self, dirs):
         return expired
         
     def removeUntrackedFiles(self, dirs):
-        """Find files that are no longer tracked and so should be removed.
-        
-        Also removes the entries from the table.
+        """Remove files that are no longer tracked by the program.
         
         
+        @type dirs: C{list} of L{twisted.python.filepath.FilePath}
+        @param dirs: a list of the directories that we are tracking
         @return: list of files that were removed
         """
         assert len(dirs) >= 1
         @return: list of files that were removed
         """
         assert len(dirs) >= 1
+        
+        # Create a list of globs and an SQL statement for the directories
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
@@ -235,6 +284,7 @@ class DB:
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
+        # Get a listing of all the files that will be removed
         c = self.conn.cursor()
         c.execute("SELECT path FROM files " + sql, newdirs)
         row = c.fetchone()
         c = self.conn.cursor()
         c.execute("SELECT path FROM files " + sql, newdirs)
         row = c.fetchone()
@@ -243,12 +293,15 @@ class DB:
             removed.append(FilePath(row['path']))
             row = c.fetchone()
 
             removed.append(FilePath(row['path']))
             row = c.fetchone()
 
+        # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         self.conn.commit()
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         self.conn.commit()
+
         return removed
     
     def close(self):
         return removed
     
     def close(self):
+        """Close the database connection."""
         self.conn.close()
 
 class TestDB(unittest.TestCase):
         self.conn.close()
 
 class TestDB(unittest.TestCase):
@@ -273,6 +326,7 @@ class TestDB(unittest.TestCase):
         self.store.storeFile(self.file, self.hash)
 
     def test_openExistingDB(self):
         self.store.storeFile(self.file, self.hash)
 
     def test_openExistingDB(self):
+        """Tests opening an existing database."""
         self.store.close()
         self.store = None
         sleep(1)
         self.store.close()
         self.store = None
         sleep(1)
@@ -281,17 +335,20 @@ class TestDB(unittest.TestCase):
         self.failUnless(res)
 
     def test_getFile(self):
         self.failUnless(res)
 
     def test_getFile(self):
+        """Tests retrieving a file from the database."""
         res = self.store.getFile(self.file)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         
     def test_lookupHash(self):
         res = self.store.getFile(self.file)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         
     def test_lookupHash(self):
+        """Tests looking up a hash in the database."""
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 1)
         self.failUnlessEqual(res[0]['path'].path, self.file.path)
         
     def test_isUnchanged(self):
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 1)
         self.failUnlessEqual(res[0]['path'].path, self.file.path)
         
     def test_isUnchanged(self):
+        """Tests checking if a file in the database is unchanged."""
         res = self.store.isUnchanged(self.file)
         self.failUnless(res)
         sleep(2)
         res = self.store.isUnchanged(self.file)
         self.failUnless(res)
         sleep(2)
@@ -302,6 +359,7 @@ class TestDB(unittest.TestCase):
         self.failUnless(res is None)
         
     def test_expiry(self):
         self.failUnless(res is None)
         
     def test_expiry(self):
+        """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
@@ -322,6 +380,7 @@ class TestDB(unittest.TestCase):
             self.store.storeFile(file, self.hash)
     
     def test_multipleHashes(self):
             self.store.storeFile(file, self.hash)
     
     def test_multipleHashes(self):
+        """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         self.build_dirs()
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
@@ -340,6 +399,7 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
+        """Tests removing untracked files from the database."""
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)