Reset the HTTPServer subdirectories when a new cache directory is created.
[quix0rs-apt-p2p.git] / apt_dht / CacheManager.py
1
2 from bz2 import BZ2Decompressor
3 from zlib import decompressobj, MAX_WBITS
4 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
5 from urlparse import urlparse
6 import os
7
8 from twisted.python import log
9 from twisted.python.filepath import FilePath
10 from twisted.internet import defer
11 from twisted.trial import unittest
12 from twisted.web2 import stream
13 from twisted.web2.http import splitHostPort
14
15 from AptPackages import AptPackages
16
17 aptpkg_dir='apt-packages'
18
19 DECOMPRESS_EXTS = ['.gz', '.bz2']
20 DECOMPRESS_FILES = ['release', 'sources', 'packages']
21
22 class ProxyFileStream(stream.SimpleStream):
23     """Saves a stream to a file while providing a new stream."""
24     
25     def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
26         """Initializes the proxy.
27         
28         @type stream: C{twisted.web2.stream.IByteStream}
29         @param stream: the input stream to read from
30         @type outFile: C{twisted.python.FilePath}
31         @param outFile: the file to write to
32         @type hash: L{Hash.HashObject}
33         @param hash: the hash object to use for the file
34         @type decompress: C{string}
35         @param decompress: also decompress the file as this type
36             (currently only '.gz' and '.bz2' are supported)
37         @type decFile: C{twisted.python.FilePath}
38         @param decFile: the file to write the decompressed data to
39         """
40         self.stream = stream
41         self.outFile = outFile.open('w')
42         self.hash = hash
43         self.hash.new()
44         self.gzfile = None
45         self.bz2file = None
46         if decompress == ".gz":
47             self.gzheader = True
48             self.gzfile = decFile.open('w')
49             self.gzdec = decompressobj(-MAX_WBITS)
50         elif decompress == ".bz2":
51             self.bz2file = decFile.open('w')
52             self.bz2dec = BZ2Decompressor()
53         self.length = self.stream.length
54         self.start = 0
55         self.doneDefer = defer.Deferred()
56
57     def _done(self):
58         """Close the output file."""
59         if not self.outFile.closed:
60             self.outFile.close()
61             self.hash.digest()
62             if self.gzfile:
63                 data_dec = self.gzdec.flush()
64                 self.gzfile.write(data_dec)
65                 self.gzfile.close()
66                 self.gzfile = None
67             if self.bz2file:
68                 self.bz2file.close()
69                 self.bz2file = None
70                 
71             self.doneDefer.callback(self.hash)
72     
73     def read(self):
74         """Read some data from the stream."""
75         if self.outFile.closed:
76             return None
77         
78         data = self.stream.read()
79         if isinstance(data, defer.Deferred):
80             data.addCallbacks(self._write, self._done)
81             return data
82         
83         self._write(data)
84         return data
85     
86     def _write(self, data):
87         """Write the stream data to the file and return it for others to use."""
88         if data is None:
89             self._done()
90             return data
91         
92         self.outFile.write(data)
93         self.hash.update(data)
94         if self.gzfile:
95             if self.gzheader:
96                 self.gzheader = False
97                 new_data = self._remove_gzip_header(data)
98                 dec_data = self.gzdec.decompress(new_data)
99             else:
100                 dec_data = self.gzdec.decompress(data)
101             self.gzfile.write(dec_data)
102         if self.bz2file:
103             dec_data = self.bz2dec.decompress(data)
104             self.bz2file.write(dec_data)
105         return data
106     
107     def _remove_gzip_header(self, data):
108         if data[:2] != '\037\213':
109             raise IOError, 'Not a gzipped file'
110         if ord(data[2]) != 8:
111             raise IOError, 'Unknown compression method'
112         flag = ord(data[3])
113         # modtime = self.fileobj.read(4)
114         # extraflag = self.fileobj.read(1)
115         # os = self.fileobj.read(1)
116
117         skip = 10
118         if flag & FEXTRA:
119             # Read & discard the extra field, if present
120             xlen = ord(data[10])
121             xlen = xlen + 256*ord(data[11])
122             skip = skip + 2 + xlen
123         if flag & FNAME:
124             # Read and discard a null-terminated string containing the filename
125             while True:
126                 if not data[skip] or data[skip] == '\000':
127                     break
128                 skip += 1
129             skip += 1
130         if flag & FCOMMENT:
131             # Read and discard a null-terminated string containing a comment
132             while True:
133                 if not data[skip] or data[skip] == '\000':
134                     break
135                 skip += 1
136             skip += 1
137         if flag & FHCRC:
138             skip += 2     # Read & discard the 16-bit header CRC
139         return data[skip:]
140
141     def close(self):
142         """Clean everything up and return None to future reads."""
143         self.length = 0
144         self._done()
145         self.stream.close()
146
147 class CacheManager:
148     """Manages all requests for cached objects."""
149     
150     def __init__(self, cache_dir, db, manager = None):
151         self.cache_dir = cache_dir
152         self.db = db
153         self.manager = manager
154     
155     def save_file(self, response, hash, url):
156         """Save a downloaded file to the cache and stream it."""
157         if response.code != 200:
158             log.msg('File was not found (%r): %s' % (response, url))
159             return response
160         
161         log.msg('Returning file: %s' % url)
162         
163         parsed = urlparse(url)
164         destFile = self.cache_dir.preauthChild(parsed[1] + parsed[2])
165         log.msg('Saving returned %r byte file to cache: %s' % (response.stream.length, destFile.path))
166         
167         if destFile.exists():
168             log.msg('File already exists, removing: %s' % destFile.path)
169             destFile.remove()
170         elif not destFile.parent().exists():
171             destFile.parent().makedirs()
172             
173         root, ext = os.path.splitext(destFile.basename())
174         if root.lower() in DECOMPRESS_FILES and ext.lower() in DECOMPRESS_EXTS:
175             ext = ext.lower()
176             decFile = destFile.sibling(root)
177             log.msg('Decompressing to: %s' % decFile.path)
178             if decFile.exists():
179                 log.msg('File already exists, removing: %s' % decFile.path)
180                 decFile.remove()
181         else:
182             ext = None
183             decFile = None
184             
185         orig_stream = response.stream
186         response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
187         response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
188                                               response.headers.getHeader('Last-Modified'),
189                                               ext, decFile)
190         response.stream.doneDefer.addErrback(self.save_error, url)
191         return response
192
193     def _save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
194         """Update the modification time and AptPackages."""
195         if modtime:
196             os.utime(destFile.path, (modtime, modtime))
197             if ext:
198                 os.utime(decFile.path, (modtime, modtime))
199         
200         result = hash.verify()
201         if result or result is None:
202             if result:
203                 log.msg('Hashes match: %s' % url)
204             else:
205                 log.msg('Hashed file to %s: %s' % (hash.hexdigest(), url))
206                 
207             urlpath, newdir = self.db.storeFile(destFile, hash.digest(), self.cache_dir)
208             log.msg('now avaliable at %s: %s' % (urlpath, url))
209             if newdir and self.manager:
210                 log.msg('A new web directory was created, so enable it')
211                 self.manager.setDirectories(self.db.getAllDirectories())
212
213             if self.manager:
214                 self.manager.new_cached_file(url, destFile, hash, urlpath)
215                 if ext:
216                     self.manager.new_cached_file(url[:-len(ext)], decFile, None, urlpath)
217         else:
218             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
219             destFile.remove()
220             if ext:
221                 decFile.remove()
222
223     def save_error(self, failure, url):
224         """An error has occurred in downloadign or saving the file."""
225         log.msg('Error occurred downloading %s' % url)
226         log.err(failure)
227         return failure
228
229 class TestMirrorManager(unittest.TestCase):
230     """Unit tests for the mirror manager."""
231     
232     timeout = 20
233     pending_calls = []
234     client = None
235     
236     def setUp(self):
237         self.client = CacheManager(FilePath('/tmp/.apt-dht'))
238         
239     def tearDown(self):
240         for p in self.pending_calls:
241             if p.active():
242                 p.cancel()
243         self.client = None
244