1d2e34273ad98161e365ce54ab7b4d5ecb7f76f0
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
10
11 assert sqlite.version_info >= (2, 1)
12
13 class DBExcept(Exception):
14     pass
15
16 class khash(str):
17     """Dummy class to convert all hashes to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
23
24 class DB:
25     """Database access for storing persistent data."""
26     
27     def __init__(self, db):
28         self.db = db
29         self.db.restat(False)
30         if self.db.exists():
31             self._loadDB()
32         else:
33             self._createNewDB()
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self):
38         try:
39             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self):
45         if not self.db.parent().exists():
46             self.db.parent().makedirs()
47         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48         c = self.conn.cursor()
49         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50         c.execute("CREATE INDEX files_urldir ON files(urldir)")
51         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
52         c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
53         c.execute("CREATE INDEX dirs_path ON dirs(path)")
54         c.close()
55         self.conn.commit()
56
57     def _removeChanged(self, file, row):
58         res = None
59         if row:
60             file.restat(False)
61             if file.exists():
62                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
63             if not res:
64                 c = self.conn.cursor()
65                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
66                 self.conn.commit()
67                 c.close()
68         return res
69         
70     def storeFile(self, file, hash, directory):
71         """Store or update a file in the database.
72         
73         @return: the urlpath to access the file, and whether a
74             new url top-level directory was needed
75         """
76         file.restat()
77         c = self.conn.cursor()
78         c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (file.path, ))
79         row = c.fetchone()
80         if row and directory == row['directory']:
81             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
82                       (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
83             newdir = False
84             urldir = row['urldir']
85         else:
86             urldir, newdir = self.findDirectory(directory)
87             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
88                       (file.path, khash(hash), urldir, len(directory.path), file.getsize(), file.getmtime(), datetime.now()))
89         self.conn.commit()
90         c.close()
91         return '/~' + str(urldir) + file.path[len(directory.path):], newdir
92         
93     def getFile(self, file):
94         """Get a file from the database.
95         
96         If it has changed or is missing, it is removed from the database.
97         
98         @return: dictionary of info for the file, False if changed, or
99             None if not in database or missing
100         """
101         c = self.conn.cursor()
102         c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (file.path, ))
103         row = c.fetchone()
104         res = self._removeChanged(file, row)
105         if res:
106             res = {}
107             res['hash'] = row['hash']
108             res['size'] = row['size']
109             res['urlpath'] = '/~' + str(row['urldir']) + file.path[row['dirlength']:]
110         c.close()
111         return res
112         
113     def isUnchanged(self, file):
114         """Check if a file in the file system has changed.
115         
116         If it has changed, it is removed from the table.
117         
118         @return: True if unchanged, False if changed, None if not in database
119         """
120         c = self.conn.cursor()
121         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
122         row = c.fetchone()
123         return self._removeChanged(file, row)
124
125     def refreshFile(self, file):
126         """Refresh the publishing time of a file.
127         
128         If it has changed or is missing, it is removed from the table.
129         
130         @return: True if unchanged, False if changed, None if not in database
131         """
132         c = self.conn.cursor()
133         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
134         row = c.fetchone()
135         res = self._removeChanged(file, row)
136         if res:
137             c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
138         return res
139     
140     def expiredFiles(self, expireAfter):
141         """Find files that need refreshing after expireAfter seconds.
142         
143         Also removes any entries from the table that no longer exist.
144         
145         @return: dictionary with keys the hashes, values a list of url paths
146         """
147         t = datetime.now() - timedelta(seconds=expireAfter)
148         c = self.conn.cursor()
149         c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
150         row = c.fetchone()
151         expired = {}
152         while row:
153             res = self._removeChanged(FilePath(row['path']), row)
154             if res:
155                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
156             row = c.fetchone()
157         c.close()
158         return expired
159         
160     def removeUntrackedFiles(self, dirs):
161         """Find files that are no longer tracked and so should be removed.
162         
163         Also removes the entries from the table.
164         
165         @return: list of files that were removed
166         """
167         assert len(dirs) >= 1
168         newdirs = []
169         sql = "WHERE"
170         for dir in dirs:
171             newdirs.append(dir.child('*').path)
172             sql += " path NOT GLOB ? AND"
173         sql = sql[:-4]
174
175         c = self.conn.cursor()
176         c.execute("SELECT path FROM files " + sql, newdirs)
177         row = c.fetchone()
178         removed = []
179         while row:
180             removed.append(FilePath(row['path']))
181             row = c.fetchone()
182
183         if removed:
184             c.execute("DELETE FROM files " + sql, newdirs)
185         self.conn.commit()
186         return removed
187     
188     def findDirectory(self, directory):
189         """Store or update a directory in the database.
190         
191         @return: the index of the url directory, and whether it is new or not
192         """
193         c = self.conn.cursor()
194         c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory.path, ))
195         row = c.fetchone()
196         c.close()
197         if row['urldir']:
198             return row['urldir'], False
199
200         # Not found, need to add a new one
201         c = self.conn.cursor()
202         c.execute("INSERT INTO dirs (path) VALUES (?)", (directory.path, ))
203         self.conn.commit()
204         urldir = c.lastrowid
205         c.close()
206         return urldir, True
207         
208     def getAllDirectories(self):
209         """Get all the current directories avaliable."""
210         c = self.conn.cursor()
211         c.execute("SELECT urldir, path FROM dirs")
212         row = c.fetchone()
213         dirs = {}
214         while row:
215             dirs['~' + str(row['urldir'])] = FilePath(row['path'])
216             row = c.fetchone()
217         c.close()
218         return dirs
219     
220     def reconcileDirectories(self):
221         """Remove any unneeded directories by checking which are used by files."""
222         c = self.conn.cursor()
223         c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
224         self.conn.commit()
225         return bool(c.rowcount)
226         
227     def close(self):
228         self.conn.close()
229
230 class TestDB(unittest.TestCase):
231     """Tests for the khashmir database."""
232     
233     timeout = 5
234     db = FilePath('/tmp/khashmir.db')
235     file = FilePath('/tmp/apt-dht/khashmir.test')
236     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
237     directory = FilePath('/tmp/apt-dht/')
238     urlpath = '/~1/khashmir.test'
239     testfile = 'tmp/khashmir.test'
240     dirs = [FilePath('/tmp/apt-dht/top1'),
241             FilePath('/tmp/apt-dht/top2/sub1'),
242             FilePath('/tmp/apt-dht/top2/sub2/')]
243
244     def setUp(self):
245         if not self.file.parent().exists():
246             self.file.parent().makedirs()
247         self.file.setContent('fgfhds')
248         self.file.touch()
249         self.store = DB(self.db)
250         self.store.storeFile(self.file, self.hash, self.directory)
251
252     def test_openExistsingDB(self):
253         self.store.close()
254         self.store = None
255         sleep(1)
256         self.store = DB(self.db)
257         res = self.store.isUnchanged(self.file)
258         self.failUnless(res)
259
260     def test_getFile(self):
261         res = self.store.getFile(self.file)
262         self.failUnless(res)
263         self.failUnlessEqual(res['hash'], self.hash)
264         self.failUnlessEqual(res['urlpath'], self.urlpath)
265         
266     def test_getAllDirectories(self):
267         res = self.store.getAllDirectories()
268         self.failUnless(res)
269         self.failUnlessEqual(len(res.keys()), 1)
270         self.failUnlessEqual(res.keys()[0], '~1')
271         self.failUnlessEqual(res['~1'], self.directory)
272         
273     def test_isUnchanged(self):
274         res = self.store.isUnchanged(self.file)
275         self.failUnless(res)
276         sleep(2)
277         self.file.touch()
278         res = self.store.isUnchanged(self.file)
279         self.failUnless(res == False)
280         self.file.remove()
281         res = self.store.isUnchanged(self.file)
282         self.failUnless(res == None)
283         
284     def test_expiry(self):
285         res = self.store.expiredFiles(1)
286         self.failUnlessEqual(len(res.keys()), 0)
287         sleep(2)
288         res = self.store.expiredFiles(1)
289         self.failUnlessEqual(len(res.keys()), 1)
290         self.failUnlessEqual(res.keys()[0], self.hash)
291         self.failUnlessEqual(len(res[self.hash]), 1)
292         self.failUnlessEqual(res[self.hash][0], self.urlpath)
293         res = self.store.refreshFile(self.file)
294         self.failUnless(res)
295         res = self.store.expiredFiles(1)
296         self.failUnlessEqual(len(res.keys()), 0)
297         
298     def build_dirs(self):
299         for dir in self.dirs:
300             file = dir.preauthChild(self.testfile)
301             if not file.parent().exists():
302                 file.parent().makedirs()
303             file.setContent(file.path)
304             file.touch()
305             self.store.storeFile(file, self.hash, dir)
306     
307     def test_removeUntracked(self):
308         self.build_dirs()
309         res = self.store.removeUntrackedFiles(self.dirs)
310         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
311         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
312         res = self.store.removeUntrackedFiles(self.dirs)
313         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
314         res = self.store.removeUntrackedFiles(self.dirs[1:])
315         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
316         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
317         res = self.store.removeUntrackedFiles(self.dirs[:1])
318         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
319         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
320         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
321         
322     def test_reconcileDirectories(self):
323         self.build_dirs()
324         res = self.store.getAllDirectories()
325         self.failUnless(res)
326         self.failUnlessEqual(len(res.keys()), 4)
327         res = self.store.reconcileDirectories()
328         self.failUnlessEqual(res, False)
329         res = self.store.getAllDirectories()
330         self.failUnless(res)
331         self.failUnlessEqual(len(res.keys()), 4)
332         res = self.store.removeUntrackedFiles(self.dirs)
333         res = self.store.reconcileDirectories()
334         self.failUnlessEqual(res, True)
335         res = self.store.getAllDirectories()
336         self.failUnless(res)
337         self.failUnlessEqual(len(res.keys()), 3)
338         res = self.store.removeUntrackedFiles(self.dirs[:1])
339         res = self.store.reconcileDirectories()
340         self.failUnlessEqual(res, True)
341         res = self.store.getAllDirectories()
342         self.failUnless(res)
343         self.failUnlessEqual(len(res.keys()), 1)
344         res = self.store.removeUntrackedFiles([FilePath('/what')])
345         res = self.store.reconcileDirectories()
346         self.failUnlessEqual(res, True)
347         res = self.store.getAllDirectories()
348         self.failUnlessEqual(len(res.keys()), 0)
349         
350     def tearDown(self):
351         self.directory.remove()
352         self.store.close()
353         self.db.remove()