02021dc34c16a359b360f11c5100af260a5220ca
[quix0rs-apt-p2p.git] / apt_dht / CacheManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 import os
7
8 from twisted.python import log
9 from twisted.python.filepath import FilePath
10 from twisted.internet import defer, reactor
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from Hash import HashObject
16
17 aptpkg_dir='apt-packages'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class ProxyFileStream(stream.SimpleStream):
23     """Saves a stream to a file while providing a new stream."""
24     
25     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
26         """Initializes the proxy.
27         
28         @type stream: C{twisted.web2.stream.IByteStream}
29         @param stream: the input stream to read from
30         @type outFile: C{twisted.python.FilePath}
31         @param outFile: the file to write to
32         @type hash: L{Hash.HashObject}
33         @param hash: the hash object to use for the file
34         @type decompress: C{string}
35         @param decompress: also decompress the file as this type
36             (currently only '.gz' and '.bz2' are supported)
37         @type decFile: C{twisted.python.FilePath}
38         @param decFile: the file to write the decompressed data to
39         """
40         self.stream = stream
41         self.outFile = outFile.open('w')
42         self.hash = hash
43         self.hash.new()
44         self.gzfile = None
45         self.bz2file = None
46         if decompress == ".gz":
47             self.gzheader = True
48             self.gzfile = decFile.open('w')
49             self.gzdec = decompressobj(-MAX_WBITS)
50         elif decompress == ".bz2":
51             self.bz2file = decFile.open('w')
52             self.bz2dec = BZ2Decompressor()
53         self.length = self.stream.length
54         self.start = 0
55         self.doneDefer = defer.Deferred()
56
57     def _done(self):
58         """Close the output file."""
59         if not self.outFile.closed:
60             self.outFile.close()
61             self.hash.digest()
62             if self.gzfile:
63                 data_dec = self.gzdec.flush()
64                 self.gzfile.write(data_dec)
65                 self.gzfile.close()
66                 self.gzfile = None
67             if self.bz2file:
68                 self.bz2file.close()
69                 self.bz2file = None
70                 
71             self.doneDefer.callback(self.hash)
72     
73     def read(self):
74         """Read some data from the stream."""
75         if self.outFile.closed:
76             return None
77         
78         data = self.stream.read()
79         if isinstance(data, defer.Deferred):
80             data.addCallbacks(self._write, self._done)
81             return data
82         
83         self._write(data)
84         return data
85     
86     def _write(self, data):
87         """Write the stream data to the file and return it for others to use."""
88         if data is None:
89             self._done()
90             return data
91         
92         self.outFile.write(data)
93         self.hash.update(data)
94         if self.gzfile:
95             if self.gzheader:
96                 self.gzheader = False
97                 new_data = self._remove_gzip_header(data)
98                 dec_data = self.gzdec.decompress(new_data)
99             else:
100                 dec_data = self.gzdec.decompress(data)
101             self.gzfile.write(dec_data)
102         if self.bz2file:
103             dec_data = self.bz2dec.decompress(data)
104             self.bz2file.write(dec_data)
105         return data
106     
107     def _remove_gzip_header(self, data):
108         if data[:2] != '\037\213':
109             raise IOError, 'Not a gzipped file'
110         if ord(data[2]) != 8:
111             raise IOError, 'Unknown compression method'
112         flag = ord(data[3])
113         # modtime = self.fileobj.read(4)
114         # extraflag = self.fileobj.read(1)
115         # os = self.fileobj.read(1)
116
117         skip = 10
118         if flag & FEXTRA:
119             # Read & discard the extra field, if present
120             xlen = ord(data[10])
121             xlen = xlen + 256*ord(data[11])
122             skip = skip + 2 + xlen
123         if flag & FNAME:
124             # Read and discard a null-terminated string containing the filename
125             while True:
126                 if not data[skip] or data[skip] == '\000':
127                     break
128                 skip += 1
129             skip += 1
130         if flag & FCOMMENT:
131             # Read and discard a null-terminated string containing a comment
132             while True:
133                 if not data[skip] or data[skip] == '\000':
134                     break
135                 skip += 1
136             skip += 1
137         if flag & FHCRC:
138             skip += 2     # Read & discard the 16-bit header CRC
139         return data[skip:]
140
141     def close(self):
142         """Clean everything up and return None to future reads."""
143         self.length = 0
144         self._done()
145         self.stream.close()
146
147 class CacheManager:
148     """Manages all requests for cached objects."""
149     
150     def __init__(self, cache_dir, db, other_dirs = [], manager = None):
151         self.cache_dir = cache_dir
152         self.other_dirs = other_dirs
153         self.all_dirs = self.other_dirs[:]
154         self.all_dirs.insert(0, self.cache_dir)
155         self.db = db
156         self.manager = manager
157         self.scanning = []
158         
159         # Init the database, remove old files
160         self.db.removeUntrackedFiles(self.all_dirs)
161         
162         
163     def scanDirectories(self):
164         """Scan the cache directories, hashing new and rehashing changed files."""
165         assert not self.scanning, "a directory scan is already under way"
166         self.scanning = self.all_dirs[:]
167         self._scanDirectories()
168
169     def _scanDirectories(self, result = None, walker = None):
170         # Need to start waling a new directory
171         if walker is None:
172             # If there are any left, get them
173             if self.scanning:
174                 log.msg('started scanning directory: %s' % self.scanning[0].path)
175                 walker = self.scanning[0].walk()
176             else:
177                 log.msg('cache directory scan complete')
178                 return
179             
180         try:
181             # Get the next file in the directory
182             file = walker.next()
183         except StopIteration:
184             # No files left, go to the next directory
185             log.msg('done scanning directory: %s' % self.scanning[0].path)
186             self.scanning.pop(0)
187             reactor.callLater(0, self._scanDirectories)
188             return
189
190         # If it's not a file ignore it
191         if not file.isfile():
192             log.msg('entering directory: %s' % file.path)
193             reactor.callLater(0, self._scanDirectories, None, walker)
194             return
195
196         # If it's already properly in the DB, ignore it
197         db_status = self.db.isUnchanged(file)
198         if db_status:
199             log.msg('file is unchanged: %s' % file.path)
200             reactor.callLater(0, self._scanDirectories, None, walker)
201             return
202         
203         # Don't hash files in the cache that are not in the DB
204         if self.scanning[0] == self.cache_dir:
205             if db_status is None:
206                 log.msg('ignoring unknown cache file: %s' % file.path)
207             else:
208                 log.msg('removing changed cache file: %s' % file.path)
209                 file.remove()
210             reactor.callLater(0, self._scanDirectories, None, walker)
211             return
212
213         # Otherwise hash it
214         log.msg('start hash checking file: %s' % file.path)
215         hash = HashObject()
216         df = hash.hashInThread(file)
217         df.addBoth(self._doneHashing, file, walker)
218         df.addErrback(log.err)
219     
220     def _doneHashing(self, result, file, walker):
221     
222         if isinstance(result, HashObject):
223             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
224             url = None
225             if self.scanning[0] == self.cache_dir:
226                 url = 'http:/' + file.path[len(self.cache_dir.path):]
227             new_hash = self.db.storeFile(file, result.digest())
228             df = self.manager.new_cached_file(file, result, new_hash, url, True)
229             if df is None:
230                 reactor.callLater(0, self._scanDirectories, None, walker)
231             else:
232                 df.addBoth(self._scanDirectories, walker)
233         else:
234             log.msg('hash check of %s failed' % file.path)
235             log.err(result)
236             reactor.callLater(0, self._scanDirectories, None, walker)
237
238     def save_file(self, response, hash, url):
239         """Save a downloaded file to the cache and stream it."""
240         if response.code != 200:
241             log.msg('File was not found (%r): %s' % (response, url))
242             return response
243         
244         log.msg('Returning file: %s' % url)
245         
246         parsed = urlparse(url)
247         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
248         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
249         
250         if destFile.exists():
251             log.msg('File already exists, removing: %s' % destFile.path)
252             destFile.remove()
253         elif not destFile.parent().exists():
254             destFile.parent().makedirs()
255             
256         root, ext = os.path.splitext(destFile.basename())
257         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
258             ext = ext.lower()
259             decFile = destFile.sibling(root)
260             log.msg('Decompressing to: %s' % decFile.path)
261             if decFile.exists():
262                 log.msg('File already exists, removing: %s' % decFile.path)
263                 decFile.remove()
264         else:
265             ext = None
266             decFile = None
267             
268         orig_stream = response.stream
269         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
270         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
271                                               response.headers.getHeader('Last-Modified'),
272                                               ext, decFile)
273         response.stream.doneDefer.addErrback(self.save_error, url)
274         return response
275
276     def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
277         """Update the modification time and AptPackages."""
278         if modtime:
279             os.utime(destFile.path, (modtime, modtime))
280             if ext:
281                 os.utime(decFile.path, (modtime, modtime))
282         
283         result = hash.verify()
284         if result or result is None:
285             if result:
286                 log.msg('Hashes match: %s' % url)
287             else:
288                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
289                 
290             new_hash = self.db.storeFile(destFile, hash.digest())
291             log.msg('now avaliable: %s' % (url))
292
293             if self.manager:
294                 self.manager.new_cached_file(destFile, hash, new_hash, url)
295                 if ext:
296                     self.manager.new_cached_file(decFile, None, False, url[:-len(ext)])
297         else:
298             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
299             destFile.remove()
300             if ext:
301                 decFile.remove()
302
303     def save_error(self, failure, url):
304         """An error has occurred in downloadign or saving the file."""
305         log.msg('Error occurred downloading %s' % url)
306         log.err(failure)
307         return failure
308
309 class TestMirrorManager(unittest.TestCase):
310     """Unit tests for the mirror manager."""
311     
312     timeout = 20
313     pending_calls = []
314     client = None
315     
316     def setUp(self):
317         self.client = CacheManager(FilePath('/tmp/.apt-dht'))
318         
319     def tearDown(self):
320         for p in self.pending_calls:
321             if p.active():
322                 p.cancel()
323         self.client = None
324