Also remove changed cache files during directory scan.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
10
11 assert sqlite.version_info >= (2, 1)
12
13 class DBExcept(Exception):
14     pass
15
16 class khash(str):
17     """Dummy class to convert all hashes to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
23
24 class DB:
25     """Database access for storing persistent data."""
26     
27     def __init__(self, db):
28         self.db = db
29         self.db.restat(False)
30         if self.db.exists():
31             self._loadDB()
32         else:
33             self._createNewDB()
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self):
38         try:
39             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self):
45         if not self.db.parent().exists():
46             self.db.parent().makedirs()
47         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48         c = self.conn.cursor()
49         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50         c.execute("CREATE INDEX files_hash ON files(hash)")
51         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
52         c.close()
53         self.conn.commit()
54
55     def _removeChanged(self, file, row):
56         res = None
57         if row:
58             file.restat(False)
59             if file.exists():
60                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
61             if not res:
62                 c = self.conn.cursor()
63                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
64                 self.conn.commit()
65                 c.close()
66         return res
67         
68     def storeFile(self, file, hash):
69         """Store or update a file in the database.
70         
71         @return: True if the hash was not in the database before
72             (so it needs to be added to the DHT)
73         """
74         new_hash = True
75         refreshTime = datetime.now()
76         c = self.conn.cursor()
77         c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
78         row = c.fetchone()
79         if row and row['max_refresh']:
80             new_hash = False
81             refreshTime = row['max_refresh']
82         c.close()
83         
84         file.restat()
85         c = self.conn.cursor()
86         c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
87         row = c.fetchone()
88         if row:
89             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
90                       (khash(hash), file.getsize(), file.getmtime(), refreshTime))
91         else:
92             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
93                       (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
94         self.conn.commit()
95         c.close()
96         
97         return new_hash
98         
99     def getFile(self, file):
100         """Get a file from the database.
101         
102         If it has changed or is missing, it is removed from the database.
103         
104         @return: dictionary of info for the file, False if changed, or
105             None if not in database or missing
106         """
107         c = self.conn.cursor()
108         c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
109         row = c.fetchone()
110         res = None
111         if row:
112             res = self._removeChanged(file, row)
113             if res:
114                 res = {}
115                 res['hash'] = row['hash']
116                 res['size'] = row['size']
117         c.close()
118         return res
119         
120     def lookupHash(self, hash):
121         """Find a file by hash in the database.
122         
123         If any found files have changed or are missing, they are removed
124         from the database.
125         
126         @return: list of dictionaries of info for the found files
127         """
128         c = self.conn.cursor()
129         c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
130         row = c.fetchone()
131         files = []
132         while row:
133             file = FilePath(row['path'])
134             res = self._removeChanged(file, row)
135             if res:
136                 res = {}
137                 res['path'] = file
138                 res['size'] = row['size']
139                 res['refreshed'] = row['refreshed']
140                 files.append(res)
141             row = c.fetchone()
142         c.close()
143         return files
144         
145     def isUnchanged(self, file):
146         """Check if a file in the file system has changed.
147         
148         If it has changed, it is removed from the table.
149         
150         @return: True if unchanged, False if changed, None if not in database
151         """
152         c = self.conn.cursor()
153         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
154         row = c.fetchone()
155         return self._removeChanged(file, row)
156
157     def refreshHash(self, hash):
158         """Refresh the publishing time all files with a hash."""
159         refreshTime = datetime.now()
160         c = self.conn.cursor()
161         c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
162         c.close()
163     
164     def expiredFiles(self, expireAfter):
165         """Find files that need refreshing after expireAfter seconds.
166         
167         For each hash that needs refreshing, finds all the files with that hash.
168         If the file has changed or is missing, it is removed from the table.
169         
170         @return: dictionary with keys the hashes, values a list of FilePaths
171         """
172         t = datetime.now() - timedelta(seconds=expireAfter)
173         
174         # First find the hashes that need refreshing
175         c = self.conn.cursor()
176         c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
177         row = c.fetchone()
178         expired = {}
179         while row:
180             expired.setdefault(row['hash'], [])
181             row = c.fetchone()
182         c.close()
183
184         # Now find the files for each hash
185         for hash in expired.keys():
186             c = self.conn.cursor()
187             c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
188             row = c.fetchone()
189             while row:
190                 res = self._removeChanged(FilePath(row['path']), row)
191                 if res:
192                     expired[hash].append(FilePath(row['path']))
193                 row = c.fetchone()
194             if len(expired[hash]) == 0:
195                 del expired[hash]
196             c.close()
197         
198         return expired
199         
200     def removeUntrackedFiles(self, dirs):
201         """Find files that are no longer tracked and so should be removed.
202         
203         Also removes the entries from the table.
204         
205         @return: list of files that were removed
206         """
207         assert len(dirs) >= 1
208         newdirs = []
209         sql = "WHERE"
210         for dir in dirs:
211             newdirs.append(dir.child('*').path)
212             sql += " path NOT GLOB ? AND"
213         sql = sql[:-4]
214
215         c = self.conn.cursor()
216         c.execute("SELECT path FROM files " + sql, newdirs)
217         row = c.fetchone()
218         removed = []
219         while row:
220             removed.append(FilePath(row['path']))
221             row = c.fetchone()
222
223         if removed:
224             c.execute("DELETE FROM files " + sql, newdirs)
225         self.conn.commit()
226         return removed
227     
228     def close(self):
229         self.conn.close()
230
231 class TestDB(unittest.TestCase):
232     """Tests for the khashmir database."""
233     
234     timeout = 5
235     db = FilePath('/tmp/khashmir.db')
236     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
237     directory = FilePath('/tmp/apt-dht/')
238     file = FilePath('/tmp/apt-dht/khashmir.test')
239     testfile = 'tmp/khashmir.test'
240     dirs = [FilePath('/tmp/apt-dht/top1'),
241             FilePath('/tmp/apt-dht/top2/sub1'),
242             FilePath('/tmp/apt-dht/top2/sub2/')]
243
244     def setUp(self):
245         if not self.file.parent().exists():
246             self.file.parent().makedirs()
247         self.file.setContent('fgfhds')
248         self.file.touch()
249         self.store = DB(self.db)
250         self.store.storeFile(self.file, self.hash)
251
252     def test_openExistsingDB(self):
253         self.store.close()
254         self.store = None
255         sleep(1)
256         self.store = DB(self.db)
257         res = self.store.isUnchanged(self.file)
258         self.failUnless(res)
259
260     def test_getFile(self):
261         res = self.store.getFile(self.file)
262         self.failUnless(res)
263         self.failUnlessEqual(res['hash'], self.hash)
264         
265     def test_lookupHash(self):
266         res = self.store.lookupHash(self.hash)
267         self.failUnless(res)
268         self.failUnlessEqual(len(res), 1)
269         self.failUnlessEqual(res[0]['path'].path, self.file.path)
270         
271     def test_isUnchanged(self):
272         res = self.store.isUnchanged(self.file)
273         self.failUnless(res)
274         sleep(2)
275         self.file.touch()
276         res = self.store.isUnchanged(self.file)
277         self.failUnless(res == False)
278         self.file.remove()
279         res = self.store.isUnchanged(self.file)
280         self.failUnless(res == None)
281         
282     def test_expiry(self):
283         res = self.store.expiredFiles(1)
284         self.failUnlessEqual(len(res.keys()), 0)
285         sleep(2)
286         res = self.store.expiredFiles(1)
287         self.failUnlessEqual(len(res.keys()), 1)
288         self.failUnlessEqual(res.keys()[0], self.hash)
289         self.failUnlessEqual(len(res[self.hash]), 1)
290         self.store.refreshHash(self.hash)
291         res = self.store.expiredFiles(1)
292         self.failUnlessEqual(len(res.keys()), 0)
293         
294     def build_dirs(self):
295         for dir in self.dirs:
296             file = dir.preauthChild(self.testfile)
297             if not file.parent().exists():
298                 file.parent().makedirs()
299             file.setContent(file.path)
300             file.touch()
301             self.store.storeFile(file, self.hash)
302     
303     def test_multipleHashes(self):
304         self.build_dirs()
305         res = self.store.expiredFiles(1)
306         self.failUnlessEqual(len(res.keys()), 0)
307         res = self.store.lookupHash(self.hash)
308         self.failUnless(res)
309         self.failUnlessEqual(len(res), 4)
310         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
311         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
312         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
313         sleep(2)
314         res = self.store.expiredFiles(1)
315         self.failUnlessEqual(len(res.keys()), 1)
316         self.failUnlessEqual(res.keys()[0], self.hash)
317         self.failUnlessEqual(len(res[self.hash]), 4)
318         self.store.refreshHash(self.hash)
319         res = self.store.expiredFiles(1)
320         self.failUnlessEqual(len(res.keys()), 0)
321     
322     def test_removeUntracked(self):
323         self.build_dirs()
324         res = self.store.removeUntrackedFiles(self.dirs)
325         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
326         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
327         res = self.store.removeUntrackedFiles(self.dirs)
328         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
329         res = self.store.removeUntrackedFiles(self.dirs[1:])
330         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
331         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
332         res = self.store.removeUntrackedFiles(self.dirs[:1])
333         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
334         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
335         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
336         
337     def tearDown(self):
338         self.directory.remove()
339         self.store.close()
340         self.db.remove()