ProxyFileStream also calculates hash while downloading.
[quix0rs-apt-p2p.git] / apt_dht / MirrorManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 from binascii import a2b_hex, b2a_hex
7 import os, sha, md5
8
9 from twisted.python import log, filepath
10 from twisted.internet import defer
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from AptPackages import AptPackages
16
17 aptpkg_dir='.apt-dht'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class MirrorError(Exception):
23     """Exception raised when there's a problem with the mirror."""
24
25 class ProxyFileStream(stream.SimpleStream):
26     """Saves a stream to a file while providing a new stream."""
27     
28     def __init__(self, stream, outFile, hashType = "sha1", decompress = None, decFile = None):
29         """Initializes the proxy.
30         
31         @type stream: C{twisted.web2.stream.IByteStream}
32         @param stream: the input stream to read from
33         @type outFile: C{twisted.python.filepath.FilePath}
34         @param outFile: the file to write to
35         @type hashType: C{string}
36         @param hashType: also hash the file using this hashing function
37             (currently only 'sha1' and 'md5' are supported)
38         @type decompress: C{string}
39         @param decompress: also decompress the file as this type
40             (currently only '.gz' and '.bz2' are supported)
41         @type decFile: C{twisted.python.filepath.FilePath}
42         @param decFile: the file to write the decompressed data to
43         """
44         self.stream = stream
45         self.outFile = outFile.open('w')
46         self.hasher = None
47         if hashType == "sha1":
48             self.hasher = sha.new()
49         elif hashType == "md5":
50             self.hasher = md5.new()
51         self.gzfile = None
52         self.bz2file = None
53         if decompress == ".gz":
54             self.gzheader = True
55             self.gzfile = decFile.open('w')
56             self.gzdec = decompressobj(-MAX_WBITS)
57         elif decompress == ".bz2":
58             self.bz2file = decFile.open('w')
59             self.bz2dec = BZ2Decompressor()
60         self.length = self.stream.length
61         self.start = 0
62         self.doneDefer = defer.Deferred()
63
64     def _done(self):
65         """Close the output file."""
66         if not self.outFile.closed:
67             self.outFile.close()
68             fileHash = None
69             if self.hasher:
70                 fileHash = self.hasher.digest()
71             if self.gzfile:
72                 data_dec = self.gzdec.flush()
73                 self.gzfile.write(data_dec)
74                 self.gzfile.close()
75                 self.gzfile = None
76             if self.bz2file:
77                 self.bz2file.close()
78                 self.bz2file = None
79                 
80             self.doneDefer.callback(fileHash)
81     
82     def read(self):
83         """Read some data from the stream."""
84         if self.outFile.closed:
85             return None
86         
87         data = self.stream.read()
88         if isinstance(data, defer.Deferred):
89             data.addCallbacks(self._write, self._done)
90             return data
91         
92         self._write(data)
93         return data
94     
95     def _write(self, data):
96         """Write the stream data to the file and return it for others to use."""
97         if data is None:
98             self._done()
99             return data
100         
101         self.outFile.write(data)
102         if self.hasher:
103             self.hasher.update(data)
104         if self.gzfile:
105             if self.gzheader:
106                 self.gzheader = False
107                 new_data = self._remove_gzip_header(data)
108                 dec_data = self.gzdec.decompress(new_data)
109             else:
110                 dec_data = self.gzdec.decompress(data)
111             self.gzfile.write(dec_data)
112         if self.bz2file:
113             dec_data = self.bz2dec.decompress(data)
114             self.bz2file.write(dec_data)
115         return data
116     
117     def _remove_gzip_header(self, data):
118         if data[:2] != '\037\213':
119             raise IOError, 'Not a gzipped file'
120         if ord(data[2]) != 8:
121             raise IOError, 'Unknown compression method'
122         flag = ord(data[3])
123         # modtime = self.fileobj.read(4)
124         # extraflag = self.fileobj.read(1)
125         # os = self.fileobj.read(1)
126
127         skip = 10
128         if flag & FEXTRA:
129             # Read & discard the extra field, if present
130             xlen = ord(data[10])
131             xlen = xlen + 256*ord(data[11])
132             skip = skip + 2 + xlen
133         if flag & FNAME:
134             # Read and discard a null-terminated string containing the filename
135             while True:
136                 if not data[skip] or data[skip] == '\000':
137                     break
138                 skip += 1
139             skip += 1
140         if flag & FCOMMENT:
141             # Read and discard a null-terminated string containing a comment
142             while True:
143                 if not data[skip] or data[skip] == '\000':
144                     break
145                 skip += 1
146             skip += 1
147         if flag & FHCRC:
148             skip += 2     # Read & discard the 16-bit header CRC
149         return data[skip:]
150
151     def close(self):
152         """Clean everything up and return None to future reads."""
153         self.length = 0
154         self._done()
155         self.stream.close()
156
157 class MirrorManager:
158     """Manages all requests for mirror objects."""
159     
160     def __init__(self, cache_dir):
161         self.cache_dir = cache_dir
162         self.cache = filepath.FilePath(self.cache_dir)
163         self.apt_caches = {}
164     
165     def extractPath(self, url):
166         parsed = urlparse(url)
167         host, port = splitHostPort(parsed[0], parsed[1])
168         site = host + ":" + str(port)
169         path = parsed[2]
170             
171         i = max(path.rfind('/dists/'), path.rfind('/pool/'))
172         if i >= 0:
173             baseDir = path[:i]
174             path = path[i:]
175         else:
176             # Uh oh, this is not good
177             log.msg("Couldn't find a good base directory for path: %s" % (site + path))
178             baseDir = ''
179             if site in self.apt_caches:
180                 longest_match = 0
181                 for base in self.apt_caches[site]:
182                     base_match = ''
183                     for dirs in path.split('/'):
184                         if base.startswith(base_match + '/' + dirs):
185                             base_match += '/' + dirs
186                         else:
187                             break
188                     if len(base_match) > longest_match:
189                         longest_match = len(base_match)
190                         baseDir = base_match
191             log.msg("Settled on baseDir: %s" % baseDir)
192         
193         return site, baseDir, path
194         
195     def init(self, site, baseDir):
196         if site not in self.apt_caches:
197             self.apt_caches[site] = {}
198             
199         if baseDir not in self.apt_caches[site]:
200             site_cache = os.path.join(self.cache_dir, aptpkg_dir, 'mirrors', site + baseDir.replace('/', '_'))
201             self.apt_caches[site][baseDir] = AptPackages(site_cache)
202     
203     def updatedFile(self, url, file_path):
204         site, baseDir, path = self.extractPath(url)
205         self.init(site, baseDir)
206         self.apt_caches[site][baseDir].file_updated(path, file_path)
207
208     def findHash(self, url):
209         site, baseDir, path = self.extractPath(url)
210         if site in self.apt_caches and baseDir in self.apt_caches[site]:
211             d = self.apt_caches[site][baseDir].findHash(path)
212             d.addCallback(self.translateHash)
213             return d
214         d = defer.Deferred()
215         d.errback(MirrorError("Site Not Found"))
216         return d
217     
218     def translateHash(self, (hash, size)):
219         """Translate a hash from apt's hex encoding to a string."""
220         if hash:
221             hash = a2b_hex(hash)
222         return (hash, size)
223
224     def save_file(self, response, hash, size, url):
225         """Save a downloaded file to the cache and stream it."""
226         log.msg('Returning file: %s' % url)
227         
228         parsed = urlparse(url)
229         destFile = self.cache.preauthChild(parsed[1] + parsed[2])
230         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
231         
232         if destFile.exists():
233             log.msg('File already exists, removing: %s' % destFile.path)
234             destFile.remove()
235         else:
236             destFile.parent().makedirs()
237             
238         root, ext = os.path.splitext(destFile.basename())
239         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
240             ext = ext.lower()
241             decFile = destFile.sibling(root)
242             log.msg('Decompressing to: %s' % decFile.path)
243             if decFile.exists():
244                 log.msg('File already exists, removing: %s' % decFile.path)
245                 decFile.remove()
246         else:
247             ext = None
248             decFile = None
249             
250         if hash and len(hash) == 16:
251             hashType = "md5"
252         else:
253             hashType = "sha1"
254         
255         orig_stream = response.stream
256         response.stream = ProxyFileStream(orig_stream, destFile, hashType, ext, decFile)
257         response.stream.doneDefer.addCallback(self.save_complete, hash, size, url, destFile,
258                                               response.headers.getHeader('Last-Modified'),
259                                               ext, decFile)
260         response.stream.doneDefer.addErrback(self.save_error, url)
261         return response
262
263     def save_complete(self, result, hash, size, url, destFile, modtime = None, ext = None, decFile = None):
264         """Update the modification time and AptPackages."""
265         if modtime:
266             os.utime(destFile.path, (modtime, modtime))
267             if ext:
268                 os.utime(decFile.path, (modtime, modtime))
269         
270         if not hash or result == hash:
271             if hash:
272                 log.msg('Hashes match: %s' % url)
273             else:
274                 log.msg('Hashed file to %s: %s' % (b2a_hex(result), url))
275                 
276             self.updatedFile(url, destFile.path)
277             if ext:
278                 self.updatedFile(url[:-len(ext)], decFile.path)
279         else:
280             log.msg("Hashes don't match %s != %s: %s" % (b2a_hex(hash), b2a_hex(result), url))
281
282     def save_error(self, failure, url):
283         """An error has occurred in downloadign or saving the file."""
284         log.msg('Error occurred downloading %s' % url)
285         log.err(failure)
286         return failure
287
288 class TestMirrorManager(unittest.TestCase):
289     """Unit tests for the mirror manager."""
290     
291     timeout = 20
292     pending_calls = []
293     client = None
294     
295     def setUp(self):
296         self.client = MirrorManager('/tmp')
297         
298     def test_extractPath(self):
299         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
300         self.failUnless(site == "ftp.us.debian.org:80", "no match: %s" % site)
301         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
302         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
303
304         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org:16999/debian/pool/d/dpkg/dpkg_1.2.1-1.tar.gz')
305         self.failUnless(site == "ftp.us.debian.org:16999", "no match: %s" % site)
306         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
307         self.failUnless(path == "/pool/d/dpkg/dpkg_1.2.1-1.tar.gz", "no match: %s" % path)
308
309         site, baseDir, path = self.client.extractPath('http://debian.camrdale.org/dists/unstable/Release')
310         self.failUnless(site == "debian.camrdale.org:80", "no match: %s" % site)
311         self.failUnless(baseDir == "", "no match: %s" % baseDir)
312         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
313
314     def verifyHash(self, found_hash, path, true_hash):
315         self.failUnless(found_hash[0] == true_hash, 
316                     "%s hashes don't match: %s != %s" % (path, found_hash[0], true_hash))
317
318     def test_findHash(self):
319         self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
320         self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
321         for f in os.walk('/var/lib/apt/lists').next()[2]:
322             if f[-7:] == "Release" and self.packagesFile.startswith(f[:-7]):
323                 self.releaseFile = f
324                 break
325         
326         self.client.updatedFile('http://' + self.releaseFile.replace('_','/'), 
327                                 '/var/lib/apt/lists/' + self.releaseFile)
328         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
329                                 self.packagesFile[self.packagesFile.find('_dists_')+1:].replace('_','/'), 
330                                 '/var/lib/apt/lists/' + self.packagesFile)
331         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
332                                 self.sourcesFile[self.sourcesFile.find('_dists_')+1:].replace('_','/'), 
333                                 '/var/lib/apt/lists/' + self.sourcesFile)
334
335         lastDefer = defer.Deferred()
336         
337         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
338                             '/var/lib/apt/lists/' + self.releaseFile + 
339                             ' | grep -E " main/binary-i386/Packages.bz2$"'
340                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
341         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/binary-i386/Packages.bz2'
342
343         d = self.client.findHash(idx_path)
344         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
345
346         pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
347                             '/var/lib/apt/lists/' + self.packagesFile + 
348                             ' | grep -E "^SHA1:" | head -n 1' + 
349                             ' | cut -d\  -f 2').read().rstrip('\n')
350         pkg_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + \
351                    os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
352                             '/var/lib/apt/lists/' + self.packagesFile + 
353                             ' | grep -E "^Filename:" | head -n 1' + 
354                             ' | cut -d\  -f 2').read().rstrip('\n')
355
356         d = self.client.findHash(pkg_path)
357         d.addCallback(self.verifyHash, pkg_path, a2b_hex(pkg_hash))
358
359         src_dir = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
360                             '/var/lib/apt/lists/' + self.sourcesFile + 
361                             ' | grep -E "^Directory:" | head -n 1' + 
362                             ' | cut -d\  -f 2').read().rstrip('\n')
363         src_hashes = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
364                             '/var/lib/apt/lists/' + self.sourcesFile + 
365                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
366                             ' | cut -d\  -f 2').read().split('\n')[:-1]
367         src_paths = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
368                             '/var/lib/apt/lists/' + self.sourcesFile + 
369                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
370                             ' | cut -d\  -f 4').read().split('\n')[:-1]
371
372         for i in range(len(src_hashes)):
373             src_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + src_dir + '/' + src_paths[i]
374             d = self.client.findHash(src_path)
375             d.addCallback(self.verifyHash, src_path, a2b_hex(src_hashes[i]))
376             
377         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
378                             '/var/lib/apt/lists/' + self.releaseFile + 
379                             ' | grep -E " main/source/Sources.bz2$"'
380                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
381         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/source/Sources.bz2'
382
383         d = self.client.findHash(idx_path)
384         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
385
386         d.addBoth(lastDefer.callback)
387         return lastDefer
388
389     def tearDown(self):
390         for p in self.pending_calls:
391             if p.active():
392                 p.cancel()
393         self.client = None
394