]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht/Hash.py
Added piece hashing to the HashObject.
[quix0rs-apt-p2p.git] / apt_dht / Hash.py
index 270e61150ecf7cdc3bd71d5e0384041ae1620119..ec985989f4eabfcee8f8a2b18f3a6d19eeb7aec3 100644 (file)
@@ -2,8 +2,14 @@
 from binascii import b2a_hex, a2b_hex
 import sys
 
+from twisted.internet import threads, defer
 from twisted.trial import unittest
 
+PIECE_SIZE = 512*1024
+
+class HashError(ValueError):
+    """An error has occurred while hashing a file."""
+    
 class HashObject:
     """Manages hashes and hashing for a file."""
     
@@ -24,23 +30,27 @@ class HashObject:
               {'name': 'md5',
                    'AptPkgRecord': 'MD5Hash', 
                    'AptSrcRecord': True, 
-                   'AptIndexRecord': 'MD5Sum',
+                   'AptIndexRecord': 'MD5SUM',
                    'old_module': 'md5',
                    'hashlib_func': 'md5',
                    },
             ]
     
-    def __init__(self):
+    def __init__(self, digest = None, size = None):
         self.hashTypeNum = 0    # Use the first if nothing else matters
         self.expHash = None
         self.expHex = None
         self.expSize = None
         self.expNormHash = None
         self.fileHasher = None
-        self.fileHash = None
+        self.pieceHasher = None
+        self.fileHash = digest
+        self.pieceHash = []
+        self.size = size
         self.fileHex = None
         self.fileNormHash = None
         self.done = True
+        self.result = None
         if sys.version_info < (2, 5):
             # sha256 is not available in python before 2.5, remove it
             for hashType in self.ORDER:
@@ -52,7 +62,8 @@ class HashObject:
         if bits is not None:
             bytes = (bits - 1) // 8 + 1
         else:
-            assert(bytes is not None)
+            if bytes is None:
+                raise HashError, "you must specify one of bits or bytes"
         if len(hashString) < bytes:
             hashString = hashString + '\000'*(bytes - len(hashString))
         elif len(hashString) > bytes:
@@ -80,32 +91,91 @@ class HashObject:
         return self.expNormHash
 
     #### Methods for hashing data
-    def new(self):
-        """Generate a new hashing object suitable for hashing a file."""
-        self.size = 0
-        self.done = False
+    def new(self, force = False):
+        """Generate a new hashing object suitable for hashing a file.
+        
+        @param force: set to True to force creating a new hasher even if
+            the hash has been verified already
+        """
+        if self.result is None or force == True:
+            self.result = None
+            self.done = False
+            self.fileHasher = self._new()
+            self.pieceHasher = None
+            self.fileHash = None
+            self.pieceHash = []
+            self.size = 0
+            self.fileHex = None
+            self.fileNormHash = None
+
+    def _new(self):
+        """Create a new hashing object according to the hash type."""
         if sys.version_info < (2, 5):
             mod = __import__(self.ORDER[self.hashTypeNum]['old_module'], globals(), locals(), [])
-            self.fileHasher = mod.new()
+            return mod.new()
         else:
             import hashlib
             func = getattr(hashlib, self.ORDER[self.hashTypeNum]['hashlib_func'])
-            self.fileHasher = func()
-        return self.fileHasher
+            return func()
 
     def update(self, data):
         """Add more data to the file hasher."""
