CacheManager scans the cache directory during initialization.
[quix0rs-apt-p2p.git] / apt_dht / CacheManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 import os
7
8 from twisted.python import log
9 from twisted.python.filepath import FilePath
10 from twisted.internet import defer, reactor
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from Hash import HashObject
16
17 aptpkg_dir='apt-packages'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class ProxyFileStream(stream.SimpleStream):
23     """Saves a stream to a file while providing a new stream."""
24     
25     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
26         """Initializes the proxy.
27         
28         @type stream: C{twisted.web2.stream.IByteStream}
29         @param stream: the input stream to read from
30         @type outFile: C{twisted.python.FilePath}
31         @param outFile: the file to write to
32         @type hash: L{Hash.HashObject}
33         @param hash: the hash object to use for the file
34         @type decompress: C{string}
35         @param decompress: also decompress the file as this type
36             (currently only '.gz' and '.bz2' are supported)
37         @type decFile: C{twisted.python.FilePath}
38         @param decFile: the file to write the decompressed data to
39         """
40         self.stream = stream
41         self.outFile = outFile.open('w')
42         self.hash = hash
43         self.hash.new()
44         self.gzfile = None
45         self.bz2file = None
46         if decompress == ".gz":
47             self.gzheader = True
48             self.gzfile = decFile.open('w')
49             self.gzdec = decompressobj(-MAX_WBITS)
50         elif decompress == ".bz2":
51             self.bz2file = decFile.open('w')
52             self.bz2dec = BZ2Decompressor()
53         self.length = self.stream.length
54         self.start = 0
55         self.doneDefer = defer.Deferred()
56
57     def _done(self):
58         """Close the output file."""
59         if not self.outFile.closed:
60             self.outFile.close()
61             self.hash.digest()
62             if self.gzfile:
63                 data_dec = self.gzdec.flush()
64                 self.gzfile.write(data_dec)
65                 self.gzfile.close()
66                 self.gzfile = None
67             if self.bz2file:
68                 self.bz2file.close()
69                 self.bz2file = None
70                 
71             self.doneDefer.callback(self.hash)
72     
73     def read(self):
74         """Read some data from the stream."""
75         if self.outFile.closed:
76             return None
77         
78         data = self.stream.read()
79         if isinstance(data, defer.Deferred):
80             data.addCallbacks(self._write, self._done)
81             return data
82         
83         self._write(data)
84         return data
85     
86     def _write(self, data):
87         """Write the stream data to the file and return it for others to use."""
88         if data is None:
89             self._done()
90             return data
91         
92         self.outFile.write(data)
93         self.hash.update(data)
94         if self.gzfile:
95             if self.gzheader:
96                 self.gzheader = False
97                 new_data = self._remove_gzip_header(data)
98                 dec_data = self.gzdec.decompress(new_data)
99             else:
100                 dec_data = self.gzdec.decompress(data)
101             self.gzfile.write(dec_data)
102         if self.bz2file:
103             dec_data = self.bz2dec.decompress(data)
104             self.bz2file.write(dec_data)
105         return data
106     
107     def _remove_gzip_header(self, data):
108         if data[:2] != '\037\213':
109             raise IOError, 'Not a gzipped file'
110         if ord(data[2]) != 8:
111             raise IOError, 'Unknown compression method'
112         flag = ord(data[3])
113         # modtime = self.fileobj.read(4)
114         # extraflag = self.fileobj.read(1)
115         # os = self.fileobj.read(1)
116
117         skip = 10
118         if flag & FEXTRA:
119             # Read & discard the extra field, if present
120             xlen = ord(data[10])
121             xlen = xlen + 256*ord(data[11])
122             skip = skip + 2 + xlen
123         if flag & FNAME:
124             # Read and discard a null-terminated string containing the filename
125             while True:
126                 if not data[skip] or data[skip] == '\000':
127                     break
128                 skip += 1
129             skip += 1
130         if flag & FCOMMENT:
131             # Read and discard a null-terminated string containing a comment
132             while True:
133                 if not data[skip] or data[skip] == '\000':
134                     break
135                 skip += 1
136             skip += 1
137         if flag & FHCRC:
138             skip += 2     # Read & discard the 16-bit header CRC
139         return data[skip:]
140
141     def close(self):
142         """Clean everything up and return None to future reads."""
143         self.length = 0
144         self._done()
145         self.stream.close()
146
147 class CacheManager:
148     """Manages all requests for cached objects."""
149     
150     def __init__(self, cache_dir, db, manager = None):
151         self.cache_dir = cache_dir
152         self.db = db
153         self.manager = manager
154         self.scanning = []
155         
156         # Init the database, remove old files, init the HTTP dirs
157         self.db.removeUntrackedFiles([self.cache_dir])
158         self.db.reconcileDirectories()
159         self.manager.setDirectories(self.db.getAllDirectories())
160         
161         
162     def scanDirectories(self):
163         """Scan the cache directories, hashing new and rehashing changed files."""
164         assert not self.scanning, "a directory scan is already under way"
165         self.scanning.append(self.cache_dir)
166         self._scanDirectories()
167
168     def _scanDirectories(self, walker = None):
169         # Need to start waling a new directory
170         if walker is None:
171             # If there are any left, get them
172             if self.scanning:
173                 log.msg('started scanning directory: %s' % self.scanning[0].path)
174                 walker = self.scanning[0].walk()
175             else:
176                 # Done, just check if the HTTP directories need updating
177                 log.msg('cache directory scan complete')
178                 if self.db.reconcileDirectories():
179                     self.manager.setDirectories(self.db.getAllDirectories())
180                 return
181             
182         try:
183             # Get the next file in the directory
184             file = walker.next()
185         except StopIteration:
186             # No files left, go to the next directory
187             log.msg('done scanning directory: %s' % self.scanning[0].path)
188             self.scanning.pop(0)
189             reactor.callLater(0, self._scanDirectories)
190             return
191
192         # If it's not a file, or it's already properly in the DB, ignore it
193         if not file.isfile() or self.db.isUnchanged(file):
194             if not file.isfile():
195                 log.msg('entering directory: %s' % file.path)
196             else:
197                 log.msg('file is unchanged: %s' % file.path)
198             reactor.callLater(0, self._scanDirectories, walker)
199             return
200
201         # Otherwise hash it
202         log.msg('start hash checking file: %s' % file.path)
203         hash = HashObject()
204         df = hash.hashInThread(file)
205         df.addBoth(self._doneHashing, file, walker)
206         df.addErrback(log.err)
207     
208     def _doneHashing(self, result, file, walker):
209         reactor.callLater(0, self._scanDirectories, walker)
210     
211         if isinstance(result, HashObject):
212             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
213             if self.scanning[0] == self.cache_dir:
214                 mirror_dir = self.cache_dir.child(file.path[len(self.cache_dir.path)+1:].split('/', 1)[0])
215                 urlpath, newdir = self.db.storeFile(file, result.digest(), mirror_dir)
216                 url = 'http:/' + file.path[len(self.cache_dir.path):]
217             else:
218                 urlpath, newdir = self.db.storeFile(file, result.digest(), self.scanning[0])
219                 url = None
220             if newdir:
221                 self.manager.setDirectories(self.db.getAllDirectories())
222             self.manager.new_cached_file(file, result, urlpath, url)
223         else:
224             log.msg('hash check of %s failed' % file.path)
225             log.err(result)
226
227     def save_file(self, response, hash, url):
228         """Save a downloaded file to the cache and stream it."""
229         if response.code != 200:
230             log.msg('File was not found (%r): %s' % (response, url))
231             return response
232         
233         log.msg('Returning file: %s' % url)
234         
235         parsed = urlparse(url)
236         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
237         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
238         
239         if destFile.exists():
240             log.msg('File already exists, removing: %s' % destFile.path)
241             destFile.remove()
242         elif not destFile.parent().exists():
243             destFile.parent().makedirs()
244             
245         root, ext = os.path.splitext(destFile.basename())
246         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
247             ext = ext.lower()
248             decFile = destFile.sibling(root)
249             log.msg('Decompressing to: %s' % decFile.path)
250             if decFile.exists():
251                 log.msg('File already exists, removing: %s' % decFile.path)
252                 decFile.remove()
253         else:
254             ext = None
255             decFile = None
256             
257         orig_stream = response.stream
258         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
259         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
260                                               response.headers.getHeader('Last-Modified'),
261                                               ext, decFile)
262         response.stream.doneDefer.addErrback(self.save_error, url)
263         return response
264
265     def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
266         """Update the modification time and AptPackages."""
267         if modtime:
268             os.utime(destFile.path, (modtime, modtime))
269             if ext:
270                 os.utime(decFile.path, (modtime, modtime))
271         
272         result = hash.verify()
273         if result or result is None:
274             if result:
275                 log.msg('Hashes match: %s' % url)
276             else:
277                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
278                 
279             mirror_dir = self.cache_dir.child(destFile.path[len(self.cache_dir.path)+1:].split('/', 1)[0])
280             urlpath, newdir = self.db.storeFile(destFile, hash.digest(), mirror_dir)
281             log.msg('now avaliable at %s: %s' % (urlpath, url))
282
283             if self.manager:
284                 if newdir:
285                     log.msg('A new web directory was created, so enable it')
286                     self.manager.setDirectories(self.db.getAllDirectories())
287     
288                 self.manager.new_cached_file(destFile, hash, urlpath, url)
289                 if ext:
290                     self.manager.new_cached_file(decFile, None, urlpath, url[:-len(ext)])
291         else:
292             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
293             destFile.remove()
294             if ext:
295                 decFile.remove()
296
297     def save_error(self, failure, url):
298         """An error has occurred in downloadign or saving the file."""
299         log.msg('Error occurred downloading %s' % url)
300         log.err(failure)
301         return failure
302
303 class TestMirrorManager(unittest.TestCase):
304     """Unit tests for the mirror manager."""
305     
306     timeout = 20
307     pending_calls = []
308     client = None
309     
310     def setUp(self):
311         self.client = CacheManager(FilePath('/tmp/.apt-dht'))
312         
313     def tearDown(self):
314         for p in self.pending_calls:
315             if p.active():
316                 p.cancel()
317         self.client = None
318