Move all streams to new Streams module and replace ProxyFileStream with GrowingFileSt...
[quix0rs-apt-p2p.git] / apt_p2p / CacheManager.py
index b991093f61c3266c50f27cb93188854abb7960bf..42f89d926c69ae6d1d96acda2f4c8f70b56ebce2 100644 (file)
@@ -5,9 +5,6 @@
 @var DECOMPRESS_FILES: a list of file names that need to be decompressed
 """
 
-from bz2 import BZ2Decompressor
-from zlib import decompressobj, MAX_WBITS
-from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
 from urlparse import urlparse
 import os
 
@@ -15,9 +12,9 @@ from twisted.python import log
 from twisted.python.filepath import FilePath
 from twisted.internet import defer, reactor
 from twisted.trial import unittest
-from twisted.web2 import stream
 from twisted.web2.http import splitHostPort
 
+from Streams import GrowingFileStream, StreamToFile
 from Hash import HashObject
 from apt_p2p_conf import config
 
@@ -27,184 +24,6 @@ DECOMPRESS_FILES = ['release', 'sources', 'packages']
 class CacheError(Exception):
     """Error occurred downloading a file to the cache."""
 
-class ProxyFileStream(stream.SimpleStream):
-    """Saves a stream to a file while providing a new stream.
-    
-    Also optionally decompresses the file while it is being downloaded.
-    
-    @type stream: L{twisted.web2.stream.IByteStream}
-    @ivar stream: the input stream being read
-    @type outFile: L{twisted.python.filepath.FilePath}
-    @ivar outFile: the file being written
-    @type hash: L{Hash.HashObject}
-    @ivar hash: the hash object for the file
-    @type gzfile: C{file}
-    @ivar gzfile: the open file to write decompressed gzip data to
-    @type gzdec: L{zlib.decompressobj}
-    @ivar gzdec: the decompressor to use for the compressed gzip data
-    @type gzheader: C{boolean}
-    @ivar gzheader: whether the gzip header still needs to be removed from
-        the zlib compressed data
-    @type bz2file: C{file}
-    @ivar bz2file: the open file to write decompressed bz2 data to
-    @type bz2dec: L{bz2.BZ2Decompressor}
-    @ivar bz2dec: the decompressor to use for the compressed bz2 data
-    @type length: C{int}
-    @ivar length: the length of the original (compressed) file
-    @type doneDefer: L{twisted.internet.defer.Deferred}
-    @ivar doneDefer: the deferred that will fire when done streaming
-    
-    @group Stream implementation: read, close
-    
-    """
-    
-    def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
-        """Initializes the proxy.
-        
-        @type stream: L{twisted.web2.stream.IByteStream}
-        @param stream: the input stream to read from
-        @type outFile: L{twisted.python.filepath.FilePath}
-        @param outFile: the file to write to
-        @type hash: L{Hash.HashObject}
-        @param hash: the hash object to use for the file
-        @type decompress: C{string}
-        @param decompress: also decompress the file as this type
-            (currently only '.gz' and '.bz2' are supported)
-        @type decFile: C{twisted.python.FilePath}
-        @param decFile: the file to write the decompressed data to
-        """
-        self.stream = stream
-        self.outFile = outFile.open('w')
-        self.hash = hash
-        self.hash.new()
-        self.gzfile = None
-        self.bz2file = None
-        if decompress == ".gz":
-            self.gzheader = True
-            self.gzfile = decFile.open('w')
-            self.gzdec = decompressobj(-MAX_WBITS)
-        elif decompress == ".bz2":
-            self.bz2file = decFile.open('w')
-            self.bz2dec = BZ2Decompressor()
-        self.length = self.stream.length
-        self.doneDefer = defer.Deferred()
-
-    def _done(self):
-        """Close all the output files, return the result."""
-        if not self.outFile.closed:
-            self.outFile.close()
-            self.hash.digest()
-            if self.gzfile:
-                # Finish the decompression
-                data_dec = self.gzdec.flush()
-                self.gzfile.write(data_dec)
-                self.gzfile.close()
-                self.gzfile = None
-            if self.bz2file:
-                self.bz2file.close()
-                self.bz2file = None
-    
-    def _error(self, err):
-        """Close all the output files, return the error."""
-        if not self.outFile.closed:
-            self._done()
-            self.stream.close()
-            self.doneDefer.errback(err)
-
-    def read(self):
-        """Read some data from the stream."""
-        if self.outFile.closed:
-            return None
-        
-        # Read data from the stream, deal with the possible deferred
-        data = self.stream.read()
-        if isinstance(data, defer.Deferred):
-            data.addCallbacks(self._write, self._error)
-            return data
-        
-        self._write(data)
-        return data
-    
-    def _write(self, data):
-        """Write the stream data to the file and return it for others to use.
-        
-        Also optionally decompresses it.
-        """
-        if data is None:
-            if not self.outFile.closed:
-                self._done()
-                self.doneDefer.callback(self.hash)
-            return data
-        
-        # Write and hash the streamed data
-        self.outFile.write(data)
-        self.hash.update(data)
-        
-        if self.gzfile:
-            # Decompress the zlib portion of the file
-            if self.gzheader:
-                # Remove the gzip header junk
-                self.gzheader = False
-                new_data = self._remove_gzip_header(data)
-                dec_data = self.gzdec.decompress(new_data)
-            else:
-                dec_data = self.gzdec.decompress(data)
-            self.gzfile.write(dec_data)
-        if self.bz2file:
-            # Decompress the bz2 file
-            dec_data = self.bz2dec.decompress(data)
-            self.bz2file.write(dec_data)
-
-        return data
-    
-    def _remove_gzip_header(self, data):
-        """Remove the gzip header from the zlib compressed data."""
-        # Read, check & discard the header fields
-        if data[:2] != '\037\213':
-            raise IOError, 'Not a gzipped file'
-        if ord(data[2]) != 8:
-            raise IOError, 'Unknown compression method'
-        flag = ord(data[3])
-        # modtime = self.fileobj.read(4)
-        # extraflag = self.fileobj.read(1)
-        # os = self.fileobj.read(1)
-
-        skip = 10
-        if flag & FEXTRA:
-            # Read & discard the extra field
-            xlen = ord(data[10])
-            xlen = xlen + 256*ord(data[11])
-            skip = skip + 2 + xlen
-        if flag & FNAME:
-            # Read and discard a null-terminated string containing the filename
-            while True:
-                if not data[skip] or data[skip] == '\000':
-                    break
-                skip += 1
-            skip += 1
-        if flag & FCOMMENT:
-            # Read and discard a null-terminated string containing a comment
-            while True:
-                if not data[skip] or data[skip] == '\000':
-                    break
-                skip += 1
-            skip += 1
-        if flag & FHCRC:
-            skip += 2     # Read & discard the 16-bit header CRC
-
-        return data[skip:]
-
-    def close(self):
-        """Clean everything up and return None to future reads."""
-        log.msg('ProxyFileStream was prematurely closed after only %d/%d bytes' % (self.hash.size, self.length))
-        if self.hash.size < self.length:
-            self._error(CacheError('Prematurely closed, all data was not written'))
-        elif not self.outFile.closed:
-            self._done()
-            self.doneDefer.callback(self.hash)
-        self.length = 0
-        self.stream.close()
-
 class CacheManager:
     """Manages all downloaded files and requests for cached objects.
     
