Move the normalization of key lengths from the HashObject to the DHT.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.python.filepath import FilePath
9 from twisted.trial import unittest
10
11 assert sqlite.version_info >= (2, 1)
12
13 class DBExcept(Exception):
14     pass
15
16 class khash(str):
17     """Dummy class to convert all hashes to base64 for storing in the DB."""
18     
19 sqlite.register_adapter(khash, b2a_base64)
20 sqlite.register_converter("KHASH", a2b_base64)
21 sqlite.register_converter("khash", a2b_base64)
22 sqlite.enable_callback_tracebacks(True)
23
24 class DB:
25     """Database access for storing persistent data."""
26     
27     def __init__(self, db):
28         self.db = db
29         self.db.restat(False)
30         if self.db.exists():
31             self._loadDB()
32         else:
33             self._createNewDB()
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self):
38         try:
39             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self):
45         if not self.db.parent().exists():
46             self.db.parent().makedirs()
47         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
48         c = self.conn.cursor()
49         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
50                                       "size NUMBER, mtime NUMBER)")
51         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
52                                        "hash KHASH UNIQUE, pieces KHASH, " +
53                                        "piecehash KHASH, refreshed TIMESTAMP)")
54         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
55         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
56         c.close()
57         self.conn.commit()
58
59     def _removeChanged(self, file, row):
60         res = None
61         if row:
62             file.restat(False)
63             if file.exists():
64                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
65             if not res:
66                 c = self.conn.cursor()
67                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
68                 self.conn.commit()
69                 c.close()
70         return res
71         
72     def storeFile(self, file, hash, pieces = ''):
73         """Store or update a file in the database.
74         
75         @return: True if the hash was not in the database before
76             (so it needs to be added to the DHT)
77         """
78         piecehash = ''
79         if pieces:
80             s = sha.new().update(pieces)
81             piecehash = sha.digest()
82         c = self.conn.cursor()
83         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
84         row = c.fetchone()
85         if row:
86             assert piecehash == row['piecehash']
87             new_hash = False
88             hashID = row['hashID']
89         else:
90             c = self.conn.cursor()
91             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
92                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
93             self.conn.commit()
94             new_hash = True
95             hashID = c.lastrowid
96         
97         file.restat()
98         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
99                   (file.path, hashID, file.getsize(), file.getmtime()))
100         self.conn.commit()
101         c.close()
102         
103         return new_hash
104         
105     def getFile(self, file):
106         """Get a file from the database.
107         
108         If it has changed or is missing, it is removed from the database.
109         
110         @return: dictionary of info for the file, False if changed, or
111             None if not in database or missing
112         """
113         c = self.conn.cursor()
114         c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
115         row = c.fetchone()
116         res = None
117         if row:
118             res = self._removeChanged(file, row)
119             if res:
120                 res = {}
121                 res['hash'] = row['hash']
122                 res['size'] = row['size']
123                 res['pieces'] = row['pieces']
124         c.close()
125         return res
126         
127     def lookupHash(self, hash, filesOnly = False):
128         """Find a file by hash in the database.
129         
130         If any found files have changed or are missing, they are removed
131         from the database. If filesOnly is False then it will also look for
132         piece string hashes if no files can be found.
133         
134         @return: list of dictionaries of info for the found files
135         """
136         c = self.conn.cursor()
137         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
138         row = c.fetchone()
139         files = []
140         while row:
141             file = FilePath(row['path'])
142             res = self._removeChanged(file, row)
143             if res:
144                 res = {}
145                 res['path'] = file
146                 res['size'] = row['size']
147                 res['refreshed'] = row['refreshed']
148                 res['pieces'] = row['pieces']
149                 files.append(res)
150             row = c.fetchone()
151             
152         if not filesOnly and not files:
153             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
154             row = c.fetchone()
155             if row:
156                 res = {}
157                 res['refreshed'] = row['refreshed']
158                 res['pieces'] = row['pieces']
159                 files.append(res)
160
161         c.close()
162         return files
163         
164     def isUnchanged(self, file):
165         """Check if a file in the file system has changed.
166         
167         If it has changed, it is removed from the table.
168         
169         @return: True if unchanged, False if changed, None if not in database
170         """
171         c = self.conn.cursor()
172         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
173         row = c.fetchone()
174         return self._removeChanged(file, row)
175
176     def refreshHash(self, hash):
177         """Refresh the publishing time all files with a hash."""
178         c = self.conn.cursor()
179         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
180         c.close()
181     
182     def expiredHashes(self, expireAfter):
183         """Find files that need refreshing after expireAfter seconds.
184         
185         For each hash that needs refreshing, finds all the files with that hash.
186         If the file has changed or is missing, it is removed from the table.
187         
188         @return: dictionary with keys the hashes, values a list of FilePaths
189         """
190         t = datetime.now() - timedelta(seconds=expireAfter)
191         
192         # First find the hashes that need refreshing
193         c = self.conn.cursor()
194         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
195         row = c.fetchone()
196         expired = {}
197         while row:
198             res = expired.setdefault(row['hash'], {})
199             res['hashID'] = row['hashID']
200             res['hash'] = row['hash']
201             res['pieces'] = row['pieces']
202             row = c.fetchone()
203
204         # Make sure there are still valid files for each hash
205         for hash in expired.values():
206             valid = False
207             c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
208             row = c.fetchone()
209             while row:
210                 res = self._removeChanged(FilePath(row['path']), row)
211                 if res:
212                     valid = True
213                 row = c.fetchone()
214             if not valid:
215                 del expired[hash['hash']]
216                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
217                 
218         self.conn.commit()
219         c.close()
220         
221         return expired
222         
223     def removeUntrackedFiles(self, dirs):
224         """Find files that are no longer tracked and so should be removed.
225         
226         Also removes the entries from the table.
227         
228         @return: list of files that were removed
229         """
230         assert len(dirs) >= 1
231         newdirs = []
232         sql = "WHERE"
233         for dir in dirs:
234             newdirs.append(dir.child('*').path)
235             sql += " path NOT GLOB ? AND"
236         sql = sql[:-4]
237
238         c = self.conn.cursor()
239         c.execute("SELECT path FROM files " + sql, newdirs)
240         row = c.fetchone()
241         removed = []
242         while row:
243             removed.append(FilePath(row['path']))
244             row = c.fetchone()
245
246         if removed:
247             c.execute("DELETE FROM files " + sql, newdirs)
248         self.conn.commit()
249         return removed
250     
251     def close(self):
252         self.conn.close()
253
254 class TestDB(unittest.TestCase):
255     """Tests for the khashmir database."""
256     
257     timeout = 5
258     db = FilePath('/tmp/khashmir.db')
259     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
260     directory = FilePath('/tmp/apt-dht/')
261     file = FilePath('/tmp/apt-dht/khashmir.test')
262     testfile = 'tmp/khashmir.test'
263     dirs = [FilePath('/tmp/apt-dht/top1'),
264             FilePath('/tmp/apt-dht/top2/sub1'),
265             FilePath('/tmp/apt-dht/top2/sub2/')]
266
267     def setUp(self):
268         if not self.file.parent().exists():
269             self.file.parent().makedirs()
270         self.file.setContent('fgfhds')
271         self.file.touch()
272         self.store = DB(self.db)
273         self.store.storeFile(self.file, self.hash)
274
275     def test_openExistingDB(self):
276         self.store.close()
277         self.store = None
278         sleep(1)
279         self.store = DB(self.db)
280         res = self.store.isUnchanged(self.file)
281         self.failUnless(res)
282
283     def test_getFile(self):
284         res = self.store.getFile(self.file)
285         self.failUnless(res)
286         self.failUnlessEqual(res['hash'], self.hash)
287         
288     def test_lookupHash(self):
289         res = self.store.lookupHash(self.hash)
290         self.failUnless(res)
291         self.failUnlessEqual(len(res), 1)
292         self.failUnlessEqual(res[0]['path'].path, self.file.path)
293         
294     def test_isUnchanged(self):
295         res = self.store.isUnchanged(self.file)
296         self.failUnless(res)
297         sleep(2)
298         self.file.touch()
299         res = self.store.isUnchanged(self.file)
300         self.failUnless(res == False)
301         res = self.store.isUnchanged(self.file)
302         self.failUnless(res is None)
303         
304     def test_expiry(self):
305         res = self.store.expiredHashes(1)
306         self.failUnlessEqual(len(res.keys()), 0)
307         sleep(2)
308         res = self.store.expiredHashes(1)
309         self.failUnlessEqual(len(res.keys()), 1)
310         self.failUnlessEqual(res.keys()[0], self.hash)
311         self.store.refreshHash(self.hash)
312         res = self.store.expiredHashes(1)
313         self.failUnlessEqual(len(res.keys()), 0)
314         
315     def build_dirs(self):
316         for dir in self.dirs:
317             file = dir.preauthChild(self.testfile)
318             if not file.parent().exists():
319                 file.parent().makedirs()
320             file.setContent(file.path)
321             file.touch()
322             self.store.storeFile(file, self.hash)
323     
324     def test_multipleHashes(self):
325         self.build_dirs()
326         res = self.store.expiredHashes(1)
327         self.failUnlessEqual(len(res.keys()), 0)
328         res = self.store.lookupHash(self.hash)
329         self.failUnless(res)
330         self.failUnlessEqual(len(res), 4)
331         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
332         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
333         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
334         sleep(2)
335         res = self.store.expiredHashes(1)
336         self.failUnlessEqual(len(res.keys()), 1)
337         self.failUnlessEqual(res.keys()[0], self.hash)
338         self.store.refreshHash(self.hash)
339         res = self.store.expiredHashes(1)
340         self.failUnlessEqual(len(res.keys()), 0)
341     
342     def test_removeUntracked(self):
343         self.build_dirs()
344         res = self.store.removeUntrackedFiles(self.dirs)
345         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
346         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
347         res = self.store.removeUntrackedFiles(self.dirs)
348         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
349         res = self.store.removeUntrackedFiles(self.dirs[1:])
350         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
351         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
352         res = self.store.removeUntrackedFiles(self.dirs[:1])
353         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
354         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
355         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
356         
357     def tearDown(self):
358         self.directory.remove()
359         self.store.close()
360         self.db.remove()
361