Add an option to not error out when only a local IP address can be found.
[quix0rs-apt-p2p.git] / apt_dht / CacheManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 import os
7
8 from twisted.python import log
9 from twisted.python.filepath import FilePath
10 from twisted.internet import defer, reactor
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from Hash import HashObject
16
17 aptpkg_dir='apt-packages'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class ProxyFileStream(stream.SimpleStream):
23     """Saves a stream to a file while providing a new stream."""
24     
25     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
26         """Initializes the proxy.
27         
28         @type stream: C{twisted.web2.stream.IByteStream}
29         @param stream: the input stream to read from
30         @type outFile: C{twisted.python.FilePath}
31         @param outFile: the file to write to
32         @type hash: L{Hash.HashObject}
33         @param hash: the hash object to use for the file
34         @type decompress: C{string}
35         @param decompress: also decompress the file as this type
36             (currently only '.gz' and '.bz2' are supported)
37         @type decFile: C{twisted.python.FilePath}
38         @param decFile: the file to write the decompressed data to
39         """
40         self.stream = stream
41         self.outFile = outFile.open('w')
42         self.hash = hash
43         self.hash.new()
44         self.gzfile = None
45         self.bz2file = None
46         if decompress == ".gz":
47             self.gzheader = True
48             self.gzfile = decFile.open('w')
49             self.gzdec = decompressobj(-MAX_WBITS)
50         elif decompress == ".bz2":
51             self.bz2file = decFile.open('w')
52             self.bz2dec = BZ2Decompressor()
53         self.length = self.stream.length
54         self.start = 0
55         self.doneDefer = defer.Deferred()
56
57     def _done(self):
58         """Close the output file."""
59         if not self.outFile.closed:
60             self.outFile.close()
61             self.hash.digest()
62             if self.gzfile:
63                 data_dec = self.gzdec.flush()
64                 self.gzfile.write(data_dec)
65                 self.gzfile.close()
66                 self.gzfile = None
67             if self.bz2file:
68                 self.bz2file.close()
69                 self.bz2file = None
70                 
71             self.doneDefer.callback(self.hash)
72     
73     def read(self):
74         """Read some data from the stream."""
75         if self.outFile.closed:
76             return None
77         
78         data = self.stream.read()
79         if isinstance(data, defer.Deferred):
80             data.addCallbacks(self._write, self._done)
81             return data
82         
83         self._write(data)
84         return data
85     
86     def _write(self, data):
87         """Write the stream data to the file and return it for others to use."""
88         if data is None:
89             self._done()
90             return data
91         
92         self.outFile.write(data)
93         self.hash.update(data)
94         if self.gzfile:
95             if self.gzheader:
96                 self.gzheader = False
97                 new_data = self._remove_gzip_header(data)
98                 dec_data = self.gzdec.decompress(new_data)
99             else:
100                 dec_data = self.gzdec.decompress(data)
101             self.gzfile.write(dec_data)
102         if self.bz2file:
103             dec_data = self.bz2dec.decompress(data)
104             self.bz2file.write(dec_data)
105         return data
106     
107     def _remove_gzip_header(self, data):
108         if data[:2] != '\037\213':
109             raise IOError, 'Not a gzipped file'
110         if ord(data[2]) != 8:
111             raise IOError, 'Unknown compression method'
112         flag = ord(data[3])
113         # modtime = self.fileobj.read(4)
114         # extraflag = self.fileobj.read(1)
115         # os = self.fileobj.read(1)
116
117         skip = 10
118         if flag & FEXTRA:
119             # Read & discard the extra field, if present
120             xlen = ord(data[10])
121             xlen = xlen + 256*ord(data[11])
122             skip = skip + 2 + xlen
123         if flag & FNAME:
124             # Read and discard a null-terminated string containing the filename
125             while True:
126                 if not data[skip] or data[skip] == '\000':
127                     break
128                 skip += 1
129             skip += 1
130         if flag & FCOMMENT:
131             # Read and discard a null-terminated string containing a comment
132             while True:
133                 if not data[skip] or data[skip] == '\000':
134                     break
135                 skip += 1
136             skip += 1
137         if flag & FHCRC:
138             skip += 2     # Read & discard the 16-bit header CRC
139         return data[skip:]
140
141     def close(self):
142         """Clean everything up and return None to future reads."""
143         self.length = 0
144         self._done()
145         self.stream.close()
146
147 class CacheManager:
148     """Manages all requests for cached objects."""
149     
150     def __init__(self, cache_dir, db, other_dirs = [], manager = None):
151         self.cache_dir = cache_dir
152         self.other_dirs = other_dirs
153         self.all_dirs = self.other_dirs[:]
154         self.all_dirs.insert(0, self.cache_dir)
155         self.db = db
156         self.manager = manager
157         self.scanning = []
158         
159         # Init the database, remove old files, init the HTTP dirs
160         self.db.removeUntrackedFiles(self.all_dirs)
161         self.db.reconcileDirectories()
162         self.manager.setDirectories(self.db.getAllDirectories())
163         
164         
165     def scanDirectories(self):
166         """Scan the cache directories, hashing new and rehashing changed files."""
167         assert not self.scanning, "a directory scan is already under way"
168         self.scanning = self.all_dirs[:]
169         self._scanDirectories()
170
171     def _scanDirectories(self, result = None, walker = None):
172         # Need to start waling a new directory
173         if walker is None:
174             # If there are any left, get them
175             if self.scanning:
176                 log.msg('started scanning directory: %s' % self.scanning[0].path)
177                 walker = self.scanning[0].walk()
178             else:
179                 # Done, just check if the HTTP directories need updating
180                 log.msg('cache directory scan complete')
181                 if self.db.reconcileDirectories():
182                     self.manager.setDirectories(self.db.getAllDirectories())
183                 return
184             
185         try:
186             # Get the next file in the directory
187             file = walker.next()
188         except StopIteration:
189             # No files left, go to the next directory
190             log.msg('done scanning directory: %s' % self.scanning[0].path)
191             self.scanning.pop(0)
192             reactor.callLater(0, self._scanDirectories)
193             return
194
195         # If it's not a file, or it's already properly in the DB, ignore it
196         if not file.isfile() or self.db.isUnchanged(file):
197             if not file.isfile():
198                 log.msg('entering directory: %s' % file.path)
199             else:
200                 log.msg('file is unchanged: %s' % file.path)
201             reactor.callLater(0, self._scanDirectories, None, walker)
202             return
203
204         # Otherwise hash it
205         log.msg('start hash checking file: %s' % file.path)
206         hash = HashObject()
207         df = hash.hashInThread(file)
208         df.addBoth(self._doneHashing, file, walker)
209         df.addErrback(log.err)
210     
211     def _doneHashing(self, result, file, walker):
212     
213         if isinstance(result, HashObject):
214             log.msg('hash check of %s completed with hash: %s' % (file.path, result.hexdigest()))
215             if self.scanning[0] == self.cache_dir:
216                 mirror_dir = self.cache_dir.child(file.path[len(self.cache_dir.path)+1:].split('/', 1)[0])
217                 urlpath, newdir = self.db.storeFile(file, result.digest(), mirror_dir)
218                 url = 'http:/' + file.path[len(self.cache_dir.path):]
219             else:
220                 urlpath, newdir = self.db.storeFile(file, result.digest(), self.scanning[0])
221                 url = None
222             if newdir:
223                 self.manager.setDirectories(self.db.getAllDirectories())
224             df = self.manager.new_cached_file(file, result, urlpath, url)
225             if df is None:
226                 reactor.callLater(0, self._scanDirectories, None, walker)
227             else:
228                 df.addBoth(self._scanDirectories, walker)
229         else:
230             log.msg('hash check of %s failed' % file.path)
231             log.err(result)
232             reactor.callLater(0, self._scanDirectories, None, walker)
233
234     def save_file(self, response, hash, url):
235         """Save a downloaded file to the cache and stream it."""
236         if response.code != 200:
237             log.msg('File was not found (%r): %s' % (response, url))
238             return response
239         
240         log.msg('Returning file: %s' % url)
241         
242         parsed = urlparse(url)
243         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
244         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
245         
246         if destFile.exists():
247             log.msg('File already exists, removing: %s' % destFile.path)
248             destFile.remove()
249         elif not destFile.parent().exists():
250             destFile.parent().makedirs()
251             
252         root, ext = os.path.splitext(destFile.basename())
253         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
254             ext = ext.lower()
255             decFile = destFile.sibling(root)
256             log.msg('Decompressing to: %s' % decFile.path)
257             if decFile.exists():
258                 log.msg('File already exists, removing: %s' % decFile.path)
259                 decFile.remove()
260         else:
261             ext = None
262             decFile = None
263             
264         orig_stream = response.stream
265         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
266         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
267                                               response.headers.getHeader('Last-Modified'),
268                                               ext, decFile)
269         response.stream.doneDefer.addErrback(self.save_error, url)
270         return response
271
272     def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
273         """Update the modification time and AptPackages."""
274         if modtime:
275             os.utime(destFile.path, (modtime, modtime))
276             if ext:
277                 os.utime(decFile.path, (modtime, modtime))
278         
279         result = hash.verify()
280         if result or result is None:
281             if result:
282                 log.msg('Hashes match: %s' % url)
283             else:
284                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
285                 
286             mirror_dir = self.cache_dir.child(destFile.path[len(self.cache_dir.path)+1:].split('/', 1)[0])
287             urlpath, newdir = self.db.storeFile(destFile, hash.digest(), mirror_dir)
288             log.msg('now avaliable at %s: %s' % (urlpath, url))
289
290             if self.manager:
291                 if newdir:
292                     log.msg('A new web directory was created, so enable it')
293                     self.manager.setDirectories(self.db.getAllDirectories())
294     
295                 self.manager.new_cached_file(destFile, hash, urlpath, url)
296                 if ext:
297                     self.manager.new_cached_file(decFile, None, urlpath, url[:-len(ext)])
298         else:
299             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
300             destFile.remove()
301             if ext:
302                 decFile.remove()
303
304     def save_error(self, failure, url):
305         """An error has occurred in downloadign or saving the file."""
306         log.msg('Error occurred downloading %s' % url)
307         log.err(failure)
308         return failure
309
310 class TestMirrorManager(unittest.TestCase):
311     """Unit tests for the mirror manager."""
312     
313     timeout = 20
314     pending_calls = []
315     client = None
316     
317     def setUp(self):
318         self.client = CacheManager(FilePath('/tmp/.apt-dht'))
319         
320     def tearDown(self):
321         for p in self.pending_calls:
322             if p.active():
323                 p.cancel()
324         self.client = None
325