Fix some errors in the new twisted HTTP client's connectionLost() methods.
[quix0rs-apt-p2p.git] / apt_p2p / db.py
1
2 """An sqlite database for storing persistent files and hashes."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os, sha
9
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
12
13 assert sqlite.version_info >= (2, 1)
14
15 class DBExcept(Exception):
16     """An error occurred in accessing the database."""
17     pass
18
19 class khash(str):
20     """Dummy class to convert all hashes to base64 for storing in the DB."""
21
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
27
28 class DB:
29     """An sqlite database for storing persistent files and hashes.
30     
31     @type db: L{twisted.python.filepath.FilePath}
32     @ivar db: the database file to use
33     @type conn: L{pysqlite2.dbapi2.Connection}
34     @ivar conn: an open connection to the sqlite database
35     """
36     
37     def __init__(self, db):
38         """Load or create the database file.
39         
40         @type db: L{twisted.python.filepath.FilePath}
41         @param db: the database file to use
42         """
43         self.db = db
44         self.db.restat(False)
45         if self.db.exists():
46             self._loadDB()
47         else:
48             self._createNewDB()
49         self.conn.text_factory = str
50         self.conn.row_factory = sqlite.Row
51         
52     #{ DB Functions
53     def _loadDB(self):
54         """Open a new connection to the existing database file"""
55         try:
56             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
57         except:
58             import traceback
59             raise DBExcept, "Couldn't open DB", traceback.format_exc()
60         
61     def _createNewDB(self):
62         """Open a connection to a new database and create the necessary tables."""
63         if not self.db.parent().exists():
64             self.db.parent().makedirs()
65         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
66         c = self.conn.cursor()
67         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
68                                       "dht BOOL, size NUMBER, mtime NUMBER)")
69         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
70                                        "hash KHASH UNIQUE, pieces KHASH, " +
71                                        "piecehash KHASH, refreshed TIMESTAMP)")
72         c.execute("CREATE TABLE stats (param TEXT PRIMARY KEY UNIQUE, value NUMERIC)")
73         c.execute("CREATE INDEX hashes_hash ON hashes(hash)")
74         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
75         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
76         c.close()
77         self.conn.commit()
78
79     def close(self):
80         """Close the database connection."""
81         self.conn.close()
82
83     #{ Files and Hashes
84     def _removeChanged(self, file, row):
85         """If the file has changed or is missing, remove it from the DB.
86         
87         @type file: L{twisted.python.filepath.FilePath}
88         @param file: the file to check
89         @type row: C{dictionary}-like object
90         @param row: contains the expected 'size' and 'mtime' of the file
91         @rtype: C{boolean}
92         @return: True if the file is unchanged, False if it is changed,
93             and None if it is missing
94         """
95         res = None
96         if row:
97             file.restat(False)
98             if file.exists():
99                 # Compare the current with the expected file properties
100                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
101             if not res:
102                 # Remove the file from the database
103                 c = self.conn.cursor()
104                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
105                 self.conn.commit()
106                 c.close()
107         return res
108         
109     def storeFile(self, file, hash, dht = True, pieces = ''):
110         """Store or update a file in the database.
111         
112         @type file: L{twisted.python.filepath.FilePath}
113         @param file: the file to check
114         @type hash: C{string}
115         @param hash: the hash of the file
116         @param dht: whether the file is added to the DHT
117             (optional, defaults to true)
118         @type pieces: C{string}
119         @param pieces: the concatenated list of the hashes of the pieces of
120             the file (optional, defaults to the empty string)
121         @return: True if the hash was not in the database before
122             (so it needs to be added to the DHT)
123         """
124         # Hash the pieces to get the piecehash
125         piecehash = ''
126         if pieces:
127             piecehash = sha.new(pieces).digest()
128             
129         # Check the database for the hash
130         c = self.conn.cursor()
131         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
132         row = c.fetchone()
133         if row:
134             assert piecehash == row['piecehash']
135             new_hash = False
136             hashID = row['hashID']
137         else:
138             # Add the new hash to the database
139             c = self.conn.cursor()
140             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
141                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
142             self.conn.commit()
143             new_hash = True
144             hashID = c.lastrowid
145
146         # Add the file to the database
147         file.restat()
148         c.execute("INSERT OR REPLACE INTO files (path, hashID, dht, size, mtime) VALUES (?, ?, ?, ?, ?)",
149                   (file.path, hashID, dht, file.getsize(), file.getmtime()))
150         self.conn.commit()
151         c.close()
152         
153         return new_hash
154         
155     def getFile(self, file):
156         """Get a file from the database.
157         
158         If it has changed or is missing, it is removed from the database.
159         
160         @type file: L{twisted.python.filepath.FilePath}
161         @param file: the file to check
162         @return: dictionary of info for the file, False if changed, or
163             None if not in database or missing
164         """
165         c = self.conn.cursor()
166         c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
167         row = c.fetchone()
168         res = None
169         if row:
170             res = self._removeChanged(file, row)
171             if res:
172                 res = {}
173                 res['hash'] = row['hash']
174                 res['size'] = row['size']
175                 res['pieces'] = row['pieces']
176         c.close()
177         return res
178         
179     def lookupHash(self, hash, filesOnly = False):
180         """Find a file by hash in the database.
181         
182         If any found files have changed or are missing, they are removed
183         from the database. If filesOnly is False then it will also look for
184         piece string hashes if no files can be found.
185         
186         @return: list of dictionaries of info for the found files
187         """
188         # Try to find the hash in the files table
189         c = self.conn.cursor()
190         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
191         row = c.fetchone()
192         files = []
193         while row:
194             # Save the file to the list of found files
195             file = FilePath(row['path'])
196             res = self._removeChanged(file, row)
197             if res:
198                 res = {}
199                 res['path'] = file
200                 res['size'] = row['size']
201                 res['refreshed'] = row['refreshed']
202                 res['pieces'] = row['pieces']
203                 files.append(res)
204             row = c.fetchone()
205             
206         if not filesOnly and not files:
207             # No files were found, so check the piecehashes as well
208             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
209             row = c.fetchone()
210             if row:
211                 res = {}
212                 res['refreshed'] = row['refreshed']
213                 res['pieces'] = row['pieces']
214                 files.append(res)
215
216         c.close()
217         return files
218         
219     def isUnchanged(self, file):
220         """Check if a file in the file system has changed.
221         
222         If it has changed, it is removed from the database.
223         
224         @return: True if unchanged, False if changed, None if not in database
225         """
226         c = self.conn.cursor()
227         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
228         row = c.fetchone()
229         return self._removeChanged(file, row)
230
231     def refreshHash(self, hash):
232         """Refresh the publishing time of a hash."""
233         c = self.conn.cursor()
234         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
235         c.close()
236     
237     def expiredHashes(self, expireAfter):
238         """Find files that need refreshing after expireAfter seconds.
239         
240         For each hash that needs refreshing, finds all the files with that hash.
241         If the file has changed or is missing, it is removed from the table.
242         
243         @return: dictionary with keys the hashes, values a list of FilePaths
244         """
245         t = datetime.now() - timedelta(seconds=expireAfter)
246         
247         # Find all the hashes that need refreshing
248         c = self.conn.cursor()
249         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
250         row = c.fetchone()
251         expired = {}
252         while row:
253             res = expired.setdefault(row['hash'], {})
254             res['hashID'] = row['hashID']
255             res['hash'] = row['hash']
256             res['pieces'] = row['pieces']
257             row = c.fetchone()
258
259         # Make sure there are still valid DHT files for each hash
260         for hash in expired.values():
261             dht = False
262             non_dht = False
263             c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
264             row = c.fetchone()
265             while row:
266                 if row['dht']:
267                     dht = True
268                 else:
269                     non_dht = True
270                 row = c.fetchone()
271             if not dht:
272                 # Remove hashes for which no DHT files are still available
273                 del expired[hash['hash']]
274                 if not non_dht:
275                     # Remove hashes for which no files are still available
276                     c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
277                 else:
278                     # There are still some non-DHT files available, so refresh them
279                     c.execute("UPDATE hashes SET refreshed = ? WHERE hashID = ?",
280                               (datetime.now(), hash['hashID']))
281                 
282         self.conn.commit()
283         c.close()
284         
285         return expired
286         
287     def removeUntrackedFiles(self, dirs):
288         """Remove files that are no longer tracked by the program.
289         
290         @type dirs: C{list} of L{twisted.python.filepath.FilePath}
291         @param dirs: a list of the directories that we are tracking
292         @return: list of files that were removed
293         """
294         assert len(dirs) >= 1
295         
296         # Create a list of globs and an SQL statement for the directories
297         newdirs = []
298         sql = "WHERE"
299         for dir in dirs:
300             newdirs.append(dir.child('*').path)
301             sql += " path NOT GLOB ? AND"
302         sql = sql[:-4]
303
304         # Get a listing of all the files that will be removed
305         c = self.conn.cursor()
306         c.execute("SELECT path FROM files " + sql, newdirs)
307         row = c.fetchone()
308         removed = []
309         while row:
310             removed.append(FilePath(row['path']))
311             row = c.fetchone()
312
313         # Delete all the removed files from the database
314         if removed:
315             c.execute("DELETE FROM files " + sql, newdirs)
316             self.conn.commit()
317         
318         c.execute("SELECT path FROM files")
319         rows = c.fetchall()
320         for row in rows:
321             if not os.path.exists(row['path']):
322                 # Leave hashes, they will be removed on next refresh
323                 c.execute("DELETE FROM files WHERE path = ?", (row['path'], ))
324                 removed.append(FilePath(row['path']))
325         self.conn.commit()
326
327         return removed
328     
329     #{ Statistics
330     def dbStats(self):
331         """Count the total number of files and hashes in the database.
332         
333         @rtype: (C{int}, C{int})
334         @return: the number of distinct hashes and total files in the database
335         """
336         c = self.conn.cursor()
337         c.execute("SELECT COUNT(hash) as num_hashes FROM hashes")
338         hashes = 0
339         row = c.fetchone()
340         if row:
341             hashes = row[0]
342         c.execute("SELECT COUNT(path) as num_files FROM files")
343         files = 0
344         row = c.fetchone()
345         if row:
346             files = row[0]
347         return hashes, files
348
349     def getStats(self):
350         """Retrieve the saved statistics from the DB.
351         
352         @return: dictionary of statistics
353         """
354         c = self.conn.cursor()
355         c.execute("SELECT param, value FROM stats")
356         row = c.fetchone()
357         stats = {}
358         while row:
359             stats[row['param']] = row['value']
360             row = c.fetchone()
361         c.close()
362         return stats
363         
364     def saveStats(self, stats):
365         """Save the statistics to the DB."""
366         c = self.conn.cursor()
367         for param in stats:
368             c.execute("INSERT OR REPLACE INTO stats (param, value) VALUES (?, ?)",
369                       (param, stats[param]))
370             self.conn.commit()
371         c.close()
372         
373 class TestDB(unittest.TestCase):
374     """Tests for the khashmir database."""
375     
376     timeout = 5
377     db = FilePath('/tmp/khashmir.db')
378     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
379     directory = FilePath('/tmp/apt-p2p/')
380     file = FilePath('/tmp/apt-p2p/khashmir.test')
381     testfile = 'tmp/khashmir.test'
382     dirs = [FilePath('/tmp/apt-p2p/top1'),
383             FilePath('/tmp/apt-p2p/top2/sub1'),
384             FilePath('/tmp/apt-p2p/top2/sub2/')]
385
386     def setUp(self):
387         if not self.file.parent().exists():
388             self.file.parent().makedirs()
389         self.file.setContent('fgfhds')
390         self.file.touch()
391         self.store = DB(self.db)
392         self.store.storeFile(self.file, self.hash)
393
394     def test_openExistingDB(self):
395         """Tests opening an existing database."""
396         self.store.close()
397         self.store = None
398         sleep(1)
399         self.store = DB(self.db)
400         res = self.store.isUnchanged(self.file)
401         self.failUnless(res)
402
403     def test_getFile(self):
404         """Tests retrieving a file from the database."""
405         res = self.store.getFile(self.file)
406         self.failUnless(res)
407         self.failUnlessEqual(res['hash'], self.hash)
408         
409     def test_lookupHash(self):
410         """Tests looking up a hash in the database."""
411         res = self.store.lookupHash(self.hash)
412         self.failUnless(res)
413         self.failUnlessEqual(len(res), 1)
414         self.failUnlessEqual(res[0]['path'].path, self.file.path)
415         
416     def test_isUnchanged(self):
417         """Tests checking if a file in the database is unchanged."""
418         res = self.store.isUnchanged(self.file)
419         self.failUnless(res)
420         sleep(2)
421         self.file.touch()
422         res = self.store.isUnchanged(self.file)
423         self.failUnless(res == False)
424         res = self.store.isUnchanged(self.file)
425         self.failUnless(res is None)
426         
427     def test_expiry(self):
428         """Tests retrieving the files from the database that have expired."""
429         res = self.store.expiredHashes(1)
430         self.failUnlessEqual(len(res.keys()), 0)
431         sleep(2)
432         res = self.store.expiredHashes(1)
433         self.failUnlessEqual(len(res.keys()), 1)
434         self.failUnlessEqual(res.keys()[0], self.hash)
435         self.store.refreshHash(self.hash)
436         res = self.store.expiredHashes(1)
437         self.failUnlessEqual(len(res.keys()), 0)
438         
439     def build_dirs(self):
440         for dir in self.dirs:
441             file = dir.preauthChild(self.testfile)
442             if not file.parent().exists():
443                 file.parent().makedirs()
444             file.setContent(file.path)
445             file.touch()
446             self.store.storeFile(file, self.hash)
447     
448     def test_multipleHashes(self):
449         """Tests looking up a hash with multiple files in the database."""
450         self.build_dirs()
451         res = self.store.expiredHashes(1)
452         self.failUnlessEqual(len(res.keys()), 0)
453         res = self.store.lookupHash(self.hash)
454         self.failUnless(res)
455         self.failUnlessEqual(len(res), 4)
456         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
457         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
458         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
459         sleep(2)
460         res = self.store.expiredHashes(1)
461         self.failUnlessEqual(len(res.keys()), 1)
462         self.failUnlessEqual(res.keys()[0], self.hash)
463         self.store.refreshHash(self.hash)
464         res = self.store.expiredHashes(1)
465         self.failUnlessEqual(len(res.keys()), 0)
466     
467     def test_removeUntracked(self):
468         """Tests removing untracked files from the database."""
469         self.build_dirs()
470         file = self.dirs[0].child('test.khashmir')
471         file.setContent(file.path)
472         file.touch()
473         self.store.storeFile(file, self.hash)
474         res = self.store.removeUntrackedFiles(self.dirs)
475         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
476         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
477         res = self.store.removeUntrackedFiles(self.dirs)
478         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
479         file.remove()
480         res = self.store.removeUntrackedFiles(self.dirs)
481         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
482         self.failUnlessEqual(res[0], self.dirs[0].child('test.khashmir'), 'Got removed paths: %r' % res)
483         res = self.store.removeUntrackedFiles(self.dirs[1:])
484         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
485         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
486         res = self.store.removeUntrackedFiles(self.dirs[:1])
487         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
488         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
489         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
490         
491     def tearDown(self):
492         self.directory.remove()
493         self.store.close()
494         self.db.remove()
495