Minor update to the multiple peer downloading (still not working).
[quix0rs-apt-p2p.git] / apt_p2p / db.py
1
2 """An sqlite database for storing persistent files and hashes."""
3
4 from datetime import datetime, timedelta
5 from pysqlite2 import dbapi2 as sqlite
6 from binascii import a2b_base64, b2a_base64
7 from time import sleep
8 import os
9
10 from twisted.python.filepath import FilePath
11 from twisted.trial import unittest
12
13 assert sqlite.version_info >= (2, 1)
14
15 class DBExcept(Exception):
16     """An error occurred in accessing the database."""
17     pass
18
19 class khash(str):
20     """Dummy class to convert all hashes to base64 for storing in the DB."""
21
22 # Initialize the database to work with 'khash' objects (binary strings)
23 sqlite.register_adapter(khash, b2a_base64)
24 sqlite.register_converter("KHASH", a2b_base64)
25 sqlite.register_converter("khash", a2b_base64)
26 sqlite.enable_callback_tracebacks(True)
27
28 class DB:
29     """An sqlite database for storing persistent files and hashes.
30     
31     @type db: L{twisted.python.filepath.FilePath}
32     @ivar db: the database file to use
33     @type conn: L{pysqlite2.dbapi2.Connection}
34     @ivar conn: an open connection to the sqlite database
35     """
36     
37     def __init__(self, db):
38         """Load or create the database file.
39         
40         @type db: L{twisted.python.filepath.FilePath}
41         @param db: the database file to use
42         """
43         self.db = db
44         self.db.restat(False)
45         if self.db.exists():
46             self._loadDB()
47         else:
48             self._createNewDB()
49         self.conn.text_factory = str
50         self.conn.row_factory = sqlite.Row
51         
52     def _loadDB(self):
53         """Open a new connection to the existing database file"""
54         try:
55             self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
56         except:
57             import traceback
58             raise DBExcept, "Couldn't open DB", traceback.format_exc()
59         
60     def _createNewDB(self):
61         """Open a connection to a new database and create the necessary tables."""
62         if not self.db.parent().exists():
63             self.db.parent().makedirs()
64         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
65         c = self.conn.cursor()
66         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
67                                       "size NUMBER, mtime NUMBER)")
68         c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
69                                        "hash KHASH UNIQUE, pieces KHASH, " +
70                                        "piecehash KHASH, refreshed TIMESTAMP)")
71         c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
72         c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
73         c.close()
74         self.conn.commit()
75
76     def _removeChanged(self, file, row):
77         """If the file has changed or is missing, remove it from the DB.
78         
79         @type file: L{twisted.python.filepath.FilePath}
80         @param file: the file to check
81         @type row: C{dictionary}-like object
82         @param row: contains the expected 'size' and 'mtime' of the file
83         @rtype: C{boolean}
84         @return: True if the file is unchanged, False if it is changed,
85             and None if it is missing
86         """
87         res = None
88         if row:
89             file.restat(False)
90             if file.exists():
91                 # Compare the current with the expected file properties
92                 res = (row['size'] == file.getsize() and row['mtime'] == file.getmtime())
93             if not res:
94                 # Remove the file from the database
95                 c = self.conn.cursor()
96                 c.execute("DELETE FROM files WHERE path = ?", (file.path, ))
97                 self.conn.commit()
98                 c.close()
99         return res
100         
101     def storeFile(self, file, hash, pieces = ''):
102         """Store or update a file in the database.
103         
104         @type file: L{twisted.python.filepath.FilePath}
105         @param file: the file to check
106         @type hash: C{string}
107         @param hash: the hash of the file
108         @type pieces: C{string}
109         @param pieces: the concatenated list of the hashes of the pieces of
110             the file (optional, defaults to the empty string)
111         @return: True if the hash was not in the database before
112             (so it needs to be added to the DHT)
113         """
114         # Hash the pieces to get the piecehash
115         piecehash = ''
116         if pieces:
117             piecehash = sha.new(pieces).digest()
118             
119         # Check the database for the hash
120         c = self.conn.cursor()
121         c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
122         row = c.fetchone()
123         if row:
124             assert piecehash == row['piecehash']
125             new_hash = False
126             hashID = row['hashID']
127         else:
128             # Add the new hash to the database
129             c = self.conn.cursor()
130             c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?, ?)",
131                       (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
132             self.conn.commit()
133             new_hash = True
134             hashID = c.lastrowid
135
136         # Add the file to the database
137         file.restat()
138         c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
139                   (file.path, hashID, file.getsize(), file.getmtime()))
140         self.conn.commit()
141         c.close()
142         
143         return new_hash
144         
145     def getFile(self, file):
146         """Get a file from the database.
147         
148         If it has changed or is missing, it is removed from the database.
149         
150         @type file: L{twisted.python.filepath.FilePath}
151         @param file: the file to check
152         @return: dictionary of info for the file, False if changed, or
153             None if not in database or missing
154         """
155         c = self.conn.cursor()
156         c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
157         row = c.fetchone()
158         res = None
159         if row:
160             res = self._removeChanged(file, row)
161             if res:
162                 res = {}
163                 res['hash'] = row['hash']
164                 res['size'] = row['size']
165                 res['pieces'] = row['pieces']
166         c.close()
167         return res
168         
169     def lookupHash(self, hash, filesOnly = False):
170         """Find a file by hash in the database.
171         
172         If any found files have changed or are missing, they are removed
173         from the database. If filesOnly is False then it will also look for
174         piece string hashes if no files can be found.
175         
176         @return: list of dictionaries of info for the found files
177         """
178         # Try to find the hash in the files table
179         c = self.conn.cursor()
180         c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
181         row = c.fetchone()
182         files = []
183         while row:
184             # Save the file to the list of found files
185             file = FilePath(row['path'])
186             res = self._removeChanged(file, row)
187             if res:
188                 res = {}
189                 res['path'] = file
190                 res['size'] = row['size']
191                 res['refreshed'] = row['refreshed']
192                 res['pieces'] = row['pieces']
193                 files.append(res)
194             row = c.fetchone()
195             
196         if not filesOnly and not files:
197             # No files were found, so check the piecehashes as well
198             c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
199             row = c.fetchone()
200             if row:
201                 res = {}
202                 res['refreshed'] = row['refreshed']
203                 res['pieces'] = row['pieces']
204                 files.append(res)
205
206         c.close()
207         return files
208         
209     def isUnchanged(self, file):
210         """Check if a file in the file system has changed.
211         
212         If it has changed, it is removed from the database.
213         
214         @return: True if unchanged, False if changed, None if not in database
215         """
216         c = self.conn.cursor()
217         c.execute("SELECT size, mtime FROM files WHERE path = ?", (file.path, ))
218         row = c.fetchone()
219         return self._removeChanged(file, row)
220
221     def refreshHash(self, hash):
222         """Refresh the publishing time of a hash."""
223         c = self.conn.cursor()
224         c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
225         c.close()
226     
227     def expiredHashes(self, expireAfter):
228         """Find files that need refreshing after expireAfter seconds.
229         
230         For each hash that needs refreshing, finds all the files with that hash.
231         If the file has changed or is missing, it is removed from the table.
232         
233         @return: dictionary with keys the hashes, values a list of FilePaths
234         """
235         t = datetime.now() - timedelta(seconds=expireAfter)
236         
237         # Find all the hashes that need refreshing
238         c = self.conn.cursor()
239         c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
240         row = c.fetchone()
241         expired = {}
242         while row:
243             res = expired.setdefault(row['hash'], {})
244             res['hashID'] = row['hashID']
245             res['hash'] = row['hash']
246             res['pieces'] = row['pieces']
247             row = c.fetchone()
248
249         # Make sure there are still valid files for each hash
250         for hash in expired.values():
251             valid = False
252             c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
253             row = c.fetchone()
254             while row:
255                 res = self._removeChanged(FilePath(row['path']), row)
256                 if res:
257                     valid = True
258                 row = c.fetchone()
259             if not valid:
260                 # Remove hashes for which no files are still available
261                 del expired[hash['hash']]
262                 c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
263                 
264         self.conn.commit()
265         c.close()
266         
267         return expired
268         
269     def removeUntrackedFiles(self, dirs):
270         """Remove files that are no longer tracked by the program.
271         
272         @type dirs: C{list} of L{twisted.python.filepath.FilePath}
273         @param dirs: a list of the directories that we are tracking
274         @return: list of files that were removed
275         """
276         assert len(dirs) >= 1
277         
278         # Create a list of globs and an SQL statement for the directories
279         newdirs = []
280         sql = "WHERE"
281         for dir in dirs:
282             newdirs.append(dir.child('*').path)
283             sql += " path NOT GLOB ? AND"
284         sql = sql[:-4]
285
286         # Get a listing of all the files that will be removed
287         c = self.conn.cursor()
288         c.execute("SELECT path FROM files " + sql, newdirs)
289         row = c.fetchone()
290         removed = []
291         while row:
292             removed.append(FilePath(row['path']))
293             row = c.fetchone()
294
295         # Delete all the removed files from the database
296         if removed:
297             c.execute("DELETE FROM files " + sql, newdirs)
298         self.conn.commit()
299
300         return removed
301     
302     def close(self):
303         """Close the database connection."""
304         self.conn.close()
305
306 class TestDB(unittest.TestCase):
307     """Tests for the khashmir database."""
308     
309     timeout = 5
310     db = FilePath('/tmp/khashmir.db')
311     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
312     directory = FilePath('/tmp/apt-p2p/')
313     file = FilePath('/tmp/apt-p2p/khashmir.test')
314     testfile = 'tmp/khashmir.test'
315     dirs = [FilePath('/tmp/apt-p2p/top1'),
316             FilePath('/tmp/apt-p2p/top2/sub1'),
317             FilePath('/tmp/apt-p2p/top2/sub2/')]
318
319     def setUp(self):
320         if not self.file.parent().exists():
321             self.file.parent().makedirs()
322         self.file.setContent('fgfhds')
323         self.file.touch()
324         self.store = DB(self.db)
325         self.store.storeFile(self.file, self.hash)
326
327     def test_openExistingDB(self):
328         """Tests opening an existing database."""
329         self.store.close()
330         self.store = None
331         sleep(1)
332         self.store = DB(self.db)
333         res = self.store.isUnchanged(self.file)
334         self.failUnless(res)
335
336     def test_getFile(self):
337         """Tests retrieving a file from the database."""
338         res = self.store.getFile(self.file)
339         self.failUnless(res)
340         self.failUnlessEqual(res['hash'], self.hash)
341         
342     def test_lookupHash(self):
343         """Tests looking up a hash in the database."""
344         res = self.store.lookupHash(self.hash)
345         self.failUnless(res)
346         self.failUnlessEqual(len(res), 1)
347         self.failUnlessEqual(res[0]['path'].path, self.file.path)
348         
349     def test_isUnchanged(self):
350         """Tests checking if a file in the database is unchanged."""
351         res = self.store.isUnchanged(self.file)
352         self.failUnless(res)
353         sleep(2)
354         self.file.touch()
355         res = self.store.isUnchanged(self.file)
356         self.failUnless(res == False)
357         res = self.store.isUnchanged(self.file)
358         self.failUnless(res is None)
359         
360     def test_expiry(self):
361         """Tests retrieving the files from the database that have expired."""
362         res = self.store.expiredHashes(1)
363         self.failUnlessEqual(len(res.keys()), 0)
364         sleep(2)
365         res = self.store.expiredHashes(1)
366         self.failUnlessEqual(len(res.keys()), 1)
367         self.failUnlessEqual(res.keys()[0], self.hash)
368         self.store.refreshHash(self.hash)
369         res = self.store.expiredHashes(1)
370         self.failUnlessEqual(len(res.keys()), 0)
371         
372     def build_dirs(self):
373         for dir in self.dirs:
374             file = dir.preauthChild(self.testfile)
375             if not file.parent().exists():
376                 file.parent().makedirs()
377             file.setContent(file.path)
378             file.touch()
379             self.store.storeFile(file, self.hash)
380     
381     def test_multipleHashes(self):
382         """Tests looking up a hash with multiple files in the database."""
383         self.build_dirs()
384         res = self.store.expiredHashes(1)
385         self.failUnlessEqual(len(res.keys()), 0)
386         res = self.store.lookupHash(self.hash)
387         self.failUnless(res)
388         self.failUnlessEqual(len(res), 4)
389         self.failUnlessEqual(res[0]['refreshed'], res[1]['refreshed'])
390         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
391         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
392         sleep(2)
393         res = self.store.expiredHashes(1)
394         self.failUnlessEqual(len(res.keys()), 1)
395         self.failUnlessEqual(res.keys()[0], self.hash)
396         self.store.refreshHash(self.hash)
397         res = self.store.expiredHashes(1)
398         self.failUnlessEqual(len(res.keys()), 0)
399     
400     def test_removeUntracked(self):
401         """Tests removing untracked files from the database."""
402         self.build_dirs()
403         res = self.store.removeUntrackedFiles(self.dirs)
404         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
405         self.failUnlessEqual(res[0], self.file, 'Got removed paths: %r' % res)
406         res = self.store.removeUntrackedFiles(self.dirs)
407         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
408         res = self.store.removeUntrackedFiles(self.dirs[1:])
409         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
410         self.failUnlessEqual(res[0], self.dirs[0].preauthChild(self.testfile), 'Got removed paths: %r' % res)
411         res = self.store.removeUntrackedFiles(self.dirs[:1])
412         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
413         self.failUnlessIn(self.dirs[1].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
414         self.failUnlessIn(self.dirs[2].preauthChild(self.testfile), res, 'Got removed paths: %r' % res)
415         
416     def tearDown(self):
417         self.directory.remove()
418         self.store.close()
419         self.db.remove()
420