]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
Pass the new HashObjects around everywhere (untested).
authorCameron Dale <camrdale@gmail.com>
Thu, 10 Jan 2008 00:54:20 +0000 (16:54 -0800)
committerCameron Dale <camrdale@gmail.com>
Thu, 10 Jan 2008 00:54:20 +0000 (16:54 -0800)
apt_dht/AptPackages.py
apt_dht/MirrorManager.py
apt_dht/apt_dht.py
apt_dht_Khashmir/DHT.py

index a2e743e60fde74f07c9c9bed7d6132b6e66f9d4a..47a17dc49d608482de596fcb7ac9d5d3a2b3d41a 100644 (file)
@@ -15,6 +15,8 @@ from twisted.trial import unittest
 import apt_pkg, apt_inst
 from apt import OpProgress
 
+from Hash import HashObject
+
 apt_pkg.init()
 
 TRACKED_FILES = ['release', 'sources', 'packages']
@@ -290,7 +292,7 @@ class AptPackages:
         """An error occurred while trying to find a hash."""
         log.msg('An error occurred while looking up a hash for: %s' % path)
         log.err(failure)
-        d.callback((None, None))
+        d.callback(HashObject())
 
     def _findHash(self, loadResult, path, d):
         """Really find the hash for a path.
@@ -299,7 +301,7 @@ class AptPackages:
         function are pending.
         """
         if not loadResult:
-            d.callback((None, None))
+            d.callback(HashObject())
             return loadResult
         
         # First look for the path in the cache of index files
@@ -307,7 +309,9 @@ class AptPackages:
             if path.startswith(release[:-7]):
                 for indexFile in self.indexrecords[release]:
                     if release[:-7] + indexFile == path:
-                        d.callback(self.indexrecords[release][indexFile]['SHA1'])
+                        h = HashObject()
+                        h.setFromIndexRecord(self.indexrecords[release][indexFile])
+                        d.callback(h)
                         return loadResult
         
         package = path.split('/')[-1].split('_')[0]
@@ -319,7 +323,9 @@ class AptPackages:
                 for verFile in version.FileList:
                     if self.records.Lookup(verFile):
                         if '/' + self.records.FileName == path:
-                            d.callback((self.records.SHA1Hash, size))
+                            h = HashObject()
+                            h.setFromPkgRecord(self.records, size)
+                            d.callback(h)
                             return loadResult
         except KeyError:
             pass
@@ -330,10 +336,12 @@ class AptPackages:
             if self.srcrecords.Lookup(package):
                 for f in self.srcrecords.Files:
                     if path == '/' + f[2]:
-                        d.callback((f[0], f[1]))
+                        h = HashObject()
+                        h.setFromSrcRecord(f)
+                        d.callback(h)
                         return loadResult
         
-        d.callback((None, None))
+        d.callback(HashObject())
         return loadResult
 
 class TestAptPackages(unittest.TestCase):
index d56e16da16520f259fb1d12b8f44a7f6bac2b537..98d46a53ed71b9a6f5915400f983dc279fff38c2 100644 (file)
@@ -25,16 +25,15 @@ class MirrorError(Exception):
 class ProxyFileStream(stream.SimpleStream):
     """Saves a stream to a file while providing a new stream."""
     
-    def __init__(self, stream, outFile, hashType = "sha1", decompress = None, decFile = None):
+    def __init__(self, stream, outFile, hash, decompress = None, decFile = None):
         """Initializes the proxy.
         
         @type stream: C{twisted.web2.stream.IByteStream}
         @param stream: the input stream to read from
         @type outFile: C{twisted.python.filepath.FilePath}
         @param outFile: the file to write to
-        @type hashType: C{string}
-        @param hashType: also hash the file using this hashing function
-            (currently only 'sha1' and 'md5' are supported)
+        @type hash: L{Hash.HashObject}
+        @param hash: the hash object to use for the file
         @type decompress: C{string}
         @param decompress: also decompress the file as this type
             (currently only '.gz' and '.bz2' are supported)
@@ -43,11 +42,8 @@ class ProxyFileStream(stream.SimpleStream):
         """
         self.stream = stream
         self.outFile = outFile.open('w')
-        self.hasher = None
-        if hashType == "sha1":
-            self.hasher = sha.new()
-        elif hashType == "md5":
-            self.hasher = md5.new()
+        self.hash = hash
+        self.hash.new()
         self.gzfile = None
         self.bz2file = None
         if decompress == ".gz":
