Documented the db module.
authorCameron Dale <camrdale@gmail.com>
Sat, 1 Mar 2008 02:44:40 +0000 (18:44 -0800)
committerCameron Dale <camrdale@gmail.com>
Sat, 1 Mar 2008 02:44:40 +0000 (18:44 -0800)
apt_dht/db.py

index e1d6d7b474f7f69bce218b85af50235e5d6f1dcb..f72b104d8c5965950ad744a432af86d3431e103d 100644 (file)
@@ -1,4 +1,6 @@
 
+"""An sqlite database for storing persistent files and hashes."""
+
 from datetime import datetime, timedelta
 from pysqlite2 import dbapi2 as sqlite
 from binascii import a2b_base64, b2a_base64
@@ -11,20 +13,33 @@ from twisted.trial import unittest
 assert sqlite.version_info >= (2, 1)
 
 class DBExcept(Exception):
+    """An error occurred in accessing the database."""
     pass
 
 class khash(str):
     """Dummy class to convert all hashes to base64 for storing in the DB."""
-    
+
+# Initialize the database to work with 'khash' objects (binary strings)
 sqlite.register_adapter(khash, b2a_base64)
 sqlite.register_converter("KHASH", a2b_base64)
 sqlite.register_converter("khash", a2b_base64)
 sqlite.enable_callback_tracebacks(True)
 
 class DB:
-    """Database access for storing persistent data."""
+    """An sqlite database for storing persistent files and hashes.
+    
+    @type db: L{twisted.python.filepath.FilePath}
+    @ivar db: the database file to use
+    @type conn: L{pysqlite2.dbapi2.Connection}
+    @ivar conn: an open connection to the sqlite database
+    """
     
     def __init__(self, db):
+        """Load or create the database file.
+        
+        @type db: L{twisted.python.filepath.FilePath}
+        @param db: the database file to use
+        """
         self.db = db
         self.db.restat(False)
         if self.db.exists():
@@ -35,6 +50,7 @@ class DB:
         self.conn.row_factory = sqlite.Row
         
     def _loadDB(self):
+        """Open a new connection to the existing database file"""
         try:
             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         except:
@@ -42,6 +58,7 @@ class DB:
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
     def _createNewDB(self):
+        """Open a connection to a new database and create the necessary tables."""
         if not self.db.parent().exists():
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
@@ -57,12 +74,24 @@ class DB:
         self.conn.commit()
 
     def _removeChanged(self, file, row):
+        """If the file has changed or is missing, remove it from the DB.
+        
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
+        @type row: C{dictionary}-like object
+        @param row: contains the expected 'size' and 'mtime' of the file
+        @rtype: C{boolean}
+        @return: True if the file is unchanged, False if it is changed,
+            and None if it is missing
+        """
         res = None
         if row:
             file.restat(False)
             if file.exists():
+                # Compare the current with the expected file properties
                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
             if not res:
+                # Remove the file from the database
                 c = self.conn.cursor()
                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
                 self.conn.commit()
@@ -72,13 +101,23 @@ class DB:
     def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
+        @type hash: C{string}
+        @param hash: the hash of the file
+        @type pieces: C{string}
+        @param pieces: the concatenated list of the hashes of the pieces of
+            the file (optional, defaults to the empty string)
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
+        # Hash the pieces to get the piecehash
         piecehash = ''
         if pieces:
             s = sha.new().update(pieces)
             piecehash = sha.digest()
+            
+        # Check the database for the hash
         c = self.conn.cursor()
         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
@@ -87,13 +126,15 @@ class DB:
             new_hash = False
             hashID = row['hashID']
         else:
+            # Add the new hash to the database
             c = self.conn.cursor()
             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
             self.conn.commit()
             new_hash = True
             hashID = c.lastrowid
-        
+
+        # Add the file to the database
         file.restat()
         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
                   (file.path, hashID, file.getsize(), file.getmtime()))
@@ -107,6 +148,8 @@ class DB:
         
         If it has changed or is missing, it is removed from the database.
         
