Use FilePath everywhere and create new CacheManager module.
[quix0rs-apt-p2p.git] / apt_dht / db.py
index 9725aa88f10c36e0bd137de1cb5051850d7cde91..1d2e34273ad98161e365ce54ab7b4d5ecb7f76f0 100644 (file)
@@ -5,6 +5,7 @@ from binascii import a2b_base64, b2a_base64
 from time import sleep
 import os
 
+from twisted.python.filepath import FilePath
 from twisted.trial import unittest
 
 assert sqlite.version_info >= (2, 1)
@@ -25,24 +26,25 @@ class DB:
     
     def __init__(self, db):
         self.db = db
-        try:
-            os.stat(db)
-        except OSError:
-            self._createNewDB(db)
+        self.db.restat(False)
+        if self.db.exists():
+            self._loadDB()
         else:
-            self._loadDB(db)
+            self._createNewDB()
         self.conn.text_factory = str
         self.conn.row_factory = sqlite.Row
         
-    def _loadDB(self, db):
+    def _loadDB(self):
         try:
-            self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
+            self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         except:
             import traceback
             raise DBExcept, "Couldn't open DB", traceback.format_exc()
         
-    def _createNewDB(self, db):
-        self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
+    def _createNewDB(self):
+        if not self.db.parent().exists():
+            self.db.parent().makedirs()
+        self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
         c.execute("CREATE INDEX files_urldir ON files(urldir)")
@@ -52,49 +54,43 @@ class DB:
         c.close()
         self.conn.commit()
 
-    def _removeChanged(self, path, row):
+    def _removeChanged(self, file, row):
         res = None
         if row:
-            try:
-                stat = os.stat(path)
-            except:
-                stat = None
-            if stat:
-                res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
+            file.restat(False)
+            if file.exists():
+                res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
             if not res:
                 c = self.conn.cursor()
-                c.execute("DELETE FROM files WHERE path = ?", (path, ))
+                c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
                 self.conn.commit()
                 c.close()
         return res
         
-    def storeFile(self, path, hash, directory):
+    def storeFile(self, file, hash, directory):
         """Store or update a file in the database.
         
         @return: the urlpath to access the file, and whether a
             new url top-level directory was needed
         """
-        path = os.path.abspath(path)
-        directory = os.path.abspath(directory)
-        assert path.startswith(directory)
-        stat = os.stat(path)
+        file.restat()
         c = self.conn.cursor()
-        c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (path, ))
+        c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
         row = c.fetchone()
         if row and directory == row['directory']:
             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), stat.st_size, stat.st_mtime, datetime.now()))
