HTTPServer uses the hash to lookup the file in the DB (no more directories).
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
10
11 assert sqlite.version_info >= (2, 1)
12
13 class DBExcept(Exception):
14     pass
15
16 class khash(str):
17     """Dummy class to convert all hashes to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
23
24 class DB:
25     """Database access for storing persistent data."""
26     
27     def __init__(self, db):
28         self.db = db
29         self.db.restat(False)
30         if self.db.exists():
31             self._loadDB()
32         else:
33             self._createNewDB()
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self):
38         try:
39             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self):
45         if not self.db.parent().exists():
46             self.db.parent().makedirs()
47         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48         c = self.conn.cursor()
49         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
50         c.execute("CREATE INDEX files_hash ON files(hash)")
51         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
52         c.close()
53         self.conn.commit()
54
55     def _removeChanged(self, file, row):
56         res = None
57         if row:
58             file.restat(False)
59             if file.exists():
60                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
61             if not res:
62                 c = self.conn.cursor()
63                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
64                 self.conn.commit()
65                 c.close()
66         return res
67         
68     def storeFile(self, file, hash):
69         """Store or update a file in the database."""
70         file.restat()
71         c = self.conn.cursor()
72         c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
73         row = c.fetchone()
74         if row:
75             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
76                       (khash(hash), file.getsize(), file.getmtime(), datetime.now()))
77         else:
78             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
79                       (file.path, khash(hash), file.getsize(), file.getmtime(), datetime.now()))
80         self.conn.commit()
81         c.close()
82         
83     def getFile(self, file):
84         """Get a file from the database.
85         
86         If it has changed or is missing, it is removed from the database.
87         
88         @return: dictionary of info for the file, False if changed, or
89             None if not in database or missing
90         """
91         c = self.conn.cursor()
92         c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
93         row = c.fetchone()
94         res = None
95         if row:
96             res = self._removeChanged(file, row)
97             if res:
98                 res = {}
99                 res['hash'] = row['hash']
100                 res['size'] = row['size']
101         c.close()
102         return res
103         
104     def lookupHash(self, hash):
105         """Find a file by hash in the database.
106         
107         If any found files have changed or are missing, they are removed
108         from the database.
109         
110         @return: list of dictionaries of info for the found files
111         """
112         c = self.conn.cursor()
113         c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
114         row = c.fetchone()
115         files = []
116         while row:
117             file = FilePath(row['path'])
118             res = self._removeChanged(file, row)
119             if res:
120                 res = {}
121                 res['path'] = file
122                 res['size'] = row['size']
123                 files.append(res)
124             row = c.fetchone()
125         c.close()
126         return files
127         
128     def isUnchanged(self, file):
129         """Check if a file in the file system has changed.
130         
131         If it has changed, it is removed from the table.
132         
133         @return: True if unchanged, False if changed, None if not in database
134         """
135         c = self.conn.cursor()
136         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
137         row = c.fetchone()
138         return self._removeChanged(file, row)
139
140     def refreshFile(self, file):
141         """Refresh the publishing time of a file.
142         
143         If it has changed or is missing, it is removed from the table.
144         
145         @return: True if unchanged, False if changed, None if not in database
146         """
147         c = self.conn.cursor()
148         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
149         row = c.fetchone()
150         res = None
151         if row:
152             res = self._removeChanged(file, row)
153             if res:
154                 c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), file.path))
155         return res
156     
157     def expiredFiles(self, expireAfter):
158         """Find files that need refreshing after expireAfter seconds.
159         
160         Also removes any entries from the table that no longer exist.
161         
162         @return: dictionary with keys the hashes, values a list of FilePaths
163         """
164         t = datetime.now() - timedelta(seconds=expireAfter)
165         c = self.conn.cursor()
166         c.execute("SELECT path, hash, size, mtime FROM files WHERE refreshed < ?", (t, ))
167         row = c.fetchone()
168         expired = {}
169         while row:
170             res = self._removeChanged(FilePath(row['path']), row)
171             if res:
172                 expired.setdefault(row['hash'], []).append(FilePath(row['path']))
173             row = c.fetchone()
174         c.close()
175         return expired
176         
177     def removeUntrackedFiles(self, dirs):
178         """Find files that are no longer tracked and so should be removed.
179         
180         Also removes the entries from the table.
181         
182         @return: list of files that were removed
183         """
184         assert len(dirs) >= 1
185         newdirs = []
186         sql = "WHERE"
187         for dir in dirs:
188             newdirs.append(dir.child('*').path)
189             sql += " path NOT GLOB ? AND"
190         sql = sql[:-4]
191
192         c = self.conn.cursor()
193         c.execute("SELECT path FROM files " + sql, newdirs)
194         row = c.fetchone()
195         removed = []
196         while row:
197             removed.append(FilePath(row['path']))
198             row = c.fetchone()
199
200         if removed:
201             c.execute("DELETE FROM files " + sql, newdirs)
202         self.conn.commit()
203         return removed
204     
205     def close(self):
206         self.conn.close()
207
208 class TestDB(unittest.TestCase):
209     """Tests for the khashmir database."""
210     
211     timeout = 5
212     db = FilePath('/tmp/khashmir.db')
213     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
214     directory = FilePath('/tmp/apt-dht/')
215     file = FilePath('/tmp/apt-dht/khashmir.test')
216     testfile = 'tmp/khashmir.test'
217     dirs = [FilePath('/tmp/apt-dht/top1'),
218             FilePath('/tmp/apt-dht/top2/sub1'),
219             FilePath('/tmp/apt-dht/top2/sub2/')]
220
221     def setUp(self):
222         if not self.file.parent().exists():
223             self.file.parent().makedirs()
224         self.file.setContent('fgfhds')
225         self.file.touch()
226         self.store = DB(self.db)
227         self.store.storeFile(self.file, self.hash)
228
229     def test_openExistsingDB(self):
230         self.store.close()
231         self.store = None
232         sleep(1)
233         self.store = DB(self.db)
234         res = self.store.isUnchanged(self.file)
235         self.failUnless(res)
236
237     def test_getFile(self):
238         res = self.store.getFile(self.file)
239         self.failUnless(res)
240         self.failUnlessEqual(res['hash'], self.hash)
241         
242     def test_isUnchanged(self):
243         res = self.store.isUnchanged(self.file)
244         self.failUnless(res)
245         sleep(2)
246         self.file.touch()
247         res = self.store.isUnchanged(self.file)
248         self.failUnless(res == False)
249         self.file.remove()
250         res = self.store.isUnchanged(self.file)
251         self.failUnless(res == None)
252         
253     def test_expiry(self):
254         res = self.store.expiredFiles(1)
255         self.failUnlessEqual(len(res.keys()), 0)
256         sleep(2)
257         res = self.store.expiredFiles(1)
258         self.failUnlessEqual(len(res.keys()), 1)
259         self.failUnlessEqual(res.keys()[0], self.hash)
260         self.failUnlessEqual(len(res[self.hash]), 1)
261         res = self.store.refreshFile(self.file)
262         self.failUnless(res)
263         res = self.store.expiredFiles(1)
264         self.failUnlessEqual(len(res.keys()), 0)
265         
266     def build_dirs(self):
267         for dir in self.dirs:
268             file = dir.preauthChild(self.testfile)
269             if not file.parent().exists():
270                 file.parent().makedirs()
271             file.setContent(file.path)
272             file.touch()
273             self.store.storeFile(file, self.hash)
274     
275     def test_removeUntracked(self):
276         self.build_dirs()
277         res = self.store.removeUntrackedFiles(self.dirs)
278         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
279         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
280         res = self.store.removeUntrackedFiles(self.dirs)
281         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
282         res = self.store.removeUntrackedFiles(self.dirs[1:])
283         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
284         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
285         res = self.store.removeUntrackedFiles(self.dirs[:1])
286         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
287         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
288         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
289         
290     def tearDown(self):
291         self.directory.remove()
292         self.store.close()
293         self.db.remove()