Store piece hashes in the DB.
authorCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 23:58:05 +0000 (15:58 -0800)
committerCameron Dale <camrdale@gmail.com>
Wed, 27 Feb 2008 23:58:05 +0000 (15:58 -0800)
Also broke up single table into files and hashes.
Main code uses the new hash routine.
lookupHash now checks for torrent string hashes.
The HTTPServer still needs to interpret torrent string hashes.

apt_dht/Hash.py
apt_dht/apt_dht.py
apt_dht/db.py

index ec98598..a7a8e40 100644 (file)
@@ -15,6 +15,7 @@ class HashObject:
     
     """The priority ordering of hashes, and how to extract them."""
     ORDER = [ {'name': 'sha1', 
+                   'length': 20,
                    'AptPkgRecord': 'SHA1Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA1',
@@ -22,12 +23,14 @@ class HashObject:
                    'hashlib_func': 'sha1',
                    },
               {'name': 'sha256',
+                   'length': 32,
                    'AptPkgRecord': 'SHA256Hash', 
                    'AptSrcRecord': False, 
                    'AptIndexRecord': 'SHA256',
                    'hashlib_func': 'sha256',
                    },
               {'name': 'md5',
+                   'length': 16,
                    'AptPkgRecord': 'MD5Hash', 
                    'AptSrcRecord': True, 
                    'AptIndexRecord': 'MD5SUM',
@@ -36,8 +39,15 @@ class HashObject:
                    },
             ]
     
-    def __init__(self, digest = None, size = None):
+    def __init__(self, digest = None, size = None, pieces = ''):
         self.hashTypeNum = 0    # Use the first if nothing else matters
+        if sys.version_info < (2, 5):
+            # sha256 is not available in python before 2.5, remove it
+            for hashType in self.ORDER:
+                if hashType['name'] == 'sha256':
+                    del self.ORDER[self.ORDER.index(hashType)]
+                    break
+
         self.expHash = None
         self.expHex = None
         self.expSize = None
@@ -45,18 +55,13 @@ class HashObject:
         self.fileHasher = None
         self.pieceHasher = None
         self.fileHash = digest
-        self.pieceHash = []
+        self.pieceHash = [pieces[x:x+self.ORDER[self.hashTypeNum]['length']]
+                          for x in xrange(0, len(pieces), self.ORDER[self.hashTypeNum]['length'])]
         self.size = size
         self.fileHex = None
         self.fileNormHash = None
         self.done = True
         self.result = None
-        if sys.version_info < (2, 5):
-            # sha256 is not available in python before 2.5, remove it
-            for hashType in self.ORDER:
-                if hashType['name'] == 'sha256':
-                    del self.ORDER[self.ORDER.index(hashType)]
-                    break
         
     def _norm_hash(self, hashString, bits=None, bytes=None):
         if bits is not None:
@@ -173,9 +178,6 @@ class HashObject:
             # Save the last piece hash
             if self.pieceHasher:
                 self.pieceHash.append(self.pieceHasher.digest())
-            else:
-                # If there are no piece hashes, then the file hash is the only piece hash
-                self.pieceHash.append(self.fileHash)
         return self.fileHash
 
     def hexdigest(self):
index 151e3db..94260bd 100644 (file)
@@ -58,17 +58,24 @@ class AptDHT:
     def refreshFiles(self):
         """Refresh any files in the DHT that are about to expire."""
         expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
-        hashes = self.db.expiredFiles(expireAfter)
+        hashes = self.db.expiredHashes(expireAfter)
         if len(hashes.keys()) > 0:
             log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
-        for raw_hash in hashes:
+        self._refreshFiles(None, hashes)
+        
+    def _refreshFiles(self, result, hashes):
+        if result is not None:
+            log.msg('Storage resulted in: %r' % result)
+
+        if hashes:
+            raw_hash = hashes.keys()[0]
             self.db.refreshHash(raw_hash)
