Main database finished for now, including unittests.
authorCameron Dale <camrdale@gmail.com>
Sat, 12 Jan 2008 04:14:53 +0000 (20:14 -0800)
committerCameron Dale <camrdale@gmail.com>
Sat, 12 Jan 2008 04:14:53 +0000 (20:14 -0800)
apt_dht/db.py

index d6a5d6801dfaef950df40725a6b8de1b4a1f039c..c451874c62eb8db8c0fb4a47c1c9aa2f11e7a8b6 100644 (file)
@@ -50,16 +50,52 @@ class DB:
         c.close()
         self.conn.commit()
 
-    def storeFile(self, path, hash, urlpath, refreshed):
+    def _removeChanged(self, path, row):
+        res = None
+        if row:
+            try:
+                stat = os.stat(path)
+            except:
+                stat = None
+            if stat:
+                res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
+            if not res:
+                c = self.conn.cursor()
+                c.execute("DELETE FROM files WHERE path = ?", (path, ))
+                self.conn.commit()
+                c.close()
+        return res
+        
+    def storeFile(self, path, hash, urlpath):
         """Store or update a file in the database."""
         path = os.path.abspath(path)
         stat = os.stat(path)
         c = self.conn.cursor()
-        c.execute("INSERT OR REPLACE INTO kv VALUES (?, ?, ?, ?, ?, ?)", 
+        c.execute("INSERT OR REPLACE INTO files VALUES (?, ?, ?, ?, ?, ?)", 
                   (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now()))
         self.conn.commit()
         c.close()
         
+    def getFile(self, path):
+        """Get a file from the database.
+        
+        If it has changed or is missing, it is removed from the database.
+        
+        @return: dictionary of info for the file, False if changed, or
+            None if not in database or missing
+        """
+        path = os.path.abspath(path)
+        c = self.conn.cursor()
+        c.execute("SELECT hash, urlpath, size, mtime FROM files WHERE path = ?", (path, ))
+        row = c.fetchone()
+        res = self._removeChanged(path, row)
+        if res:
+            res = {}
+            res['hash'] = row['hash']
+            res['urlpath'] = row['urlpath']
+        c.close()
+        return res
+        
     def isUnchanged(self, path):
         """Check if a file in the file system has changed.
         
@@ -68,19 +104,27 @@ class DB:
         @return: True if unchanged, False if changed, None if not in database
         """
         path = os.path.abspath(path)
-        stat = os.stat(path)
         c = self.conn.cursor()
         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
         row = c.fetchone()
-        res = None
-        if row:
-            res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
-            if not res:
-                c.execute("DELETE FROM files WHERE path = ?", path)
-                self.conn.commit()
-        c.close()
-        return res
+        return self._removeChanged(path, row)
 
+    def refreshFile(self, path):
+        """Refresh the publishing time of a file.
+        
+        If it has changed or is missing, it is removed from the table.
+        
+        @return: True if unchanged, False if changed, None if not in database
+        """
+        path = os.path.abspath(path)
+        c = self.conn.cursor()
+        c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
+        row = c.fetchone()
+        res = self._removeChanged(path, row)
+        if res:
+            c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
+        return res
+    
     def expiredFiles(self, expireAfter):
         """Find files that need refreshing after expireAfter seconds.
         
@@ -90,19 +134,15 @@ class DB:
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         c = self.conn.cursor()
-        c.execute("SELECT path, hash, urlpath FROM files WHERE refreshed < ?", (t, ))
+        c.execute("SELECT path, hash, urlpath, size, mtime FROM files WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         expired = {}
-        missing = []
         while row:
-            if os.path.exists(row['path']):
+            res = self._removeChanged(row['path'], row)
+            if res:
                 expired.setdefault(row['hash'], []).append(row['urlpath'])
-            else:
-                missing.append((row['path'],))
             row = c.fetchone()
-        if missing:
-            c.executemany("DELETE FROM files WHERE path = ?", missing)
-        self.conn.commit()
+        c.close()
         return expired
         
     def removeUntrackedFiles(self, dirs):
@@ -113,15 +153,15 @@ class DB:
         @return: list of files that were removed
         """
         assert len(dirs) >= 1
-        dirs = dirs.copy()
+        newdirs = []
         sql = "WHERE"
-        for i in xrange(len(dirs)):
-            dirs[i] = os.path.abspath(dirs[i])
-            sql += " path NOT GLOB ?/* AND"
+        for dir in dirs:
+            newdirs.append(os.path.abspath(dir) + os.sep + '*')
+            sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
         c = self.conn.cursor()
-        c.execute("SELECT path FROM files " + sql, dirs)
+        c.execute("SELECT path FROM files " + sql, newdirs)
         row = c.fetchone()
         removed = []
         while row:
@@ -129,7 +169,7 @@ class DB:
             row = c.fetchone()
 
         if removed:
-            c.execute("DELETE FROM files " + sql, dirs)
+            c.execute("DELETE FROM files " + sql, newdirs)
         self.conn.commit()
         return removed
         
@@ -141,54 +181,78 @@ class TestDB(unittest.TestCase):
     
     timeout = 5
     db = '/tmp/khashmir.db'
-    key = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+    path = '/tmp/khashmir.test'
+    hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
+    urlpath = '/~1/what/ever/khashmir.test'
+    dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
 
     def setUp(self):
+        f = open(self.path, 'w')
+        f.write('fgfhds')
+        f.close()
+        os.utime(self.path, None)
         self.store = DB(self.db)
+        self.store.storeFile(self.path, self.hash, self.urlpath)
 
-    def test_selfNode(self):
-        self.store.saveSelfNode(self.key)
-        self.failUnlessEqual(self.store.getSelfNode(), self.key)
-        
-    def test_Value(self):
-        self.store.storeValue(self.key, 'foobar')
-        val = self.store.retrieveValues(self.key)
-        self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'foobar')
-        
-    def test_expireValues(self):
-        self.store.storeValue(self.key, 'foobar')
+    def test_getFile(self):
+        res = self.store.getFile(self.path)
+        self.failUnless(res)
+        self.failUnlessEqual(res['hash'], self.hash)
+        self.failUnlessEqual(res['urlpath'], self.urlpath)
+        
+    def test_isUnchanged(self):
+        res = self.store.isUnchanged(self.path)
+        self.failUnless(res)
         sleep(2)
