44e692b416d92e47122deee6ca9fc9788bdd697e
[quix0rs-apt-p2p.git] / apt_p2p / db.py
1
2 """An sqlite database for storing persistent files and hashes."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os, sha
9
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
12
13 assert sqlite.version_info >= (2, 1)
14
15 class DBExcept(Exception):
16     """An error occurred in accessing the database."""
17     pass
18
19 class khash(str):
20     """Dummy class to convert all hashes to base64 for storing in the DB."""
21
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
27
28 class DB:
29     """An sqlite database for storing persistent files and hashes.
30     
31     @type db: L{twisted.python.filepath.FilePath}
32     @ivar db: the database file to use
33     @type conn: L{pysqlite2.dbapi2.Connection}
34     @ivar conn: an open connection to the sqlite database
35     """
36     
37     def __init__(self, db):
38         """Load or create the database file.
39         
40         @type db: L{twisted.python.filepath.FilePath}
41         @param db: the database file to use
42         """
43         self.db = db
44         self.db.restat(False)
45         if self.db.exists():
46             self._loadDB()
47         else:
48             self._createNewDB()
49         self.conn.text_factory = str
50         self.conn.row_factory = sqlite.Row
51         
52     #{ DB Functions
53     def _loadDB(self):
54         """Open a new connection to the existing database file"""
55         try:
56             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
57         except:
58             import traceback
59             raise DBExcept, "Couldn't open DB", traceback.format_exc()
60         
61     def _createNewDB(self):
62         """Open a connection to a new database and create the necessary tables."""
63         if not self.db.parent().exists():
64             self.db.parent().makedirs()
65         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
66         c = self.conn.cursor()
67         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
68                                       "size NUMBER, mtime NUMBER)")
69         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
70                                        "hash KHASH UNIQUE, pieces KHASH, " +
71                                        "piecehash KHASH, refreshed TIMESTAMP)")
72         c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
73         c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
74         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
75         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
76         c.close()
77         self.conn.commit()
78
79     def close(self):
80         """Close the database connection."""
81         self.conn.close()
82
83     #{ Files and Hashes
84     def _removeChanged(self, file, row):
85         """If the file has changed or is missing, remove it from the DB.
86         
87         @type file: L{twisted.python.filepath.FilePath}
88         @param file: the file to check
89         @type row: C{dictionary}-like object
90         @param row: contains the expected 'size' and 'mtime' of the file
91         @rtype: C{boolean}
92         @return: True if the file is unchanged, False if it is changed,
93             and None if it is missing
94         """
95         res = None
96         if row:
97             file.restat(False)
98             if file.exists():
99                 # Compare the current with the expected file properties
100                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
101             if not res:
102                 # Remove the file from the database
103                 c = self.conn.cursor()
104                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
105                 self.conn.commit()
106                 c.close()
107         return res
108         
109     def storeFile(self, file, hash, pieces = ''):
110         """Store or update a file in the database.
111         
112         @type file: L{twisted.python.filepath.FilePath}
113         @param file: the file to check
114         @type hash: C{string}
115         @param hash: the hash of the file
116         @type pieces: C{string}
117         @param pieces: the concatenated list of the hashes of the pieces of
118             the file (optional, defaults to the empty string)
119         @return: True if the hash was not in the database before
120             (so it needs to be added to the DHT)
121         """
122         # Hash the pieces to get the piecehash
123         piecehash = ''
124         if pieces:
125             piecehash = sha.new(pieces).digest()
126             
127         # Check the database for the hash
128         c = self.conn.cursor()
129         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
130         row = c.fetchone()
131         if row:
132             assert piecehash == row['piecehash']
133             new_hash = False
134             hashID = row['hashID']
135         else:
136             # Add the new hash to the database
137             c = self.conn.cursor()
138             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
139                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
140             self.conn.commit()
141             new_hash = True
142             hashID = c.lastrowid
143
144         # Add the file to the database
145         file.restat()
146         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
147                   (file.path, hashID, file.getsize(), file.getmtime()))
148         self.conn.commit()
149         c.close()
150         
151         return new_hash
152         
153     def getFile(self, file):
154         """Get a file from the database.
155         
156         If it has changed or is missing, it is removed from the database.
157         
158         @type file: L{twisted.python.filepath.FilePath}
159         @param file: the file to check
160         @return: dictionary of info for the file, False if changed, or
161             None if not in database or missing
162         """
163         c = self.conn.cursor()
164         c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
165         row = c.fetchone()
166         res = None
167         if row:
168             res = self._removeChanged(file, row)
169             if res:
170                 res = {}
171                 res['hash'] = row['hash']
172                 res['size'] = row['size']
173                 res['pieces'] = row['pieces']
174         c.close()
175         return res
176         
177     def lookupHash(self, hash, filesOnly = False):
178         """Find a file by hash in the database.
179         
180         If any found files have changed or are missing, they are removed
181         from the database. If filesOnly is False then it will also look for
182         piece string hashes if no files can be found.
183         
184         @return: list of dictionaries of info for the found files
185         """
186         # Try to find the hash in the files table
187         c = self.conn.cursor()
188         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
189         row = c.fetchone()
190         files = []
191         while row:
192             # Save the file to the list of found files
193             file = FilePath(row['path'])
194             res = self._removeChanged(file, row)
195             if res:
196                 res = {}
197                 res['path'] = file
198                 res['size'] = row['size']
199                 res['refreshed'] = row['refreshed']
200                 res['pieces'] = row['pieces']
201                 files.append(res)
202             row = c.fetchone()
203             
204         if not filesOnly and not files:
205             # No files were found, so check the piecehashes as well
206             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
207             row = c.fetchone()
208             if row:
209                 res = {}
210                 res['refreshed'] = row['refreshed']
211                 res['pieces'] = row['pieces']
212                 files.append(res)
213
214         c.close()
215         return files
216         
217     def isUnchanged(self, file):
218         """Check if a file in the file system has changed.
219         
220         If it has changed, it is removed from the database.
221         
222         @return: True if unchanged, False if changed, None if not in database
223         """
224         c = self.conn.cursor()
225         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
226         row = c.fetchone()
227         return self._removeChanged(file, row)
228
229     def refreshHash(self, hash):
230         """Refresh the publishing time of a hash."""
231         c = self.conn.cursor()
232         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
233         c.close()
234     
235     def expiredHashes(self, expireAfter):
236         """Find files that need refreshing after expireAfter seconds.
237         
238         For each hash that needs refreshing, finds all the files with that hash.
239         If the file has changed or is missing, it is removed from the table.
240         
241         @return: dictionary with keys the hashes, values a list of FilePaths
242         """
243         t = datetime.now() - timedelta(seconds=expireAfter)
244         
245         # Find all the hashes that need refreshing
246         c = self.conn.cursor()
247         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
248         row = c.fetchone()
249         expired = {}
250         while row:
251             res = expired.setdefault(row['hash'], {})
252             res['hashID'] = row['hashID']
253             res['hash'] = row['hash']
254             res['pieces'] = row['pieces']
255             row = c.fetchone()
256
257         # Make sure there are still valid files for each hash
258         for hash in expired.values():
259             valid = False
260             c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
261             row = c.fetchone()
262             while row:
263                 res = self._removeChanged(FilePath(row['path']), row)
264                 if res:
265                     valid = True
266                 row = c.fetchone()
267             if not valid:
268                 # Remove hashes for which no files are still available
269                 del expired[hash['hash']]
270                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
271                 
272         self.conn.commit()
273         c.close()
274         
275         return expired
276         
277     def removeUntrackedFiles(self, dirs):
278         """Remove files that are no longer tracked by the program.
279         
280         @type dirs: C{list} of L{twisted.python.filepath.FilePath}
281         @param dirs: a list of the directories that we are tracking
282         @return: list of files that were removed
283         """
284         assert len(dirs) >= 1
285         
286         # Create a list of globs and an SQL statement for the directories
287         newdirs = []
288         sql = "WHERE"
289         for dir in dirs:
290             newdirs.append(dir.child('*').path)
291             sql += " path NOT GLOB ? AND"
292         sql = sql[:-4]
293
294         # Get a listing of all the files that will be removed
295         c = self.conn.cursor()
296         c.execute("SELECT path FROM files " + sql, newdirs)
297         row = c.fetchone()
298         removed = []
299         while row:
300             removed.append(FilePath(row['path']))
301             row = c.fetchone()
302
303         # Delete all the removed files from the database
304         if removed:
305             c.execute("DELETE FROM files " + sql, newdirs)
306         self.conn.commit()
307
308         return removed
309     
310     #{ Statistics
311     def dbStats(self):
312         """Count the total number of files and hashes in the database.
313         
314         @rtype: (C{int}, C{int})
315         @return: the number of distinct hashes and total files in the database
316         """
317         c = self.conn.cursor()
318         c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
319         hashes = 0
320         row = c.fetchone()
321         if row:
322             hashes = row[0]
323         c.execute("SELECT COUNT(path) as num_files FROM files")
324         files = 0
325         row = c.fetchone()
326         if row:
327             files = row[0]
328         return hashes, files
329
330     def getStats(self):
331         """Retrieve the saved statistics from the DB.
332         
333         @return: dictionary of statistics
334         """
335         c = self.conn.cursor()
336         c.execute("SELECT param, value FROM stats")
337         row = c.fetchone()
338         stats = {}
339         while row:
340             stats[row['param']] = row['value']
341             row = c.fetchone()
342         c.close()
343         return stats
344         
345     def saveStats(self, stats):
346         """Save the statistics to the DB."""
347         c = self.conn.cursor()
348         for param in stats:
349             c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
350                       (param, stats[param]))
351             self.conn.commit()
352         c.close()
353         
354 class TestDB(unittest.TestCase):
355     """Tests for the khashmir database."""
356     
357     timeout = 5
358     db = FilePath('/tmp/khashmir.db')
359     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
360     directory = FilePath('/tmp/apt-p2p/')
361     file = FilePath('/tmp/apt-p2p/khashmir.test')
362     testfile = 'tmp/khashmir.test'
363     dirs = [FilePath('/tmp/apt-p2p/top1'),
364             FilePath('/tmp/apt-p2p/top2/sub1'),
365             FilePath('/tmp/apt-p2p/top2/sub2/')]
366
367     def setUp(self):
368         if not self.file.parent().exists():
369             self.file.parent().makedirs()
370         self.file.setContent('fgfhds')
371         self.file.touch()
372         self.store = DB(self.db)
373         self.store.storeFile(self.file, self.hash)
374
375     def test_openExistingDB(self):
376         """Tests opening an existing database."""
377         self.store.close()
378         self.store = None
379         sleep(1)
380         self.store = DB(self.db)
381         res = self.store.isUnchanged(self.file)
382         self.failUnless(res)
383
384     def test_getFile(self):
385         """Tests retrieving a file from the database."""
386         res = self.store.getFile(self.file)
387         self.failUnless(res)
388         self.failUnlessEqual(res['hash'], self.hash)
389         
390     def test_lookupHash(self):
391         """Tests looking up a hash in the database."""
392         res = self.store.lookupHash(self.hash)
393         self.failUnless(res)
394         self.failUnlessEqual(len(res), 1)
395         self.failUnlessEqual(res[0]['path'].path, self.file.path)
396         
397     def test_isUnchanged(self):
398         """Tests checking if a file in the database is unchanged."""
399         res = self.store.isUnchanged(self.file)
400         self.failUnless(res)
401         sleep(2)
402         self.file.touch()
403         res = self.store.isUnchanged(self.file)
404         self.failUnless(res == False)
405         res = self.store.isUnchanged(self.file)
406         self.failUnless(res is None)
407         
408     def test_expiry(self):
409         """Tests retrieving the files from the database that have expired."""
410         res = self.store.expiredHashes(1)
411         self.failUnlessEqual(len(res.keys()), 0)
412         sleep(2)
413         res = self.store.expiredHashes(1)
414         self.failUnlessEqual(len(res.keys()), 1)
415         self.failUnlessEqual(res.keys()[0], self.hash)
416         self.store.refreshHash(self.hash)
417         res = self.store.expiredHashes(1)
418         self.failUnlessEqual(len(res.keys()), 0)
419         
420     def build_dirs(self):
421         for dir in self.dirs:
422             file = dir.preauthChild(self.testfile)
423             if not file.parent().exists():
424                 file.parent().makedirs()
425             file.setContent(file.path)
426             file.touch()
427             self.store.storeFile(file, self.hash)
428     
429     def test_multipleHashes(self):
430         """Tests looking up a hash with multiple files in the database."""
431         self.build_dirs()
432         res = self.store.expiredHashes(1)
433         self.failUnlessEqual(len(res.keys()), 0)
434         res = self.store.lookupHash(self.hash)
435         self.failUnless(res)
436         self.failUnlessEqual(len(res), 4)
437         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
438         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
439         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
440         sleep(2)
441         res = self.store.expiredHashes(1)
442         self.failUnlessEqual(len(res.keys()), 1)
443         self.failUnlessEqual(res.keys()[0], self.hash)
444         self.store.refreshHash(self.hash)
445         res = self.store.expiredHashes(1)
446         self.failUnlessEqual(len(res.keys()), 0)
447     
448     def test_removeUntracked(self):
449         """Tests removing untracked files from the database."""
450         self.build_dirs()
451         res = self.store.removeUntrackedFiles(self.dirs)
452         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
453         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
454         res = self.store.removeUntrackedFiles(self.dirs)
455         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
456         res = self.store.removeUntrackedFiles(self.dirs[1:])
457         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
458         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
459         res = self.store.removeUntrackedFiles(self.dirs[:1])
460         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
461         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
462         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
463         
464     def tearDown(self):
465         self.directory.remove()
466         self.store.close()
467         self.db.remove()
468