2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
8 from twisted.trial import unittest
10 assert sqlite.version_info >= (2, 1)
12 class DBExcept(Exception):
16 """Dummy class to convert all hashes to base64 for storing in the DB."""
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
23 """Database access for storing persistent data."""
25 def __init__(self, db):
33 self.conn.text_factory = str
34 self.conn.row_factory = sqlite.Row
36 def _loadDB(self, db):
38 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
41 raise DBExcept, "Couldn't open DB", traceback.format_exc()
43 def _createNewDB(self, db):
44 self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
45 c = self.conn.cursor()
46 c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
47 # c.execute("CREATE INDEX files_hash ON files(hash)")
48 c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
49 c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)")
53 def _removeChanged(self, path, row):
61 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
63 c = self.conn.cursor()
64 c.execute("DELETE FROM files WHERE path = ?", (path, ))
69 def storeFile(self, path, hash, urlpath):
70 """Store or update a file in the database."""
71 path = os.path.abspath(path)
73 c = self.conn.cursor()
74 c.execute("INSERT OR REPLACE INTO files VALUES (?, ?, ?, ?, ?, ?)",
75 (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now()))
79 def getFile(self, path):
80 """Get a file from the database.
82 If it has changed or is missing, it is removed from the database.
84 @return: dictionary of info for the file, False if changed, or
85 None if not in database or missing
87 path = os.path.abspath(path)
88 c = self.conn.cursor()
89 c.execute("SELECT hash, urlpath, size, mtime FROM files WHERE path = ?", (path, ))
91 res = self._removeChanged(path, row)
94 res['hash'] = row['hash']
95 res['urlpath'] = row['urlpath']
99 def isUnchanged(self, path):
100 """Check if a file in the file system has changed.
102 If it has changed, it is removed from the table.
104 @return: True if unchanged, False if changed, None if not in database
106 path = os.path.abspath(path)
107 c = self.conn.cursor()
108 c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
110 return self._removeChanged(path, row)
112 def refreshFile(self, path):
113 """Refresh the publishing time of a file.
115 If it has changed or is missing, it is removed from the table.
117 @return: True if unchanged, False if changed, None if not in database
119 path = os.path.abspath(path)
120 c = self.conn.cursor()
121 c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
123 res = self._removeChanged(path, row)
125 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
128 def expiredFiles(self, expireAfter):
129 """Find files that need refreshing after expireAfter seconds.
131 Also removes any entries from the table that no longer exist.
133 @return: dictionary with keys the hashes, values a list of url paths
135 t = datetime.now() - timedelta(seconds=expireAfter)
136 c = self.conn.cursor()
137 c.execute("SELECT path, hash, urlpath, size, mtime FROM files WHERE refreshed < ?", (t, ))
141 res = self._removeChanged(row['path'], row)
143 expired.setdefault(row['hash'], []).append(row['urlpath'])
148 def removeUntrackedFiles(self, dirs):
149 """Find files that are no longer tracked and so should be removed.
151 Also removes the entries from the table.
153 @return: list of files that were removed
155 assert len(dirs) >= 1
159 newdirs.append(os.path.abspath(dir) + os.sep + '*')
160 sql += " path NOT GLOB ? AND"
163 c = self.conn.cursor()
164 c.execute("SELECT path FROM files " + sql, newdirs)
168 removed.append(row['path'])
172 c.execute("DELETE FROM files " + sql, newdirs)
179 class TestDB(unittest.TestCase):
180 """Tests for the khashmir database."""
183 db = '/tmp/khashmir.db'
184 path = '/tmp/khashmir.test'
185 hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
186 urlpath = '/~1/what/ever/khashmir.test'
187 dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
190 f = open(self.path, 'w')
193 os.utime(self.path, None)
194 self.store = DB(self.db)
195 self.store.storeFile(self.path, self.hash, self.urlpath)
197 def test_getFile(self):
198 res = self.store.getFile(self.path)
200 self.failUnlessEqual(res['hash'], self.hash)
201 self.failUnlessEqual(res['urlpath'], self.urlpath)
203 def test_isUnchanged(self):
204 res = self.store.isUnchanged(self.path)
207 os.utime(self.path, None)
208 res = self.store.isUnchanged(self.path)
209 self.failUnless(res == False)
211 res = self.store.isUnchanged(self.path)
212 self.failUnless(res == None)
214 def test_expiry(self):
215 res = self.store.expiredFiles(1)
216 self.failUnlessEqual(len(res.keys()), 0)
218 res = self.store.expiredFiles(1)
219 self.failUnlessEqual(len(res.keys()), 1)
220 self.failUnlessEqual(res.keys()[0], self.hash)
221 self.failUnlessEqual(len(res[self.hash]), 1)
222 self.failUnlessEqual(res[self.hash][0], self.urlpath)
223 res = self.store.refreshFile(self.path)
225 res = self.store.expiredFiles(1)
226 self.failUnlessEqual(len(res.keys()), 0)
228 def test_removeUntracked(self):
229 for dir in self.dirs:
230 path = os.path.join(dir, self.path[1:])
231 os.makedirs(os.path.dirname(path))
236 self.store.storeFile(path, self.hash, self.urlpath)
238 res = self.store.removeUntrackedFiles(self.dirs)
239 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
240 self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
241 res = self.store.removeUntrackedFiles(self.dirs)
242 self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
243 res = self.store.removeUntrackedFiles(self.dirs[1:])
244 self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
245 self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
246 res = self.store.removeUntrackedFiles(self.dirs[:1])
247 self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
248 self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
249 self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
252 for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
254 os.remove(os.path.join(root, name))
256 os.rmdir(os.path.join(root, name))