@@ -65,9 +61,7 @@ class ProxyFileStream(stream.SimpleStream):
         """Close the output file."""
         if not self.outFile.closed:
             self.outFile.close()
-            fileHash = None
-            if self.hasher:
-                fileHash = self.hasher.digest()
+            self.hash.digest()
             if self.gzfile:
                 data_dec = self.gzdec.flush()
                 self.gzfile.write(data_dec)
@@ -77,7 +71,7 @@ class ProxyFileStream(stream.SimpleStream):
                 self.bz2file.close()
                 self.bz2file = None
                 
-            self.doneDefer.callback(fileHash)
+            self.doneDefer.callback(self.hash)
     
     def read(self):
         """Read some data from the stream."""
@@ -99,8 +93,7 @@ class ProxyFileStream(stream.SimpleStream):
             return data
         
         self.outFile.write(data)
-        if self.hasher:
-            self.hasher.update(data)
+        self.hash.update(data)
         if self.gzfile:
             if self.gzheader:
                 self.gzheader = False
@@ -208,20 +201,12 @@ class MirrorManager:
     def findHash(self, url):
         site, baseDir, path = self.extractPath(url)
         if site in self.apt_caches and baseDir in self.apt_caches[site]:
-            d = self.apt_caches[site][baseDir].findHash(path)
-            d.addCallback(self.translateHash)
-            return d
+            return self.apt_caches[site][baseDir].findHash(path)
         d = defer.Deferred()
         d.errback(MirrorError("Site Not Found"))
         return d
     
-    def translateHash(self, (hash, size)):
-        """Translate a hash from apt's hex encoding to a string."""
-        if hash:
-            hash = a2b_hex(hash)
-        return (hash, size)
-
-    def save_file(self, response, hash, size, url):
+    def save_file(self, response, hash, url):
         """Save a downloaded file to the cache and stream it."""
         log.msg('Returning file: %s' % url)
         
@@ -247,28 +232,24 @@ class MirrorManager:
             ext = None
             decFile = None
             
-        if hash and len(hash) == 16:
-            hashType = "md5"
-        else:
-            hashType = "sha1"
-        
         orig_stream = response.stream
