]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht/db.py
Main database finished for now, including unittests.
[quix0rs-apt-p2p.git] / apt_dht / db.py
1
2 from datetime import datetime, timedelta
3 from pysqlite2 import dbapi2 as sqlite
4 from binascii import a2b_base64, b2a_base64
5 from time import sleep
6 import os
7
8 from twisted.trial import unittest
9
10 assert sqlite.version_info >= (2, 1)
11
12 class DBExcept(Exception):
13     pass
14
15 class khash(str):
16     """Dummy class to convert all hashes to base64 for storing in the DB."""
17     
18 sqlite.register_adapter(khash, b2a_base64)
19 sqlite.register_converter("KHASH", a2b_base64)
20 sqlite.register_converter("khash", a2b_base64)
21
22 class DB:
23     """Database access for storing persistent data."""
24     
25     def __init__(self, db):
26         self.db = db
27         try:
28             os.stat(db)
29         except OSError:
30             self._createNewDB(db)
31         else:
32             self._loadDB(db)
33         self.conn.text_factory = str
34         self.conn.row_factory = sqlite.Row
35         
36     def _loadDB(self, db):
37         try:
38             self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
39         except:
40             import traceback
41             raise DBExcept, "Couldn't open DB", traceback.format_exc()
42         
43     def _createNewDB(self, db):
44         self.conn = sqlite.connect(database=db, detect_types=sqlite.PARSE_DECLTYPES)
45         c = self.conn.cursor()
46         c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, urlpath TEXT, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
47 #        c.execute("CREATE INDEX files_hash ON files(hash)")
48         c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
49         c.execute("CREATE TABLE dirs (path TEXT PRIMARY KEY, urlpath TEXT)")
50         c.close()
51         self.conn.commit()
52
53     def _removeChanged(self, path, row):
54         res = None
55         if row:
56             try:
57                 stat = os.stat(path)
58             except:
59                 stat = None
60             if stat:
61                 res = (row['size'] == stat.st_size and row['mtime'] == stat.st_mtime)
62             if not res:
63                 c = self.conn.cursor()
64                 c.execute("DELETE FROM files WHERE path = ?", (path, ))
65                 self.conn.commit()
66                 c.close()
67         return res
68         
69     def storeFile(self, path, hash, urlpath):
70         """Store or update a file in the database."""
71         path = os.path.abspath(path)
72         stat = os.stat(path)
73         c = self.conn.cursor()
74         c.execute("INSERT OR REPLACE INTO files VALUES (?, ?, ?, ?, ?, ?)", 
75                   (path, khash(hash), urlpath, stat.st_size, stat.st_mtime, datetime.now()))
76         self.conn.commit()
77         c.close()
78         
79     def getFile(self, path):
80         """Get a file from the database.
81         
82         If it has changed or is missing, it is removed from the database.
83         
84         @return: dictionary of info for the file, False if changed, or
85             None if not in database or missing
86         """
87         path = os.path.abspath(path)
88         c = self.conn.cursor()
89         c.execute("SELECT hash, urlpath, size, mtime FROM files WHERE path = ?", (path, ))
90         row = c.fetchone()
91         res = self._removeChanged(path, row)
92         if res:
93             res = {}
94             res['hash'] = row['hash']
95             res['urlpath'] = row['urlpath']
96         c.close()
97         return res
98         
99     def isUnchanged(self, path):
100         """Check if a file in the file system has changed.
101         
102         If it has changed, it is removed from the table.
103         
104         @return: True if unchanged, False if changed, None if not in database
105         """
106         path = os.path.abspath(path)
107         c = self.conn.cursor()
108         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
109         row = c.fetchone()
110         return self._removeChanged(path, row)
111
112     def refreshFile(self, path):
113         """Refresh the publishing time of a file.
114         
115         If it has changed or is missing, it is removed from the table.
116         
117         @return: True if unchanged, False if changed, None if not in database
118         """
119         path = os.path.abspath(path)
120         c = self.conn.cursor()
121         c.execute("SELECT size, mtime FROM files WHERE path = ?", (path, ))
122         row = c.fetchone()
123         res = self._removeChanged(path, row)
124         if res:
125             c.execute("UPDATE files SET refreshed = ? WHERE path = ?", (datetime.now(), path))
126         return res
127     
128     def expiredFiles(self, expireAfter):
129         """Find files that need refreshing after expireAfter seconds.
130         
131         Also removes any entries from the table that no longer exist.
132         
133         @return: dictionary with keys the hashes, values a list of url paths
134         """
135         t = datetime.now() - timedelta(seconds=expireAfter)
136         c = self.conn.cursor()
137         c.execute("SELECT path, hash, urlpath, size, mtime FROM files WHERE refreshed < ?", (t, ))
138         row = c.fetchone()
139         expired = {}
140         while row:
141             res = self._removeChanged(row['path'], row)
142             if res:
143                 expired.setdefault(row['hash'], []).append(row['urlpath'])
144             row = c.fetchone()
145         c.close()
146         return expired
147         
148     def removeUntrackedFiles(self, dirs):
149         """Find files that are no longer tracked and so should be removed.
150         
151         Also removes the entries from the table.
152         
153         @return: list of files that were removed
154         """
155         assert len(dirs) >= 1
156         newdirs = []
157         sql = "WHERE"
158         for dir in dirs:
159             newdirs.append(os.path.abspath(dir) + os.sep + '*')
160             sql += " path NOT GLOB ? AND"
161         sql = sql[:-4]
162
163         c = self.conn.cursor()
164         c.execute("SELECT path FROM files " + sql, newdirs)
165         row = c.fetchone()
166         removed = []
167         while row:
168             removed.append(row['path'])
169             row = c.fetchone()
170
171         if removed:
172             c.execute("DELETE FROM files " + sql, newdirs)
173         self.conn.commit()
174         return removed
175         
176     def close(self):
177         self.conn.close()
178
179 class TestDB(unittest.TestCase):
180     """Tests for the khashmir database."""
181     
182     timeout = 5
183     db = '/tmp/khashmir.db'
184     path = '/tmp/khashmir.test'
185     hash = '\xca\xec\xb8\x0c\x00\xe7\x07\xf8~])\x8f\x9d\xe5_B\xff\x1a\xc4!'
186     urlpath = '/~1/what/ever/khashmir.test'
187     dirs = ['/tmp/apt-dht/top1', '/tmp/apt-dht/top2/sub1', '/tmp/apt-dht/top2/sub2/']
188
189     def setUp(self):
190         f = open(self.path, 'w')
191         f.write('fgfhds')
192         f.close()
193         os.utime(self.path, None)
194         self.store = DB(self.db)
195         self.store.storeFile(self.path, self.hash, self.urlpath)
196
197     def test_getFile(self):
198         res = self.store.getFile(self.path)
199         self.failUnless(res)
200         self.failUnlessEqual(res['hash'], self.hash)
201         self.failUnlessEqual(res['urlpath'], self.urlpath)
202         
203     def test_isUnchanged(self):
204         res = self.store.isUnchanged(self.path)
205         self.failUnless(res)
206         sleep(2)
207         os.utime(self.path, None)
208         res = self.store.isUnchanged(self.path)
209         self.failUnless(res == False)
210         os.unlink(self.path)
211         res = self.store.isUnchanged(self.path)
212         self.failUnless(res == None)
213         
214     def test_expiry(self):
215         res = self.store.expiredFiles(1)
216         self.failUnlessEqual(len(res.keys()), 0)
217         sleep(2)
218         res = self.store.expiredFiles(1)
219         self.failUnlessEqual(len(res.keys()), 1)
220         self.failUnlessEqual(res.keys()[0], self.hash)
221         self.failUnlessEqual(len(res[self.hash]), 1)
222         self.failUnlessEqual(res[self.hash][0], self.urlpath)
223         res = self.store.refreshFile(self.path)
224         self.failUnless(res)
225         res = self.store.expiredFiles(1)
226         self.failUnlessEqual(len(res.keys()), 0)
227         
228     def test_removeUntracked(self):
229         for dir in self.dirs:
230             path = os.path.join(dir, self.path[1:])
231             os.makedirs(os.path.dirname(path))
232             f = open(path, 'w')
233             f.write(path)
234             f.close()
235             os.utime(path, None)
236             self.store.storeFile(path, self.hash, self.urlpath)
237         
238         res = self.store.removeUntrackedFiles(self.dirs)
239         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
240         self.failUnlessEqual(res[0], self.path, 'Got removed paths: %r' % res)
241         res = self.store.removeUntrackedFiles(self.dirs)
242         self.failUnlessEqual(len(res), 0, 'Got removed paths: %r' % res)
243         res = self.store.removeUntrackedFiles(self.dirs[1:])
244         self.failUnlessEqual(len(res), 1, 'Got removed paths: %r' % res)
245         self.failUnlessEqual(res[0], os.path.join(self.dirs[0], self.path[1:]), 'Got removed paths: %r' % res)
246         res = self.store.removeUntrackedFiles(self.dirs[:1])
247         self.failUnlessEqual(len(res), 2, 'Got removed paths: %r' % res)
248         self.failUnlessIn(os.path.join(self.dirs[1], self.path[1:]), res, 'Got removed paths: %r' % res)
249         self.failUnlessIn(os.path.join(self.dirs[2], self.path[1:]), res, 'Got removed paths: %r' % res)
250         
251     def tearDown(self):
252         for root, dirs, files in os.walk('/tmp/apt-dht', topdown=False):
253             for name in files:
254                 os.remove(os.path.join(root, name))
255             for name in dirs:
256                 os.rmdir(os.path.join(root, name))
257         self.store.close()
258         os.unlink(self.db)