Store piece hashes in the DB.
authorCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 23:58:05 +0000 (15:58 -0800)
committerCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 23:58:05 +0000 (15:58 -0800)
Also broke up single table into files and hashes.
Main code uses the new hash routine.
lookupHash now checks for torrent string hashes.
The HTTPServer still needs to interpret torrent string hashes.

apt_dht/Hash.py
apt_dht/apt_dht.py
apt_dht/db.py

index ec985989f4eabfcee8f8a2b18f3a6d19eeb7aec3..a7a8e40fcc587038ab18abacc0d571a147a24102 100644 (file)
@@ -15,6 +15,7 @@ class HashObject:
     
     """The priority ordering of hashes, and how to extract them."""
     ORDER = [ {'name': 'sha1', 
     
     """The priority ordering of hashes, and how to extract them."""
     ORDER = [ {'name': 'sha1', 
+                   'length': 20,
                    'AptPkgRecord': 'SHA1Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA1',
                    'AptPkgRecord': 'SHA1Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA1',
@@ -22,12 +23,14 @@ class HashObject:
                    'hashlib_func': 'sha1',
                    },
               {'name': 'sha256',
                    'hashlib_func': 'sha1',
                    },
               {'name': 'sha256',
+                   'length': 32,
                    'AptPkgRecord': 'SHA256Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA256',
                    'hashlib_func': 'sha256',
                    },
               {'name': 'md5',
                    'AptPkgRecord': 'SHA256Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA256',
                    'hashlib_func': 'sha256',
                    },
               {'name': 'md5',
+                   'length': 16,
                    'AptPkgRecord': 'MD5Hash', 
                    'AptSrcRecord': True, 
                    'AptIndexRecord': 'MD5SUM',
                    'AptPkgRecord': 'MD5Hash', 
                    'AptSrcRecord': True, 
                    'AptIndexRecord': 'MD5SUM',
@@ -36,8 +39,15 @@ class HashObject:
                    },
             ]
     
                    },
             ]
     
-    def __init__(self, digest = None, size = None):
+    def __init__(self, digest = None, size = None, pieces = ''):
         self.hashTypeNum = 0    # Use the first if nothing else matters
         self.hashTypeNum = 0    # Use the first if nothing else matters
+        if sys.version_info < (2, 5):
+            # sha256 is not available in python before 2.5, remove it
+            for hashType in self.ORDER:
+                if hashType['name'] == 'sha256':
+                    del self.ORDER[self.ORDER.index(hashType)]
+                    break
+
         self.expHash = None
         self.expHex = None
         self.expSize = None
         self.expHash = None
         self.expHex = None
         self.expSize = None
@@ -45,18 +55,13 @@ class HashObject:
         self.fileHasher = None
         self.pieceHasher = None
         self.fileHash = digest
         self.fileHasher = None
         self.pieceHasher = None
         self.fileHash = digest
-        self.pieceHash = []
+        self.pieceHash = [pieces[x:x+self.ORDER[self.hashTypeNum]['length']]
+                          for x in xrange(0, len(pieces), self.ORDER[self.hashTypeNum]['length'])]
         self.size = size
         self.fileHex = None
         self.fileNormHash = None
         self.done = True
         self.result = None
         self.size = size
         self.fileHex = None
         self.fileNormHash = None
         self.done = True
         self.result = None
-        if sys.version_info < (2, 5):
-            # sha256 is not available in python before 2.5, remove it
-            for hashType in self.ORDER:
-                if hashType['name'] == 'sha256':
-                    del self.ORDER[self.ORDER.index(hashType)]
-                    break
         
     def _norm_hash(self, hashString, bits=None, bytes=None):
         if bits is not None:
         
     def _norm_hash(self, hashString, bits=None, bytes=None):
         if bits is not None:
@@ -173,9 +178,6 @@ class HashObject:
             # Save the last piece hash
             if self.pieceHasher:
                 self.pieceHash.append(self.pieceHasher.digest())
             # Save the last piece hash
             if self.pieceHasher:
                 self.pieceHash.append(self.pieceHasher.digest())
-            else:
-                # If there are no piece hashes, then the file hash is the only piece hash
-                self.pieceHash.append(self.fileHash)
         return self.fileHash
 
     def hexdigest(self):
         return self.fileHash
 
     def hexdigest(self):
index 151e3db623e88fa5afa89378136a676a82d52013..94260bd4c396a4c216e2952c9094ae2a1cdaf0cb 100644 (file)
@@ -58,17 +58,24 @@ class AptDHT:
     def refreshFiles(self):
         """Refresh any files in the DHT that are about to expire."""
         expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
     def refreshFiles(self):
         """Refresh any files in the DHT that are about to expire."""
         expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