-        assert self.done == False, "Already done, you can't add more data after calling digest() or verify()"
-        assert self.fileHasher is not None, "file hasher not initialized"
-        self.fileHasher.update(data)
-        self.size += len(data)
+        if self.result is None:
+            if self.done:
+                raise HashError, "Already done, you can't add more data after calling digest() or verify()"
+            if self.fileHasher is None:
+                raise HashError, "file hasher not initialized"
+            
+            if not self.pieceHasher and self.size + len(data) > PIECE_SIZE:
+                # Hash up to the piece size
+                self.fileHasher.update(data[:(PIECE_SIZE - self.size)])
+                data = data[(PIECE_SIZE - self.size):]
+                self.size = PIECE_SIZE
+
+                # Save the first piece digest and initialize a new piece hasher
+                self.pieceHash.append(self.fileHasher.digest())
+                self.pieceHasher = self._new()
+
+            if self.pieceHasher:
+                # Loop in case the data contains multiple pieces
+                piece_size = self.size % PIECE_SIZE
+                while piece_size + len(data) > PIECE_SIZE:
+                    # Save the piece hash and start a new one
+                    self.pieceHasher.update(data[:(PIECE_SIZE - piece_size)])
+                    self.pieceHash.append(self.pieceHasher.digest())
+                    self.pieceHasher = self._new()
+                    
+                    # Don't forget to hash the data normally
+                    self.fileHasher.update(data[:(PIECE_SIZE - piece_size)])
+                    data = data[(PIECE_SIZE - piece_size):]
+                    self.size += PIECE_SIZE - piece_size
+                    piece_size = self.size % PIECE_SIZE
+
+                # Hash any remaining data
+                self.pieceHasher.update(data)
+            
+            self.fileHasher.update(data)
+            self.size += len(data)
         
+    def pieceDigests(self):
+        """Get the piece hashes of the added file data."""
+        self.digest()
+        return self.pieceHash
+
     def digest(self):
         """Get the hash of the added file data."""
         if self.fileHash is None:
-            assert self.fileHasher is not None, "you must hash some data first"
+            if self.fileHasher is None:
+                raise HashError, "you must hash some data first"
             self.fileHash = self.fileHasher.digest()
             self.done = True
+            
+            # Save the last piece hash
+            if self.pieceHasher:
+                self.pieceHash.append(self.pieceHasher.digest())
+            else:
+                # If there are no piece hashes, then the file hash is the only piece hash
+                self.pieceHash.append(self.fileHash)
         return self.fileHash
 
     def hexdigest(self):
@@ -125,10 +195,32 @@ class HashObject:
 
     def verify(self):
         """Verify that the added file data hash matches the expected hash."""
-        if self.fileHash == None:
-            return None
-        return (self.fileHash == self.expHash and self.size == self.expSize)
+        if self.result is None and self.fileHash is not None and self.expHash is not None:
+            self.result = (self.fileHash == self.expHash and self.size == self.expSize)
+        return self.result
+    
+    def hashInThread(self, file):
+        """Hashes a file in a separate thread, callback with the result."""
+        file.restat(False)
+        if not file.exists():
+            df = defer.Deferred()
+            df.errback(HashError("file not found"))
+            return df
+        
+        df = threads.deferToThread(self._hashInThread, file)
+        return df
     
+    def _hashInThread(self, file):
+        """Hashes a file, returning itself as the result."""
+        f = file.open()
+        self.new(force = True)
+        data = f.read(4096)
+        while data:
+            self.update(data)
+            data = f.read(4096)
+        self.digest()
+        return self
+
     #### Methods for setting the expected hash
     def set(self, hashType, hashHex, size):
         """Initialize the hash object.
@@ -137,7 +229,7 @@ class HashObject:
         """
         self.hashTypeNum = self.ORDER.index(hashType)    # error if not found
         self.expHex = hashHex
-        self.expSize = size
+        self.expSize = int(size)
         self.expHash = a2b_hex(self.expHex)
         
     def setFromIndexRecord(self, record):