@@ -377,16 +196,21 @@ class CacheManager:
             
         # Create the new stream from the old one.
         orig_stream = response.stream
-        response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
-        response.stream.doneDefer.addCallback(self._save_complete, url, destFile,
-                                              response.headers.getHeader('Last-Modified'),
-                                              decFile)
-        response.stream.doneDefer.addErrback(self._save_error, url, destFile, decFile)
+        f = destFile.open('w+')
+        new_stream = GrowingFileStream(f, orig_stream.length)
+        hash.new()
+        df = StreamToFile(hash, orig_stream, f, notify = new_stream.updateAvailable,
+                          decompress = ext, decFile = decFile).run()
+        df.addCallback(self._save_complete, url, destFile, new_stream,
+                       response.headers.getHeader('Last-Modified'), decFile)
+        df.addErrback(self._save_error, url, destFile, new_stream, decFile)
+        response.stream = new_stream
 
         # Return the modified response with the new stream
         return response
 
-    def _save_complete(self, hash, url, destFile, modtime = None, decFile = None):
+    def _save_complete(self, hash, url, destFile, destStream = None,
+                       modtime = None, decFile = None):
         """Update the modification time and inform the main program.
         
         @type hash: L{Hash.HashObject}
@@ -394,6 +218,8 @@ class CacheManager:
         @param url: the URI of the actual mirror request
         @type destFile: C{twisted.python.FilePath}
         @param destFile: the file where the download was written to
+        @type destStream: L{Streams.GrowingFileStream}
+        @param destStream: the stream to notify that all data is available
         @type modtime: C{int}
         @param modtime: the modified time of the cached file (seconds since epoch)
             (optional, defaults to not setting the modification time of the file)
@@ -403,6 +229,8 @@ class CacheManager:
         """
         result = hash.verify()
         if result or result is None:
+            if destStream:
+                destStream.allAvailable()
             if modtime:
                 os.utime(destFile.path, (modtime, modtime))
             
@@ -424,22 +252,26 @@ class CacheManager:
                 decHash = HashObject()
                 ext_len = len(destFile.path) - len(decFile.path)
                 df = decHash.hashInThread(decFile)
-                df.addCallback(self._save_complete, url[:-ext_len], decFile, modtime)
+                df.addCallback(self._save_complete, url[:-ext_len], decFile, modtime = modtime)
                 df.addErrback(self._save_error, url[:-ext_len], decFile)
         else:
             log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
-            destFile.remove()
+            if destStream:
+                destStream.allAvailable(remove = True)
             if decFile:
                 decFile.remove()
 
-    def _save_error(self, failure, url, destFile, decFile = None):
+    def _save_error(self, failure, url, destFile, destStream = None, decFile = None):
         """Remove the destination files."""
         log.msg('Error occurred downloading %s' % url)
         log.err(failure)
-        destFile.restat(False)
-        if destFile.exists():
-            log.msg('Removing the incomplete file: %s' % destFile.path)
-            destFile.remove()
+        if destStream:
+            destStream.allAvailable(remove = True)
+        else:
+            destFile.restat(False)
+            if destFile.exists():
+                log.msg('Removing the incomplete file: %s' % destFile.path)
+                destFile.remove()
         if decFile:
             decFile.restat(False)
             if decFile.exists():