-        hashes = self.db.expiredFiles(expireAfter)
+        hashes = self.db.expiredHashes(expireAfter)
         if len(hashes.keys()) > 0:
             log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
         if len(hashes.keys()) > 0:
             log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
-        for raw_hash in hashes:
+        self._refreshFiles(None, hashes)
+        
+    def _refreshFiles(self, result, hashes):
+        if result is not None:
+            log.msg('Storage resulted in: %r' % result)
+
+        if hashes:
+            raw_hash = hashes.keys()[0]
             self.db.refreshHash(raw_hash)
             self.db.refreshHash(raw_hash)
-            hash = HashObject(raw_hash)
-            key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
-            value = {'c': self.my_contact}
-            storeDefer = self.dht.storeValue(key, value)
-            storeDefer.addCallback(self.store_done, hash)
-        reactor.callLater(60, self.refreshFiles)
+            hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
+            del hashes[raw_hash]
+            storeDefer = self.store(hash)
+            storeDefer.addBoth(self._refreshFiles, hashes)
+        else:
+            reactor.callLater(60, self.refreshFiles)
 
     def check_freshness(self, req, path, modtime, resp):
         log.msg('Checking if %s is still fresh' % path)
 
     def check_freshness(self, req, path, modtime, resp):
         log.msg('Checking if %s is still fresh' % path)
@@ -107,7 +114,7 @@ class AptDHT:
             log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
             
             # Lookup hash in cache
             log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
             
             # Lookup hash in cache
-            locations = self.db.lookupHash(hash.expected())
+            locations = self.db.lookupHash(hash.expected(), filesOnly = True)
             self.getCachedFile(hash, req, path, d, locations)
 
     def getCachedFile(self, hash, req, path, d, locations):
             self.getCachedFile(hash, req, path, d, locations)
 
     def getCachedFile(self, hash, req, path, d, locations):
@@ -173,33 +180,37 @@ class AptDHT:
         return response
         
     def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
         return response
         
     def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
-        """Add a newly cached file to the DHT.
+        """Add a newly cached file to the appropriate places.
         
         If the file was downloaded, set url to the path it was downloaded for.
         
         If the file was downloaded, set url to the path it was downloaded for.
-        Don't add a file to the DHT unless a hash was found for it
-        (but do add it anyway if forceDHT is True).
+        Doesn't add a file to the DHT unless a hash was found for it
+        (but does add it anyway if forceDHT is True).
         """
         if url:
             self.mirrors.updatedFile(url, file_path)
         
         if self.my_contact and hash and new_hash and (hash.expected() is not None or forceDHT):
         """
         if url:
             self.mirrors.updatedFile(url, file_path)
         
         if self.my_contact and hash and new_hash and (hash.expected() is not None or forceDHT):
-            key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
-            value = {'c': self.my_contact}
-            pieces = hash.pieceDigests()
-            if len(pieces) <= 1:
-                pass
-            elif len(pieces) <= DHT_PIECES:
-                value['t'] = {'t': ''.join(pieces)}
-            elif len(pieces) <= TORRENT_PIECES:
-                s = sha.new().update(''.join(pieces))
-                value['h'] = s.digest()
-            else:
-                s = sha.new().update(''.join(pieces))
-                value['l'] = s.digest()
-            storeDefer = self.dht.storeValue(key, value)
-            storeDefer.addCallback(self.store_done, hash)
-            return storeDefer
+            return self.store(hash)
         return None
         return None
+            
+    def store(self, hash):
+        """Add a file to the DHT."""
+        key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
+        value = {'c': self.my_contact}
+        pieces = hash.pieceDigests()
+        if len(pieces) <= 1:
+            pass
+        elif len(pieces) <= DHT_PIECES:
+            value['t'] = {'t': ''.join(pieces)}
+        elif len(pieces) <= TORRENT_PIECES:
+            s = sha.new().update(''.join(pieces))
+            value['h'] = s.digest()
+        else:
+            s = sha.new().update(''.join(pieces))
+            value['l'] = s.digest()
+        storeDefer = self.dht.storeValue(key, value)
+        storeDefer.addCallback(self.store_done, hash)
+        return storeDefer
 
     def store_done(self, result, hash):
         log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
 
     def store_done(self, result, hash):
         log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
index 1ab2b37cac96df1bc8d6a7331326ef5b1d252dfe..e1d6d7b474f7f69bce218b85af50235e5d6f1dcb 100644 (file)
@@ -46,9 +46,13 @@ class DB:
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
-        c.execute("CREATE INDEX files_hash ON files(hash)")
-        c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
+        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
+                                      "size NUMBER, mtime NUMBER)")
+        c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
+                                       "hash KHASH UNIQUE, pieces KHASH, " +
+                                       "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
+        c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
         c.close()
         self.conn.commit()
 
