a990768dfc3d7f7d906ace58b67f53f303bef1b8
[quix0rs-apt-p2p.git] / apt_p2p / CacheManager.py
1
2 """Manage a cache of downloaded files.
3
4 @var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
5 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
6 """
7
8 from bz2 import BZ2Decompressor
9 from zlib import decompressobj, MAX_WBITS
10 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
11 from urlparse import urlparse
12 import os
13
14 from twisted.python import log
15 from twisted.python.filepath import FilePath
16 from twisted.internet import defer, reactor
17 from twisted.trial import unittest
18 from twisted.web2 import stream
19 from twisted.web2.http import splitHostPort
20
21 from Hash import HashObject
22
23 DECOMPRESS_EXTS = ['.gz', '.bz2']
24 DECOMPRESS_FILES = ['release', 'sources', 'packages']
25
26 class ProxyFileStream(stream.SimpleStream):
27     """Saves a stream to a file while providing a new stream.
28     
29     Also optionally decompresses the file while it is being downloaded.
30     
31     @type stream: L{twisted.web2.stream.IByteStream}
32     @ivar stream: the input stream being read
33     @type outFile: L{twisted.python.filepath.FilePath}
34     @ivar outFile: the file being written
35     @type hash: L{Hash.HashObject}
36     @ivar hash: the hash object for the file
37     @type gzfile: C{file}
38     @ivar gzfile: the open file to write decompressed gzip data to
39     @type gzdec: L{zlib.decompressobj}
40     @ivar gzdec: the decompressor to use for the compressed gzip data
41     @type gzheader: C{boolean}
42     @ivar gzheader: whether the gzip header still needs to be removed from
43         the zlib compressed data
44     @type bz2file: C{file}
45     @ivar bz2file: the open file to write decompressed bz2 data to
46     @type bz2dec: L{bz2.BZ2Decompressor}
47     @ivar bz2dec: the decompressor to use for the compressed bz2 data
48     @type length: C{int}
49     @ivar length: the length of the original (compressed) file
50     @type doneDefer: L{twisted.internet.defer.Deferred}
51     @ivar doneDefer: the deferred that will fire when done streaming
52     
53     @group Stream implementation: read, close
54     
55     """
56     
57     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
58         """Initializes the proxy.
59         
60         @type stream: L{twisted.web2.stream.IByteStream}
61         @param stream: the input stream to read from
62         @type outFile: L{twisted.python.filepath.FilePath}
63         @param outFile: the file to write to
64         @type hash: L{Hash.HashObject}
65         @param hash: the hash object to use for the file
66         @type decompress: C{string}
67         @param decompress: also decompress the file as this type
68             (currently only '.gz' and '.bz2' are supported)
69         @type decFile: C{twisted.python.FilePath}
70         @param decFile: the file to write the decompressed data to
71         """
72         self.stream = stream
73         self.outFile = outFile.open('w')
74         self.hash = hash
75         self.hash.new()
76         self.gzfile = None
77         self.bz2file = None
78         if decompress == ".gz":
79             self.gzheader = True
80             self.gzfile = decFile.open('w')
81             self.gzdec = decompressobj(-MAX_WBITS)
82         elif decompress == ".bz2":
83             self.bz2file = decFile.open('w')
84             self.bz2dec = BZ2Decompressor()
85         self.length = self.stream.length
86         self.doneDefer = defer.Deferred()
87
88     def _done(self):
89         """Close all the output files, return the result."""
90         if not self.outFile.closed:
91             self.outFile.close()
92             self.hash.digest()
93             if self.gzfile:
94                 # Finish the decompression
95                 data_dec = self.gzdec.flush()
96                 self.gzfile.write(data_dec)
97                 self.gzfile.close()
98                 self.gzfile = None
99             if self.bz2file:
100                 self.bz2file.close()
101                 self.bz2file = None
102                 
103             self.doneDefer.callback(self.hash)
104     
105     def read(self):
106         """Read some data from the stream."""
107         if self.outFile.closed:
108             return None
109         
110         # Read data from the stream, deal with the possible deferred
111         data = self.stream.read()
112         if isinstance(data, defer.Deferred):
113             data.addCallbacks(self._write, self._done)
114             return data
115         
116         self._write(data)
117         return data
118     
119     def _write(self, data):
120         """Write the stream data to the file and return it for others to use.
121         
122         Also optionally decompresses it.
123         """
124         if data is None:
125             self._done()
126             return data
127         
128         # Write and hash the streamed data
129         self.outFile.write(data)
130         self.hash.update(data)
131         
132         if self.gzfile:
133             # Decompress the zlib portion of the file
134             if self.gzheader:
135                 # Remove the gzip header junk
136                 self.gzheader = False
137                 new_data = self._remove_gzip_header(data)
138                 dec_data = self.gzdec.decompress(new_data)
139             else:
140                 dec_data = self.gzdec.decompress(data)
141             self.gzfile.write(dec_data)
142         if self.bz2file:
143             # Decompress the bz2 file
144             dec_data = self.bz2dec.decompress(data)
145             self.bz2file.write(dec_data)
146
147         return data
148     
149     def _remove_gzip_header(self, data):
150         """Remove the gzip header from the zlib compressed data."""
151         # Read, check & discard the header fields
152         if data[:2] != '\037\213':
153             raise IOError, 'Not a gzipped file'
154         if ord(data[2]) != 8:
155             raise IOError, 'Unknown compression method'
156         flag = ord(data[3])
157         # modtime = self.fileobj.read(4)
158         # extraflag = self.fileobj.read(1)
159         # os = self.fileobj.read(1)
160
161         skip = 10
162         if flag & FEXTRA:
163             # Read & discard the extra field
164             xlen = ord(data[10])
165             xlen = xlen + 256*ord(data[11])
166             skip = skip + 2 + xlen
167         if flag & FNAME:
168             # Read and discard a null-terminated string containing the filename
169             while True:
170                 if not data[skip] or data[skip] == '\000':
171                     break
172                 skip += 1
173             skip += 1
174         if flag & FCOMMENT:
175             # Read and discard a null-terminated string containing a comment
176             while True:
177                 if not data[skip] or data[skip] == '\000':
178                     break
179                 skip += 1
180             skip += 1
181         if flag & FHCRC:
182             skip += 2     # Read & discard the 16-bit header CRC
183
184         return data[skip:]
185
186     def close(self):
187         """Clean everything up and return None to future reads."""
188         self.length = 0
189         self._done()
190         self.stream.close()
191
192 class CacheManager:
193     """Manages all downloaded files and requests for cached objects.
194     
195     @type cache_dir: L{twisted.python.filepath.FilePath}
196     @ivar cache_dir: the directory to use for storing all files
197     @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
198     @ivar other_dirs: the other directories that have shared files in them
199     @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
200     @ivar all_dirs: all the directories that have cached files in them
201     @type db: L{db.DB}
202     @ivar db: the database to use for tracking files and hashes
203     @type manager: L{apt_p2p.AptP2P}
204     @ivar manager: the main program object to send requests to
205     @type scanning: C{list} of L{twisted.python.filepath.FilePath}
206     @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
207     """
208     
209     def __init__(self, cache_dir, db, other_dirs = [], manager = None):
210         """Initialize the instance and remove any untracked files from the DB..
211         
212         @type cache_dir: L{twisted.python.filepath.FilePath}
213         @param cache_dir: the directory to use for storing all files
214         @type db: L{db.DB}
215         @param db: the database to use for tracking files and hashes
216         @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
217         @param other_dirs: the other directories that have shared files in them
218             (optional, defaults to only using the cache directory)
219         @type manager: L{apt_p2p.AptP2P}
220         @param manager: the main program object to send requests to
221             (optional, defaults to not calling back with cached files)
222         """
223         self.cache_dir = cache_dir
224         self.other_dirs = other_dirs
225         self.all_dirs = self.other_dirs[:]
226         self.all_dirs.insert(0, self.cache_dir)
227         self.db = db
228         self.manager = manager
229         self.scanning = []
230         
231         # Init the database, remove old files
232         self.db.removeUntrackedFiles(self.all_dirs)
233         
234     #{ Scanning directories
235     def scanDirectories(self):
236         """Scan the cache directories, hashing new and rehashing changed files."""
237         assert not self.scanning, "a directory scan is already under way"
238         self.scanning = self.all_dirs[:]
239         self._scanDirectories()
240
241     def _scanDirectories(self, result = None, walker = None):
242         """Walk each directory looking for cached files.
243         
244         @param result: the result of a DHT store request, not used (optional)
245         @param walker: the walker to use to traverse the current directory
246             (optional, defaults to creating a new walker from the first
247             directory in the L{CacheManager.scanning} list)
248         """
249         # Need to start walking a new directory
250         if walker is None:
251             # If there are any left, get them
252             if self.scanning:
253                 log.msg('started scanning directory: %s' % self.scanning[0].path)
254                 walker = self.scanning[0].walk()
255             else:
256                 log.msg('cache directory scan complete')
257                 return
258             
259         try:
260             # Get the next file in the directory
261             file = walker.next()
262         except StopIteration:
263             # No files left, go to the next directory
264             log.msg('done scanning directory: %s' % self.scanning[0].path)
265             self.scanning.pop(0)
266             reactor.callLater(0, self._scanDirectories)
267             return
268
269         # If it's not a file ignore it
270         if not file.isfile():
271             log.msg('entering directory: %s' % file.path)
272             reactor.callLater(0, self._scanDirectories, None, walker)
273             return
274
275         # If it's already properly in the DB, ignore it
276         db_status = self.db.isUnchanged(file)
277         if db_status:
278             log.msg('file is unchanged: %s' % file.path)
279             reactor.callLater(0, self._scanDirectories, None, walker)
280             return
281         
282         # Don't hash files in the cache that are not in the DB
283         if self.scanning[0] == self.cache_dir:
284             if db_status is None:
285                 log.msg('ignoring unknown cache file: %s' % file.path)
286             else:
287                 log.msg('removing changed cache file: %s' % file.path)
288                 file.remove()
289             reactor.callLater(0, self._scanDirectories, None, walker)
290             return
291
292         # Otherwise hash it
293         log.msg('start hash checking file: %s' % file.path)
294         hash = HashObject()
295         df = hash.hashInThread(file)
296         df.addBoth(self._doneHashing, file, walker)
297         df.addErrback(log.err)
298     
299     def _doneHashing(self, result, file, walker):
300         """If successful, add the hashed file to the DB and inform the main program."""
301         if isinstance(result, HashObject):
302             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
303             
304             # Only set a URL if this is a downloaded file
305             url = None
306             if self.scanning[0] == self.cache_dir:
307                 url = 'http:/' + file.path[len(self.cache_dir.path):]
308                 
309             # Store the hashed file in the database
310             new_hash = self.db.storeFile(file, result.digest(),
311                                          ''.join(result.pieceDigests()))
312             
313             # Tell the main program to handle the new cache file
314             df = self.manager.new_cached_file(file, result, new_hash, url, True)
315             if df is None:
316                 reactor.callLater(0, self._scanDirectories, None, walker)
317             else:
318                 df.addBoth(self._scanDirectories, walker)
319         else:
320             # Must have returned an error
321             log.msg('hash check of %s failed' % file.path)
322             log.err(result)
323             reactor.callLater(0, self._scanDirectories, None, walker)
324
325     #{ Downloading files
326     def save_file(self, response, hash, url):
327         """Save a downloaded file to the cache and stream it.
328         
329         @type response: L{twisted.web2.http.Response}
330         @param response: the response from the download
331         @type hash: L{Hash.HashObject}
332         @param hash: the hash object containing the expected hash for the file
333         @param url: the URI of the actual mirror request
334         @rtype: L{twisted.web2.http.Response}
335         @return: the final response from the download
336         """
337         if response.code != 200:
338             log.msg('File was not found (%r): %s' % (response, url))
339             return response
340         
341         log.msg('Returning file: %s' % url)
342
343         # Set the destination path for the file
344         parsed = urlparse(url)
345         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
346         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
347         
348         # Make sure there's a free place for the file
349         if destFile.exists():
350             log.msg('File already exists, removing: %s' % destFile.path)
351             destFile.remove()
352         elif not destFile.parent().exists():
353             destFile.parent().makedirs()
354
355         # Determine whether it needs to be decompressed and how
356         root, ext = os.path.splitext(destFile.basename())
357         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
358             ext = ext.lower()
359             decFile = destFile.sibling(root)
360             log.msg('Decompressing to: %s' % decFile.path)
361             if decFile.exists():
362                 log.msg('File already exists, removing: %s' % decFile.path)
363                 decFile.remove()
364         else:
365             ext = None
366             decFile = None
367             
368         # Create the new stream from the old one.
369         orig_stream = response.stream
370         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
371         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
372                                               response.headers.getHeader('Last-Modified'),
373                                               decFile)
374         response.stream.doneDefer.addErrback(self.save_error, url)
375
376         # Return the modified response with the new stream
377         return response
378
379     def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
380         """Update the modification time and inform the main program.
381         
382         @type hash: L{Hash.HashObject}
383         @param hash: the hash object containing the expected hash for the file
384         @param url: the URI of the actual mirror request
385         @type destFile: C{twisted.python.FilePath}
386         @param destFile: the file where the download was written to
387         @type modtime: C{int}
388         @param modtime: the modified time of the cached file (seconds since epoch)
389             (optional, defaults to not setting the modification time of the file)
390         @type decFile: C{twisted.python.FilePath}
391         @param decFile: the file where the decompressed download was written to
392             (optional, defaults to the file not having been compressed)
393         """
394         if modtime:
395             os.utime(destFile.path, (modtime, modtime))
396             if decFile:
397                 os.utime(decFile.path, (modtime, modtime))
398         
399         result = hash.verify()
400         if result or result is None:
401             if result:
402                 log.msg('Hashes match: %s' % url)
403             else:
404                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
405                 
406             new_hash = self.db.storeFile(destFile, hash.digest(),
407                                          ''.join(hash.pieceDigests()))
408             log.msg('now avaliable: %s' % (url))
409
410             if self.manager:
411                 self.manager.new_cached_file(destFile, hash, new_hash, url)
412                 if decFile:
413                     ext_len = len(destFile.path) - len(decFile.path)
414                     self.manager.new_cached_file(decFile, None, False, url[:-ext_len])
415         else:
416             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
417             destFile.remove()
418             if decFile:
419                 decFile.remove()
420
421     def save_error(self, failure, url):
422         """An error has occurred in downloadign or saving the file."""
423         log.msg('Error occurred downloading %s' % url)
424         log.err(failure)
425         return failure
426
427 class TestMirrorManager(unittest.TestCase):
428     """Unit tests for the mirror manager."""
429     
430     timeout = 20
431     pending_calls = []
432     client = None
433     
434     def setUp(self):
435         self.client = CacheManager(FilePath('/tmp/.apt-p2p'))
436         
437     def tearDown(self):
438         for p in self.pending_calls:
439             if p.active():
440                 p.cancel()
441         self.client = None
442