Documented the db module.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 """An sqlite database for storing persistent files and hashes."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os
9
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
12
13 assert sqlite.version_info >= (2, 1)
14
15 class DBExcept(Exception):
16     """An error occurred in accessing the database."""
17     pass
18
19 class khash(str):
20     """Dummy class to convert all hashes to base64 for storing in the DB."""
21
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
27
28 class DB:
29     """An sqlite database for storing persistent files and hashes.
30     
31     @type db: L{twisted.python.filepath.FilePath}
32     @ivar db: the database file to use
33     @type conn: L{pysqlite2.dbapi2.Connection}
34     @ivar conn: an open connection to the sqlite database
35     """
36     
37     def __init__(self, db):
38         """Load or create the database file.
39         
40         @type db: L{twisted.python.filepath.FilePath}
41         @param db: the database file to use
42         """
43         self.db = db
44         self.db.restat(False)
45         if self.db.exists():
46             self._loadDB()
47         else:
48             self._createNewDB()
49         self.conn.text_factory = str
50         self.conn.row_factory = sqlite.Row
51         
52     def _loadDB(self):
53         """Open a new connection to the existing database file"""
54         try:
55             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
56         except:
57             import traceback
58             raise DBExcept, "Couldn't open DB", traceback.format_exc()
59         
60     def _createNewDB(self):
61         """Open a connection to a new database and create the necessary tables."""
62         if not self.db.parent().exists():
63             self.db.parent().makedirs()
64         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
65         c = self.conn.cursor()
66         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
67                                       "size NUMBER, mtime NUMBER)")
68         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
69                                        "hash KHASH UNIQUE, pieces KHASH, " +
70                                        "piecehash KHASH, refreshed TIMESTAMP)")
71         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
72         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
73         c.close()
74         self.conn.commit()
75
76     def _removeChanged(self, file, row):
77         """If the file has changed or is missing, remove it from the DB.
78         
79         @type file: L{twisted.python.filepath.FilePath}
80         @param file: the file to check
81         @type row: C{dictionary}-like object
82         @param row: contains the expected 'size' and 'mtime' of the file
83         @rtype: C{boolean}
84         @return: True if the file is unchanged, False if it is changed,
85             and None if it is missing
86         """
87         res = None
88         if row:
89             file.restat(False)
90             if file.exists():
91                 # Compare the current with the expected file properties
92                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
93             if not res:
94                 # Remove the file from the database
95                 c = self.conn.cursor()
96                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
97                 self.conn.commit()
98                 c.close()
99         return res
100         
101     def storeFile(self, file, hash, pieces = ''):
102         """Store or update a file in the database.
103         
104         @type file: L{twisted.python.filepath.FilePath}
105         @param file: the file to check
106         @type hash: C{string}
107         @param hash: the hash of the file
108         @type pieces: C{string}
109         @param pieces: the concatenated list of the hashes of the pieces of
110             the file (optional, defaults to the empty string)
111         @return: True if the hash was not in the database before
112             (so it needs to be added to the DHT)
113         """
114         # Hash the pieces to get the piecehash
115         piecehash = ''
116         if pieces:
117             s = sha.new().update(pieces)
118             piecehash = sha.digest()
119             
120         # Check the database for the hash
121         c = self.conn.cursor()
122         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
123         row = c.fetchone()
124         if row:
125             assert piecehash == row['piecehash']
126             new_hash = False
127             hashID = row['hashID']
128         else:
129             # Add the new hash to the database
130             c = self.conn.cursor()
131             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
132                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
133             self.conn.commit()
134             new_hash = True
135             hashID = c.lastrowid
136
137         # Add the file to the database
138         file.restat()
139         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
140                   (file.path, hashID, file.getsize(), file.getmtime()))
141         self.conn.commit()
142         c.close()
143         
144         return new_hash
145         
146     def getFile(self, file):
147         """Get a file from the database.
148         
149         If it has changed or is missing, it is removed from the database.
150         
151         @type file: L{twisted.python.filepath.FilePath}
152         @param file: the file to check
153         @return: dictionary of info for the file, False if changed, or
154             None if not in database or missing
155         """
156         c = self.conn.cursor()
157         c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
158         row = c.fetchone()
159         res = None
160         if row:
161             res = self._removeChanged(file, row)
162             if res:
163                 res = {}
164                 res['hash'] = row['hash']
165                 res['size'] = row['size']
166                 res['pieces'] = row['pieces']
167         c.close()
168         return res
169         
170     def lookupHash(self, hash, filesOnly = False):
171         """Find a file by hash in the database.
172         
173         If any found files have changed or are missing, they are removed
174         from the database. If filesOnly is False then it will also look for
175         piece string hashes if no files can be found.
176         
177         @return: list of dictionaries of info for the found files
178         """
179         # Try to find the hash in the files table
180         c = self.conn.cursor()
181         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
182         row = c.fetchone()
183         files = []
184         while row:
185             # Save the file to the list of found files
186             file = FilePath(row['path'])
187             res = self._removeChanged(file, row)
188             if res:
189                 res = {}
190                 res['path'] = file
191                 res['size'] = row['size']
192                 res['refreshed'] = row['refreshed']
193                 res['pieces'] = row['pieces']
194                 files.append(res)
195             row = c.fetchone()
196             
197         if not filesOnly and not files:
198             # No files were found, so check the piecehashes as well
199             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
200             row = c.fetchone()
201             if row:
202                 res = {}
203                 res['refreshed'] = row['refreshed']
204                 res['pieces'] = row['pieces']
205                 files.append(res)
206
207         c.close()
208         return files
209         
210     def isUnchanged(self, file):
211         """Check if a file in the file system has changed.
212         
213         If it has changed, it is removed from the database.
214         
215         @return: True if unchanged, False if changed, None if not in database
216         """
217         c = self.conn.cursor()
218         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
219         row = c.fetchone()
220         return self._removeChanged(file, row)
221
222     def refreshHash(self, hash):
223         """Refresh the publishing time of a hash."""
224         c = self.conn.cursor()
225         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
226         c.close()
227     
228     def expiredHashes(self, expireAfter):
229         """Find files that need refreshing after expireAfter seconds.
230         
231         For each hash that needs refreshing, finds all the files with that hash.
232         If the file has changed or is missing, it is removed from the table.
233         
234         @return: dictionary with keys the hashes, values a list of FilePaths
235         """
236         t = datetime.now() - timedelta(seconds=expireAfter)
237         
238         # Find all the hashes that need refreshing
239         c = self.conn.cursor()
240         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
241         row = c.fetchone()
242         expired = {}
243         while row:
244             res = expired.setdefault(row['hash'], {})
245             res['hashID'] = row['hashID']
246             res['hash'] = row['hash']
247             res['pieces'] = row['pieces']
248             row = c.fetchone()
249
250         # Make sure there are still valid files for each hash
251         for hash in expired.values():
252             valid = False
253             c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
254             row = c.fetchone()
255             while row:
256                 res = self._removeChanged(FilePath(row['path']), row)
257                 if res:
258                     valid = True
259                 row = c.fetchone()
260             if not valid:
261                 # Remove hashes for which no files are still available
262                 del expired[hash['hash']]
263                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
264                 
265         self.conn.commit()
266         c.close()
267         
268         return expired
269         
270     def removeUntrackedFiles(self, dirs):
271         """Remove files that are no longer tracked by the program.
272         
273         @type dirs: C{list} of L{twisted.python.filepath.FilePath}
274         @param dirs: a list of the directories that we are tracking
275         @return: list of files that were removed
276         """
277         assert len(dirs) >= 1
278         
279         # Create a list of globs and an SQL statement for the directories
280         newdirs = []
281         sql = "WHERE"
282         for dir in dirs:
283             newdirs.append(dir.child('*').path)
284             sql += " path NOT GLOB ? AND"
285         sql = sql[:-4]
286
287         # Get a listing of all the files that will be removed
288         c = self.conn.cursor()
289         c.execute("SELECT path FROM files " + sql, newdirs)
290         row = c.fetchone()
291         removed = []
292         while row:
293             removed.append(FilePath(row['path']))
294             row = c.fetchone()
295
296         # Delete all the removed files from the database
297         if removed:
298             c.execute("DELETE FROM files " + sql, newdirs)
299         self.conn.commit()
300
301         return removed
302     
303     def close(self):
304         """Close the database connection."""
305         self.conn.close()
306
307 class TestDB(unittest.TestCase):
308     """Tests for the khashmir database."""
309     
310     timeout = 5
311     db = FilePath('/tmp/khashmir.db')
312     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
313     directory = FilePath('/tmp/apt-dht/')
314     file = FilePath('/tmp/apt-dht/khashmir.test')
315     testfile = 'tmp/khashmir.test'
316     dirs = [FilePath('/tmp/apt-dht/top1'),
317             FilePath('/tmp/apt-dht/top2/sub1'),
318             FilePath('/tmp/apt-dht/top2/sub2/')]
319
320     def setUp(self):
321         if not self.file.parent().exists():
322             self.file.parent().makedirs()
323         self.file.setContent('fgfhds')
324         self.file.touch()
325         self.store = DB(self.db)
326         self.store.storeFile(self.file, self.hash)
327
328     def test_openExistingDB(self):
329         """Tests opening an existing database."""
330         self.store.close()
331         self.store = None
332         sleep(1)
333         self.store = DB(self.db)
334         res = self.store.isUnchanged(self.file)
335         self.failUnless(res)
336
337     def test_getFile(self):
338         """Tests retrieving a file from the database."""
339         res = self.store.getFile(self.file)
340         self.failUnless(res)
341         self.failUnlessEqual(res['hash'], self.hash)
342         
343     def test_lookupHash(self):
344         """Tests looking up a hash in the database."""
345         res = self.store.lookupHash(self.hash)
346         self.failUnless(res)
347         self.failUnlessEqual(len(res), 1)
348         self.failUnlessEqual(res[0]['path'].path, self.file.path)
349         
350     def test_isUnchanged(self):
351         """Tests checking if a file in the database is unchanged."""
352         res = self.store.isUnchanged(self.file)
353         self.failUnless(res)
354         sleep(2)
355         self.file.touch()
356         res = self.store.isUnchanged(self.file)
357         self.failUnless(res == False)
358         res = self.store.isUnchanged(self.file)
359         self.failUnless(res is None)
360         
361     def test_expiry(self):
362         """Tests retrieving the files from the database that have expired."""
363         res = self.store.expiredHashes(1)
364         self.failUnlessEqual(len(res.keys()), 0)
365         sleep(2)
366         res = self.store.expiredHashes(1)
367         self.failUnlessEqual(len(res.keys()), 1)
368         self.failUnlessEqual(res.keys()[0], self.hash)
369         self.store.refreshHash(self.hash)
370         res = self.store.expiredHashes(1)
371         self.failUnlessEqual(len(res.keys()), 0)
372         
373     def build_dirs(self):
374         for dir in self.dirs:
375             file = dir.preauthChild(self.testfile)
376             if not file.parent().exists():
377                 file.parent().makedirs()
378             file.setContent(file.path)
379             file.touch()
380             self.store.storeFile(file, self.hash)
381     
382     def test_multipleHashes(self):
383         """Tests looking up a hash with multiple files in the database."""
384         self.build_dirs()
385         res = self.store.expiredHashes(1)
386         self.failUnlessEqual(len(res.keys()), 0)
387         res = self.store.lookupHash(self.hash)
388         self.failUnless(res)
389         self.failUnlessEqual(len(res), 4)
390         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
391         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
392         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
393         sleep(2)
394         res = self.store.expiredHashes(1)
395         self.failUnlessEqual(len(res.keys()), 1)
396         self.failUnlessEqual(res.keys()[0], self.hash)
397         self.store.refreshHash(self.hash)
398         res = self.store.expiredHashes(1)
399         self.failUnlessEqual(len(res.keys()), 0)
400     
401     def test_removeUntracked(self):
402         """Tests removing untracked files from the database."""
403         self.build_dirs()
404         res = self.store.removeUntrackedFiles(self.dirs)
405         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
406         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
407         res = self.store.removeUntrackedFiles(self.dirs)
408         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
409         res = self.store.removeUntrackedFiles(self.dirs[1:])
410         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
411         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
412         res = self.store.removeUntrackedFiles(self.dirs[:1])
413         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
414         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
415         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
416         
417     def tearDown(self):
418         self.directory.remove()
419         self.store.close()
420         self.db.remove()
421