Better handling and logging for intermittent HTTP client submission errors.
[quix0rs-apt-p2p.git] / apt_p2p / CacheManager.py
1
2 """Manage a cache of downloaded files.
3
4 @var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
5 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
6 """
7
8 from bz2 import BZ2Decompressor
9 from zlib import decompressobj, MAX_WBITS
10 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
11 from urlparse import urlparse
12 import os
13
14 from twisted.python import log
15 from twisted.python.filepath import FilePath
16 from twisted.internet import defer, reactor
17 from twisted.trial import unittest
18 from twisted.web2 import stream
19 from twisted.web2.http import splitHostPort
20
21 from Hash import HashObject
22 from apt_p2p_conf import config
23
24 DECOMPRESS_EXTS = ['.gz', '.bz2']
25 DECOMPRESS_FILES = ['release', 'sources', 'packages']
26
27 class CacheError(Exception):
28     """Error occurred downloading a file to the cache."""
29
30 class ProxyFileStream(stream.SimpleStream):
31     """Saves a stream to a file while providing a new stream.
32     
33     Also optionally decompresses the file while it is being downloaded.
34     
35     @type stream: L{twisted.web2.stream.IByteStream}
36     @ivar stream: the input stream being read
37     @type outFile: L{twisted.python.filepath.FilePath}
38     @ivar outFile: the file being written
39     @type hash: L{Hash.HashObject}
40     @ivar hash: the hash object for the file
41     @type gzfile: C{file}
42     @ivar gzfile: the open file to write decompressed gzip data to
43     @type gzdec: L{zlib.decompressobj}
44     @ivar gzdec: the decompressor to use for the compressed gzip data
45     @type gzheader: C{boolean}
46     @ivar gzheader: whether the gzip header still needs to be removed from
47         the zlib compressed data
48     @type bz2file: C{file}
49     @ivar bz2file: the open file to write decompressed bz2 data to
50     @type bz2dec: L{bz2.BZ2Decompressor}
51     @ivar bz2dec: the decompressor to use for the compressed bz2 data
52     @type length: C{int}
53     @ivar length: the length of the original (compressed) file
54     @type doneDefer: L{twisted.internet.defer.Deferred}
55     @ivar doneDefer: the deferred that will fire when done streaming
56     
57     @group Stream implementation: read, close
58     
59     """
60     
61     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
62         """Initializes the proxy.
63         
64         @type stream: L{twisted.web2.stream.IByteStream}
65         @param stream: the input stream to read from
66         @type outFile: L{twisted.python.filepath.FilePath}
67         @param outFile: the file to write to
68         @type hash: L{Hash.HashObject}
69         @param hash: the hash object to use for the file
70         @type decompress: C{string}
71         @param decompress: also decompress the file as this type
72             (currently only '.gz' and '.bz2' are supported)
73         @type decFile: C{twisted.python.FilePath}
74         @param decFile: the file to write the decompressed data to
75         """
76         self.stream = stream
77         self.outFile = outFile.open('w')
78         self.hash = hash
79         self.hash.new()
80         self.gzfile = None
81         self.bz2file = None
82         if decompress == ".gz":
83             self.gzheader = True
84             self.gzfile = decFile.open('w')
85             self.gzdec = decompressobj(-MAX_WBITS)
86         elif decompress == ".bz2":
87             self.bz2file = decFile.open('w')
88             self.bz2dec = BZ2Decompressor()
89         self.length = self.stream.length
90         self.doneDefer = defer.Deferred()
91
92     def _done(self):
93         """Close all the output files, return the result."""
94         if not self.outFile.closed:
95             self.outFile.close()
96             self.hash.digest()
97             if self.gzfile:
98                 # Finish the decompression
99                 data_dec = self.gzdec.flush()
100                 self.gzfile.write(data_dec)
101                 self.gzfile.close()
102                 self.gzfile = None
103             if self.bz2file:
104                 self.bz2file.close()
105                 self.bz2file = None
106     
107     def _error(self, err):
108         """Close all the output files, return the error."""
109         if not self.outFile.closed:
110             self._done()
111             self.stream.close()
112             self.doneDefer.errback(err)
113
114     def read(self):
115         """Read some data from the stream."""
116         if self.outFile.closed:
117             return None
118         
119         # Read data from the stream, deal with the possible deferred
120         data = self.stream.read()
121         if isinstance(data, defer.Deferred):
122             data.addCallbacks(self._write, self._error)
123             return data
124         
125         self._write(data)
126         return data
127     
128     def _write(self, data):
129         """Write the stream data to the file and return it for others to use.
130         
131         Also optionally decompresses it.
132         """
133         if data is None:
134             if not self.outFile.closed:
135                 self._done()
136                 self.doneDefer.callback(self.hash)
137             return data
138         
139         # Write and hash the streamed data
140         self.outFile.write(data)
141         self.hash.update(data)
142         
143         if self.gzfile:
144             # Decompress the zlib portion of the file
145             if self.gzheader:
146                 # Remove the gzip header junk
147                 self.gzheader = False
148                 new_data = self._remove_gzip_header(data)
149                 dec_data = self.gzdec.decompress(new_data)
150             else:
151                 dec_data = self.gzdec.decompress(data)
152             self.gzfile.write(dec_data)
153         if self.bz2file:
154             # Decompress the bz2 file
155             dec_data = self.bz2dec.decompress(data)
156             self.bz2file.write(dec_data)
157
158         return data
159     
160     def _remove_gzip_header(self, data):
161         """Remove the gzip header from the zlib compressed data."""
162         # Read, check & discard the header fields
163         if data[:2] != '\037\213':
164             raise IOError, 'Not a gzipped file'
165         if ord(data[2]) != 8:
166             raise IOError, 'Unknown compression method'
167         flag = ord(data[3])
168         # modtime = self.fileobj.read(4)
169         # extraflag = self.fileobj.read(1)
170         # os = self.fileobj.read(1)
171
172         skip = 10
173         if flag & FEXTRA:
174             # Read & discard the extra field
175             xlen = ord(data[10])
176             xlen = xlen + 256*ord(data[11])
177             skip = skip + 2 + xlen
178         if flag & FNAME:
179             # Read and discard a null-terminated string containing the filename
180             while True:
181                 if not data[skip] or data[skip] == '\000':
182                     break
183                 skip += 1
184             skip += 1
185         if flag & FCOMMENT:
186             # Read and discard a null-terminated string containing a comment
187             while True:
188                 if not data[skip] or data[skip] == '\000':
189                     break
190                 skip += 1
191             skip += 1
192         if flag & FHCRC:
193             skip += 2     # Read & discard the 16-bit header CRC
194
195         return data[skip:]
196
197     def close(self):
198         """Clean everything up and return None to future reads."""
199         log.msg('ProxyFileStream was prematurely closed after only %d/%d bytes' % (self.hash.size, self.length))
200         if self.hash.size < self.length:
201             self._error(CacheError('Prematurely closed, all data was not written'))
202         elif not self.outFile.closed:
203             self._done()
204             self.doneDefer.callback(self.hash)
205         self.length = 0
206         self.stream.close()
207
208 class CacheManager:
209     """Manages all downloaded files and requests for cached objects.
210     
211     @type cache_dir: L{twisted.python.filepath.FilePath}
212     @ivar cache_dir: the directory to use for storing all files
213     @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
214     @ivar other_dirs: the other directories that have shared files in them
215     @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
216     @ivar all_dirs: all the directories that have cached files in them
217     @type db: L{db.DB}
218     @ivar db: the database to use for tracking files and hashes
219     @type manager: L{apt_p2p.AptP2P}
220     @ivar manager: the main program object to send requests to
221     @type scanning: C{list} of L{twisted.python.filepath.FilePath}
222     @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
223     """
224     
225     def __init__(self, cache_dir, db, manager = None):
226         """Initialize the instance and remove any untracked files from the DB..
227         
228         @type cache_dir: L{twisted.python.filepath.FilePath}
229         @param cache_dir: the directory to use for storing all files
230         @type db: L{db.DB}
231         @param db: the database to use for tracking files and hashes
232         @type manager: L{apt_p2p.AptP2P}
233         @param manager: the main program object to send requests to
234             (optional, defaults to not calling back with cached files)
235         """
236         self.cache_dir = cache_dir
237         self.other_dirs = [FilePath(f) for f in config.getstringlist('DEFAULT', 'OTHER_DIRS')]
238         self.all_dirs = self.other_dirs[:]
239         self.all_dirs.insert(0, self.cache_dir)
240         self.db = db
241         self.manager = manager
242         self.scanning = []
243         
244         # Init the database, remove old files
245         self.db.removeUntrackedFiles(self.all_dirs)
246         
247     #{ Scanning directories
248     def scanDirectories(self):
249         """Scan the cache directories, hashing new and rehashing changed files."""
250         assert not self.scanning, "a directory scan is already under way"
251         self.scanning = self.all_dirs[:]
252         self._scanDirectories()
253
254     def _scanDirectories(self, result = None, walker = None):
255         """Walk each directory looking for cached files.
256         
257         @param result: the result of a DHT store request, not used (optional)
258         @param walker: the walker to use to traverse the current directory
259             (optional, defaults to creating a new walker from the first
260             directory in the L{CacheManager.scanning} list)
261         """
262         # Need to start walking a new directory
263         if walker is None:
264             # If there are any left, get them
265             if self.scanning:
266                 log.msg('started scanning directory: %s' % self.scanning[0].path)
267                 walker = self.scanning[0].walk()
268             else:
269                 log.msg('cache directory scan complete')
270                 return
271             
272         try:
273             # Get the next file in the directory
274             file = walker.next()
275         except StopIteration:
276             # No files left, go to the next directory
277             log.msg('done scanning directory: %s' % self.scanning[0].path)
278             self.scanning.pop(0)
279             reactor.callLater(0, self._scanDirectories)
280             return
281
282         # If it's not a file ignore it
283         if not file.isfile():
284             log.msg('entering directory: %s' % file.path)
285             reactor.callLater(0, self._scanDirectories, None, walker)
286             return
287
288         # If it's already properly in the DB, ignore it
289         db_status = self.db.isUnchanged(file)
290         if db_status:
291             log.msg('file is unchanged: %s' % file.path)
292             reactor.callLater(0, self._scanDirectories, None, walker)
293             return
294         
295         # Don't hash files in the cache that are not in the DB
296         if self.scanning[0] == self.cache_dir:
297             if db_status is None:
298                 log.msg('ignoring unknown cache file: %s' % file.path)
299             else:
300                 log.msg('removing changed cache file: %s' % file.path)
301                 file.remove()
302             reactor.callLater(0, self._scanDirectories, None, walker)
303             return
304
305         # Otherwise hash it
306         log.msg('start hash checking file: %s' % file.path)
307         hash = HashObject()
308         df = hash.hashInThread(file)
309         df.addBoth(self._doneHashing, file, walker)
310     
311     def _doneHashing(self, result, file, walker):
312         """If successful, add the hashed file to the DB and inform the main program."""
313         if isinstance(result, HashObject):
314             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
315             
316             # Only set a URL if this is a downloaded file
317             url = None
318             if self.scanning[0] == self.cache_dir:
319                 url = 'http:/' + file.path[len(self.cache_dir.path):]
320                 
321             # Store the hashed file in the database
322             new_hash = self.db.storeFile(file, result.digest(), True,
323                                          ''.join(result.pieceDigests()))
324             
325             # Tell the main program to handle the new cache file
326             df = self.manager.new_cached_file(file, result, new_hash, url, True)
327             if df is None:
328                 reactor.callLater(0, self._scanDirectories, None, walker)
329             else:
330                 df.addBoth(self._scanDirectories, walker)
331         else:
332             # Must have returned an error
333             log.msg('hash check of %s failed' % file.path)
334             log.err(result)
335             reactor.callLater(0, self._scanDirectories, None, walker)
336
337     #{ Downloading files
338     def save_file(self, response, hash, url):
339         """Save a downloaded file to the cache and stream it.
340         
341         @type response: L{twisted.web2.http.Response}
342         @param response: the response from the download
343         @type hash: L{Hash.HashObject}
344         @param hash: the hash object containing the expected hash for the file
345         @param url: the URI of the actual mirror request
346         @rtype: L{twisted.web2.http.Response}
347         @return: the final response from the download
348         """
349         if response.code != 200:
350             log.msg('File was not found (%r): %s' % (response, url))
351             return response
352         
353         log.msg('Returning file: %s' % url)
354
355         # Set the destination path for the file
356         parsed = urlparse(url)
357         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
358         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
359         
360         # Make sure there's a free place for the file
361         if destFile.exists():
362             log.msg('File already exists, removing: %s' % destFile.path)
363             destFile.remove()
364         if not destFile.parent().exists():
365             destFile.parent().makedirs()
366
367         # Determine whether it needs to be decompressed and how
368         root, ext = os.path.splitext(destFile.basename())
369         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
370             ext = ext.lower()
371             decFile = destFile.sibling(root)
372             log.msg('Decompressing to: %s' % decFile.path)
373             if decFile.exists():
374                 log.msg('File already exists, removing: %s' % decFile.path)
375                 decFile.remove()
376         else:
377             ext = None
378             decFile = None
379             
380         # Create the new stream from the old one.
381         orig_stream = response.stream
382         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
383         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
384                                               response.headers.getHeader('Last-Modified'),
385                                               decFile)
386         response.stream.doneDefer.addErrback(self._save_error, url, destFile, decFile)
387
388         # Return the modified response with the new stream
389         return response
390
391     def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
392         """Update the modification time and inform the main program.
393         
394         @type hash: L{Hash.HashObject}
395         @param hash: the hash object containing the expected hash for the file
396         @param url: the URI of the actual mirror request
397         @type destFile: C{twisted.python.FilePath}
398         @param destFile: the file where the download was written to
399         @type modtime: C{int}
400         @param modtime: the modified time of the cached file (seconds since epoch)
401             (optional, defaults to not setting the modification time of the file)
402         @type decFile: C{twisted.python.FilePath}
403         @param decFile: the file where the decompressed download was written to
404             (optional, defaults to the file not having been compressed)
405         """
406         result = hash.verify()
407         if result or result is None:
408             if modtime:
409                 os.utime(destFile.path, (modtime, modtime))
410             
411             if result:
412                 log.msg('Hashes match: %s' % url)
413                 dht = True
414             else:
415                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
416                 dht = False
417                 
418             new_hash = self.db.storeFile(destFile, hash.digest(), dht,
419                                          ''.join(hash.pieceDigests()))
420
421             if self.manager:
422                 self.manager.new_cached_file(destFile, hash, new_hash, url)
423
424             if decFile:
425                 # Hash the decompressed file and add it to the DB
426                 decHash = HashObject()
427                 ext_len = len(destFile.path) - len(decFile.path)
428                 df = decHash.hashInThread(decFile)
429                 df.addCallback(self._save_complete, url[:-ext_len], decFile, modtime)
430                 df.addErrback(self._save_error, url[:-ext_len], decFile)
431         else:
432             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
433             destFile.remove()
434             if decFile:
435                 decFile.remove()
436
437     def _save_error(self, failure, url, destFile, decFile = None):
438         """Remove the destination files."""
439         log.msg('Error occurred downloading %s' % url)
440         log.err(failure)
441         destFile.restat(False)
442         if destFile.exists():
443             log.msg('Removing the incomplete file: %s' % destFile.path)
444             destFile.remove()
445         if decFile:
446             decFile.restat(False)
447             if decFile.exists():
448                 log.msg('Removing the incomplete file: %s' % decFile.path)
449                 decFile.remove()
450
451     def save_error(self, failure, url):
452         """An error has occurred in downloading or saving the file"""
453         log.msg('Error occurred downloading %s' % url)
454         log.err(failure)
455         return failure
456
457 class TestMirrorManager(unittest.TestCase):
458     """Unit tests for the mirror manager."""
459     
460     timeout = 20
461     pending_calls = []
462     client = None
463     
464     def setUp(self):
465         self.client = CacheManager(FilePath('/tmp/.apt-p2p'))
466         
467     def tearDown(self):
468         for p in self.pending_calls:
469             if p.active():
470                 p.cancel()
471         self.client = None
472