+                      (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
             newdir = False
             urldir = row['urldir']
         else:
             urldir, newdir = self.findDirectory(directory)
             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
-                      (path, khash(hash), urldir, len(directory), stat.st_size, stat.st_mtime, datetime.now()))
+                      (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
         self.conn.commit()
         c.close()
-        return '/~' + str(urldir) + path[len(directory):], newdir
+        return '/~' + str(urldir) + file.path[len(directory.path):], newdir
         
-    def getFile(self, path):
+    def getFile(self, file):
         """Get a file from the database.
         
         If it has changed or is missing, it is removed from the database.
@@ -102,45 +98,43 @@ class DB:
         @return: dictionary of info for the file, False if changed, or
             None if not in database or missing
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
-        c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
-        res = self._removeChanged(path, row)
+        res = self._removeChanged(file, row)
         if res:
             res = {}
             res['hash'] = row['hash']
-            res['urlpath'] = '/~' + str(row['urldir']) + path[row['dirlength']:]
+            res['size'] = row['size']
+            res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
         c.close()
         return res
         
-    def isUnchanged(self, path):
+    def isUnchanged(self, file):
         """Check if a file in the file system has changed.
         
         If it has changed, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
-        c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
-        return self._removeChanged(path, row)
+        return self._removeChanged(file, row)
 
-    def refreshFile(self, path):
+    def refreshFile(self, file):
         """Refresh the publishing time of a file.
         
         If it has changed or is missing, it is removed from the table.
         
         @return: True if unchanged, False if changed, None if not in database
         """
-        path = os.path.abspath(path)
         c = self.conn.cursor()
-        c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
+        c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
         row = c.fetchone()
-        res = self._removeChanged(path, row)
+        res = self._removeChanged(file, row)
         if res:
-            c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
+            c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
         return res
     
     def expiredFiles(self, expireAfter):
@@ -156,7 +150,7 @@ class DB:
         row = c.fetchone()
         expired = {}
         while row:
-            res = self._removeChanged(row['path'], row)
+            res = self._removeChanged(FilePath(row['path']), row)
             if res:
                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
             row = c.fetchone()
@@ -174,7 +168,7 @@ class DB:
         newdirs = []
         sql = "WHERE"
         for dir in dirs:
-            newdirs.append(os.path.abspath(dir) + os.sep + '*')
+            newdirs.append(dir.child('*').path)
             sql += " path NOT GLOB ? AND"
         sql = sql[:-4]
 
@@ -183,7 +177,7 @@ class DB:
         row = c.fetchone()
         removed = []
         while row:
-            removed.append(row['path'])
+            removed.append(FilePath(row['path']))
             row = c.fetchone()
 
         if removed:
@@ -196,9 +190,8 @@ class DB:
         
         @return: the index of the url directory, and whether it is new or not
         """
-        directory = os.path.abspath(directory)
         c = self.conn.cursor()
-        c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory, ))
+        c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
         row = c.fetchone()
         c.close()
         if row['urldir']:
@@ -206,7 +199,7 @@ class DB:
 
         # Not found, need to add a new one
         c = self.conn.cursor()
-        c.execute("INSERT INTO dirs (path) VALUES (?)", (directory, ))
+        c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
         self.conn.commit()
         urldir = c.lastrowid
         c.close()
@@ -219,7 +212,7 @@ class DB:
         row = c.fetchone()
         dirs = {}
         while row:
-            dirs['~' + str(row['urldir'])] = row['path']
+            dirs['~' + str(row['urldir'])] = FilePath(row['path'])
             row = c.fetchone()
         c.close()
         return dirs
@@ -238,23 +231,34 @@ class TestDB(unittest.TestCase):
     """Tests for the khashmir database."""
     
     timeout = 5
-    db = '/tmp/khashmir.db'
-    path = '/tmp/khashmir.test'
+    db = FilePath('/tmp/khashmir.db')
+    file = FilePath('/tmp/apt-dht/khashmir.test')
     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
-    directory = '/tmp/'
+    directory = FilePath('/tmp/apt-dht/')
     urlpath = '/~1/khashmir.test'
-    dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
+    testfile = 'tmp/khashmir.test'
+    dirs = [FilePath('/tmp/apt-dht/top1'),
+            FilePath('/tmp/apt-dht/top2/sub1'),
+            FilePath('/tmp/apt-dht/top2/sub2/')]
 
     def setUp(self):
-        f = open(self.path, 'w')
-        f.write('fgfhds')
-        f.close()
-        os.utime(self.path, None)
+        if not self.file.parent().exists():
+            self.file.parent().makedirs()
+        self.file.setContent('fgfhds')
+        self.file.touch()
         self.store = DB(self.db)
-        self.store.storeFile(self.path, self.hash, self.directory)
+        self.store.storeFile(self.file, self.hash, self.directory)
+
+    def test_openExistsingDB(self):
+        self.store.close()
+        self.store = None
+        sleep(1)
+        self.store = DB(self.db)
+        res = self.store.isUnchanged(self.file)
+        self.failUnless(res)
 
     def test_getFile(self):
-        res = self.store.getFile(self.path)
+        res = self.store.getFile(self.file)
         self.failUnless(res)
         self.failUnlessEqual(res['hash'], self.hash)
         self.failUnlessEqual(res['urlpath'], self.urlpath)
@@ -264,17 +268,17 @@ class TestDB(unittest.TestCase):
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], '~1')
-        self.failUnlessEqual(res['~1'], os.path.abspath(self.directory))
+        self.failUnlessEqual(res['~1'], self.directory)
         
     def test_isUnchanged(self):
-        res = self.store.isUnchanged(self.path)
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res)
         sleep(2)
-        os.utime(self.path, None)
-        res = self.store.isUnchanged(self.path)
+        self.file.touch()
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
-        os.unlink(self.path)
-        res = self.store.isUnchanged(self.path)
+        self.file.remove()
+        res = self.store.isUnchanged(self.file)
         self.failUnless(res == None)
         
     def test_expiry(self):
@@ -286,35 +290,34 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res[self.hash]), 1)
         self.failUnlessEqual(res[self.hash][0], self.urlpath)
-        res = self.store.refreshFile(self.path)
+        res = self.store.refreshFile(self.file)
         self.failUnless(res)
         res = self.store.expiredFiles(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
-            path = os.path.join(dir, self.path[1:])
-            os.makedirs(os.path.dirname(path))
-            f = open(path, 'w')
-            f.write(path)
-            f.close()
-            os.utime(path, None)
-            self.store.storeFile(path, self.hash, dir)
+            file = dir.preauthChild(self.testfile)
+            if not file.parent().exists():
+                file.parent().makedirs()
+            file.setContent(file.path)
+            file.touch()
+            self.store.storeFile(file, self.hash, dir)
     
     def test_removeUntracked(self):
         self.build_dirs()
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
-        self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs)
         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[1:])
         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
-        self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
+        self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
         res = self.store.removeUntrackedFiles(self.dirs[:1])
         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
-        self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
-        self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
+        self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
+        self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
         
     def test_reconcileDirectories(self):
         self.build_dirs()
@@ -338,17 +341,13 @@ class TestDB(unittest.TestCase):
         res = self.store.getAllDirectories()
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 1)
-        res = self.store.removeUntrackedFiles(['/what'])
+        res = self.store.removeUntrackedFiles([FilePath('/what')])
         res = self.store.reconcileDirectories()
         self.failUnlessEqual(res, True)
         res = self.store.getAllDirectories()
         self.failUnlessEqual(len(res.keys()), 0)
         
     def tearDown(self):
-        for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
-            for name in files:
-                os.remove(os.path.join(root, name))
-            for name in dirs:
-                os.rmdir(os.path.join(root, name))
+        self.directory.remove()
         self.store.close()
-        os.unlink(self.db)
+        self.db.remove()