-            hash = HashObject(raw_hash)
-            key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
-            value = {'c': self.my_contact}
-            storeDefer = self.dht.storeValue(key, value)
-            storeDefer.addCallback(self.store_done, hash)
-        reactor.callLater(60, self.refreshFiles)
+            hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
+            del hashes[raw_hash]
+            storeDefer = self.store(hash)
+            storeDefer.addBoth(self._refreshFiles, hashes)
+        else:
+            reactor.callLater(60, self.refreshFiles)
 
     def check_freshness(self, req, path, modtime, resp):
         log.msg('Checking if %s is still fresh' % path)
@@ -107,7 +114,7 @@ class AptDHT:
             log.msg('Found hash %s for %s' % (hash.hexexpected(), path))
             
             # Lookup hash in cache
-            locations = self.db.lookupHash(hash.expected())
+            locations = self.db.lookupHash(hash.expected(), filesOnly = True)
             self.getCachedFile(hash, req, path, d, locations)
 
     def getCachedFile(self, hash, req, path, d, locations):
@@ -173,33 +180,37 @@ class AptDHT:
         return response
         
     def new_cached_file(self, file_path, hash, new_hash, url = None, forceDHT = False):
-        """Add a newly cached file to the DHT.
+        """Add a newly cached file to the appropriate places.
         
         If the file was downloaded, set url to the path it was downloaded for.
