]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht/db.py
HTTPServer now gets it's dictionary of subdirectories externally.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 assert sqlite.version_info >= (2, 1)
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
21 sqlite.enable_callback_tracebacks(True)
22
23 class DB:
24     """Database access for storing persistent data."""
25     
26     def __init__(self, db):
27         self.db = db
28         try:
29             os.stat(db)
30         except OSError:
31             self._createNewDB(db)
32         else:
33             self._loadDB(db)
34         self.conn.text_factory = str
35         self.conn.row_factory = sqlite.Row
36         
37     def _loadDB(self, db):
38         try:
39             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
40         except:
41             import traceback
42             raise DBExcept, "Couldn't open DB", traceback.format_exc()
43         
44     def _createNewDB(self, db):
45         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
46         c = self.conn.cursor()
47         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urldir INTEGER, dirlength INTEGER, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
48         c.execute("CREATE INDEX files_urldir ON files(urldir)")
49         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
50         c.execute("CREATE TABLE dirs (urldir INTEGER PRIMARY KEY AUTOINCREMENT, path TEXT)")
51         c.execute("CREATE INDEX dirs_path ON dirs(path)")
52         c.close()
53         self.conn.commit()
54
55     def _removeChanged(self, path, row):
56         res = None
57         if row:
58             try:
59                 stat = os.stat(path)
60             except:
61                 stat = None
62             if stat:
63                 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
64             if not res:
65                 c = self.conn.cursor()
66                 c.execute("DELETE FROM files WHERE path = ?", (path, ))
67                 self.conn.commit()
68                 c.close()
69         return res
70         
71     def storeFile(self, path, hash, directory):
72         """Store or update a file in the database."""
73         path = os.path.abspath(path)
74         directory = os.path.abspath(directory)
75         assert path.startswith(directory)
76         stat = os.stat(path)
77         c = self.conn.cursor()
78         c.execute("SELECT dirs.urldir AS urldir, dirs.path AS directory FROM dirs JOIN files USING (urldir) WHERE files.path = ?", (path, ))
79         row = c.fetchone()
80         if row and directory == row['directory']:
81             c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
82                       (khash(hash), stat.st_size, stat.st_mtime, datetime.now()))
83             newdir = False
84         else:
85             urldir, newdir = self.findDirectory(directory)
86             c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?, ?, ?)",
87                       (path, khash(hash), urldir, len(directory), stat.st_size, stat.st_mtime, datetime.now()))
88         self.conn.commit()
89         c.close()
90         return newdir
91         
92     def getFile(self, path):
93         """Get a file from the database.
94         
95         If it has changed or is missing, it is removed from the database.
96         
97         @return: dictionary of info for the file, False if changed, or
98             None if not in database or missing
99         """
100         path = os.path.abspath(path)
101         c = self.conn.cursor()
102         c.execute("SELECT hash, urldir, dirlength, size, mtime FROM files WHERE path = ?", (path, ))
103         row = c.fetchone()
104         res = self._removeChanged(path, row)
105         if res:
106             res = {}
107             res['hash'] = row['hash']
108             res['urlpath'] = '/~' + str(row['urldir']) + path[row['dirlength']:]
109         c.close()
110         return res
111         
112     def isUnchanged(self, path):
113         """Check if a file in the file system has changed.
114         
115         If it has changed, it is removed from the table.
116         
117         @return: True if unchanged, False if changed, None if not in database
118         """
119         path = os.path.abspath(path)
120         c = self.conn.cursor()
121         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
122         row = c.fetchone()
123         return self._removeChanged(path, row)
124
125     def refreshFile(self, path):
126         """Refresh the publishing time of a file.
127         
128         If it has changed or is missing, it is removed from the table.
129         
130         @return: True if unchanged, False if changed, None if not in database
131         """
132         path = os.path.abspath(path)
133         c = self.conn.cursor()
134         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
135         row = c.fetchone()
136         res = self._removeChanged(path, row)
137         if res:
138             c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
139         return res
140     
141     def expiredFiles(self, expireAfter):
142         """Find files that need refreshing after expireAfter seconds.
143         
144         Also removes any entries from the table that no longer exist.
145         
146         @return: dictionary with keys the hashes, values a list of url paths
147         """
148         t = datetime.now() - timedelta(seconds=expireAfter)
149         c = self.conn.cursor()
150         c.execute("SELECT path, hash, urldir, dirlength, size, mtime FROM files WHERE refreshed < ?", (t, ))
151         row = c.fetchone()
152         expired = {}
153         while row:
154             res = self._removeChanged(row['path'], row)
155             if res:
156                 expired.setdefault(row['hash'], []).append('/~' + str(row['urldir']) + row['path'][row['dirlength']:])
157             row = c.fetchone()
158         c.close()
159         return expired
160         
161     def removeUntrackedFiles(self, dirs):
162         """Find files that are no longer tracked and so should be removed.
163         
164         Also removes the entries from the table.
165         
166         @return: list of files that were removed
167         """
168         assert len(dirs) >= 1
169         newdirs = []
170         sql = "WHERE"
171         for dir in dirs:
172             newdirs.append(os.path.abspath(dir) + os.sep + '*')
173             sql += " path NOT GLOB ? AND"
174         sql = sql[:-4]
175
176         c = self.conn.cursor()
177         c.execute("SELECT path FROM files " + sql, newdirs)
178         row = c.fetchone()
179         removed = []
180         while row:
181             removed.append(row['path'])
182             row = c.fetchone()
183
184         if removed:
185             c.execute("DELETE FROM files " + sql, newdirs)
186         self.conn.commit()
187         return removed
188     
189     def findDirectory(self, directory):
190         """Store or update a directory in the database.
191         
192         @return: the index of the url directory, and whether it is new or not
193         """
194         directory = os.path.abspath(directory)
195         c = self.conn.cursor()
196         c.execute("SELECT min(urldir) AS urldir FROM dirs WHERE path = ?", (directory, ))
197         row = c.fetchone()
198         c.close()
199         if row['urldir']:
200             return row['urldir'], False
201
202         # Not found, need to add a new one
203         c = self.conn.cursor()
204         c.execute("INSERT INTO dirs (path) VALUES (?)", (directory, ))
205         self.conn.commit()
206         urldir = c.lastrowid
207         c.close()
208         return urldir, True
209         
210     def getAllDirectories(self):
211         """Get all the current directories avaliable."""
212         c = self.conn.cursor()
213         c.execute("SELECT urldir, path FROM dirs")
214         row = c.fetchone()
215         dirs = {}
216         while row:
217             dirs['~' + str(row['urldir'])] = row['path']
218             row = c.fetchone()
219         c.close()
220         return dirs
221     
222     def reconcileDirectories(self):
223         """Remove any unneeded directories by checking which are used by files."""
224         c = self.conn.cursor()
225         c.execute('DELETE FROM dirs WHERE urldir NOT IN (SELECT DISTINCT urldir FROM files)')
226         self.conn.commit()
227         return bool(c.rowcount)
228         
229     def close(self):
230         self.conn.close()
231
232 class TestDB(unittest.TestCase):
233     """Tests for the khashmir database."""
234     
235     timeout = 5
236     db = '/tmp/khashmir.db'
237     path = '/tmp/khashmir.test'
238     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
239     directory = '/tmp/'
240     urlpath = '/~1/khashmir.test'
241     dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
242
243     def setUp(self):
244         f = open(self.path, 'w')
245         f.write('fgfhds')
246         f.close()
247         os.utime(self.path, None)
248         self.store = DB(self.db)
249         self.store.storeFile(self.path, self.hash, self.directory)
250
251     def test_getFile(self):
252         res = self.store.getFile(self.path)
253         self.failUnless(res)
254         self.failUnlessEqual(res['hash'], self.hash)
255         self.failUnlessEqual(res['urlpath'], self.urlpath)
256         
257     def test_getAllDirectories(self):
258         res = self.store.getAllDirectories()
259         self.failUnless(res)
260         self.failUnlessEqual(len(res.keys()), 1)
261         self.failUnlessEqual(res.keys()[0], '~1')
262         self.failUnlessEqual(res['~1'], os.path.abspath(self.directory))
263         
264     def test_isUnchanged(self):
265         res = self.store.isUnchanged(self.path)
266         self.failUnless(res)
267         sleep(2)
268         os.utime(self.path, None)
269         res = self.store.isUnchanged(self.path)
270         self.failUnless(res == False)
271         os.unlink(self.path)
272         res = self.store.isUnchanged(self.path)
273         self.failUnless(res == None)
274         
275     def test_expiry(self):
276         res = self.store.expiredFiles(1)
277         self.failUnlessEqual(len(res.keys()), 0)
278         sleep(2)
279         res = self.store.expiredFiles(1)
280         self.failUnlessEqual(len(res.keys()), 1)
281         self.failUnlessEqual(res.keys()[0], self.hash)
282         self.failUnlessEqual(len(res[self.hash]), 1)
283         self.failUnlessEqual(res[self.hash][0], self.urlpath)
284         res = self.store.refreshFile(self.path)
285         self.failUnless(res)
286         res = self.store.expiredFiles(1)
287         self.failUnlessEqual(len(res.keys()), 0)
288         
289     def build_dirs(self):
290         for dir in self.dirs:
291             path = os.path.join(dir, self.path[1:])
292             os.makedirs(os.path.dirname(path))
293             f = open(path, 'w')
294             f.write(path)
295             f.close()
296             os.utime(path, None)
297             self.store.storeFile(path, self.hash, dir)
298     
299     def test_removeUntracked(self):
300         self.build_dirs()
301         res = self.store.removeUntrackedFiles(self.dirs)
302         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
303         self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
304         res = self.store.removeUntrackedFiles(self.dirs)
305         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
306         res = self.store.removeUntrackedFiles(self.dirs[1:])
307         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
308         self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
309         res = self.store.removeUntrackedFiles(self.dirs[:1])
310         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
311         self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
312         self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
313         
314     def test_reconcileDirectories(self):
315         self.build_dirs()
316         res = self.store.getAllDirectories()
317         self.failUnless(res)
318         self.failUnlessEqual(len(res.keys()), 4)
319         res = self.store.reconcileDirectories()
320         self.failUnlessEqual(res, False)
321         res = self.store.getAllDirectories()
322         self.failUnless(res)
323         self.failUnlessEqual(len(res.keys()), 4)
324         res = self.store.removeUntrackedFiles(self.dirs)
325         res = self.store.reconcileDirectories()
326         self.failUnlessEqual(res, True)
327         res = self.store.getAllDirectories()
328         self.failUnless(res)
329         self.failUnlessEqual(len(res.keys()), 3)
330         res = self.store.removeUntrackedFiles(self.dirs[:1])
331         res = self.store.reconcileDirectories()
332         self.failUnlessEqual(res, True)
333         res = self.store.getAllDirectories()
334         self.failUnless(res)
335         self.failUnlessEqual(len(res.keys()), 1)
336         res = self.store.removeUntrackedFiles(['/what'])
337         res = self.store.reconcileDirectories()
338         self.failUnlessEqual(res, True)
339         res = self.store.getAllDirectories()
340         self.failUnlessEqual(len(res.keys()), 0)
341         
342     def tearDown(self):
343         for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
344             for name in files:
345                 os.remove(os.path.join(root, name))
346             for name in dirs:
347                 os.rmdir(os.path.join(root, name))
348         self.store.close()
349         os.unlink(self.db)