@@ -188,29 +280,49 @@ class TestHashObject(unittest.TestCase):
     
     def test_normalize(self):
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('12345678901234567890'), 0)
+        h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('12345678901234567'), 0)
+        h.set(h.ORDER[0], b2a_hex('12345678901234567'), '0')
         self.failUnless(h.normexpected(bits = 160) == '12345678901234567\000\000\000')
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('1234567890123456789012345'), 0)
+        h.set(h.ORDER[0], b2a_hex('1234567890123456789012345'), '0')
         self.failUnless(h.normexpected(bytes = 20) == '12345678901234567890')
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('1234567890123456789'), 0)
+        h.set(h.ORDER[0], b2a_hex('1234567890123456789'), '0')
         self.failUnless(h.normexpected(bytes = 20) == '1234567890123456789\000')
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('123456789012345678901'), 0)
+        h.set(h.ORDER[0], b2a_hex('123456789012345678901'), '0')
         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
 
     def test_failure(self):
         h = HashObject()
-        h.set(h.ORDER[0], b2a_hex('12345678901234567890'), 0)
-        self.failUnlessRaises(AssertionError, h.normexpected)
-        self.failUnlessRaises(AssertionError, h.digest)
-        self.failUnlessRaises(AssertionError, h.hexdigest)
-        self.failUnlessRaises(AssertionError, h.update, 'gfgf')
+        h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
+        self.failUnlessRaises(HashError, h.normexpected)
+        self.failUnlessRaises(HashError, h.digest)
+        self.failUnlessRaises(HashError, h.hexdigest)
+        self.failUnlessRaises(HashError, h.update, 'gfgf')
     
+    def test_pieces(self):
+        h = HashObject()
+        h.new()
+        h.update('1234567890'*120*1024)
+        self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
+        pieces = h.pieceDigests()
+        self.failUnless(len(pieces) == 3)
+        self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
+        self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
+        self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
+        h.new(True)
+        for i in xrange(120*1024):
+            h.update('1234567890')
+        pieces = h.pieceDigests()
+        self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
+        self.failUnless(len(pieces) == 3)
+        self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
+        self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
+        self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
+        
     def test_sha1(self):
         h = HashObject()
         found = False
@@ -219,11 +331,11 @@ class TestHashObject(unittest.TestCase):
                 found = True
                 break
         self.failUnless(found == True)
-        h.set(hashType, 'c722df87e1acaa64b27aac4e174077afc3623540', 19)
+        h.set(hashType, 'c722df87e1acaa64b27aac4e174077afc3623540', '19')
         h.new()
         h.update('apt-dht is the best')
         self.failUnless(h.hexdigest() == 'c722df87e1acaa64b27aac4e174077afc3623540')
-        self.failUnlessRaises(AssertionError, h.update, 'gfgf')
+        self.failUnlessRaises(HashError, h.update, 'gfgf')
         self.failUnless(h.verify() == True)
         
     def test_md5(self):
@@ -234,11 +346,11 @@ class TestHashObject(unittest.TestCase):
                 found = True
                 break
         self.failUnless(found == True)
-        h.set(hashType, '2a586bcd1befc5082c872dcd96a01403', 19)
+        h.set(hashType, '2a586bcd1befc5082c872dcd96a01403', '19')
         h.new()
         h.update('apt-dht is the best')
         self.failUnless(h.hexdigest() == '2a586bcd1befc5082c872dcd96a01403')
-        self.failUnlessRaises(AssertionError, h.update, 'gfgf')
+        self.failUnlessRaises(HashError, h.update, 'gfgf')
         self.failUnless(h.verify() == True)
         
     def test_sha256(self):
@@ -249,12 +361,12 @@ class TestHashObject(unittest.TestCase):
                 found = True
                 break
         self.failUnless(found == True)
-        h.set(hashType, '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7', 19)
+        h.set(hashType, '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7', '19')
         h.new()
         h.update('apt-dht is the best')
         self.failUnless(h.hexdigest() == '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7')
-        self.failUnlessRaises(AssertionError, h.update, 'gfgf')
+        self.failUnlessRaises(HashError, h.update, 'gfgf')
         self.failUnless(h.verify() == True)
 
     if sys.version_info < (2, 5):
-        test_sha256.skip = "SHA256 hashes are not supported on python until version 2.5"
+        test_sha256.skip = "SHA256 hashes are not supported by Python until version 2.5"