a0f6b6b48b4fa044252b8bab7fe64afb81a4b453
[quix0rs-apt-p2p.git] / apt_dht / CacheManager.py
1
2 """Manage a cache of downloaded files.
3
4 @var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
5 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
6 """
7
8 from bz2 import BZ2Decompressor
9 from zlib import decompressobj, MAX_WBITS
10 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
11 from urlparse import urlparse
12 import os
13
14 from twisted.python import log
15 from twisted.python.filepath import FilePath
16 from twisted.internet import defer, reactor
17 from twisted.trial import unittest
18 from twisted.web2 import stream
19 from twisted.web2.http import splitHostPort
20
21 from Hash import HashObject
22
23 DECOMPRESS_EXTS = ['.gz', '.bz2']
24 DECOMPRESS_FILES = ['release', 'sources', 'packages']
25
26 class ProxyFileStream(stream.SimpleStream):
27     """Saves a stream to a file while providing a new stream.
28     
29     Also optionally decompresses the file while it is being downloaded.
30     
31     @type stream: L{twisted.web2.stream.IByteStream}
32     @ivar stream: the input stream being read
33     @type outFile: L{twisted.python.filepath.FilePath}
34     @ivar outFile: the file being written
35     @type hash: L{Hash.HashObject}
36     @ivar hash: the hash object for the file
37     @type gzfile: C{file}
38     @ivar gzfile: the open file to write decompressed gzip data to
39     @type gzdec: L{zlib.decompressobj}
40     @ivar gzdec: the decompressor to use for the compressed gzip data
41     @type gzheader: C{boolean}
42     @ivar gzheader: whether the gzip header still needs to be removed from
43         the zlib compressed data
44     @type bz2file: C{file}
45     @ivar bz2file: the open file to write decompressed bz2 data to
46     @type bz2dec: L{bz2.BZ2Decompressor}
47     @ivar bz2dec: the decompressor to use for the compressed bz2 data
48     @type length: C{int}
49     @ivar length: the length of the original (compressed) file
50     @type doneDefer: L{twisted.internet.defer.Deferred}
51     @ivar doneDefer: the deferred that will fire when done streaming
52     
53     @group Stream implementation: read, close
54     
55     """
56     
57     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
58         """Initializes the proxy.
59         
60         @type stream: L{twisted.web2.stream.IByteStream}
61         @param stream: the input stream to read from
62         @type outFile: L{twisted.python.filepath.FilePath}
63         @param outFile: the file to write to
64         @type hash: L{Hash.HashObject}
65         @param hash: the hash object to use for the file
66         @type decompress: C{string}
67         @param decompress: also decompress the file as this type
68             (currently only '.gz' and '.bz2' are supported)
69         @type decFile: C{twisted.python.FilePath}
70         @param decFile: the file to write the decompressed data to
71         """
72         self.stream = stream
73         self.outFile = outFile.open('w')
74         self.hash = hash
75         self.hash.new()
76         self.gzfile = None
77         self.bz2file = None
78         if decompress == ".gz":
79             self.gzheader = True
80             self.gzfile = decFile.open('w')
81             self.gzdec = decompressobj(-MAX_WBITS)
82         elif decompress == ".bz2":
83             self.bz2file = decFile.open('w')
84             self.bz2dec = BZ2Decompressor()
85         self.length = self.stream.length
86         self.doneDefer = defer.Deferred()
87
88     def _done(self):
89         """Close all the output files, return the result."""
90         if not self.outFile.closed:
91             self.outFile.close()
92             self.hash.digest()
93             if self.gzfile:
94                 # Finish the decompression
95                 data_dec = self.gzdec.flush()
96                 self.gzfile.write(data_dec)
97                 self.gzfile.close()
98                 self.gzfile = None
99             if self.bz2file:
100                 self.bz2file.close()
101                 self.bz2file = None
102                 
103             self.doneDefer.callback(self.hash)
104     
105     def read(self):
106         """Read some data from the stream."""
107         if self.outFile.closed:
108             return None
109         
110         # Read data from the stream, deal with the possible deferred
111         data = self.stream.read()
112         if isinstance(data, defer.Deferred):
113             data.addCallbacks(self._write, self._done)
114             return data
115         
116         self._write(data)
117         return data
118     
119     def _write(self, data):
120         """Write the stream data to the file and return it for others to use.
121         
122         Also optionally decompresses it.
123         """
124         if data is None:
125             self._done()
126             return data
127         
128         # Write and hash the streamed data
129         self.outFile.write(data)
130         self.hash.update(data)
131         
132         if self.gzfile:
133             # Decompress the zlib portion of the file
134             if self.gzheader:
135                 # Remove the gzip header junk
136                 self.gzheader = False
137                 new_data = self._remove_gzip_header(data)
138                 dec_data = self.gzdec.decompress(new_data)
139             else:
140                 dec_data = self.gzdec.decompress(data)
141             self.gzfile.write(dec_data)
142         if self.bz2file:
143             # Decompress the bz2 file
144             dec_data = self.bz2dec.decompress(data)
145             self.bz2file.write(dec_data)
146
147         return data
148     
149     def _remove_gzip_header(self, data):
150         """Remove the gzip header from the zlib compressed data."""
151         # Read, check & discard the header fields
152         if data[:2] != '\037\213':
153             raise IOError, 'Not a gzipped file'
154         if ord(data[2]) != 8:
155             raise IOError, 'Unknown compression method'
156         flag = ord(data[3])
157         # modtime = self.fileobj.read(4)
158         # extraflag = self.fileobj.read(1)
159         # os = self.fileobj.read(1)
160
161         skip = 10
162         if flag & FEXTRA:
163             # Read & discard the extra field
164             xlen = ord(data[10])
165             xlen = xlen + 256*ord(data[11])
166             skip = skip + 2 + xlen
167         if flag & FNAME:
168             # Read and discard a null-terminated string containing the filename
169             while True:
170                 if not data[skip] or data[skip] == '\000':
171                     break
172                 skip += 1
173             skip += 1
174         if flag & FCOMMENT:
175             # Read and discard a null-terminated string containing a comment
176             while True:
177                 if not data[skip] or data[skip] == '\000':
178                     break
179                 skip += 1
180             skip += 1
181         if flag & FHCRC:
182             skip += 2     # Read & discard the 16-bit header CRC
183
184         return data[skip:]
185
186     def close(self):
187         """Clean everything up and return None to future reads."""
188         self.length = 0
189         self._done()
190         self.stream.close()
191
192 class CacheManager:
193     """Manages all downloaded files and requests for cached objects.
194     
195     @type cache_dir: L{twisted.python.filepath.FilePath}
196     @ivar cache_dir: the directory to use for storing all files
197     @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
198     @ivar other_dirs: the other directories that have shared files in them
199     @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
200     @ivar all_dirs: all the directories that have cached files in them
201     @type db: L{db.DB}
202     @ivar db: the database to use for tracking files and hashes
203     @type manager: L{apt_dht.AptDHT}
204     @ivar manager: the main program object to send requests to
205     @type scanning: C{list} of L{twisted.python.filepath.FilePath}
206     @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
207     """
208     
209     def __init__(self, cache_dir, db, other_dirs = [], manager = None):
210         """Initialize the instance and remove any untracked files from the DB..
211         
212         @type cache_dir: L{twisted.python.filepath.FilePath}
213         @param cache_dir: the directory to use for storing all files
214         @type db: L{db.DB}
215         @param db: the database to use for tracking files and hashes
216         @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
217         @param other_dirs: the other directories that have shared files in them
218             (optional, defaults to only using the cache directory)
219         @type manager: L{apt_dht.AptDHT}
220         @param manager: the main program object to send requests to
221             (optional, defaults to not calling back with cached files)
222         """
223         self.cache_dir = cache_dir
224         self.other_dirs = other_dirs
225         self.all_dirs = self.other_dirs[:]
226         self.all_dirs.insert(0, self.cache_dir)
227         self.db = db
228         self.manager = manager
229         self.scanning = []
230         
231         # Init the database, remove old files
232         self.db.removeUntrackedFiles(self.all_dirs)
233         
234     #{ Scanning directories
235     def scanDirectories(self):
236         """Scan the cache directories, hashing new and rehashing changed files."""
237         assert not self.scanning, "a directory scan is already under way"
238         self.scanning = self.all_dirs[:]
239         self._scanDirectories()
240
241     def _scanDirectories(self, result = None, walker = None):
242         """Walk each directory looking for cached files.
243         
244         @param result: the result of a DHT store request, not used (optional)
245         @param walker: the walker to use to traverse the current directory
246             (optional, defaults to creating a new walker from the first
247             directory in the L{CacheManager.scanning} list)
248         """
249         # Need to start walking a new directory
250         if walker is None:
251             # If there are any left, get them
252             if self.scanning:
253                 log.msg('started scanning directory: %s' % self.scanning[0].path)
254                 walker = self.scanning[0].walk()
255             else:
256                 log.msg('cache directory scan complete')
257                 return
258             
259         try:
260             # Get the next file in the directory
261             file = walker.next()
262         except StopIteration:
263             # No files left, go to the next directory
264             log.msg('done scanning directory: %s' % self.scanning[0].path)
265             self.scanning.pop(0)
266             reactor.callLater(0, self._scanDirectories)
267             return
268
269         # If it's not a file ignore it
270         if not file.isfile():
271             log.msg('entering directory: %s' % file.path)
272             reactor.callLater(0, self._scanDirectories, None, walker)
273             return
274
275         # If it's already properly in the DB, ignore it
276         db_status = self.db.isUnchanged(file)
277         if db_status:
278             log.msg('file is unchanged: %s' % file.path)
279             reactor.callLater(0, self._scanDirectories, None, walker)
280             return
281         
282         # Don't hash files in the cache that are not in the DB
283         if self.scanning[0] == self.cache_dir:
284             if db_status is None:
285                 log.msg('ignoring unknown cache file: %s' % file.path)
286             else:
287                 log.msg('removing changed cache file: %s' % file.path)
288                 file.remove()
289             reactor.callLater(0, self._scanDirectories, None, walker)
290             return
291
292         # Otherwise hash it
293         log.msg('start hash checking file: %s' % file.path)
294         hash = HashObject()
295         df = hash.hashInThread(file)
296         df.addBoth(self._doneHashing, file, walker)
297         df.addErrback(log.err)
298     
299     def _doneHashing(self, result, file, walker):
300         """If successful, add the hashed file to the DB and inform the main program."""
301         if isinstance(result, HashObject):
302             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
303             
304             # Only set a URL if this is a downloaded file
305             url = None
306             if self.scanning[0] == self.cache_dir:
307                 url = 'http:/' + file.path[len(self.cache_dir.path):]
308                 
309             # Store the hashed file in the database
310             new_hash = self.db.storeFile(file, result.digest())
311             
312             # Tell the main program to handle the new cache file
313             df = self.manager.new_cached_file(file, result, new_hash, url, True)
314             if df is None:
315                 reactor.callLater(0, self._scanDirectories, None, walker)
316             else:
317                 df.addBoth(self._scanDirectories, walker)
318         else:
319             # Must have returned an error
320             log.msg('hash check of %s failed' % file.path)
321             log.err(result)
322             reactor.callLater(0, self._scanDirectories, None, walker)
323
324     #{ Downloading files
325     def save_file(self, response, hash, url):
326         """Save a downloaded file to the cache and stream it.
327         
328         @type response: L{twisted.web2.http.Response}
329         @param response: the response from the download
330         @type hash: L{Hash.HashObject}
331         @param hash: the hash object containing the expected hash for the file
332         @param url: the URI of the actual mirror request
333         @rtype: L{twisted.web2.http.Response}
334         @return: the final response from the download
335         """
336         if response.code != 200:
337             log.msg('File was not found (%r): %s' % (response, url))
338             return response
339         
340         log.msg('Returning file: %s' % url)
341
342         # Set the destination path for the file
343         parsed = urlparse(url)
344         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
345         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
346         
347         # Make sure there's a free place for the file
348         if destFile.exists():
349             log.msg('File already exists, removing: %s' % destFile.path)
350             destFile.remove()
351         elif not destFile.parent().exists():
352             destFile.parent().makedirs()
353
354         # Determine whether it needs to be decompressed and how
355         root, ext = os.path.splitext(destFile.basename())
356         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
357             ext = ext.lower()
358             decFile = destFile.sibling(root)
359             log.msg('Decompressing to: %s' % decFile.path)
360             if decFile.exists():
361                 log.msg('File already exists, removing: %s' % decFile.path)
362                 decFile.remove()
363         else:
364             ext = None
365             decFile = None
366             
367         # Create the new stream from the old one.
368         orig_stream = response.stream
369         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
370         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
371                                               response.headers.getHeader('Last-Modified'),
372                                               decFile)
373         response.stream.doneDefer.addErrback(self.save_error, url)
374
375         # Return the modified response with the new stream
376         return response
377
378     def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
379         """Update the modification time and inform the main program.
380         
381         @type hash: L{Hash.HashObject}
382         @param hash: the hash object containing the expected hash for the file
383         @param url: the URI of the actual mirror request
384         @type destFile: C{twisted.python.FilePath}
385         @param destFile: the file where the download was written to
386         @type modtime: C{int}
387         @param modtime: the modified time of the cached file (seconds since epoch)
388             (optional, defaults to not setting the modification time of the file)
389         @type decFile: C{twisted.python.FilePath}
390         @param decFile: the file where the decompressed download was written to
391             (optional, defaults to the file not having been compressed)
392         """
393         if modtime:
394             os.utime(destFile.path, (modtime, modtime))
395             if decFile:
396                 os.utime(decFile.path, (modtime, modtime))
397         
398         result = hash.verify()
399         if result or result is None:
400             if result:
401                 log.msg('Hashes match: %s' % url)
402             else:
403                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
404                 
405             new_hash = self.db.storeFile(destFile, hash.digest())
406             log.msg('now avaliable: %s' % (url))
407
408             if self.manager:
409                 self.manager.new_cached_file(destFile, hash, new_hash, url)
410                 if decFile:
411                     ext_len = len(destFile.path) - len(decFile.path)
412                     self.manager.new_cached_file(decFile, None, False, url[:-ext_len])
413         else:
414             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
415             destFile.remove()
416             if decFile:
417                 decFile.remove()
418
419     def save_error(self, failure, url):
420         """An error has occurred in downloadign or saving the file."""
421         log.msg('Error occurred downloading %s' % url)
422         log.err(failure)
423         return failure
424
425 class TestMirrorManager(unittest.TestCase):
426     """Unit tests for the mirror manager."""
427     
428     timeout = 20
429     pending_calls = []
430     client = None
431     
432     def setUp(self):
433         self.client = CacheManager(FilePath('/tmp/.apt-dht'))
434         
435     def tearDown(self):
436         for p in self.pending_calls:
437             if p.active():
438                 p.cancel()
439         self.client = None
440