9725aa88f10c36e0bd137de1cb5051850d7cde91
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 assert sqlite.version_info >= (2, 1)
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
21 sqlite.enable_callback_tracebacks(True)
22
23 class DB:
24     """Database access for storing persistent data."""
25     
26     def __init__(self, db):
27         self.db = db
28         try:
29             os.stat(db)
30         except OSError:
31             self._createNewDB(db)
32         else:
33             self._loadDB(db)
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self, db):
38         try:
39             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self, db):
45         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         c = self.conn.cursor()
47         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
48         c.execute("CREATE INDEX files_urldir ON files(urldir)")
49         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
50         c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
51         c.execute("CREATE INDEX dirs_path ON dirs(path)")
52         c.close()
53         self.conn.commit()
54
55     def _removeChanged(self, path, row):
56         res = None
57         if row:
58             try:
59                 stat = os.stat(path)
60             except:
61                 stat = None
62             if stat:
63                 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
64             if not res:
65                 c = self.conn.cursor()
66                 c.execute("DELETE FROM files WHERE path = ?", (path, ))
67                 self.conn.commit()
68                 c.close()
69         return res
70         
71     def storeFile(self, path, hash, directory):
72         """Store or update a file in the database.
73         
74         @return: the urlpath to access the file, and whether a
75             new url top-level directory was needed
76         """
77         path = os.path.abspath(path)
78         directory = os.path.abspath(directory)
79         assert path.startswith(directory)
80         stat = os.stat(path)
81         c = self.conn.cursor()
82         c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (path, ))
83         row = c.fetchone()
84         if row and directory == row['directory']:
85             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
86                       (khash(hash), stat.st_size, stat.st_mtime, datetime.now()))
87             newdir = False
88             urldir = row['urldir']
89         else:
90             urldir, newdir = self.findDirectory(directory)
91             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
92                       (path, khash(hash), urldir, len(directory), stat.st_size, stat.st_mtime, datetime.now()))
93         self.conn.commit()
94         c.close()
95         return '/~' + str(urldir) + path[len(directory):], newdir
96         
97     def getFile(self, path):
98         """Get a file from the database.
99         
100         If it has changed or is missing, it is removed from the database.
101         
102         @return: dictionary of info for the file, False if changed, or
103             None if not in database or missing
104         """
105         path = os.path.abspath(path)
106         c = self.conn.cursor()
107         c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (path, ))
108         row = c.fetchone()
109         res = self._removeChanged(path, row)
110         if res:
111             res = {}
112             res['hash'] = row['hash']
113             res['urlpath'] = '/~' + str(row['urldir']) + path[row['dirlength']:]
114         c.close()
115         return res
116         
117     def isUnchanged(self, path):
118         """Check if a file in the file system has changed.
119         
120         If it has changed, it is removed from the table.
121         
122         @return: True if unchanged, False if changed, None if not in database
123         """
124         path = os.path.abspath(path)
125         c = self.conn.cursor()
126         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
127         row = c.fetchone()
128         return self._removeChanged(path, row)
129
130     def refreshFile(self, path):
131         """Refresh the publishing time of a file.
132         
133         If it has changed or is missing, it is removed from the table.
134         
135         @return: True if unchanged, False if changed, None if not in database
136         """
137         path = os.path.abspath(path)
138         c = self.conn.cursor()
139         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
140         row = c.fetchone()
141         res = self._removeChanged(path, row)
142         if res:
143             c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
144         return res
145     
146     def expiredFiles(self, expireAfter):
147         """Find files that need refreshing after expireAfter seconds.
148         
149         Also removes any entries from the table that no longer exist.
150         
151         @return: dictionary with keys the hashes, values a list of url paths
152         """
153         t = datetime.now() - timedelta(seconds=expireAfter)
154         c = self.conn.cursor()
155         c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
156         row = c.fetchone()
157         expired = {}
158         while row:
159             res = self._removeChanged(row['path'], row)
160             if res:
161                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
162             row = c.fetchone()
163         c.close()
164         return expired
165         
166     def removeUntrackedFiles(self, dirs):
167         """Find files that are no longer tracked and so should be removed.
168         
169         Also removes the entries from the table.
170         
171         @return: list of files that were removed
172         """
173         assert len(dirs) >= 1
174         newdirs = []
175         sql = "WHERE"
176         for dir in dirs:
177             newdirs.append(os.path.abspath(dir) + os.sep + '*')
178             sql += " path NOT GLOB ? AND"
179         sql = sql[:-4]
180
181         c = self.conn.cursor()
182         c.execute("SELECT path FROM files " + sql, newdirs)
183         row = c.fetchone()
184         removed = []
185         while row:
186             removed.append(row['path'])
187             row = c.fetchone()
188
189         if removed:
190             c.execute("DELETE FROM files " + sql, newdirs)
191         self.conn.commit()
192         return removed
193     
194     def findDirectory(self, directory):
195         """Store or update a directory in the database.
196         
197         @return: the index of the url directory, and whether it is new or not
198         """
199         directory = os.path.abspath(directory)
200         c = self.conn.cursor()
201         c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory, ))
202         row = c.fetchone()
203         c.close()
204         if row['urldir']:
205             return row['urldir'], False
206
207         # Not found, need to add a new one
208         c = self.conn.cursor()
209         c.execute("INSERT INTO dirs (path) VALUES (?)", (directory, ))
210         self.conn.commit()
211         urldir = c.lastrowid
212         c.close()
213         return urldir, True
214         
215     def getAllDirectories(self):
216         """Get all the current directories avaliable."""
217         c = self.conn.cursor()
218         c.execute("SELECT urldir, path FROM dirs")
219         row = c.fetchone()
220         dirs = {}
221         while row:
222             dirs['~' + str(row['urldir'])] = row['path']
223             row = c.fetchone()
224         c.close()
225         return dirs
226     
227     def reconcileDirectories(self):
228         """Remove any unneeded directories by checking which are used by files."""
229         c = self.conn.cursor()
230         c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
231         self.conn.commit()
232         return bool(c.rowcount)
233         
234     def close(self):
235         self.conn.close()
236
237 class TestDB(unittest.TestCase):
238     """Tests for the khashmir database."""
239     
240     timeout = 5
241     db = '/tmp/khashmir.db'
242     path = '/tmp/khashmir.test'
243     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
244     directory = '/tmp/'
245     urlpath = '/~1/khashmir.test'
246     dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
247
248     def setUp(self):
249         f = open(self.path, 'w')
250         f.write('fgfhds')
251         f.close()
252         os.utime(self.path, None)
253         self.store = DB(self.db)
254         self.store.storeFile(self.path, self.hash, self.directory)
255
256     def test_getFile(self):
257         res = self.store.getFile(self.path)
258         self.failUnless(res)
259         self.failUnlessEqual(res['hash'], self.hash)
260         self.failUnlessEqual(res['urlpath'], self.urlpath)
261         
262     def test_getAllDirectories(self):
263         res = self.store.getAllDirectories()
264         self.failUnless(res)
265         self.failUnlessEqual(len(res.keys()), 1)
266         self.failUnlessEqual(res.keys()[0], '~1')
267         self.failUnlessEqual(res['~1'], os.path.abspath(self.directory))
268         
269     def test_isUnchanged(self):
270         res = self.store.isUnchanged(self.path)
271         self.failUnless(res)
272         sleep(2)
273         os.utime(self.path, None)
274         res = self.store.isUnchanged(self.path)
275         self.failUnless(res == False)
276         os.unlink(self.path)
277         res = self.store.isUnchanged(self.path)
278         self.failUnless(res == None)
279         
280     def test_expiry(self):
281         res = self.store.expiredFiles(1)
282         self.failUnlessEqual(len(res.keys()), 0)
283         sleep(2)
284         res = self.store.expiredFiles(1)
285         self.failUnlessEqual(len(res.keys()), 1)
286         self.failUnlessEqual(res.keys()[0], self.hash)
287         self.failUnlessEqual(len(res[self.hash]), 1)
288         self.failUnlessEqual(res[self.hash][0], self.urlpath)
289         res = self.store.refreshFile(self.path)
290         self.failUnless(res)
291         res = self.store.expiredFiles(1)
292         self.failUnlessEqual(len(res.keys()), 0)
293         
294     def build_dirs(self):
295         for dir in self.dirs:
296             path = os.path.join(dir, self.path[1:])
297             os.makedirs(os.path.dirname(path))
298             f = open(path, 'w')
299             f.write(path)
300             f.close()
301             os.utime(path, None)
302             self.store.storeFile(path, self.hash, dir)
303     
304     def test_removeUntracked(self):
305         self.build_dirs()
306         res = self.store.removeUntrackedFiles(self.dirs)
307         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
308         self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
309         res = self.store.removeUntrackedFiles(self.dirs)
310         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
311         res = self.store.removeUntrackedFiles(self.dirs[1:])
312         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
313         self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
314         res = self.store.removeUntrackedFiles(self.dirs[:1])
315         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
316         self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
317         self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
318         
319     def test_reconcileDirectories(self):
320         self.build_dirs()
321         res = self.store.getAllDirectories()
322         self.failUnless(res)
323         self.failUnlessEqual(len(res.keys()), 4)
324         res = self.store.reconcileDirectories()
325         self.failUnlessEqual(res, False)
326         res = self.store.getAllDirectories()
327         self.failUnless(res)
328         self.failUnlessEqual(len(res.keys()), 4)
329         res = self.store.removeUntrackedFiles(self.dirs)
330         res = self.store.reconcileDirectories()
331         self.failUnlessEqual(res, True)
332         res = self.store.getAllDirectories()
333         self.failUnless(res)
334         self.failUnlessEqual(len(res.keys()), 3)
335         res = self.store.removeUntrackedFiles(self.dirs[:1])
336         res = self.store.reconcileDirectories()
337         self.failUnlessEqual(res, True)
338         res = self.store.getAllDirectories()
339         self.failUnless(res)
340         self.failUnlessEqual(len(res.keys()), 1)
341         res = self.store.removeUntrackedFiles(['/what'])
342         res = self.store.reconcileDirectories()
343         self.failUnlessEqual(res, True)
344         res = self.store.getAllDirectories()
345         self.failUnlessEqual(len(res.keys()), 0)
346         
347     def tearDown(self):
348         for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
349             for name in files:
350                 os.remove(os.path.join(root, name))
351             for name in dirs:
352                 os.rmdir(os.path.join(root, name))
353         self.store.close()
354         os.unlink(self.db)