Move the normalization of key lengths from the HashObject to the DHT.
[quix0rs-apt-p2p.git] / apt_dht / Hash.py
1
2 from binascii import b2a_hex, a2b_hex
3 import sys
4
5 from twisted.internet import threads, defer
6 from twisted.trial import unittest
7
8 PIECE_SIZE = 512*1024
9
10 class HashError(ValueError):
11     """An error has occurred while hashing a file."""
12     
13 class HashObject:
14     """Manages hashes and hashing for a file."""
15     
16     """The priority ordering of hashes, and how to extract them."""
17     ORDER = [ {'name': 'sha1', 
18                    'length': 20,
19                    'AptPkgRecord': 'SHA1Hash', 
20                    'AptSrcRecord': False, 
21                    'AptIndexRecord': 'SHA1',
22                    'old_module': 'sha',
23                    'hashlib_func': 'sha1',
24                    },
25               {'name': 'sha256',
26                    'length': 32,
27                    'AptPkgRecord': 'SHA256Hash', 
28                    'AptSrcRecord': False, 
29                    'AptIndexRecord': 'SHA256',
30                    'hashlib_func': 'sha256',
31                    },
32               {'name': 'md5',
33                    'length': 16,
34                    'AptPkgRecord': 'MD5Hash', 
35                    'AptSrcRecord': True, 
36                    'AptIndexRecord': 'MD5SUM',
37                    'old_module': 'md5',
38                    'hashlib_func': 'md5',
39                    },
40             ]
41     
42     def __init__(self, digest = None, size = None, pieces = ''):
43         self.hashTypeNum = 0    # Use the first if nothing else matters
44         if sys.version_info < (2, 5):
45             # sha256 is not available in python before 2.5, remove it
46             for hashType in self.ORDER:
47                 if hashType['name'] == 'sha256':
48                     del self.ORDER[self.ORDER.index(hashType)]
49                     break
50
51         self.expHash = None
52         self.expHex = None
53         self.expSize = None
54         self.expNormHash = None
55         self.fileHasher = None
56         self.pieceHasher = None
57         self.fileHash = digest
58         self.pieceHash = [pieces[x:x+self.ORDER[self.hashTypeNum]['length']]
59                           for x in xrange(0, len(pieces), self.ORDER[self.hashTypeNum]['length'])]
60         self.size = size
61         self.fileHex = None
62         self.fileNormHash = None
63         self.done = True
64         self.result = None
65         
66     #### Methods for returning the expected hash
67     def expected(self):
68         """Get the expected hash."""
69         return self.expHash
70     
71     def hexexpected(self):
72         """Get the expected hash in hex format."""
73         if self.expHex is None and self.expHash is not None:
74             self.expHex = b2a_hex(self.expHash)
75         return self.expHex
76     
77     #### Methods for hashing data
78     def new(self, force = False):
79         """Generate a new hashing object suitable for hashing a file.
80         
81         @param force: set to True to force creating a new hasher even if
82             the hash has been verified already
83         """
84         if self.result is None or force == True:
85             self.result = None
86             self.done = False
87             self.fileHasher = self._new()
88             self.pieceHasher = None
89             self.fileHash = None
90             self.pieceHash = []
91             self.size = 0
92             self.fileHex = None
93             self.fileNormHash = None
94
95     def _new(self):
96         """Create a new hashing object according to the hash type."""
97         if sys.version_info < (2, 5):
98             mod = __import__(self.ORDER[self.hashTypeNum]['old_module'], globals(), locals(), [])
99             return mod.new()
100         else:
101             import hashlib
102             func = getattr(hashlib, self.ORDER[self.hashTypeNum]['hashlib_func'])
103             return func()
104
105     def update(self, data):
106         """Add more data to the file hasher."""
107         if self.result is None:
108             if self.done:
109                 raise HashError, "Already done, you can't add more data after calling digest() or verify()"
110             if self.fileHasher is None:
111                 raise HashError, "file hasher not initialized"
112             
113             if not self.pieceHasher and self.size + len(data) > PIECE_SIZE:
114                 # Hash up to the piece size
115                 self.fileHasher.update(data[:(PIECE_SIZE - self.size)])
116                 data = data[(PIECE_SIZE - self.size):]
117                 self.size = PIECE_SIZE
118
119                 # Save the first piece digest and initialize a new piece hasher
120                 self.pieceHash.append(self.fileHasher.digest())
121                 self.pieceHasher = self._new()
122
123             if self.pieceHasher:
124                 # Loop in case the data contains multiple pieces
125                 piece_size = self.size % PIECE_SIZE
126                 while piece_size + len(data) > PIECE_SIZE:
127                     # Save the piece hash and start a new one
128                     self.pieceHasher.update(data[:(PIECE_SIZE - piece_size)])
129                     self.pieceHash.append(self.pieceHasher.digest())
130                     self.pieceHasher = self._new()
131                     
132                     # Don't forget to hash the data normally
133                     self.fileHasher.update(data[:(PIECE_SIZE - piece_size)])
134                     data = data[(PIECE_SIZE - piece_size):]
135                     self.size += PIECE_SIZE - piece_size
136                     piece_size = self.size % PIECE_SIZE
137
138                 # Hash any remaining data
139                 self.pieceHasher.update(data)
140             
141             self.fileHasher.update(data)
142             self.size += len(data)
143         
144     def pieceDigests(self):
145         """Get the piece hashes of the added file data."""
146         self.digest()
147         return self.pieceHash
148
149     def digest(self):
150         """Get the hash of the added file data."""
151         if self.fileHash is None:
152             if self.fileHasher is None:
153                 raise HashError, "you must hash some data first"
154             self.fileHash = self.fileHasher.digest()
155             self.done = True
156             
157             # Save the last piece hash
158             if self.pieceHasher:
159                 self.pieceHash.append(self.pieceHasher.digest())
160         return self.fileHash
161
162     def hexdigest(self):
163         """Get the hash of the added file data in hex format."""
164         if self.fileHex is None:
165             self.fileHex = b2a_hex(self.digest())
166         return self.fileHex
167         
168     def verify(self):
169         """Verify that the added file data hash matches the expected hash."""
170         if self.result is None and self.fileHash is not None and self.expHash is not None:
171             self.result = (self.fileHash == self.expHash and self.size == self.expSize)
172         return self.result
173     
174     def hashInThread(self, file):
175         """Hashes a file in a separate thread, callback with the result."""
176         file.restat(False)
177         if not file.exists():
178             df = defer.Deferred()
179             df.errback(HashError("file not found"))
180             return df
181         
182         df = threads.deferToThread(self._hashInThread, file)
183         return df
184     
185     def _hashInThread(self, file):
186         """Hashes a file, returning itself as the result."""
187         f = file.open()
188         self.new(force = True)
189         data = f.read(4096)
190         while data:
191             self.update(data)
192             data = f.read(4096)
193         self.digest()
194         return self
195
196     #### Methods for setting the expected hash
197     def set(self, hashType, hashHex, size):
198         """Initialize the hash object.
199         
200         @param hashType: must be one of the dictionaries from L{ORDER}
201         """
202         self.hashTypeNum = self.ORDER.index(hashType)    # error if not found
203         self.expHex = hashHex
204         self.expSize = int(size)
205         self.expHash = a2b_hex(self.expHex)
206         
207     def setFromIndexRecord(self, record):
208         """Set the hash from the cache of index file records.
209         
210         @type record: C{dictionary}
211         @param record: keys are hash types, values are tuples of (hash, size)
212         """
213         for hashType in self.ORDER:
214             result = record.get(hashType['AptIndexRecord'], None)
215             if result:
216                 self.set(hashType, result[0], result[1])
217                 return True
218         return False
219
220     def setFromPkgRecord(self, record, size):
221         """Set the hash from Apt's binary packages cache.
222         
223         @param record: whatever is returned by apt_pkg.GetPkgRecords()
224         """
225         for hashType in self.ORDER:
226             hashHex = getattr(record, hashType['AptPkgRecord'], None)
227             if hashHex:
228                 self.set(hashType, hashHex, size)
229                 return True
230         return False
231     
232     def setFromSrcRecord(self, record):
233         """Set the hash from Apt's source package records cache.
234         
235         Currently very simple since Apt only tracks MD5 hashes of source files.
236         
237         @type record: (C{string}, C{int}, C{string})
238         @param record: the hash, size and path of the source file
239         """
240         for hashType in self.ORDER:
241             if hashType['AptSrcRecord']:
242                 self.set(hashType, record[0], record[1])
243                 return True
244         return False
245
246 class TestHashObject(unittest.TestCase):
247     """Unit tests for the hash objects."""
248     
249     timeout = 5
250     if sys.version_info < (2, 4):
251         skip = "skippingme"
252     
253     def test_failure(self):
254         h = HashObject()
255         h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
256         self.failUnlessRaises(HashError, h.digest)
257         self.failUnlessRaises(HashError, h.hexdigest)
258         self.failUnlessRaises(HashError, h.update, 'gfgf')
259     
260     def test_pieces(self):
261         h = HashObject()
262         h.new()
263         h.update('1234567890'*120*1024)
264         self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
265         pieces = h.pieceDigests()
266         self.failUnless(len(pieces) == 3)
267         self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
268         self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
269         self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
270         h.new(True)
271         for i in xrange(120*1024):
272             h.update('1234567890')
273         pieces = h.pieceDigests()
274         self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
275         self.failUnless(len(pieces) == 3)
276         self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
277         self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
278         self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
279         
280     def test_sha1(self):
281         h = HashObject()
282         found = False
283         for hashType in h.ORDER:
284             if hashType['name'] == 'sha1':
285                 found = True
286                 break
287         self.failUnless(found == True)
288         h.set(hashType, 'c722df87e1acaa64b27aac4e174077afc3623540', '19')
289         h.new()
290         h.update('apt-dht is the best')
291         self.failUnless(h.hexdigest() == 'c722df87e1acaa64b27aac4e174077afc3623540')
292         self.failUnlessRaises(HashError, h.update, 'gfgf')
293         self.failUnless(h.verify() == True)
294         
295     def test_md5(self):
296         h = HashObject()
297         found = False
298         for hashType in h.ORDER:
299             if hashType['name'] == 'md5':
300                 found = True
301                 break
302         self.failUnless(found == True)
303         h.set(hashType, '2a586bcd1befc5082c872dcd96a01403', '19')
304         h.new()
305         h.update('apt-dht is the best')
306         self.failUnless(h.hexdigest() == '2a586bcd1befc5082c872dcd96a01403')
307         self.failUnlessRaises(HashError, h.update, 'gfgf')
308         self.failUnless(h.verify() == True)
309         
310     def test_sha256(self):
311         h = HashObject()
312         found = False
313         for hashType in h.ORDER:
314             if hashType['name'] == 'sha256':
315                 found = True
316                 break
317         self.failUnless(found == True)
318         h.set(hashType, '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7', '19')
319         h.new()
320         h.update('apt-dht is the best')
321         self.failUnless(h.hexdigest() == '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7')
322         self.failUnlessRaises(HashError, h.update, 'gfgf')
323         self.failUnless(h.verify() == True)
324
325     if sys.version_info < (2, 5):
326         test_sha256.skip = "SHA256 hashes are not supported by Python until version 2.5"