-        Don't add a file to the DHT unless a hash was found for it
-        (but do add it anyway if forceDHT is True).
+        Doesn't add a file to the DHT unless a hash was found for it
+        (but does add it anyway if forceDHT is True).
         """
         if url:
             self.mirrors.updatedFile(url, file_path)
         
         if self.my_contact and hash and new_hash and (hash.expected() is not None or forceDHT):
-            key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
-            value = {'c': self.my_contact}
-            pieces = hash.pieceDigests()
-            if len(pieces) <= 1:
-                pass
-            elif len(pieces) <= DHT_PIECES:
-                value['t'] = {'t': ''.join(pieces)}
-            elif len(pieces) <= TORRENT_PIECES:
-                s = sha.new().update(''.join(pieces))
-                value['h'] = s.digest()
-            else:
-                s = sha.new().update(''.join(pieces))
-                value['l'] = s.digest()
-            storeDefer = self.dht.storeValue(key, value)
-            storeDefer.addCallback(self.store_done, hash)
-            return storeDefer
+            return self.store(hash)
         return None
+            
+    def store(self, hash):
+        """Add a file to the DHT."""
+        key = hash.norm(bits = config.getint(config.get('DEFAULT', 'DHT'), 'HASH_LENGTH'))
+        value = {'c': self.my_contact}
+        pieces = hash.pieceDigests()
+        if len(pieces) <= 1:
+            pass
+        elif len(pieces) <= DHT_PIECES:
+            value['t'] = {'t': ''.join(pieces)}
+        elif len(pieces) <= TORRENT_PIECES:
+            s = sha.new().update(''.join(pieces))
+            value['h'] = s.digest()
+        else:
+            s = sha.new().update(''.join(pieces))
+            value['l'] = s.digest()
+        storeDefer = self.dht.storeValue(key, value)
+        storeDefer.addCallback(self.store_done, hash)
+        return storeDefer
 
     def store_done(self, result, hash):
         log.msg('Added %s to the DHT: %r' % (hash.hexdigest(), result))
index 1ab2b37..e1d6d7b 100644 (file)
@@ -46,9 +46,13 @@ class DB:
             self.db.parent().makedirs()
         self.conn = sqlite.connect(database=self.db.path, detect_types=sqlite.PARSE_DECLTYPES)
         c = self.conn.cursor()
-        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY, hash KHASH, size NUMBER, mtime NUMBER, refreshed TIMESTAMP)")
-        c.execute("CREATE INDEX files_hash ON files(hash)")
-        c.execute("CREATE INDEX files_refreshed ON files(refreshed)")
+        c.execute("CREATE TABLE files (path TEXT PRIMARY KEY UNIQUE, hashID INTEGER, " +
+                                      "size NUMBER, mtime NUMBER)")
+        c.execute("CREATE TABLE hashes (hashID INTEGER PRIMARY KEY AUTOINCREMENT, " +
+                                       "hash KHASH UNIQUE, pieces KHASH, " +
+                                       "piecehash KHASH, refreshed TIMESTAMP)")
+        c.execute("CREATE INDEX hashes_refreshed ON hashes(refreshed)")
+        c.execute("CREATE INDEX hashes_piecehash ON hashes(piecehash)")
         c.close()
         self.conn.commit()
 
@@ -65,32 +69,34 @@ class DB:
                 c.close()
         return res
         
-    def storeFile(self, file, hash):
+    def storeFile(self, file, hash, pieces = ''):
         """Store or update a file in the database.
         
         @return: True if the hash was not in the database before
             (so it needs to be added to the DHT)
         """
-        new_hash = True
-        refreshTime = datetime.now()
+        piecehash = ''
+        if pieces:
+            s = sha.new().update(pieces)
+            piecehash = sha.digest()
         c = self.conn.cursor()
-        c.execute("SELECT MAX(refreshed) AS max_refresh FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT hashID, piecehash FROM hashes WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
-        if row and row['max_refresh']:
+        if row:
+            assert piecehash == row['piecehash']
             new_hash = False
-            refreshTime = row['max_refresh']
-        c.close()
+            hashID = row['hashID']
+        else:
+            c = self.conn.cursor()
+            c.execute("INSERT OR REPLACE INTO hashes (hash, pieces, piecehash, refreshed) VALUES (?, ?, ?)",
+                      (khash(hash), khash(pieces), khash(piecehash), datetime.now()))
+            self.conn.commit()
+            new_hash = True
+            hashID = c.lastrowid
         
         file.restat()
-        c = self.conn.cursor()
-        c.execute("SELECT path FROM files WHERE path = ?", (file.path, ))
-        row = c.fetchone()
-        if row:
-            c.execute("UPDATE files SET hash = ?, size = ?, mtime = ?, refreshed = ?", 
-                      (khash(hash), file.getsize(), file.getmtime(), refreshTime))
-        else:
-            c.execute("INSERT OR REPLACE INTO files VALUES(?, ?, ?, ?, ?)",
-                      (file.path, khash(hash), file.getsize(), file.getmtime(), refreshTime))
+        c.execute("INSERT OR REPLACE INTO files (path, hashID, size, mtime) VALUES (?, ?, ?, ?)",
+                  (file.path, hashID, file.getsize(), file.getmtime()))
         self.conn.commit()
         c.close()
         
@@ -105,7 +111,7 @@ class DB:
             None if not in database or missing
         """
         c = self.conn.cursor()
-        c.execute("SELECT hash, size, mtime FROM files WHERE path = ?", (file.path, ))
+        c.execute("SELECT hash, size, mtime, pieces FROM files JOIN hashes USING (hashID) WHERE path = ?", (file.path, ))
         row = c.fetchone()
         res = None
         if row:
@@ -114,19 +120,21 @@ class DB:
                 res = {}
                 res['hash'] = row['hash']
                 res['size'] = row['size']
+                res['pieces'] = row['pieces']
         c.close()
         return res
         
-    def lookupHash(self, hash):
+    def lookupHash(self, hash, filesOnly = False):
         """Find a file by hash in the database.
         
         If any found files have changed or are missing, they are removed
-        from the database.
+        from the database. If filesOnly is False then it will also look for
+        piece string hashes if no files can be found.
         
         @return: list of dictionaries of info for the found files
         """
         c = self.conn.cursor()
-        c.execute("SELECT path, size, mtime, refreshed FROM files WHERE hash = ?", (khash(hash), ))
+        c.execute("SELECT path, size, mtime, refreshed, pieces FROM files JOIN hashes USING (hashID) WHERE hash = ?", (khash(hash), ))
         row = c.fetchone()
         files = []
         while row:
@@ -137,8 +145,19 @@ class DB:
                 res['path'] = file
                 res['size'] = row['size']
                 res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
                 files.append(res)
             row = c.fetchone()
+            
+        if not filesOnly and not files:
+            c.execute("SELECT refreshed, pieces, piecehash FROM hashes WHERE piecehash = ?", (khash(hash), ))
+            row = c.fetchone()
+            if row:
+                res = {}
+                res['refreshed'] = row['refreshed']
+                res['pieces'] = row['pieces']
+                files.append(res)
+
         c.close()
         return files
         
@@ -156,12 +175,11 @@ class DB:
 
     def refreshHash(self, hash):
         """Refresh the publishing time all files with a hash."""
-        refreshTime = datetime.now()
         c = self.conn.cursor()
-        c.execute("UPDATE files SET refreshed = ? WHERE hash = ?", (refreshTime, khash(hash)))
+        c.execute("UPDATE hashes SET refreshed = ? WHERE hash = ?", (datetime.now(), khash(hash)))
         c.close()
     
-    def expiredFiles(self, expireAfter):
+    def expiredHashes(self, expireAfter):
         """Find files that need refreshing after expireAfter seconds.
         
         For each hash that needs refreshing, finds all the files with that hash.
@@ -173,27 +191,32 @@ class DB:
         
         # First find the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT DISTINCT hash FROM files WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
         row = c.fetchone()
         expired = {}
         while row:
-            expired.setdefault(row['hash'], [])
+            res = expired.setdefault(row['hash'], {})
+            res['hashID'] = row['hashID']
+            res['hash'] = row['hash']
+            res['pieces'] = row['pieces']
             row = c.fetchone()
-        c.close()
 
-        # Now find the files for each hash
-        for hash in expired.keys():
-            c = self.conn.cursor()
-            c.execute("SELECT path, size, mtime FROM files WHERE hash = ?", (khash(hash), ))
+        # Make sure there are still valid files for each hash
+        for hash in expired.values():
+            valid = False
+            c.execute("SELECT path, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
                 res = self._removeChanged(FilePath(row['path']), row)
                 if res:
-                    expired[hash].append(FilePath(row['path']))
+                    valid = True
                 row = c.fetchone()
-            if len(expired[hash]) == 0:
-                del expired[hash]
-            c.close()
+            if not valid:
+                del expired[hash['hash']]
+                c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
+                
+        self.conn.commit()
+        c.close()
         
         return expired
         
@@ -249,7 +272,7 @@ class TestDB(unittest.TestCase):
         self.store = DB(self.db)
         self.store.storeFile(self.file, self.hash)
 
-    def test_openExistsingDB(self):
+    def test_openExistingDB(self):
         self.store.close()
         self.store = None
         sleep(1)
@@ -275,20 +298,18 @@ class TestDB(unittest.TestCase):
         self.file.touch()
         res = self.store.isUnchanged(self.file)
         self.failUnless(res == False)
-        self.file.remove()
         res = self.store.isUnchanged(self.file)
-        self.failUnless(res == None)
+        self.failUnless(res is None)
         
     def test_expiry(self):
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 1)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         
     def build_dirs(self):
@@ -302,7 +323,7 @@ class TestDB(unittest.TestCase):
     
     def test_multipleHashes(self):
         self.build_dirs()
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
@@ -311,12 +332,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[2]['refreshed'])
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 1)
         self.failUnlessEqual(res.keys()[0], self.hash)
-        self.failUnlessEqual(len(res[self.hash]), 4)
         self.store.refreshHash(self.hash)
-        res = self.store.expiredFiles(1)
+        res = self.store.expiredHashes(1)
         self.failUnlessEqual(len(res.keys()), 0)
     
     def test_removeUntracked(self):
@@ -338,3 +358,4 @@ class TestDB(unittest.TestCase):
         self.directory.remove()
         self.store.close()
         self.db.remove()
+