-        self.store.storeValue(self.key, 'barfoo')
-        self.store.expireValues(1)
-        val = self.store.retrieveValues(self.key)
-        self.failUnlessEqual(len(val), 1)
-        self.failUnlessEqual(val[0], 'barfoo')
-        
-    def test_RoutingTable(self):
-        class dummy:
-            id = self.key
-            host = "127.0.0.1"
-            port = 9977
-            def contents(self):
-                return (self.id, self.host, self.port)
-        dummy2 = dummy()
-        dummy2.id = '\xaa\xbb\xcc\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
-        dummy2.host = '205.23.67.124'
-        dummy2.port = 12345
-        class bl:
-            def __init__(self):
-                self.l = []
-        bl1 = bl()
-        bl1.l.append(dummy())
-        bl2 = bl()
-        bl2.l.append(dummy2)
-        buckets = [bl1, bl2]
-        self.store.dumpRoutingTable(buckets)
-        rt = self.store.getRoutingTable()
-        self.failUnlessIn(dummy().contents(), rt)
-        self.failUnlessIn(dummy2.contents(), rt)
+        os.utime(self.path, None)
+        res = self.store.isUnchanged(self.path)
+        self.failUnless(res == False)
+        os.unlink(self.path)
+        res = self.store.isUnchanged(self.path)
+        self.failUnless(res == None)
+        
+    def test_expiry(self):
+        res = self.store.expiredFiles(1)
+        self.failUnlessEqual(len(res.keys()), 0)
+        sleep(2)
+        res = self.store.expiredFiles(1)
+        self.failUnlessEqual(len(res.keys()), 1)
+        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res[self.hash]), 1)
+        self.failUnlessEqual(res[self.hash][0], self.urlpath)
+        res = self.store.refreshFile(self.path)
+        self.failUnless(res)
+        res = self.store.expiredFiles(1)
+        self.failUnlessEqual(len(res.keys()), 0)
+        
+    def test_removeUntracked(self):
+        for dir in self.dirs:
+            path = os.path.join(dir, self.path[1:])
+            os.makedirs(os.path.dirname(path))
+            f = open(path, 'w')
+            f.write(path)
+            f.close()
+            os.utime(path, None)
+            self.store.storeFile(path, self.hash, self.urlpath)
+        
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
+        res = self.store.removeUntrackedFiles(self.dirs)
+        self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
+        res = self.store.removeUntrackedFiles(self.dirs[1:])
+        self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
+        res = self.store.removeUntrackedFiles(self.dirs[:1])
+        self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
+        self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
+        self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
         
     def tearDown(self):
+        for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
+            for name in files:
+                os.remove(os.path.join(root, name))
+            for name in dirs:
+                os.rmdir(os.path.join(root, name))
         self.store.close()
         os.unlink(self.db)