Add back the updatedFile method to the MirrorManager.
[quix0rs-apt-p2p.git] / apt_dht / MirrorManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 from binascii import a2b_hex
7 import os
8
9 from twisted.python import log, filepath
10 from twisted.internet import defer
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from AptPackages import AptPackages
16
17 aptpkg_dir='.apt-dht'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class MirrorError(Exception):
23     """Exception raised when there's a problem with the mirror."""
24
25 class ProxyFileStream(stream.SimpleStream):
26     """Saves a stream to a file while providing a new stream."""
27     
28     def __init__(self, stream, outFile, decompress = None, decFile = None):
29         """Initializes the proxy.
30         
31         @type stream: C{twisted.web2.stream.IByteStream}
32         @param stream: the input stream to read from
33         @type outFile: C{twisted.python.filepath.FilePath}
34         @param outFile: the file to write to
35         @type decompress: C{string}
36         @param decompress: also decompress the file as this type
37             (currently only '.gz' and '.bz2' are supported)
38         @type decFile: C{twisted.python.filepath.FilePath}
39         @param decFile: the file to write the decompressed data to
40         """
41         self.stream = stream
42         self.outFile = outFile.open('w')
43         self.gzfile = None
44         self.bz2file = None
45         if decompress == ".gz":
46             self.gzheader = True
47             self.gzfile = decFile.open('w')
48             self.gzdec = decompressobj(-MAX_WBITS)
49         elif decompress == ".bz2":
50             self.bz2file = decFile.open('w')
51             self.bz2dec = BZ2Decompressor()
52         self.length = self.stream.length
53         self.start = 0
54         self.doneDefer = defer.Deferred()
55
56     def _done(self):
57         """Close the output file."""
58         if not self.outFile.closed:
59             self.outFile.close()
60             if self.gzfile:
61                 data_dec = self.gzdec.flush()
62                 self.gzfile.write(data_dec)
63                 self.gzfile.close()
64                 self.gzfile = None
65             if self.bz2file:
66                 self.bz2file.close()
67                 self.bz2file = None
68                 
69             self.doneDefer.callback(1)
70     
71     def read(self):
72         """Read some data from the stream."""
73         if self.outFile.closed:
74             return None
75         
76         data = self.stream.read()
77         if isinstance(data, defer.Deferred):
78             data.addCallbacks(self._write, self._done)
79             return data
80         
81         self._write(data)
82         return data
83     
84     def _write(self, data):
85         """Write the stream data to the file and return it for others to use."""
86         if data is None:
87             self._done()
88             return data
89         
90         self.outFile.write(data)
91         if self.gzfile:
92             if self.gzheader:
93                 self.gzheader = False
94                 new_data = self._remove_gzip_header(data)
95                 dec_data = self.gzdec.decompress(new_data)
96             else:
97                 dec_data = self.gzdec.decompress(data)
98             self.gzfile.write(dec_data)
99         if self.bz2file:
100             dec_data = self.bz2dec.decompress(data)
101             self.bz2file.write(dec_data)
102         return data
103     
104     def _remove_gzip_header(self, data):
105         if data[:2] != '\037\213':
106             raise IOError, 'Not a gzipped file'
107         if ord(data[2]) != 8:
108             raise IOError, 'Unknown compression method'
109         flag = ord(data[3])
110         # modtime = self.fileobj.read(4)
111         # extraflag = self.fileobj.read(1)
112         # os = self.fileobj.read(1)
113
114         skip = 10
115         if flag & FEXTRA:
116             # Read & discard the extra field, if present
117             xlen = ord(data[10])
118             xlen = xlen + 256*ord(data[11])
119             skip = skip + 2 + xlen
120         if flag & FNAME:
121             # Read and discard a null-terminated string containing the filename
122             while True:
123                 if not data[skip] or data[skip] == '\000':
124                     break
125                 skip += 1
126             skip += 1
127         if flag & FCOMMENT:
128             # Read and discard a null-terminated string containing a comment
129             while True:
130                 if not data[skip] or data[skip] == '\000':
131                     break
132                 skip += 1
133             skip += 1
134         if flag & FHCRC:
135             skip += 2     # Read & discard the 16-bit header CRC
136         return data[skip:]
137
138     def close(self):
139         """Clean everything up and return None to future reads."""
140         self.length = 0
141         self._done()
142         self.stream.close()
143
144 class MirrorManager:
145     """Manages all requests for mirror objects."""
146     
147     def __init__(self, cache_dir):
148         self.cache_dir = cache_dir
149         self.cache = filepath.FilePath(self.cache_dir)
150         self.apt_caches = {}
151     
152     def extractPath(self, url):
153         parsed = urlparse(url)
154         host, port = splitHostPort(parsed[0], parsed[1])
155         site = host + ":" + str(port)
156         path = parsed[2]
157             
158         i = max(path.rfind('/dists/'), path.rfind('/pool/'))
159         if i >= 0:
160             baseDir = path[:i]
161             path = path[i:]
162         else:
163             # Uh oh, this is not good
164             log.msg("Couldn't find a good base directory for path: %s" % (site + path))
165             baseDir = ''
166             if site in self.apt_caches:
167                 longest_match = 0
168                 for base in self.apt_caches[site]:
169                     base_match = ''
170                     for dirs in path.split('/'):
171                         if base.startswith(base_match + '/' + dirs):
172                             base_match += '/' + dirs
173                         else:
174                             break
175                     if len(base_match) > longest_match:
176                         longest_match = len(base_match)
177                         baseDir = base_match
178             log.msg("Settled on baseDir: %s" % baseDir)
179         
180         return site, baseDir, path
181         
182     def init(self, site, baseDir):
183         if site not in self.apt_caches:
184             self.apt_caches[site] = {}
185             
186         if baseDir not in self.apt_caches[site]:
187             site_cache = os.path.join(self.cache_dir, aptpkg_dir, 'mirrors', site + baseDir.replace('/', '_'))
188             self.apt_caches[site][baseDir] = AptPackages(site_cache)
189     
190     def updatedFile(self, url, file_path):
191         site, baseDir, path = self.extractPath(url)
192         self.init(site, baseDir)
193         self.apt_caches[site][baseDir].file_updated(path, file_path)
194
195     def findHash(self, url):
196         site, baseDir, path = self.extractPath(url)
197         if site in self.apt_caches and baseDir in self.apt_caches[site]:
198             d = self.apt_caches[site][baseDir].findHash(path)
199             d.addCallback(self.translateHash)
200             return d
201         d = defer.Deferred()
202         d.errback(MirrorError("Site Not Found"))
203         return d
204     
205     def translateHash(self, (hash, size)):
206         """Translate a hash from apt's hex encoding to a string."""
207         if hash:
208             hash = a2b_hex(hash)
209         return (hash, size)
210
211     def save_file(self, response, hash, size, url):
212         """Save a downloaded file to the cache and stream it."""
213         log.msg('Returning file: %s' % url)
214         
215         parsed = urlparse(url)
216         destFile = self.cache.preauthChild(parsed[1] + parsed[2])
217         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
218         
219         if destFile.exists():
220             log.msg('File already exists, removing: %s' % destFile.path)
221             destFile.remove()
222         else:
223             destFile.parent().makedirs()
224             
225         root, ext = os.path.splitext(destFile.basename())
226         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
227             ext = ext.lower()
228             decFile = destFile.sibling(root)
229             log.msg('Decompressing to: %s' % decFile.path)
230             if decFile.exists():
231                 log.msg('File already exists, removing: %s' % decFile.path)
232                 decFile.remove()
233         else:
234             ext = None
235             decFile = None
236         
237         orig_stream = response.stream
238         response.stream = ProxyFileStream(orig_stream, destFile, ext, decFile)
239         response.stream.doneDefer.addCallback(self.save_complete, url, destFile,
240                                               response.headers.getHeader('Last-Modified'),
241                                               ext, decFile)
242         response.stream.doneDefer.addErrback(self.save_error, url)
243         return response
244
245     def save_complete(self, result, url, destFile, modtime = None, ext = None, decFile = None):
246         """Update the modification time and AptPackages."""
247         if modtime:
248             os.utime(destFile.path, (modtime, modtime))
249             if ext:
250                 os.utime(decFile.path, (modtime, modtime))
251             
252         self.updatedFile(url, destFile.path)
253         if ext:
254             self.updatedFile(url[:-len(ext)], decFile.path)
255
256     def save_error(self, failure, url):
257         """An error has occurred in downloadign or saving the file."""
258         log.msg('Error occurred downloading %s' % url)
259         log.err(failure)
260         return failure
261
262 class TestMirrorManager(unittest.TestCase):
263     """Unit tests for the mirror manager."""
264     
265     timeout = 20
266     pending_calls = []
267     client = None
268     
269     def setUp(self):
270         self.client = MirrorManager('/tmp')
271         
272     def test_extractPath(self):
273         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org/debian/dists/unstable/Release')
274         self.failUnless(site == "ftp.us.debian.org:80", "no match: %s" % site)
275         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
276         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
277
278         site, baseDir, path = self.client.extractPath('http://ftp.us.debian.org:16999/debian/pool/d/dpkg/dpkg_1.2.1-1.tar.gz')
279         self.failUnless(site == "ftp.us.debian.org:16999", "no match: %s" % site)
280         self.failUnless(baseDir == "/debian", "no match: %s" % baseDir)
281         self.failUnless(path == "/pool/d/dpkg/dpkg_1.2.1-1.tar.gz", "no match: %s" % path)
282
283         site, baseDir, path = self.client.extractPath('http://debian.camrdale.org/dists/unstable/Release')
284         self.failUnless(site == "debian.camrdale.org:80", "no match: %s" % site)
285         self.failUnless(baseDir == "", "no match: %s" % baseDir)
286         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
287
288     def verifyHash(self, found_hash, path, true_hash):
289         self.failUnless(found_hash[0] == true_hash, 
290                     "%s hashes don't match: %s != %s" % (path, found_hash[0], true_hash))
291
292     def test_findHash(self):
293         self.packagesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Packages$" | tail -n 1').read().rstrip('\n')
294         self.sourcesFile = os.popen('ls -Sr /var/lib/apt/lists/ | grep -E "_main_.*Sources$" | tail -n 1').read().rstrip('\n')
295         for f in os.walk('/var/lib/apt/lists').next()[2]:
296             if f[-7:] == "Release" and self.packagesFile.startswith(f[:-7]):
297                 self.releaseFile = f
298                 break
299         
300         self.client.updatedFile('http://' + self.releaseFile.replace('_','/'), 
301                                 '/var/lib/apt/lists/' + self.releaseFile)
302         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
303                                 self.packagesFile[self.packagesFile.find('_dists_')+1:].replace('_','/'), 
304                                 '/var/lib/apt/lists/' + self.packagesFile)
305         self.client.updatedFile('http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') +
306                                 self.sourcesFile[self.sourcesFile.find('_dists_')+1:].replace('_','/'), 
307                                 '/var/lib/apt/lists/' + self.sourcesFile)
308
309         lastDefer = defer.Deferred()
310         
311         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
312                             '/var/lib/apt/lists/' + self.releaseFile + 
313                             ' | grep -E " main/binary-i386/Packages.bz2$"'
314                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
315         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/binary-i386/Packages.bz2'
316
317         d = self.client.findHash(idx_path)
318         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
319
320         pkg_hash = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
321                             '/var/lib/apt/lists/' + self.packagesFile + 
322                             ' | grep -E "^SHA1:" | head -n 1' + 
323                             ' | cut -d\  -f 2').read().rstrip('\n')
324         pkg_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + \
325                    os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
326                             '/var/lib/apt/lists/' + self.packagesFile + 
327                             ' | grep -E "^Filename:" | head -n 1' + 
328                             ' | cut -d\  -f 2').read().rstrip('\n')
329
330         d = self.client.findHash(pkg_path)
331         d.addCallback(self.verifyHash, pkg_path, a2b_hex(pkg_hash))
332
333         src_dir = os.popen('grep -A 30 -E "^Package: dpkg$" ' + 
334                             '/var/lib/apt/lists/' + self.sourcesFile + 
335                             ' | grep -E "^Directory:" | head -n 1' + 
336                             ' | cut -d\  -f 2').read().rstrip('\n')
337         src_hashes = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
338                             '/var/lib/apt/lists/' + self.sourcesFile + 
339                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
340                             ' | cut -d\  -f 2').read().split('\n')[:-1]
341         src_paths = os.popen('grep -A 20 -E "^Package: dpkg$" ' + 
342                             '/var/lib/apt/lists/' + self.sourcesFile + 
343                             ' | grep -A 4 -E "^Files:" | grep -E "^ " ' + 
344                             ' | cut -d\  -f 4').read().split('\n')[:-1]
345
346         for i in range(len(src_hashes)):
347             src_path = 'http://' + self.releaseFile[:self.releaseFile.find('_dists_')+1].replace('_','/') + src_dir + '/' + src_paths[i]
348             d = self.client.findHash(src_path)
349             d.addCallback(self.verifyHash, src_path, a2b_hex(src_hashes[i]))
350             
351         idx_hash = os.popen('grep -A 3000 -E "^SHA1:" ' + 
352                             '/var/lib/apt/lists/' + self.releaseFile + 
353                             ' | grep -E " main/source/Sources.bz2$"'
354                             ' | head -n 1 | cut -d\  -f 2').read().rstrip('\n')
355         idx_path = 'http://' + self.releaseFile.replace('_','/')[:-7] + 'main/source/Sources.bz2'
356
357         d = self.client.findHash(idx_path)
358         d.addCallback(self.verifyHash, idx_path, a2b_hex(idx_hash))
359
360         d.addBoth(lastDefer.callback)
361         return lastDefer
362
363     def tearDown(self):
364         for p in self.pending_calls:
365             if p.active():
366                 p.cancel()
367         self.client = None
368