]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - apt_dht/MirrorManager.py
Pass the new HashObjects around everywhere (untested).
[quix0rs-apt-p2p.git] / apt_dht / MirrorManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 from binascii import a2b_hex, b2a_hex
7 import os, sha, md5
8
9 from twisted.python import log, filepath
10 from twisted.internet import defer
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from AptPackages import AptPackages
16
17 aptpkg_dir='.apt-dht'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class MirrorError(Exception):
23     """Exception raised when there's a problem with the mirror."""
24
25 class ProxyFileStream(stream.SimpleStream):
26     """Saves a stream to a file while providing a new stream."""
27     
28     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
29         """Initializes the proxy.
30         
31         @type stream: C{twisted.web2.stream.IByteStream}
32         @param stream: the input stream to read from
33         @type outFile: C{twisted.python.filepath.FilePath}
34         @param outFile: the file to write to
35         @type hash: L{Hash.HashObject}
36         @param hash: the hash object to use for the file
37         @type decompress: C{string}
38         @param decompress: also decompress the file as this type
39             (currently only '.gz' and '.bz2' are supported)
40         @type decFile: C{twisted.python.filepath.FilePath}
41         @param decFile: the file to write the decompressed data to
42         """
43         self.stream = stream
44         self.outFile = outFile.open('w')
45         self.hash = hash
46         self.hash.new()
47         self.gzfile = None
48         self.bz2file = None
49         if decompress == ".gz":
50             self.gzheader = True
51             self.gzfile = decFile.open('w')
52             self.gzdec = decompressobj(-MAX_WBITS)
53         elif decompress == ".bz2":
54             self.bz2file = decFile.open('w')
55             self.bz2dec = BZ2Decompressor()
56         self.length = self.stream.length
57         self.start = 0
58         self.doneDefer = defer.Deferred()
59
60     def _done(self):
61         """Close the output file."""
62         if not self.outFile.closed:
63             self.outFile.close()
64             self.hash.digest()
65             if self.gzfile:
66                 data_dec = self.gzdec.flush()
67                 self.gzfile.write(data_dec)
68                 self.gzfile.close()
69                 self.gzfile = None
70             if self.bz2file:
71                 self.bz2file.close()
72                 self.bz2file = None
73                 
74             self.doneDefer.callback(self.hash)
75     
76     def read(self):
77         """Read some data from the stream."""
78         if self.outFile.closed:
79             return None
80         
81         data = self.stream.read()
82         if isinstance(data, defer.Deferred):
83             data.addCallbacks(self._write, self._done)
84             return data
85         
86         self._write(data)
87         return data
88     
89     def _write(self, data):
90         """Write the stream data to the file and return it for others to use."""
91         if data is None:
92             self._done()
93             return data
94         
95         self.outFile.write(data)
96         self.hash.update(data)
97         if self.gzfile:
98             if self.gzheader:
99                 self.gzheader = False
100                 new_data = self._remove_gzip_header(data)
101                 dec_data = self.gzdec.decompress(new_data)
102             else:
103                 dec_data = self.gzdec.decompress(data)
104             self.gzfile.write(dec_data)
105         if self.bz2file:
106             dec_data = self.bz2dec.decompress(data)
107             self.bz2file.write(dec_data)
108         return data
109     
110     def _remove_gzip_header(self, data):
111         if data[:2] != '\037\213':
112             raise IOError, 'Not a gzipped file'
113         if ord(data[2]) != 8:
114             raise IOError, 'Unknown compression method'
115         flag = ord(data[3])
116         # modtime = self.fileobj.read(4)
117         # extraflag = self.fileobj.read(1)
118         # os = self.fileobj.read(1)
119
120         skip = 10
121         if flag & FEXTRA:
122             # Read & discard the extra field, if present
123             xlen = ord(data[10])
124             xlen = xlen + 256*ord(data[11])
125             skip = skip + 2 + xlen
126         if flag & FNAME:
127             # Read and discard a null-terminated string containing the filename
128             while True:
129                 if not data[skip] or data[skip] == '\000':
130                     break
131                 skip += 1
132             skip += 1
133         if flag & FCOMMENT:
134             # Read and discard a null-terminated string containing a comment
135             while True:
136                 if not data[skip] or data[skip] == '\000':
137                     break
138                 skip += 1
139             skip += 1
140         if flag & FHCRC:
141             skip += 2     # Read & discard the 16-bit header CRC
142         return data[skip:]
143
144     def close(self):
145         """Clean everything up and return None to future reads."""
146         self.length = 0
147         self._done()
148         self.stream.close()
149
150 class MirrorManager:
151     """Manages all requests for mirror objects."""
152     
153     def __init__(self, cache_dir):
154         self.cache_dir = cache_dir
155         self.cache = filepath.FilePath(self.cache_dir)
156         self.apt_caches = {}
157     
158     def extractPath(self, url):
159         parsed = urlparse(url)
160         host, port = splitHostPort(parsed[0], parsed[1])
161         site = host + ":" + str(port)
162         path = parsed[2]
163             
164         i = max(path.rfind('/dists/'), path.rfind('/pool/'))
165         if i >= 0:
166             baseDir = path[:i]
167             path = path[i:]
168         else:
169             # Uh oh, this is not good
170             log.msg("Couldn't find a good base directory for path: %s" % (site + path))
171             baseDir = ''
172             if site in self.apt_caches:
173                 longest_match = 0
174                 for base in self.apt_caches[site]:
175                     base_match = ''
176                     for dirs in path.split('/'):
177                         if base.startswith(base_match + '/' + dirs):
178                             base_match += '/' + dirs
179                         else:
180                             break
181                     if len(base_match) > longest_match:
182                         longest_match = len(base_match)
183                         baseDir = base_match
184             log.msg("Settled on baseDir: %s" % baseDir)
185         
186         return site, baseDir, path
187         
188     def init(self, site, baseDir):
189         if site not in self.apt_caches:
190             self.apt_caches[site] = {}
191             
192         if baseDir not in self.apt_caches[site]:
193             site_cache = os.path.join(self.cache_dir, aptpkg_dir, 'mirrors', site + baseDir.replace('/', '_'))
194             self.apt_caches[site][baseDir] = AptPackages(site_cache)
195     
196     def updatedFile(self, url, file_path):
197         site, baseDir, path = self.extractPath(url)
198         self.init(site, baseDir)
199         self.apt_caches[site][baseDir].file_updated(path, file_path)
200
201     def findHash(self, url):
202         site, baseDir, path = self.extractPath(url)
203         if site in self.apt_caches and baseDir in self.apt_caches[site]:
204             return self.apt_caches[site][baseDir].findHash(path)
205         d = defer.Deferred()
206         d.errback(MirrorError("Site Not Found"))
207         return d
208     
209     def save_file(self, response, hash, url):
210         """Save a downloaded file to the cache and stream it."""
211         log.msg('Returning file: %s' % url)
212         
213         parsed = urlparse(url)
214         destFile = self.cache.preauthChild(parsed[1] + parsed[2])
215         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
216         
217         if destFile.exists():
218             log.msg('File already exists, removing: %s' % destFile.path)
219             destFile.remove()
220         else:
221             destFile.parent().makedirs()
222             
223         root, ext = os.path.splitext(destFile.basename())
224         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
225             ext = ext.lower()
226             decFile = destFile.sibling(root)
227             log.msg('Decompressing to: %s' % decFile.path)
228             if decFile.exists():
229                 log.msg('File already exists, removing: %s' % decFile.path)
230                 decFile.remove()
231         else:
232             ext = None
233             decFile = None
234             
235         orig_stream = response.stream
236         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
237         response.stream.doneDefer.addCallback(self.save_complete, url, destFile,
238                                               response.headers.getHeader('Last-Modified'),
239                                               ext, decFile)
240         response.stream.doneDefer.addErrback(self.save_error, url)
241         return response
242
243     def save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
244         """Update the modification time and AptPackages."""
245         if modtime:
246             os.utime(destFile.path, (modtime, modtime))
247             if ext:
248                 os.utime(decFile.path, (modtime, modtime))
249         
250         result = hash.verify()
251         if result or result is None:
252             if result:
253                 log.msg('Hashes match: %s' % url)
254             else:
255                 log.msg('Hashed file to %s: %s' % (b2a_hex(result), url))
256                 
257             self.updatedFile(url, destFile.path)
258             if ext:
259                 self.updatedFile(url[:-len(ext)], decFile.path)
260         else:
261             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
262
263     def save_error(self, failure, url):
264         """An error has occurred in downloadign or saving the file."""
265         log.msg('Error occurred downloading %s' % url)
266         log.err(failure)
267         return failure
268
269 class TestMirrorManager(unittest.TestCase):
270     """Unit tests for the mirror manager."""
271     
272     timeout = 20
273     pending_calls = []
274     client = None
275     
276     def setUp(self):
277         self.client = MirrorManager('/tmp')
278         
279     def test_extractPath(self):
280         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
281         self.failUnless(site == "ftp.us.debian.org:80", "no match: %s" % site)
282         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
283         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
284
285         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org:16999/debian/pool/d/dpkg/dpkg_1.2.1-1.tar.gz')
286         self.failUnless(site == "ftp.us.debian.org:16999", "no match: %s" % site)
287         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
288         self.failUnless(path == "/pool/d/dpkg/dpkg_1.2.1-1.tar.gz", "no match: %s" % path)
289
290         site, baseDir, path = self.client.extractPath('http://debian.camrdale.org/dists/unstable/Release')
291         self.failUnless(site == "debian.camrdale.org:80", "no match: %s" % site)
292         self.failUnless(baseDir == "", "no match: %s" % baseDir)
293         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
294
295     def verifyHash(self, found_hash, path, true_hash):
296         self.failUnless(found_hash.hexexpected() == true_hash, 
297                     "%s hashes don't match: %s != %s" % (path, found_hash[0], true_hash))
298
299     def test_findHash(self):
300         self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
301         self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
302         for f in os.walk('/var/lib/apt/lists').next()[2]:
303             if f[-7:] == "Release" and self.packagesFile.startswith(f[:-7]):
304                 self.releaseFile = f
305                 break
306         
307         self.client.updatedFile('http://' + self.releaseFile.replace('_','/'), 
308                                 '/var/lib/apt/lists/' + self.releaseFile)
309         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
310                                 self.packagesFile[self.packagesFile.find('_dists_')+1:].replace('_','/'), 
311                                 '/var/lib/apt/lists/' + self.packagesFile)
312         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
313                                 self.sourcesFile[self.sourcesFile.find('_dists_')+1:].replace('_','/'), 
314                                 '/var/lib/apt/lists/' + self.sourcesFile)
315
316         lastDefer = defer.Deferred()
317         
318         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
319                             '/var/lib/apt/lists/' + self.releaseFile + 
320                             ' | grep -E " main/binary-i386/Packages.bz2$"'
321                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
322         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/binary-i386/Packages.bz2'
323
324         d = self.client.findHash(idx_path)
325         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
326
327         pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
328                             '/var/lib/apt/lists/' + self.packagesFile + 
329                             ' | grep -E "^SHA1:" | head -n 1' + 
330                             ' | cut -d\  -f 2').read().rstrip('\n')
331         pkg_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + \
332                    os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
333                             '/var/lib/apt/lists/' + self.packagesFile + 
334                             ' | grep -E "^Filename:" | head -n 1' + 
335                             ' | cut -d\  -f 2').read().rstrip('\n')
336
337         d = self.client.findHash(pkg_path)
338         d.addCallback(self.verifyHash, pkg_path, a2b_hex(pkg_hash))
339
340         src_dir = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
341                             '/var/lib/apt/lists/' + self.sourcesFile + 
342                             ' | grep -E "^Directory:" | head -n 1' + 
343                             ' | cut -d\  -f 2').read().rstrip('\n')
344         src_hashes = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
345                             '/var/lib/apt/lists/' + self.sourcesFile + 
346                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
347                             ' | cut -d\  -f 2').read().split('\n')[:-1]
348         src_paths = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
349                             '/var/lib/apt/lists/' + self.sourcesFile + 
350                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
351                             ' | cut -d\  -f 4').read().split('\n')[:-1]
352
353         for i in range(len(src_hashes)):
354             src_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + src_dir + '/' + src_paths[i]
355             d = self.client.findHash(src_path)
356             d.addCallback(self.verifyHash, src_path, a2b_hex(src_hashes[i]))
357             
358         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
359                             '/var/lib/apt/lists/' + self.releaseFile + 
360                             ' | grep -E " main/source/Sources.bz2$"'
361                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
362         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/source/Sources.bz2'
363
364         d = self.client.findHash(idx_path)
365         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
366
367         d.addBoth(lastDefer.callback)
368         return lastDefer
369
370     def tearDown(self):
371         for p in self.pending_calls:
372             if p.active():
373                 p.cancel()
374         self.client = None
375