ProxyFileStream also calculates hash while downloading.
[quix0rs-apt-p2p.git] / apt_dht / MirrorManager.py
index c89a5915bc6622fdddd9fc4394656aa77732420f..d56e16da16520f259fb1d12b8f44a7f6bac2b537 100644 (file)
@@ -3,8 +3,8 @@ from bz2 import BZ2Decompressor
 from zlib import decompressobj, MAX_WBITS
 from gzip import FCOMMENT, FEXTRA, FHCRC, FNAME, FTEXT
 from urlparse import urlparse
-from binascii import a2b_hex
-import os
+from binascii import a2b_hex, b2a_hex
+import os, sha, md5
 
 from twisted.python import log, filepath
 from twisted.internet import defer
@@ -25,13 +25,16 @@ class MirrorError(Exception):
 class ProxyFileStream(stream.SimpleStream):
     """Saves a stream to a file while providing a new stream."""
     
-    def __init__(self, stream, outFile, decompress = None, decFile = None):
+    def __init__(self, stream, outFile, hashType = "sha1", decompress = None, decFile = None):
         """Initializes the proxy.
         
         @type stream: C{twisted.web2.stream.IByteStream}
         @param stream: the input stream to read from
         @type outFile: C{twisted.python.filepath.FilePath}
         @param outFile: the file to write to
+        @type hashType: C{string}
+        @param hashType: also hash the file using this hashing function
+            (currently only 'sha1' and 'md5' are supported)
         @type decompress: C{string}
         @param decompress: also decompress the file as this type
             (currently only '.gz' and '.bz2' are supported)
@@ -40,6 +43,11 @@ class ProxyFileStream(stream.SimpleStream):
         """
         self.stream = stream
         self.outFile = outFile.open('w')
+        self.hasher = None
+        if hashType == "sha1":
+            self.hasher = sha.new()
+        elif hashType == "md5":
+            self.hasher = md5.new()
         self.gzfile = None
         self.bz2file = None
         if decompress == ".gz":
@@ -57,6 +65,9 @@ class ProxyFileStream(stream.SimpleStream):
         """Close the output file."""
         if not self.outFile.closed:
             self.outFile.close()
+            fileHash = None
+            if self.hasher:
+                fileHash = self.hasher.digest()
             if self.gzfile:
                 data_dec = self.gzdec.flush()
                 self.gzfile.write(data_dec)
@@ -66,7 +77,7 @@ class ProxyFileStream(stream.SimpleStream):
                 self.bz2file.close()
                 self.bz2file = None
                 
-            self.doneDefer.callback(1)
+            self.doneDefer.callback(fileHash)
     
     def read(self):
         """Read some data from the stream."""
@@ -88,6 +99,8 @@ class ProxyFileStream(stream.SimpleStream):
             return data
         
         self.outFile.write(data)
+        if self.hasher:
+            self.hasher.update(data)
         if self.gzfile:
             if self.gzheader:
                 self.gzheader = False
@@ -233,25 +246,38 @@ class MirrorManager:
         else:
             ext = None
             decFile = None
+            
+        if hash and len(hash) == 16:
+            hashType = "md5"
+        else:
+            hashType = "sha1"
         
         orig_stream = response.stream
-        response.stream = ProxyFileStream(orig_stream, destFile, ext, decFile)
-        response.stream.doneDefer.addCallback(self.save_complete, url, destFile,
+        response.stream = ProxyFileStream(orig_stream, destFile, hashType, ext, decFile)
+        response.stream.doneDefer.addCallback(self.save_complete, hash, size, url, destFile,
                                               response.headers.getHeader('Last-Modified'),
                                               ext, decFile)
         response.stream.doneDefer.addErrback(self.save_error, url)
         return response
 
-    def save_complete(self, result, url, destFile, modtime = None, ext = None, decFile = None):
+    def save_complete(self, result, hash, size, url, destFile, modtime = None, ext = None, decFile = None):
         """Update the modification time and AptPackages."""
         if modtime:
             os.utime(destFile.path, (modtime, modtime))
             if ext:
                 os.utime(decFile.path, (modtime, modtime))
-            
-        self.updatedFile(url, destFile.path)
-        if ext:
-            self.updatedFile(url[:-len(ext)], decFile.path)
+        
+        if not hash or result == hash:
+            if hash:
+                log.msg('Hashes match: %s' % url)
+            else:
+                log.msg('Hashed file to %s: %s' % (b2a_hex(result), url))
+                
+            self.updatedFile(url, destFile.path)
+            if ext:
+                self.updatedFile(url[:-len(ext)], decFile.path)
+        else:
+            log.msg("Hashes don't match %s != %s: %s" % (b2a_hex(hash), b2a_hex(result), url))
 
     def save_error(self, failure, url):
         """An error has occurred in downloadign or saving the file."""