a7a8e40fcc587038ab18abacc0d571a147a24102
[quix0rs-apt-p2p.git] / apt_dht / Hash.py
1
2 from binascii import b2a_hex, a2b_hex
3 import sys
4
5 from twisted.internet import threads, defer
6 from twisted.trial import unittest
7
8 PIECE_SIZE = 512*1024
9
10 class HashError(ValueError):
11     """An error has occurred while hashing a file."""
12     
13 class HashObject:
14     """Manages hashes and hashing for a file."""
15     
16     """The priority ordering of hashes, and how to extract them."""
17     ORDER = [ {'name': 'sha1', 
18                    'length': 20,
19                    'AptPkgRecord': 'SHA1Hash', 
20                    'AptSrcRecord': False, 
21                    'AptIndexRecord': 'SHA1',
22                    'old_module': 'sha',
23                    'hashlib_func': 'sha1',
24                    },
25               {'name': 'sha256',
26                    'length': 32,
27                    'AptPkgRecord': 'SHA256Hash', 
28                    'AptSrcRecord': False, 
29                    'AptIndexRecord': 'SHA256',
30                    'hashlib_func': 'sha256',
31                    },
32               {'name': 'md5',
33                    'length': 16,
34                    'AptPkgRecord': 'MD5Hash', 
35                    'AptSrcRecord': True, 
36                    'AptIndexRecord': 'MD5SUM',
37                    'old_module': 'md5',
38                    'hashlib_func': 'md5',
39                    },
40             ]
41     
42     def __init__(self, digest = None, size = None, pieces = ''):
43         self.hashTypeNum = 0    # Use the first if nothing else matters
44         if sys.version_info < (2, 5):
45             # sha256 is not available in python before 2.5, remove it
46             for hashType in self.ORDER:
47                 if hashType['name'] == 'sha256':
48                     del self.ORDER[self.ORDER.index(hashType)]
49                     break
50
51         self.expHash = None
52         self.expHex = None
53         self.expSize = None
54         self.expNormHash = None
55         self.fileHasher = None
56         self.pieceHasher = None
57         self.fileHash = digest
58         self.pieceHash = [pieces[x:x+self.ORDER[self.hashTypeNum]['length']]
59                           for x in xrange(0, len(pieces), self.ORDER[self.hashTypeNum]['length'])]
60         self.size = size
61         self.fileHex = None
62         self.fileNormHash = None
63         self.done = True
64         self.result = None
65         
66     def _norm_hash(self, hashString, bits=None, bytes=None):
67         if bits is not None:
68             bytes = (bits - 1) // 8 + 1
69         else:
70             if bytes is None:
71                 raise HashError, "you must specify one of bits or bytes"
72         if len(hashString) < bytes:
73             hashString = hashString + '\000'*(bytes - len(hashString))
74         elif len(hashString) > bytes:
75             hashString = hashString[:bytes]
76         return hashString
77
78     #### Methods for returning the expected hash
79     def expected(self):
80         """Get the expected hash."""
81         return self.expHash
82     
83     def hexexpected(self):
84         """Get the expected hash in hex format."""
85         if self.expHex is None and self.expHash is not None:
86             self.expHex = b2a_hex(self.expHash)
87         return self.expHex
88     
89     def normexpected(self, bits=None, bytes=None):
90         """Normalize the binary hash for the given length.
91         
92         You must specify one of bits or bytes.
93         """
94         if self.expNormHash is None and self.expHash is not None:
95             self.expNormHash = self._norm_hash(self.expHash, bits, bytes)
96         return self.expNormHash
97
98     #### Methods for hashing data
99     def new(self, force = False):
100         """Generate a new hashing object suitable for hashing a file.
101         
102         @param force: set to True to force creating a new hasher even if
103             the hash has been verified already
104         """
105         if self.result is None or force == True:
106             self.result = None
107             self.done = False
108             self.fileHasher = self._new()
109             self.pieceHasher = None
110             self.fileHash = None
111             self.pieceHash = []
112             self.size = 0
113             self.fileHex = None
114             self.fileNormHash = None
115
116     def _new(self):
117         """Create a new hashing object according to the hash type."""
118         if sys.version_info < (2, 5):
119             mod = __import__(self.ORDER[self.hashTypeNum]['old_module'], globals(), locals(), [])
120             return mod.new()
121         else:
122             import hashlib
123             func = getattr(hashlib, self.ORDER[self.hashTypeNum]['hashlib_func'])
124             return func()
125
126     def update(self, data):
127         """Add more data to the file hasher."""
128         if self.result is None:
129             if self.done:
130                 raise HashError, "Already done, you can't add more data after calling digest() or verify()"
131             if self.fileHasher is None:
132                 raise HashError, "file hasher not initialized"
133             
134             if not self.pieceHasher and self.size + len(data) > PIECE_SIZE:
135                 # Hash up to the piece size
136                 self.fileHasher.update(data[:(PIECE_SIZE - self.size)])
137                 data = data[(PIECE_SIZE - self.size):]
138                 self.size = PIECE_SIZE
139
140                 # Save the first piece digest and initialize a new piece hasher
141                 self.pieceHash.append(self.fileHasher.digest())
142                 self.pieceHasher = self._new()
143
144             if self.pieceHasher:
145                 # Loop in case the data contains multiple pieces
146                 piece_size = self.size % PIECE_SIZE
147                 while piece_size + len(data) > PIECE_SIZE:
148                     # Save the piece hash and start a new one
149                     self.pieceHasher.update(data[:(PIECE_SIZE - piece_size)])
150                     self.pieceHash.append(self.pieceHasher.digest())
151                     self.pieceHasher = self._new()
152                     
153                     # Don't forget to hash the data normally
154                     self.fileHasher.update(data[:(PIECE_SIZE - piece_size)])
155                     data = data[(PIECE_SIZE - piece_size):]
156                     self.size += PIECE_SIZE - piece_size
157                     piece_size = self.size % PIECE_SIZE
158
159                 # Hash any remaining data
160                 self.pieceHasher.update(data)
161             
162             self.fileHasher.update(data)
163             self.size += len(data)
164         
165     def pieceDigests(self):
166         """Get the piece hashes of the added file data."""
167         self.digest()
168         return self.pieceHash
169
170     def digest(self):
171         """Get the hash of the added file data."""
172         if self.fileHash is None:
173             if self.fileHasher is None:
174                 raise HashError, "you must hash some data first"
175             self.fileHash = self.fileHasher.digest()
176             self.done = True
177             
178             # Save the last piece hash
179             if self.pieceHasher:
180                 self.pieceHash.append(self.pieceHasher.digest())
181         return self.fileHash
182
183     def hexdigest(self):
184         """Get the hash of the added file data in hex format."""
185         if self.fileHex is None:
186             self.fileHex = b2a_hex(self.digest())
187         return self.fileHex
188         
189     def norm(self, bits=None, bytes=None):
190         """Normalize the binary hash for the given length.
191         
192         You must specify one of bits or bytes.
193         """
194         if self.fileNormHash is None:
195             self.fileNormHash = self._norm_hash(self.digest(), bits, bytes)
196         return self.fileNormHash
197
198     def verify(self):
199         """Verify that the added file data hash matches the expected hash."""
200         if self.result is None and self.fileHash is not None and self.expHash is not None:
201             self.result = (self.fileHash == self.expHash and self.size == self.expSize)
202         return self.result
203     
204     def hashInThread(self, file):
205         """Hashes a file in a separate thread, callback with the result."""
206         file.restat(False)
207         if not file.exists():
208             df = defer.Deferred()
209             df.errback(HashError("file not found"))
210             return df
211         
212         df = threads.deferToThread(self._hashInThread, file)
213         return df
214     
215     def _hashInThread(self, file):
216         """Hashes a file, returning itself as the result."""
217         f = file.open()
218         self.new(force = True)
219         data = f.read(4096)
220         while data:
221             self.update(data)
222             data = f.read(4096)
223         self.digest()
224         return self
225
226     #### Methods for setting the expected hash
227     def set(self, hashType, hashHex, size):
228         """Initialize the hash object.
229         
230         @param hashType: must be one of the dictionaries from L{ORDER}
231         """
232         self.hashTypeNum = self.ORDER.index(hashType)    # error if not found
233         self.expHex = hashHex
234         self.expSize = int(size)
235         self.expHash = a2b_hex(self.expHex)
236         
237     def setFromIndexRecord(self, record):
238         """Set the hash from the cache of index file records.
239         
240         @type record: C{dictionary}
241         @param record: keys are hash types, values are tuples of (hash, size)
242         """
243         for hashType in self.ORDER:
244             result = record.get(hashType['AptIndexRecord'], None)
245             if result:
246                 self.set(hashType, result[0], result[1])
247                 return True
248         return False
249
250     def setFromPkgRecord(self, record, size):
251         """Set the hash from Apt's binary packages cache.
252         
253         @param record: whatever is returned by apt_pkg.GetPkgRecords()
254         """
255         for hashType in self.ORDER:
256             hashHex = getattr(record, hashType['AptPkgRecord'], None)
257             if hashHex:
258                 self.set(hashType, hashHex, size)
259                 return True
260         return False
261     
262     def setFromSrcRecord(self, record):
263         """Set the hash from Apt's source package records cache.
264         
265         Currently very simple since Apt only tracks MD5 hashes of source files.
266         
267         @type record: (C{string}, C{int}, C{string})
268         @param record: the hash, size and path of the source file
269         """
270         for hashType in self.ORDER:
271             if hashType['AptSrcRecord']:
272                 self.set(hashType, record[0], record[1])
273                 return True
274         return False
275
276 class TestHashObject(unittest.TestCase):
277     """Unit tests for the hash objects."""
278     
279     timeout = 5
280     if sys.version_info < (2, 4):
281         skip = "skippingme"
282     
283     def test_normalize(self):
284         h = HashObject()
285         h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
286         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
287         h = HashObject()
288         h.set(h.ORDER[0], b2a_hex('12345678901234567'), '0')
289         self.failUnless(h.normexpected(bits = 160) == '12345678901234567\000\000\000')
290         h = HashObject()
291         h.set(h.ORDER[0], b2a_hex('1234567890123456789012345'), '0')
292         self.failUnless(h.normexpected(bytes = 20) == '12345678901234567890')
293         h = HashObject()
294         h.set(h.ORDER[0], b2a_hex('1234567890123456789'), '0')
295         self.failUnless(h.normexpected(bytes = 20) == '1234567890123456789\000')
296         h = HashObject()
297         h.set(h.ORDER[0], b2a_hex('123456789012345678901'), '0')
298         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
299
300     def test_failure(self):
301         h = HashObject()
302         h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
303         self.failUnlessRaises(HashError, h.normexpected)
304         self.failUnlessRaises(HashError, h.digest)
305         self.failUnlessRaises(HashError, h.hexdigest)
306         self.failUnlessRaises(HashError, h.update, 'gfgf')
307     
308     def test_pieces(self):
309         h = HashObject()
310         h.new()
311         h.update('1234567890'*120*1024)
312         self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
313         pieces = h.pieceDigests()
314         self.failUnless(len(pieces) == 3)
315         self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
316         self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
317         self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
318         h.new(True)
319         for i in xrange(120*1024):
320             h.update('1234567890')
321         pieces = h.pieceDigests()
322         self.failUnless(h.digest() == '1(j\xd2q\x0b\n\x91\xd2\x13\x90\x15\xa3E\xcc\xb0\x8d.\xc3\xc5')
323         self.failUnless(len(pieces) == 3)
324         self.failUnless(pieces[0] == ',G \xd8\xbbPl\xf1\xa3\xa0\x0cW\n\xe6\xe6a\xc9\x95/\xe5')
325         self.failUnless(pieces[1] == '\xf6V\xeb/\xa8\xad[\x07Z\xf9\x87\xa4\xf5w\xdf\xe1|\x00\x8e\x93')
326         self.failUnless(pieces[2] == 'M[\xbf\xee\xaa+\x19\xbaV\xf699\r\x17o\xcb\x8e\xcfP\x19')
327         
328     def test_sha1(self):
329         h = HashObject()
330         found = False
331         for hashType in h.ORDER:
332             if hashType['name'] == 'sha1':
333                 found = True
334                 break
335         self.failUnless(found == True)
336         h.set(hashType, 'c722df87e1acaa64b27aac4e174077afc3623540', '19')
337         h.new()
338         h.update('apt-dht is the best')
339         self.failUnless(h.hexdigest() == 'c722df87e1acaa64b27aac4e174077afc3623540')
340         self.failUnlessRaises(HashError, h.update, 'gfgf')
341         self.failUnless(h.verify() == True)
342         
343     def test_md5(self):
344         h = HashObject()
345         found = False
346         for hashType in h.ORDER:
347             if hashType['name'] == 'md5':
348                 found = True
349                 break
350         self.failUnless(found == True)
351         h.set(hashType, '2a586bcd1befc5082c872dcd96a01403', '19')
352         h.new()
353         h.update('apt-dht is the best')
354         self.failUnless(h.hexdigest() == '2a586bcd1befc5082c872dcd96a01403')
355         self.failUnlessRaises(HashError, h.update, 'gfgf')
356         self.failUnless(h.verify() == True)
357         
358     def test_sha256(self):
359         h = HashObject()
360         found = False
361         for hashType in h.ORDER:
362             if hashType['name'] == 'sha256':
363                 found = True
364                 break
365         self.failUnless(found == True)
366         h.set(hashType, '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7', '19')
367         h.new()
368         h.update('apt-dht is the best')
369         self.failUnless(h.hexdigest() == '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7')
370         self.failUnlessRaises(HashError, h.update, 'gfgf')
371         self.failUnless(h.verify() == True)
372
373     if sys.version_info < (2, 5):
374         test_sha256.skip = "SHA256 hashes are not supported by Python until version 2.5"