Add all files to the DB with their hashes.
[quix0rs-apt-p2p.git] / apt_p2p / CacheManager.py
1
2 """Manage a cache of downloaded files.
3
4 @var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
5 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
6 """
7
8 from bz2 import BZ2Decompressor
9 from zlib import decompressobj, MAX_WBITS
10 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
11 from urlparse import urlparse
12 import os
13
14 from twisted.python import log
15 from twisted.python.filepath import FilePath
16 from twisted.internet import defer, reactor
17 from twisted.trial import unittest
18 from twisted.web2 import stream
19 from twisted.web2.http import splitHostPort
20
21 from Hash import HashObject
22 from apt_p2p_conf import config
23
24 DECOMPRESS_EXTS = ['.gz', '.bz2']
25 DECOMPRESS_FILES = ['release', 'sources', 'packages']
26
27 class CacheError(Exception):
28     """Error occurred downloading a file to the cache."""
29
30 class ProxyFileStream(stream.SimpleStream):
31     """Saves a stream to a file while providing a new stream.
32     
33     Also optionally decompresses the file while it is being downloaded.
34     
35     @type stream: L{twisted.web2.stream.IByteStream}
36     @ivar stream: the input stream being read
37     @type outFile: L{twisted.python.filepath.FilePath}
38     @ivar outFile: the file being written
39     @type hash: L{Hash.HashObject}
40     @ivar hash: the hash object for the file
41     @type gzfile: C{file}
42     @ivar gzfile: the open file to write decompressed gzip data to
43     @type gzdec: L{zlib.decompressobj}
44     @ivar gzdec: the decompressor to use for the compressed gzip data
45     @type gzheader: C{boolean}
46     @ivar gzheader: whether the gzip header still needs to be removed from
47         the zlib compressed data
48     @type bz2file: C{file}
49     @ivar bz2file: the open file to write decompressed bz2 data to
50     @type bz2dec: L{bz2.BZ2Decompressor}
51     @ivar bz2dec: the decompressor to use for the compressed bz2 data
52     @type length: C{int}
53     @ivar length: the length of the original (compressed) file
54     @type doneDefer: L{twisted.internet.defer.Deferred}
55     @ivar doneDefer: the deferred that will fire when done streaming
56     
57     @group Stream implementation: read, close
58     
59     """
60     
61     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
62         """Initializes the proxy.
63         
64         @type stream: L{twisted.web2.stream.IByteStream}
65         @param stream: the input stream to read from
66         @type outFile: L{twisted.python.filepath.FilePath}
67         @param outFile: the file to write to
68         @type hash: L{Hash.HashObject}
69         @param hash: the hash object to use for the file
70         @type decompress: C{string}
71         @param decompress: also decompress the file as this type
72             (currently only '.gz' and '.bz2' are supported)
73         @type decFile: C{twisted.python.FilePath}
74         @param decFile: the file to write the decompressed data to
75         """
76         self.stream = stream
77         self.outFile = outFile.open('w')
78         self.hash = hash
79         self.hash.new()
80         self.gzfile = None
81         self.bz2file = None
82         if decompress == ".gz":
83             self.gzheader = True
84             self.gzfile = decFile.open('w')
85             self.gzdec = decompressobj(-MAX_WBITS)
86         elif decompress == ".bz2":
87             self.bz2file = decFile.open('w')
88             self.bz2dec = BZ2Decompressor()
89         self.length = self.stream.length
90         self.doneDefer = defer.Deferred()
91
92     def _done(self):
93         """Close all the output files, return the result."""
94         if not self.outFile.closed:
95             self.outFile.close()
96             self.hash.digest()
97             if self.gzfile:
98                 # Finish the decompression
99                 data_dec = self.gzdec.flush()
100                 self.gzfile.write(data_dec)
101                 self.gzfile.close()
102                 self.gzfile = None
103             if self.bz2file:
104                 self.bz2file.close()
105                 self.bz2file = None
106     
107     def _error(self, err):
108         """Close all the output files, return the error."""
109         if not self.outFile.closed:
110             self._done()
111             self.stream.close()
112             self.doneDefer.errback(err)
113
114     def read(self):
115         """Read some data from the stream."""
116         if self.outFile.closed:
117             return None
118         
119         # Read data from the stream, deal with the possible deferred
120         data = self.stream.read()
121         if isinstance(data, defer.Deferred):
122             data.addCallbacks(self._write, self._error)
123             return data
124         
125         self._write(data)
126         return data
127     
128     def _write(self, data):
129         """Write the stream data to the file and return it for others to use.
130         
131         Also optionally decompresses it.
132         """
133         if data is None:
134             if not self.outFile.closed:
135                 self._done()
136                 self.doneDefer.callback(self.hash)
137             return data
138         
139         # Write and hash the streamed data
140         self.outFile.write(data)
141         self.hash.update(data)
142         
143         if self.gzfile:
144             # Decompress the zlib portion of the file
145             if self.gzheader:
146                 # Remove the gzip header junk
147                 self.gzheader = False
148                 new_data = self._remove_gzip_header(data)
149                 dec_data = self.gzdec.decompress(new_data)
150             else:
151                 dec_data = self.gzdec.decompress(data)
152             self.gzfile.write(dec_data)
153         if self.bz2file:
154             # Decompress the bz2 file
155             dec_data = self.bz2dec.decompress(data)
156             self.bz2file.write(dec_data)
157
158         return data
159     
160     def _remove_gzip_header(self, data):
161         """Remove the gzip header from the zlib compressed data."""
162         # Read, check & discard the header fields
163         if data[:2] != '\037\213':
164             raise IOError, 'Not a gzipped file'
165         if ord(data[2]) != 8:
166             raise IOError, 'Unknown compression method'
167         flag = ord(data[3])
168         # modtime = self.fileobj.read(4)
169         # extraflag = self.fileobj.read(1)
170         # os = self.fileobj.read(1)
171
172         skip = 10
173         if flag & FEXTRA:
174             # Read & discard the extra field
175             xlen = ord(data[10])
176             xlen = xlen + 256*ord(data[11])
177             skip = skip + 2 + xlen
178         if flag & FNAME:
179             # Read and discard a null-terminated string containing the filename
180             while True:
181                 if not data[skip] or data[skip] == '\000':
182                     break
183                 skip += 1
184             skip += 1
185         if flag & FCOMMENT:
186             # Read and discard a null-terminated string containing a comment
187             while True:
188                 if not data[skip] or data[skip] == '\000':
189                     break
190                 skip += 1
191             skip += 1
192         if flag & FHCRC:
193             skip += 2     # Read & discard the 16-bit header CRC
194
195         return data[skip:]
196
197     def close(self):
198         """Clean everything up and return None to future reads."""
199         log.msg('ProxyFileStream was prematurely closed after only %d/%d bytes' % (self.hash.size, self.length))
200         if self.hash.size < self.length:
201             self._error(CacheError('Prematurely closed, all data was not written'))
202         elif not self.outFile.closed:
203             self._done()
204             self.doneDefer.callback(self.hash)
205         self.length = 0
206         self.stream.close()
207
208 class CacheManager:
209     """Manages all downloaded files and requests for cached objects.
210     
211     @type cache_dir: L{twisted.python.filepath.FilePath}
212     @ivar cache_dir: the directory to use for storing all files
213     @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
214     @ivar other_dirs: the other directories that have shared files in them
215     @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
216     @ivar all_dirs: all the directories that have cached files in them
217     @type db: L{db.DB}
218     @ivar db: the database to use for tracking files and hashes
219     @type manager: L{apt_p2p.AptP2P}
220     @ivar manager: the main program object to send requests to
221     @type scanning: C{list} of L{twisted.python.filepath.FilePath}
222     @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
223     """
224     
225     def __init__(self, cache_dir, db, manager = None):
226         """Initialize the instance and remove any untracked files from the DB..
227         
228         @type cache_dir: L{twisted.python.filepath.FilePath}
229         @param cache_dir: the directory to use for storing all files
230         @type db: L{db.DB}
231         @param db: the database to use for tracking files and hashes
232         @type manager: L{apt_p2p.AptP2P}
233         @param manager: the main program object to send requests to
234             (optional, defaults to not calling back with cached files)
235         """
236         self.cache_dir = cache_dir
237         self.other_dirs = [FilePath(f) for f in config.getstringlist('DEFAULT', 'OTHER_DIRS')]
238         self.all_dirs = self.other_dirs[:]
239         self.all_dirs.insert(0, self.cache_dir)
240         self.db = db
241         self.manager = manager
242         self.scanning = []
243         
244         # Init the database, remove old files
245         self.db.removeUntrackedFiles(self.all_dirs)
246         
247     #{ Scanning directories
248     def scanDirectories(self):
249         """Scan the cache directories, hashing new and rehashing changed files."""
250         assert not self.scanning, "a directory scan is already under way"
251         self.scanning = self.all_dirs[:]
252         self._scanDirectories()
253
254     def _scanDirectories(self, result = None, walker = None):
255         """Walk each directory looking for cached files.
256         
257         @param result: the result of a DHT store request, not used (optional)
258         @param walker: the walker to use to traverse the current directory
259             (optional, defaults to creating a new walker from the first
260             directory in the L{CacheManager.scanning} list)
261         """
262         # Need to start walking a new directory
263         if walker is None:
264             # If there are any left, get them
265             if self.scanning:
266                 log.msg('started scanning directory: %s' % self.scanning[0].path)
267                 walker = self.scanning[0].walk()
268             else:
269                 log.msg('cache directory scan complete')
270                 return
271             
272         try:
273             # Get the next file in the directory
274             file = walker.next()
275         except StopIteration:
276             # No files left, go to the next directory
277             log.msg('done scanning directory: %s' % self.scanning[0].path)
278             self.scanning.pop(0)
279             reactor.callLater(0, self._scanDirectories)
280             return
281
282         # If it's not a file ignore it
283         if not file.isfile():
284             log.msg('entering directory: %s' % file.path)
285             reactor.callLater(0, self._scanDirectories, None, walker)
286             return
287
288         # If it's already properly in the DB, ignore it
289         db_status = self.db.isUnchanged(file)
290         if db_status:
291             log.msg('file is unchanged: %s' % file.path)
292             reactor.callLater(0, self._scanDirectories, None, walker)
293             return
294         
295         # Don't hash files in the cache that are not in the DB
296         if self.scanning[0] == self.cache_dir:
297             if db_status is None:
298                 log.msg('ignoring unknown cache file: %s' % file.path)
299             else:
300                 log.msg('removing changed cache file: %s' % file.path)
301                 file.remove()
302             reactor.callLater(0, self._scanDirectories, None, walker)
303             return
304
305         # Otherwise hash it
306         log.msg('start hash checking file: %s' % file.path)
307         hash = HashObject()
308         df = hash.hashInThread(file)
309         df.addBoth(self._doneHashing, file, walker)
310         df.addErrback(log.err)
311     
312     def _doneHashing(self, result, file, walker):
313         """If successful, add the hashed file to the DB and inform the main program."""
314         if isinstance(result, HashObject):
315             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
316             
317             # Only set a URL if this is a downloaded file
318             url = None
319             if self.scanning[0] == self.cache_dir:
320                 url = 'http:/' + file.path[len(self.cache_dir.path):]
321                 
322             # Store the hashed file in the database
323             new_hash = self.db.storeFile(file, result.digest(), True,
324                                          ''.join(result.pieceDigests()))
325             
326             # Tell the main program to handle the new cache file
327             df = self.manager.new_cached_file(file, result, new_hash, url, True)
328             if df is None:
329                 reactor.callLater(0, self._scanDirectories, None, walker)
330             else:
331                 df.addBoth(self._scanDirectories, walker)
332         else:
333             # Must have returned an error
334             log.msg('hash check of %s failed' % file.path)
335             log.err(result)
336             reactor.callLater(0, self._scanDirectories, None, walker)
337
338     #{ Downloading files
339     def save_file(self, response, hash, url):
340         """Save a downloaded file to the cache and stream it.
341         
342         @type response: L{twisted.web2.http.Response}
343         @param response: the response from the download
344         @type hash: L{Hash.HashObject}
345         @param hash: the hash object containing the expected hash for the file
346         @param url: the URI of the actual mirror request
347         @rtype: L{twisted.web2.http.Response}
348         @return: the final response from the download
349         """
350         if response.code != 200:
351             log.msg('File was not found (%r): %s' % (response, url))
352             return response
353         
354         log.msg('Returning file: %s' % url)
355
356         # Set the destination path for the file
357         parsed = urlparse(url)
358         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
359         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
360         
361         # Make sure there's a free place for the file
362         if destFile.exists():
363             log.msg('File already exists, removing: %s' % destFile.path)
364             destFile.remove()
365         if not destFile.parent().exists():
366             destFile.parent().makedirs()
367
368         # Determine whether it needs to be decompressed and how
369         root, ext = os.path.splitext(destFile.basename())
370         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
371             ext = ext.lower()
372             decFile = destFile.sibling(root)
373             log.msg('Decompressing to: %s' % decFile.path)
374             if decFile.exists():
375                 log.msg('File already exists, removing: %s' % decFile.path)
376                 decFile.remove()
377         else:
378             ext = None
379             decFile = None
380             
381         # Create the new stream from the old one.
382         orig_stream = response.stream
383         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
384         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
385                                               response.headers.getHeader('Last-Modified'),
386                                               decFile)
387         response.stream.doneDefer.addErrback(self._save_error, url, destFile, decFile)
388
389         # Return the modified response with the new stream
390         return response
391
392     def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
393         """Update the modification time and inform the main program.
394         
395         @type hash: L{Hash.HashObject}
396         @param hash: the hash object containing the expected hash for the file
397         @param url: the URI of the actual mirror request
398         @type destFile: C{twisted.python.FilePath}
399         @param destFile: the file where the download was written to
400         @type modtime: C{int}
401         @param modtime: the modified time of the cached file (seconds since epoch)
402             (optional, defaults to not setting the modification time of the file)
403         @type decFile: C{twisted.python.FilePath}
404         @param decFile: the file where the decompressed download was written to
405             (optional, defaults to the file not having been compressed)
406         """
407         result = hash.verify()
408         if result or result is None:
409             if modtime:
410                 os.utime(destFile.path, (modtime, modtime))
411             
412             if result:
413                 log.msg('Hashes match: %s' % url)
414                 dht = True
415             else:
416                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
417                 dht = False
418                 
419             new_hash = self.db.storeFile(destFile, hash.digest(), dht,
420                                          ''.join(hash.pieceDigests()))
421
422             if self.manager:
423                 self.manager.new_cached_file(destFile, hash, new_hash, url)
424
425             if decFile:
426                 # Hash the decompressed file and add it to the DB
427                 decHash = HashObject()
428                 ext_len = len(destFile.path) - len(decFile.path)
429                 df = decHash.hashInThread(decFile)
430                 df.addCallback(self._save_complete, url[:-ext_len], decFile, modtime)
431                 df.addErrback(self._save_error, url[:-ext_len], decFile)
432         else:
433             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
434             destFile.remove()
435             if decFile:
436                 decFile.remove()
437
438     def _save_error(self, failure, url, destFile, decFile = None):
439         """Remove the destination files."""
440         log.msg('Error occurred downloading %s' % url)
441         log.err(failure)
442         destFile.restat(False)
443         if destFile.exists():
444             log.msg('Removing the incomplete file: %s' % destFile.path)
445             destFile.remove()
446         if decFile:
447             decFile.restat(False)
448             if decFile.exists():
449                 log.msg('Removing the incomplete file: %s' % decFile.path)
450                 decFile.remove()
451
452     def save_error(self, failure, url):
453         """An error has occurred in downloading or saving the file"""
454         log.msg('Error occurred downloading %s' % url)
455         log.err(failure)
456         return failure
457
458 class TestMirrorManager(unittest.TestCase):
459     """Unit tests for the mirror manager."""
460     
461     timeout = 20
462     pending_calls = []
463     client = None
464     
465     def setUp(self):
466         self.client = CacheManager(FilePath('/tmp/.apt-p2p'))
467         
468     def tearDown(self):
469         for p in self.pending_calls:
470             if p.active():
471                 p.cancel()
472         self.client = None
473