Scanning cache directories on startup waits for DHT storeValue to return.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
10
11 assert sqlite.version_info >= (2, 1)
12
13 class DBExcept(Exception):
14     pass
15
16 class khash(str):
17     """Dummy class to convert all hashes to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
23
24 class DB:
25     """Database access for storing persistent data."""
26     
27     def __init__(self, db):
28         self.db = db
29         self.db.restat(False)
30         if self.db.exists():
31             self._loadDB()
32         else:
33             self._createNewDB()
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self):
38         try:
39             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self):
45         if not self.db.parent().exists():
46             self.db.parent().makedirs()
47         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48         c = self.conn.cursor()
49         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50         c.execute("CREATE INDEX files_hash ON files(hash)")
51         c.execute("CREATE INDEX files_urldir ON files(urldir)")
52         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
53         c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
54         c.execute("CREATE INDEX dirs_path ON dirs(path)")
55         c.close()
56         self.conn.commit()
57
58     def _removeChanged(self, file, row):
59         res = None
60         if row:
61             file.restat(False)
62             if file.exists():
63                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
64             if not res:
65                 c = self.conn.cursor()
66                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
67                 self.conn.commit()
68                 c.close()
69         return res
70         
71     def storeFile(self, file, hash, directory):
72         """Store or update a file in the database.
73         
74         @return: the urlpath to access the file, and whether a
75             new url top-level directory was needed
76         """
77         file.restat()
78         c = self.conn.cursor()
79         c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
80         row = c.fetchone()
81         if row and directory == row['directory']:
82             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
83                       (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
84             newdir = False
85             urldir = row['urldir']
86         else:
87             urldir, newdir = self.findDirectory(directory)
88             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
89                       (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
90         self.conn.commit()
91         c.close()
92         return '/~' + str(urldir) + file.path[len(directory.path):], newdir
93         
94     def getFile(self, file):
95         """Get a file from the database.
96         
97         If it has changed or is missing, it is removed from the database.
98         
99         @return: dictionary of info for the file, False if changed, or
100             None if not in database or missing
101         """
102         c = self.conn.cursor()
103         c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
104         row = c.fetchone()
105         res = None
106         if row:
107             res = self._removeChanged(file, row)
108             if res:
109                 res = {}
110                 res['hash'] = row['hash']
111                 res['size'] = row['size']
112                 res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
113         c.close()
114         return res
115         
116     def lookupHash(self, hash):
117         """Find a file by hash in the database.
118         
119         If any found files have changed or are missing, they are removed
120         from the database.
121         
122         @return: list of dictionaries of info for the found files
123         """
124         c = self.conn.cursor()
125         c.execute("SELECT path, urldir, dirlength, size, mtime FROM files WHERE hash = ? ORDER BY urldir", (khash(hash), ))
126         row = c.fetchone()
127         files = []
128         while row:
129             file = FilePath(row['path'])
130             res = self._removeChanged(file, row)
131             if res:
132                 res = {}
133                 res['path'] = file
134                 res['size'] = row['size']
135                 res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
136                 files.append(res)
137             row = c.fetchone()
138         c.close()
139         return files
140         
141     def isUnchanged(self, file):
142         """Check if a file in the file system has changed.
143         
144         If it has changed, it is removed from the table.
145         
146         @return: True if unchanged, False if changed, None if not in database
147         """
148         c = self.conn.cursor()
149         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
150         row = c.fetchone()
151         return self._removeChanged(file, row)
152
153     def refreshFile(self, file):
154         """Refresh the publishing time of a file.
155         
156         If it has changed or is missing, it is removed from the table.
157         
158         @return: True if unchanged, False if changed, None if not in database
159         """
160         c = self.conn.cursor()
161         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
162         row = c.fetchone()
163         res = None
164         if row:
165             res = self._removeChanged(file, row)
166             if res:
167                 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
168         return res
169     
170     def expiredFiles(self, expireAfter):
171         """Find files that need refreshing after expireAfter seconds.
172         
173         Also removes any entries from the table that no longer exist.
174         
175         @return: dictionary with keys the hashes, values a list of url paths
176         """
177         t = datetime.now() - timedelta(seconds=expireAfter)
178         c = self.conn.cursor()
179         c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
180         row = c.fetchone()
181         expired = {}
182         while row:
183             res = self._removeChanged(FilePath(row['path']), row)
184             if res:
185                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
186             row = c.fetchone()
187         c.close()
188         return expired
189         
190     def removeUntrackedFiles(self, dirs):
191         """Find files that are no longer tracked and so should be removed.
192         
193         Also removes the entries from the table.
194         
195         @return: list of files that were removed
196         """
197         assert len(dirs) >= 1
198         newdirs = []
199         sql = "WHERE"
200         for dir in dirs:
201             newdirs.append(dir.child('*').path)
202             sql += " path NOT GLOB ? AND"
203         sql = sql[:-4]
204
205         c = self.conn.cursor()
206         c.execute("SELECT path FROM files " + sql, newdirs)
207         row = c.fetchone()
208         removed = []
209         while row:
210             removed.append(FilePath(row['path']))
211             row = c.fetchone()
212
213         if removed:
214             c.execute("DELETE FROM files " + sql, newdirs)
215         self.conn.commit()
216         return removed
217     
218     def findDirectory(self, directory):
219         """Store or update a directory in the database.
220         
221         @return: the index of the url directory, and whether it is new or not
222         """
223         c = self.conn.cursor()
224         c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
225         row = c.fetchone()
226         c.close()
227         if row['urldir']:
228             return row['urldir'], False
229
230         # Not found, need to add a new one
231         c = self.conn.cursor()
232         c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
233         self.conn.commit()
234         urldir = c.lastrowid
235         c.close()
236         return urldir, True
237         
238     def getAllDirectories(self):
239         """Get all the current directories avaliable."""
240         c = self.conn.cursor()
241         c.execute("SELECT urldir, path FROM dirs")
242         row = c.fetchone()
243         dirs = {}
244         while row:
245             dirs['~' + str(row['urldir'])] = FilePath(row['path'])
246             row = c.fetchone()
247         c.close()
248         return dirs
249     
250     def reconcileDirectories(self):
251         """Remove any unneeded directories by checking which are used by files."""
252         c = self.conn.cursor()
253         c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
254         self.conn.commit()
255         return bool(c.rowcount)
256         
257     def close(self):
258         self.conn.close()
259
260 class TestDB(unittest.TestCase):
261     """Tests for the khashmir database."""
262     
263     timeout = 5
264     db = FilePath('/tmp/khashmir.db')
265     file = FilePath('/tmp/apt-dht/khashmir.test')
266     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
267     directory = FilePath('/tmp/apt-dht/')
268     urlpath = '/~1/khashmir.test'
269     testfile = 'tmp/khashmir.test'
270     dirs = [FilePath('/tmp/apt-dht/top1'),
271             FilePath('/tmp/apt-dht/top2/sub1'),
272             FilePath('/tmp/apt-dht/top2/sub2/')]
273
274     def setUp(self):
275         if not self.file.parent().exists():
276             self.file.parent().makedirs()
277         self.file.setContent('fgfhds')
278         self.file.touch()
279         self.store = DB(self.db)
280         self.store.storeFile(self.file, self.hash, self.directory)
281
282     def test_openExistsingDB(self):
283         self.store.close()
284         self.store = None
285         sleep(1)
286         self.store = DB(self.db)
287         res = self.store.isUnchanged(self.file)
288         self.failUnless(res)
289
290     def test_getFile(self):
291         res = self.store.getFile(self.file)
292         self.failUnless(res)
293         self.failUnlessEqual(res['hash'], self.hash)
294         self.failUnlessEqual(res['urlpath'], self.urlpath)
295         
296     def test_getAllDirectories(self):
297         res = self.store.getAllDirectories()
298         self.failUnless(res)
299         self.failUnlessEqual(len(res.keys()), 1)
300         self.failUnlessEqual(res.keys()[0], '~1')
301         self.failUnlessEqual(res['~1'], self.directory)
302         
303     def test_isUnchanged(self):
304         res = self.store.isUnchanged(self.file)
305         self.failUnless(res)
306         sleep(2)
307         self.file.touch()
308         res = self.store.isUnchanged(self.file)
309         self.failUnless(res == False)
310         self.file.remove()
311         res = self.store.isUnchanged(self.file)
312         self.failUnless(res == None)
313         
314     def test_expiry(self):
315         res = self.store.expiredFiles(1)
316         self.failUnlessEqual(len(res.keys()), 0)
317         sleep(2)
318         res = self.store.expiredFiles(1)
319         self.failUnlessEqual(len(res.keys()), 1)
320         self.failUnlessEqual(res.keys()[0], self.hash)
321         self.failUnlessEqual(len(res[self.hash]), 1)
322         self.failUnlessEqual(res[self.hash][0], self.urlpath)
323         res = self.store.refreshFile(self.file)
324         self.failUnless(res)
325         res = self.store.expiredFiles(1)
326         self.failUnlessEqual(len(res.keys()), 0)
327         
328     def build_dirs(self):
329         for dir in self.dirs:
330             file = dir.preauthChild(self.testfile)
331             if not file.parent().exists():
332                 file.parent().makedirs()
333             file.setContent(file.path)
334             file.touch()
335             self.store.storeFile(file, self.hash, dir)
336     
337     def test_removeUntracked(self):
338         self.build_dirs()
339         res = self.store.removeUntrackedFiles(self.dirs)
340         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
341         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
342         res = self.store.removeUntrackedFiles(self.dirs)
343         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
344         res = self.store.removeUntrackedFiles(self.dirs[1:])
345         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
346         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
347         res = self.store.removeUntrackedFiles(self.dirs[:1])
348         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
349         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
350         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
351         
352     def test_reconcileDirectories(self):
353         self.build_dirs()
354         res = self.store.getAllDirectories()
355         self.failUnless(res)
356         self.failUnlessEqual(len(res.keys()), 4)
357         res = self.store.reconcileDirectories()
358         self.failUnlessEqual(res, False)
359         res = self.store.getAllDirectories()
360         self.failUnless(res)
361         self.failUnlessEqual(len(res.keys()), 4)
362         res = self.store.removeUntrackedFiles(self.dirs)
363         res = self.store.reconcileDirectories()
364         self.failUnlessEqual(res, True)
365         res = self.store.getAllDirectories()
366         self.failUnless(res)
367         self.failUnlessEqual(len(res.keys()), 3)
368         res = self.store.removeUntrackedFiles(self.dirs[:1])
369         res = self.store.reconcileDirectories()
370         self.failUnlessEqual(res, True)
371         res = self.store.getAllDirectories()
372         self.failUnless(res)
373         self.failUnlessEqual(len(res.keys()), 1)
374         res = self.store.removeUntrackedFiles([FilePath('/what')])
375         res = self.store.reconcileDirectories()
376         self.failUnlessEqual(res, True)
377         res = self.store.getAllDirectories()
378         self.failUnlessEqual(len(res.keys()), 0)
379         
380     def tearDown(self):
381         self.directory.remove()
382         self.store.close()
383         self.db.remove()