@@ -65,32 +69,34 @@ class DB:
                 c.close()
         return res
         
                 c.close()
         return res
         
-    def storeFile(self, file, hash):
+    def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
-        new_hash = True
-        refreshTime = datetime.now()
+        piecehash = ''
+        if pieces:
+            s = sha.new().update(pieces)
+            piecehash = sha.digest()
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         row = c.fetchone()
-        if row and row['max_refresh']:
+        if row:
+            assert piecehash == row['piecehash']
             new_hash = False
             new_hash = False
-            refreshTime = row['max_refresh']
-        c.close()
+            hashID = row['hashID']
+        else:
+            c = self.conn.cursor()
+            c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
+                      (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
+            self.conn.commit()
+            new_hash = True
+            hashID = c.lastrowid
         
         file.restat()
         
         file.restat()
-        c = self.conn.cursor()
-        c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
-        row = c.fetchone()
-        if row:
-            c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), file.getsize(), file.getmtime(), refreshTime))
-        else:
-            c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
-                      (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
+                  (file.path, hashID, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
         self.conn.commit()
         c.close()
         
@@ -105,7 +111,7 @@ class DB:
             None if not in database or missing
         """
         c = self.conn.cursor()
             None if not in database or missing
         """
         c = self.conn.cursor()
-        c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
+        c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
         row = c.fetchone()
         res = None
         if row:
         row = c.fetchone()
         res = None
         if row:
@@ -114,19 +120,21 @@ class DB:
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
+                res['pieces'] = row['pieces']
         c.close()
         return res
         
         c.close()
         return res
         
-    def lookupHash(self, hash):
+    def lookupHash(self, hash, filesOnly = False):
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
-        from the database.
+        from the database. If filesOnly is False then it will also look for
+        piece string hashes if no files can be found.
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
-        c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
         row = c.fetchone()
         files = []
         while row:
@@ -137,8 +145,19 @@ class DB:
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
                 files.append(res)
             row = c.fetchone()
                 files.append(res)
             row = c.fetchone()
+            
+        if not filesOnly and not files:
+            c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
+            row = c.fetchone()
+            if row:
+                res = {}
+                res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
+                files.append(res)
+
         c.close()
         return files
         
         c.close()
         return files
         
@@ -156,12 +175,11 @@ class DB:
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
-        refreshTime = datetime.now()
         c = self.conn.cursor()
         c = self.conn.cursor()
-        c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
+        c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
     
         c.close()
     
-    def expiredFiles(self, expireAfter):
+    def expiredHashes(self, expireAfter):
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
@@ -173,27 +191,32 @@ class DB:
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         expired = {}
         while row:
         row = c.fetchone()
         expired = {}
         while row:
-            expired.setdefault(row['hash'], [])
+            res = expired.setdefault(row['hash'], {})
+            res['hashID'] = row['hashID']
+            res['hash'] = row['hash']
+            res['pieces'] = row['pieces']
             row = c.fetchone()
             row = c.fetchone()
-        c.close()
 
 
-        # Now find the files for each hash
-        for hash in expired.keys():
-            c = self.conn.cursor()
-            c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
+        # Make sure there are still valid files for each hash
+        for hash in expired.values():
+            valid = False
+            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
-                    expired[hash].append(FilePath(row['path']))
+                    valid = True
                 row = c.fetchone()
                 row = c.fetchone()
-            if len(expired[hash]) == 0:
-                del expired[hash]
-            c.close()
+            if not valid:
+                del expired[hash['hash']]
+                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                
+        self.conn.commit()
+        c.close()
         
         return expired
         
         
         return expired
         
@@ -249,7 +272,7 @@ class TestDB(unittest.TestCase):
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
-    def test_openExistsingDB(self):
+    def test_openExistingDB(self):
         self.store.close()
         self.store = None
         sleep(1)
         self.store.close()
         self.store = None
         sleep(1)
@@ -275,20 +298,18 @@ class TestDB(unittest.TestCase):
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
-        self.file.remove()
         res = self.store.isUnchanged(self.file)
         res = self.store.isUnchanged(self.file)
-        self.failUnless(res == None)
+        self.failUnless(res is None)
         
     def test_expiry(self):
         
     def test_expiry(self):
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 1)
         self.store.refreshHash(self.hash)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
@@ -302,7 +323,7 @@ class TestDB(unittest.TestCase):
     
     def test_multipleHashes(self):
         self.build_dirs()
     
     def test_multipleHashes(self):
         self.build_dirs()
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
@@ -311,12 +332,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 4)
         self.store.refreshHash(self.hash)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
@@ -338,3 +358,4 @@ class TestDB(unittest.TestCase):
         self.directory.remove()
         self.store.close()
         self.db.remove()
         self.directory.remove()
         self.store.close()
         self.db.remove()
+