Refresh DHT values just before they are due to expire.
[quix0rs-apt-p2p.git] / apt_dht / Hash.py
1
2 from binascii import b2a_hex, a2b_hex
3 import sys
4
5 from twisted.internet import threads, defer
6 from twisted.trial import unittest
7
8 class HashError(ValueError):
9     """An error has occurred while hashing a file."""
10     
11 class HashObject:
12     """Manages hashes and hashing for a file."""
13     
14     """The priority ordering of hashes, and how to extract them."""
15     ORDER = [ {'name': 'sha1', 
16                    'AptPkgRecord': 'SHA1Hash', 
17                    'AptSrcRecord': False, 
18                    'AptIndexRecord': 'SHA1',
19                    'old_module': 'sha',
20                    'hashlib_func': 'sha1',
21                    },
22               {'name': 'sha256',
23                    'AptPkgRecord': 'SHA256Hash', 
24                    'AptSrcRecord': False, 
25                    'AptIndexRecord': 'SHA256',
26                    'hashlib_func': 'sha256',
27                    },
28               {'name': 'md5',
29                    'AptPkgRecord': 'MD5Hash', 
30                    'AptSrcRecord': True, 
31                    'AptIndexRecord': 'MD5SUM',
32                    'old_module': 'md5',
33                    'hashlib_func': 'md5',
34                    },
35             ]
36     
37     def __init__(self, digest = None, size = None):
38         self.hashTypeNum = 0    # Use the first if nothing else matters
39         self.expHash = None
40         self.expHex = None
41         self.expSize = None
42         self.expNormHash = None
43         self.fileHasher = None
44         self.fileHash = digest
45         self.size = size
46         self.fileHex = None
47         self.fileNormHash = None
48         self.done = True
49         self.result = None
50         if sys.version_info < (2, 5):
51             # sha256 is not available in python before 2.5, remove it
52             for hashType in self.ORDER:
53                 if hashType['name'] == 'sha256':
54                     del self.ORDER[self.ORDER.index(hashType)]
55                     break
56         
57     def _norm_hash(self, hashString, bits=None, bytes=None):
58         if bits is not None:
59             bytes = (bits - 1) // 8 + 1
60         else:
61             if bytes is None:
62                 raise HashError, "you must specify one of bits or bytes"
63         if len(hashString) < bytes:
64             hashString = hashString + '\000'*(bytes - len(hashString))
65         elif len(hashString) > bytes:
66             hashString = hashString[:bytes]
67         return hashString
68
69     #### Methods for returning the expected hash
70     def expected(self):
71         """Get the expected hash."""
72         return self.expHash
73     
74     def hexexpected(self):
75         """Get the expected hash in hex format."""
76         if self.expHex is None and self.expHash is not None:
77             self.expHex = b2a_hex(self.expHash)
78         return self.expHex
79     
80     def normexpected(self, bits=None, bytes=None):
81         """Normalize the binary hash for the given length.
82         
83         You must specify one of bits or bytes.
84         """
85         if self.expNormHash is None and self.expHash is not None:
86             self.expNormHash = self._norm_hash(self.expHash, bits, bytes)
87         return self.expNormHash
88
89     #### Methods for hashing data
90     def new(self, force = False):
91         """Generate a new hashing object suitable for hashing a file.
92         
93         @param force: set to True to force creating a new hasher even if
94             the hash has been verified already
95         """
96         if self.result is None or force == True:
97             self.result = None
98             self.size = 0
99             self.done = False
100             if sys.version_info < (2, 5):
101                 mod = __import__(self.ORDER[self.hashTypeNum]['old_module'], globals(), locals(), [])
102                 self.fileHasher = mod.new()
103             else:
104                 import hashlib
105                 func = getattr(hashlib, self.ORDER[self.hashTypeNum]['hashlib_func'])
106                 self.fileHasher = func()
107
108     def update(self, data):
109         """Add more data to the file hasher."""
110         if self.result is None:
111             if self.done:
112                 raise HashError, "Already done, you can't add more data after calling digest() or verify()"
113             if self.fileHasher is None:
114                 raise HashError, "file hasher not initialized"
115             self.fileHasher.update(data)
116             self.size += len(data)
117         
118     def digest(self):
119         """Get the hash of the added file data."""
120         if self.fileHash is None:
121             if self.fileHasher is None:
122                 raise HashError, "you must hash some data first"
123             self.fileHash = self.fileHasher.digest()
124             self.done = True
125         return self.fileHash
126
127     def hexdigest(self):
128         """Get the hash of the added file data in hex format."""
129         if self.fileHex is None:
130             self.fileHex = b2a_hex(self.digest())
131         return self.fileHex
132         
133     def norm(self, bits=None, bytes=None):
134         """Normalize the binary hash for the given length.
135         
136         You must specify one of bits or bytes.
137         """
138         if self.fileNormHash is None:
139             self.fileNormHash = self._norm_hash(self.digest(), bits, bytes)
140         return self.fileNormHash
141
142     def verify(self):
143         """Verify that the added file data hash matches the expected hash."""
144         if self.result is None and self.fileHash is not None and self.expHash is not None:
145             self.result = (self.fileHash == self.expHash and self.size == self.expSize)
146         return self.result
147     
148     def hashInThread(self, file):
149         """Hashes a file in a separate thread, callback with the result."""
150         file.restat(False)
151         if not file.exists():
152             df = defer.Deferred()
153             df.errback(HashError("file not found"))
154             return df
155         
156         df = threads.deferToThread(self._hashInThread, file)
157         return df
158     
159     def _hashInThread(self, file):
160         """Hashes a file, returning itself as the result."""
161         f = file.open()
162         self.new(force = True)
163         data = f.read(4096)
164         while data:
165             self.update(data)
166             data = f.read(4096)
167         self.digest()
168         return self
169
170     #### Methods for setting the expected hash
171     def set(self, hashType, hashHex, size):
172         """Initialize the hash object.
173         
174         @param hashType: must be one of the dictionaries from L{ORDER}
175         """
176         self.hashTypeNum = self.ORDER.index(hashType)    # error if not found
177         self.expHex = hashHex
178         self.expSize = int(size)
179         self.expHash = a2b_hex(self.expHex)
180         
181     def setFromIndexRecord(self, record):
182         """Set the hash from the cache of index file records.
183         
184         @type record: C{dictionary}
185         @param record: keys are hash types, values are tuples of (hash, size)
186         """
187         for hashType in self.ORDER:
188             result = record.get(hashType['AptIndexRecord'], None)
189             if result:
190                 self.set(hashType, result[0], result[1])
191                 return True
192         return False
193
194     def setFromPkgRecord(self, record, size):
195         """Set the hash from Apt's binary packages cache.
196         
197         @param record: whatever is returned by apt_pkg.GetPkgRecords()
198         """
199         for hashType in self.ORDER:
200             hashHex = getattr(record, hashType['AptPkgRecord'], None)
201             if hashHex:
202                 self.set(hashType, hashHex, size)
203                 return True
204         return False
205     
206     def setFromSrcRecord(self, record):
207         """Set the hash from Apt's source package records cache.
208         
209         Currently very simple since Apt only tracks MD5 hashes of source files.
210         
211         @type record: (C{string}, C{int}, C{string})
212         @param record: the hash, size and path of the source file
213         """
214         for hashType in self.ORDER:
215             if hashType['AptSrcRecord']:
216                 self.set(hashType, record[0], record[1])
217                 return True
218         return False
219
220 class TestHashObject(unittest.TestCase):
221     """Unit tests for the hash objects."""
222     
223     timeout = 5
224     if sys.version_info < (2, 4):
225         skip = "skippingme"
226     
227     def test_normalize(self):
228         h = HashObject()
229         h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
230         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
231         h = HashObject()
232         h.set(h.ORDER[0], b2a_hex('12345678901234567'), '0')
233         self.failUnless(h.normexpected(bits = 160) == '12345678901234567\000\000\000')
234         h = HashObject()
235         h.set(h.ORDER[0], b2a_hex('1234567890123456789012345'), '0')
236         self.failUnless(h.normexpected(bytes = 20) == '12345678901234567890')
237         h = HashObject()
238         h.set(h.ORDER[0], b2a_hex('1234567890123456789'), '0')
239         self.failUnless(h.normexpected(bytes = 20) == '1234567890123456789\000')
240         h = HashObject()
241         h.set(h.ORDER[0], b2a_hex('123456789012345678901'), '0')
242         self.failUnless(h.normexpected(bits = 160) == '12345678901234567890')
243
244     def test_failure(self):
245         h = HashObject()
246         h.set(h.ORDER[0], b2a_hex('12345678901234567890'), '0')
247         self.failUnlessRaises(HashError, h.normexpected)
248         self.failUnlessRaises(HashError, h.digest)
249         self.failUnlessRaises(HashError, h.hexdigest)
250         self.failUnlessRaises(HashError, h.update, 'gfgf')
251     
252     def test_sha1(self):
253         h = HashObject()
254         found = False
255         for hashType in h.ORDER:
256             if hashType['name'] == 'sha1':
257                 found = True
258                 break
259         self.failUnless(found == True)
260         h.set(hashType, 'c722df87e1acaa64b27aac4e174077afc3623540', '19')
261         h.new()
262         h.update('apt-dht is the best')
263         self.failUnless(h.hexdigest() == 'c722df87e1acaa64b27aac4e174077afc3623540')
264         self.failUnlessRaises(HashError, h.update, 'gfgf')
265         self.failUnless(h.verify() == True)
266         
267     def test_md5(self):
268         h = HashObject()
269         found = False
270         for hashType in h.ORDER:
271             if hashType['name'] == 'md5':
272                 found = True
273                 break
274         self.failUnless(found == True)
275         h.set(hashType, '2a586bcd1befc5082c872dcd96a01403', '19')
276         h.new()
277         h.update('apt-dht is the best')
278         self.failUnless(h.hexdigest() == '2a586bcd1befc5082c872dcd96a01403')
279         self.failUnlessRaises(HashError, h.update, 'gfgf')
280         self.failUnless(h.verify() == True)
281         
282     def test_sha256(self):
283         h = HashObject()
284         found = False
285         for hashType in h.ORDER:
286             if hashType['name'] == 'sha256':
287                 found = True
288                 break
289         self.failUnless(found == True)
290         h.set(hashType, '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7', '19')
291         h.new()
292         h.update('apt-dht is the best')
293         self.failUnless(h.hexdigest() == '55b971f64d9772f733de03f23db39224f51a455cc5ad4c2db9d5740d2ab259a7')
294         self.failUnlessRaises(HashError, h.update, 'gfgf')
295         self.failUnless(h.verify() == True)
296
297     if sys.version_info < (2, 5):
298         test_sha256.skip = "SHA256 hashes are not supported by Python until version 2.5"