]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_p2p/CacheManager.py
24c821eda49a9d220bec462c0914122500c39b1e
[quix0rs-apt-p2p.git] / apt_p2p / CacheManager.py
1
2 """Manage a cache of downloaded files.
3
4 @var DECOMPRESS_EXTS: a list of file extensions that need to be decompressed
5 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
6 """
7
8 from bz2 import BZ2Decompressor
9 from zlib import decompressobj, MAX_WBITS
10 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
11 from urlparse import urlparse
12 import os
13
14 from twisted.python import log
15 from twisted.python.filepath import FilePath
16 from twisted.internet import defer, reactor
17 from twisted.trial import unittest
18 from twisted.web2 import stream
19 from twisted.web2.http import splitHostPort
20
21 from Hash import HashObject
22 from apt_p2p_conf import config
23
24 DECOMPRESS_EXTS = ['.gz', '.bz2']
25 DECOMPRESS_FILES = ['release', 'sources', 'packages']
26
27 class ProxyFileStream(stream.SimpleStream):
28     """Saves a stream to a file while providing a new stream.
29     
30     Also optionally decompresses the file while it is being downloaded.
31     
32     @type stream: L{twisted.web2.stream.IByteStream}
33     @ivar stream: the input stream being read
34     @type outFile: L{twisted.python.filepath.FilePath}
35     @ivar outFile: the file being written
36     @type hash: L{Hash.HashObject}
37     @ivar hash: the hash object for the file
38     @type gzfile: C{file}
39     @ivar gzfile: the open file to write decompressed gzip data to
40     @type gzdec: L{zlib.decompressobj}
41     @ivar gzdec: the decompressor to use for the compressed gzip data
42     @type gzheader: C{boolean}
43     @ivar gzheader: whether the gzip header still needs to be removed from
44         the zlib compressed data
45     @type bz2file: C{file}
46     @ivar bz2file: the open file to write decompressed bz2 data to
47     @type bz2dec: L{bz2.BZ2Decompressor}
48     @ivar bz2dec: the decompressor to use for the compressed bz2 data
49     @type length: C{int}
50     @ivar length: the length of the original (compressed) file
51     @type doneDefer: L{twisted.internet.defer.Deferred}
52     @ivar doneDefer: the deferred that will fire when done streaming
53     
54     @group Stream implementation: read, close
55     
56     """
57     
58     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
59         """Initializes the proxy.
60         
61         @type stream: L{twisted.web2.stream.IByteStream}
62         @param stream: the input stream to read from
63         @type outFile: L{twisted.python.filepath.FilePath}
64         @param outFile: the file to write to
65         @type hash: L{Hash.HashObject}
66         @param hash: the hash object to use for the file
67         @type decompress: C{string}
68         @param decompress: also decompress the file as this type
69             (currently only '.gz' and '.bz2' are supported)
70         @type decFile: C{twisted.python.FilePath}
71         @param decFile: the file to write the decompressed data to
72         """
73         self.stream = stream
74         self.outFile = outFile.open('w')
75         self.hash = hash
76         self.hash.new()
77         self.gzfile = None
78         self.bz2file = None
79         if decompress == ".gz":
80             self.gzheader = True
81             self.gzfile = decFile.open('w')
82             self.gzdec = decompressobj(-MAX_WBITS)
83         elif decompress == ".bz2":
84             self.bz2file = decFile.open('w')
85             self.bz2dec = BZ2Decompressor()
86         self.length = self.stream.length
87         self.doneDefer = defer.Deferred()
88
89     def _done(self):
90         """Close all the output files, return the result."""
91         if not self.outFile.closed:
92             self.outFile.close()
93             self.hash.digest()
94             if self.gzfile:
95                 # Finish the decompression
96                 data_dec = self.gzdec.flush()
97                 self.gzfile.write(data_dec)
98                 self.gzfile.close()
99                 self.gzfile = None
100             if self.bz2file:
101                 self.bz2file.close()
102                 self.bz2file = None
103                 
104             self.doneDefer.callback(self.hash)
105     
106     def read(self):
107         """Read some data from the stream."""
108         if self.outFile.closed:
109             return None
110         
111         # Read data from the stream, deal with the possible deferred
112         data = self.stream.read()
113         if isinstance(data, defer.Deferred):
114             data.addCallbacks(self._write, self._done)
115             return data
116         
117         self._write(data)
118         return data
119     
120     def _write(self, data):
121         """Write the stream data to the file and return it for others to use.
122         
123         Also optionally decompresses it.
124         """
125         if data is None:
126             self._done()
127             return data
128         
129         # Write and hash the streamed data
130         self.outFile.write(data)
131         self.hash.update(data)
132         
133         if self.gzfile:
134             # Decompress the zlib portion of the file
135             if self.gzheader:
136                 # Remove the gzip header junk
137                 self.gzheader = False
138                 new_data = self._remove_gzip_header(data)
139                 dec_data = self.gzdec.decompress(new_data)
140             else:
141                 dec_data = self.gzdec.decompress(data)
142             self.gzfile.write(dec_data)
143         if self.bz2file:
144             # Decompress the bz2 file
145             dec_data = self.bz2dec.decompress(data)
146             self.bz2file.write(dec_data)
147
148         return data
149     
150     def _remove_gzip_header(self, data):
151         """Remove the gzip header from the zlib compressed data."""
152         # Read, check & discard the header fields
153         if data[:2] != '\037\213':
154             raise IOError, 'Not a gzipped file'
155         if ord(data[2]) != 8:
156             raise IOError, 'Unknown compression method'
157         flag = ord(data[3])
158         # modtime = self.fileobj.read(4)
159         # extraflag = self.fileobj.read(1)
160         # os = self.fileobj.read(1)
161
162         skip = 10
163         if flag & FEXTRA:
164             # Read & discard the extra field
165             xlen = ord(data[10])
166             xlen = xlen + 256*ord(data[11])
167             skip = skip + 2 + xlen
168         if flag & FNAME:
169             # Read and discard a null-terminated string containing the filename
170             while True:
171                 if not data[skip] or data[skip] == '\000':
172                     break
173                 skip += 1
174             skip += 1
175         if flag & FCOMMENT:
176             # Read and discard a null-terminated string containing a comment
177             while True:
178                 if not data[skip] or data[skip] == '\000':
179                     break
180                 skip += 1
181             skip += 1
182         if flag & FHCRC:
183             skip += 2     # Read & discard the 16-bit header CRC
184
185         return data[skip:]
186
187     def close(self):
188         """Clean everything up and return None to future reads."""
189         self.length = 0
190         self._done()
191         self.stream.close()
192
193 class CacheManager:
194     """Manages all downloaded files and requests for cached objects.
195     
196     @type cache_dir: L{twisted.python.filepath.FilePath}
197     @ivar cache_dir: the directory to use for storing all files
198     @type other_dirs: C{list} of L{twisted.python.filepath.FilePath}
199     @ivar other_dirs: the other directories that have shared files in them
200     @type all_dirs: C{list} of L{twisted.python.filepath.FilePath}
201     @ivar all_dirs: all the directories that have cached files in them
202     @type db: L{db.DB}
203     @ivar db: the database to use for tracking files and hashes
204     @type manager: L{apt_p2p.AptP2P}
205     @ivar manager: the main program object to send requests to
206     @type scanning: C{list} of L{twisted.python.filepath.FilePath}
207     @ivar scanning: all the directories that are currectly being scanned or waiting to be scanned
208     """
209     
210     def __init__(self, cache_dir, db, manager = None):
211         """Initialize the instance and remove any untracked files from the DB..
212         
213         @type cache_dir: L{twisted.python.filepath.FilePath}
214         @param cache_dir: the directory to use for storing all files
215         @type db: L{db.DB}
216         @param db: the database to use for tracking files and hashes
217         @type manager: L{apt_p2p.AptP2P}
218         @param manager: the main program object to send requests to
219             (optional, defaults to not calling back with cached files)
220         """
221         self.cache_dir = cache_dir
222         self.other_dirs = [FilePath(f) for f in config.getstringlist('DEFAULT', 'OTHER_DIRS')]
223         self.all_dirs = self.other_dirs[:]
224         self.all_dirs.insert(0, self.cache_dir)
225         self.db = db
226         self.manager = manager
227         self.scanning = []
228         
229         # Init the database, remove old files
230         self.db.removeUntrackedFiles(self.all_dirs)
231         
232     #{ Scanning directories
233     def scanDirectories(self):
234         """Scan the cache directories, hashing new and rehashing changed files."""
235         assert not self.scanning, "a directory scan is already under way"
236         self.scanning = self.all_dirs[:]
237         self._scanDirectories()
238
239     def _scanDirectories(self, result = None, walker = None):
240         """Walk each directory looking for cached files.
241         
242         @param result: the result of a DHT store request, not used (optional)
243         @param walker: the walker to use to traverse the current directory
244             (optional, defaults to creating a new walker from the first
245             directory in the L{CacheManager.scanning} list)
246         """
247         # Need to start walking a new directory
248         if walker is None:
249             # If there are any left, get them
250             if self.scanning:
251                 log.msg('started scanning directory: %s' % self.scanning[0].path)
252                 walker = self.scanning[0].walk()
253             else:
254                 log.msg('cache directory scan complete')
255                 return
256             
257         try:
258             # Get the next file in the directory
259             file = walker.next()
260         except StopIteration:
261             # No files left, go to the next directory
262             log.msg('done scanning directory: %s' % self.scanning[0].path)
263             self.scanning.pop(0)
264             reactor.callLater(0, self._scanDirectories)
265             return
266
267         # If it's not a file ignore it
268         if not file.isfile():
269             log.msg('entering directory: %s' % file.path)
270             reactor.callLater(0, self._scanDirectories, None, walker)
271             return
272
273         # If it's already properly in the DB, ignore it
274         db_status = self.db.isUnchanged(file)
275         if db_status:
276             log.msg('file is unchanged: %s' % file.path)
277             reactor.callLater(0, self._scanDirectories, None, walker)
278             return
279         
280         # Don't hash files in the cache that are not in the DB
281         if self.scanning[0] == self.cache_dir:
282             if db_status is None:
283                 log.msg('ignoring unknown cache file: %s' % file.path)
284             else:
285                 log.msg('removing changed cache file: %s' % file.path)
286                 file.remove()
287             reactor.callLater(0, self._scanDirectories, None, walker)
288             return
289
290         # Otherwise hash it
291         log.msg('start hash checking file: %s' % file.path)
292         hash = HashObject()
293         df = hash.hashInThread(file)
294         df.addBoth(self._doneHashing, file, walker)
295         df.addErrback(log.err)
296     
297     def _doneHashing(self, result, file, walker):
298         """If successful, add the hashed file to the DB and inform the main program."""
299         if isinstance(result, HashObject):
300             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
301             
302             # Only set a URL if this is a downloaded file
303             url = None
304             if self.scanning[0] == self.cache_dir:
305                 url = 'http:/' + file.path[len(self.cache_dir.path):]
306                 
307             # Store the hashed file in the database
308             new_hash = self.db.storeFile(file, result.digest(),
309                                          ''.join(result.pieceDigests()))
310             
311             # Tell the main program to handle the new cache file
312             df = self.manager.new_cached_file(file, result, new_hash, url, True)
313             if df is None:
314                 reactor.callLater(0, self._scanDirectories, None, walker)
315             else:
316                 df.addBoth(self._scanDirectories, walker)
317         else:
318             # Must have returned an error
319             log.msg('hash check of %s failed' % file.path)
320             log.err(result)
321             reactor.callLater(0, self._scanDirectories, None, walker)
322
323     #{ Downloading files
324     def save_file(self, response, hash, url):
325         """Save a downloaded file to the cache and stream it.
326         
327         @type response: L{twisted.web2.http.Response}
328         @param response: the response from the download
329         @type hash: L{Hash.HashObject}
330         @param hash: the hash object containing the expected hash for the file
331         @param url: the URI of the actual mirror request
332         @rtype: L{twisted.web2.http.Response}
333         @return: the final response from the download
334         """
335         if response.code != 200:
336             log.msg('File was not found (%r): %s' % (response, url))
337             return response
338         
339         log.msg('Returning file: %s' % url)
340
341         # Set the destination path for the file
342         parsed = urlparse(url)
343         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
344         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
345         
346         # Make sure there's a free place for the file
347         if destFile.exists():
348             log.msg('File already exists, removing: %s' % destFile.path)
349             destFile.remove()
350         elif not destFile.parent().exists():
351             destFile.parent().makedirs()
352
353         # Determine whether it needs to be decompressed and how
354         root, ext = os.path.splitext(destFile.basename())
355         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
356             ext = ext.lower()
357             decFile = destFile.sibling(root)
358             log.msg('Decompressing to: %s' % decFile.path)
359             if decFile.exists():
360                 log.msg('File already exists, removing: %s' % decFile.path)
361                 decFile.remove()
362         else:
363             ext = None
364             decFile = None
365             
366         # Create the new stream from the old one.
367         orig_stream = response.stream
368         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
369         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
370                                               response.headers.getHeader('Last-Modified'),
371                                               decFile)
372         response.stream.doneDefer.addErrback(self.save_error, url)
373
374         # Return the modified response with the new stream
375         return response
376
377     def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
378         """Update the modification time and inform the main program.
379         
380         @type hash: L{Hash.HashObject}
381         @param hash: the hash object containing the expected hash for the file
382         @param url: the URI of the actual mirror request
383         @type destFile: C{twisted.python.FilePath}
384         @param destFile: the file where the download was written to
385         @type modtime: C{int}
386         @param modtime: the modified time of the cached file (seconds since epoch)
387             (optional, defaults to not setting the modification time of the file)
388         @type decFile: C{twisted.python.FilePath}
389         @param decFile: the file where the decompressed download was written to
390             (optional, defaults to the file not having been compressed)
391         """
392         if modtime:
393             os.utime(destFile.path, (modtime, modtime))
394             if decFile:
395                 os.utime(decFile.path, (modtime, modtime))
396         
397         result = hash.verify()
398         if result or result is None:
399             if result:
400                 log.msg('Hashes match: %s' % url)
401             else:
402                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
403                 
404             new_hash = self.db.storeFile(destFile, hash.digest(),
405                                          ''.join(hash.pieceDigests()))
406             log.msg('now avaliable: %s' % (url))
407
408             if self.manager:
409                 self.manager.new_cached_file(destFile, hash, new_hash, url)
410                 if decFile:
411                     ext_len = len(destFile.path) - len(decFile.path)
412                     self.manager.new_cached_file(decFile, None, False, url[:-ext_len])
413         else:
414             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
415             destFile.remove()
416             if decFile:
417                 decFile.remove()
418
419     def save_error(self, failure, url):
420         """An error has occurred in downloadign or saving the file."""
421         log.msg('Error occurred downloading %s' % url)
422         log.err(failure)
423         return failure
424
425 class TestMirrorManager(unittest.TestCase):
426     """Unit tests for the mirror manager."""
427     
428     timeout = 20
429     pending_calls = []
430     client = None
431     
432     def setUp(self):
433         self.client = CacheManager(FilePath('/tmp/.apt-p2p'))
434         
435     def tearDown(self):
436         for p in self.pending_calls:
437             if p.active():
438                 p.cancel()
439         self.client = None
440