-        response.stream = ProxyFileStream(orig_stream, destFile, hashType, ext, decFile)
-        response.stream.doneDefer.addCallback(self.save_complete, hash, size, url, destFile,
+        response.stream = ProxyFileStream(orig_stream, destFile, hash, ext, decFile)
+        response.stream.doneDefer.addCallback(self.save_complete, url, destFile,
                                               response.headers.getHeader('Last-Modified'),
                                               ext, decFile)
         response.stream.doneDefer.addErrback(self.save_error, url)
         return response
 
-    def save_complete(self, result, hash, size, url, destFile, modtime = None, ext = None, decFile = None):
+    def save_complete(self, hash, url, destFile, modtime = None, ext = None, decFile = None):
         """Update the modification time and AptPackages."""
         if modtime:
             os.utime(destFile.path, (modtime, modtime))
             if ext:
                 os.utime(decFile.path, (modtime, modtime))
         
-        if not hash or result == hash:
-            if hash:
+        result = hash.verify()
+        if result or result is None:
+            if result:
                 log.msg('Hashes match: %s' % url)
             else:
                 log.msg('Hashed file to %s: %s' % (b2a_hex(result), url))
@@ -277,7 +258,7 @@ class MirrorManager:
             if ext:
                 self.updatedFile(url[:-len(ext)], decFile.path)
         else:
-            log.msg("Hashes don't match %s != %s: %s" % (b2a_hex(hash), b2a_hex(result), url))
+            log.msg("Hashes don't match %s != %s: %s" % (hash.hexexpected(), hash.hexdigest(), url))
 
     def save_error(self, failure, url):
         """An error has occurred in downloadign or saving the file."""
@@ -312,7 +293,7 @@ class TestMirrorManager(unittest.TestCase):
         self.failUnless(path == "/dists/unstable/Release", "no match: %s" % path)
 
     def verifyHash(self, found_hash, path, true_hash):
-        self.failUnless(found_hash[0] == true_hash, 
+        self.failUnless(found_hash.hexexpected() == true_hash, 
                     "%s hashes don't match: %s != %s" % (path, found_hash[0], true_hash))
 
     def test_findHash(self):
index 13f2337ee6e00a21a6b212530af8ac57cf5d9f8c..51ab5825e2e4464c9137bd62f1ae876ab3d2319f 100644 (file)
@@ -50,27 +50,28 @@ class AptDHT:
         log.err(failure)
         self.findHash_done((None, None), path, d)
         
-    def findHash_done(self, (hash, size), path, d):
-        if hash is None:
+    def findHash_done(self, hash, path, d):
+        if hash.expected() is None:
             log.msg('Hash for %s was not found' % path)
-            self.download_file([path], hash, size, path, d)
+            self.download_file([path], hash, path, d)
         else:
-            log.msg('Found hash %s for %s' % (b2a_hex(hash), path))
+            log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
             # Lookup hash from DHT
-            lookupDefer = self.dht.getValue(hash)
-            lookupDefer.addCallback(self.lookupHash_done, hash, size, path, d)
+            key = hash.normexpected(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
+            lookupDefer = self.dht.getValue(key)
+            lookupDefer.addCallback(self.lookupHash_done, hash, path, d)
             
-    def lookupHash_done(self, locations, hash, size, path, d):
+    def lookupHash_done(self, locations, hash, path, d):
         if not locations:
             log.msg('Peers for %s were not found' % path)
-            self.download_file([path], hash, size, path, d)
+            self.download_file([path], hash, path, d)
         else:
             log.msg('Found peers for %s: %r' % (path, locations))
             # Download from the found peers
-            self.download_file(locations, hash, size, path, d)
+            self.download_file(locations, hash, path, d)
             
-    def download_file(self, locations, hash, size, path, d):
+    def download_file(self, locations, hash, path, d):
         getDefer = self.peers.get(locations)
-        getDefer.addCallback(self.mirrors.save_file, hash, size, path)
+        getDefer.addCallback(self.mirrors.save_file, hash, path)
         getDefer.addErrback(self.mirrors.save_error, path)
         getDefer.addCallbacks(d.callback, d.errback)
index 36d7231fe75e48024d5fa3208054f861f97cd79b..590a77cb224a59026e41ca41631340daf6362902 100644 (file)
@@ -98,15 +98,6 @@ class DHT:
             self.joined = False
             self.khashmir.shutdown()
         
-    def normalizeKey(self, key):
-        """Normalize a key's length suitable for insertion in the DHT."""
-        key_bytes = (self.config['HASH_LENGTH'] - 1) // 8 + 1
-        if len(key) < key_bytes:
-            key = key + '\000'*(key_bytes - len(key))
-        elif len(key) > key_bytes:
-            key = key[:key_bytes]
-        return key
-    
     def getValue(self, key):
         """See L{apt_dht.interfaces.IDHT}."""
         if self.config is None:
@@ -115,7 +106,6 @@ class DHT:
             raise DHTError, "have not joined a network yet"
 
         d = defer.Deferred()
-        key = self.normalizeKey(key)
         if key not in self.retrieving:
             self.khashmir.valueForKey(key, self._getValue)
         self.retrieving.setdefault(key, []).append(d)
@@ -141,7 +131,6 @@ class DHT:
         if not self.joined:
             raise DHTError, "have not joined a network yet"
 
-        key = self.normalizeKey(key)
         if key in self.storing and value in self.storing[key]:
             raise DHTError, "already storing that key with the same value"
 
@@ -184,13 +173,6 @@ class TestSimpleDHT(unittest.TestCase):
         self.b.bootstrap = ["127.0.0.1:4044"]
         self.b.cache_dir = '/tmp'
         
-    def test_normalizeKey(self):
-        self.failUnless(self.a.normalizeKey('12345678901234567890') == '12345678901234567890')
-        self.failUnless(self.a.normalizeKey('12345678901234567') == '12345678901234567\000\000\000')
-        self.failUnless(self.a.normalizeKey('1234567890123456789012345') == '12345678901234567890')
-        self.failUnless(self.a.normalizeKey('1234567890123456789') == '1234567890123456789\000')
-        self.failUnless(self.a.normalizeKey('123456789012345678901') == '12345678901234567890')
-    
     def test_bootstrap_join(self):
         d = self.a.join()
         return d