+        @type file: L{twisted.python.filepath.FilePath}
+        @param file: the file to check
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
@@ -133,11 +176,13 @@ class DB:
         
         @return: list of dictionaries of info for the found files
         """
+        # Try to find the hash in the files table
         c = self.conn.cursor()
         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
+            # Save the file to the list of found files
             file = FilePath(row['path'])
             res = self._removeChanged(file, row)
             if res:
@@ -150,6 +195,7 @@ class DB:
             row = c.fetchone()
             
         if not filesOnly and not files:
+            # No files were found, so check the piecehashes as well
             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
             row = c.fetchone()
             if row:
@@ -164,7 +210,7 @@ class DB:
     def isUnchanged(self, file):
         """Check if a file in the file system has changed.
         
-        If it has changed, it is removed from the table.
+        If it has changed, it is removed from the database.
         
         @return: True if unchanged, False if changed, None if not in database
         """
@@ -174,7 +220,7 @@ class DB:
         return self._removeChanged(file, row)
 
     def refreshHash(self, hash):
-        """Refresh the publishing time all files with a hash."""
+        """Refresh the publishing time of a hash."""
         c = self.conn.cursor()
         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
@@ -189,7 +235,7 @@ class DB:
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
-        # First find the hashes that need refreshing
+        # Find all the hashes that need refreshing
         c = self.conn.cursor()
         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
@@ -212,6 +258,7 @@ class DB:
                     valid = True
                 row = c.fetchone()
             if not valid:
+                # Remove hashes for which no files are still available
                 del expired[hash['hash']]
                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
                 
@@ -221,13 +268,15 @@ class DB:
         return expired
         
     def removeUntrackedFiles(self, dirs):
-        """Find files that are no longer tracked and so should be removed.
-        
-        Also removes the entries from the table.
+        """Remove files that are no longer tracked by the program.
         
+        @type dirs: C{list} of L{twisted.python.filepath.FilePath}
+        @param dirs: a list of the directories that we are tracking
         @return: list of files that were removed
         """
         assert len(dirs) >= 1
+        
+        # Create a list of globs and an SQL statement for the directories
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
@@ -235,6 +284,7 @@ class DB:
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
+        # Get a listing of all the files that will be removed
         c = self.conn.cursor()
         c.execute("SELECT path FROM files " + sql, newdirs)
         row = c.fetchone()
@@ -243,12 +293,15 @@ class DB:
             removed.append(FilePath(row['path']))
             row = c.fetchone()
 
+        # Delete all the removed files from the database
         if removed:
             c.execute("DELETE FROM files " + sql, newdirs)
         self.conn.commit()
+
         return removed
     
     def close(self):
+        """Close the database connection."""
         self.conn.close()
 
 class TestDB(unittest.TestCase):
@@ -273,6 +326,7 @@ class TestDB(unittest.TestCase):
         self.store.storeFile(self.file, self.hash)
 
     def test_openExistingDB(self):
+        """Tests opening an existing database."""
         self.store.close()
         self.store = None
         sleep(1)
@@ -281,17 +335,20 @@ class TestDB(unittest.TestCase):
         self.failUnless(res)
 
     def test_getFile(self):
+        """Tests retrieving a file from the database."""
         res = self.store.getFile(self.file)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         
     def test_lookupHash(self):
+        """Tests looking up a hash in the database."""
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 1)
         self.failUnlessEqual(res[0]['path'].path, self.file.path)
         
     def test_isUnchanged(self):
+        """Tests checking if a file in the database is unchanged."""
         res = self.store.isUnchanged(self.file)
         self.failUnless(res)
         sleep(2)
@@ -302,6 +359,7 @@ class TestDB(unittest.TestCase):
         self.failUnless(res is None)
         
     def test_expiry(self):
+        """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
@@ -322,6 +380,7 @@ class TestDB(unittest.TestCase):
             self.store.storeFile(file, self.hash)
     
     def test_multipleHashes(self):
+        """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
@@ -340,6 +399,7 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
+        """Tests removing untracked files from the database."""
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)