Fix a traceback generating bug in the refreshing of DHT entries.
authorCameron Dale <camrdale@gmail.com>
Wed, 7 May 2008 07:10:38 +0000 (00:10 -0700)
committerCameron Dale <camrdale@gmail.com>
Wed, 7 May 2008 07:10:38 +0000 (00:10 -0700)
Also change the DB's expiredHashes to return a sorted list.

apt_p2p/DHTManager.py
apt_p2p/db.py

index 42a23e7..14e88e4 100644 (file)
@@ -32,6 +32,10 @@ class DHT:
     @type my_contact: C{string}
     @ivar my_contact: the 6-byte compact peer representation of this peer's
         download information (IP address and port)
+    @type nextRefresh: L{twisted.internet.interfaces.IDelayedCall}
+    @ivar nextRefresh: the next delayed call to refreshFiles
+    @type refreshingHashes: C{list} of C{dictionary}
+    @ivar refreshingHashes: the list of hashes that still need to be refreshed
     """
     
     def __init__(self, dhtClass, db):
@@ -43,6 +47,8 @@ class DHT:
         self.dhtClass = dhtClass
         self.db = db
         self.my_contact = None
+        self.nextRefresh = None
+        self.refreshingHashes = []
         
     def start(self):
         self.dht = self.dhtClass()
@@ -63,7 +69,8 @@ class DHT:
         if not my_addr:
             raise RuntimeError, "IP address for this machine could not be found"
         self.my_contact = compact(my_addr, config.getint('DEFAULT', 'PORT'))
-        self.nextRefresh = reactor.callLater(60, self.refreshFiles)
+        if not self.nextRefresh or not self.nextRefresh.active():
+            self.nextRefresh = reactor.callLater(60, self.refreshFiles)
         return (my_addr, config.getint('DEFAULT', 'PORT'))
 
     def joinError(self, failure):
@@ -72,31 +79,30 @@ class DHT:
         log.err(failure)
         return failure
     
-    def refreshFiles(self, result = None, hashes = {}):
+    def refreshFiles(self, result = None):
         """Refresh any files in the DHT that are about to expire."""
         if result is not None:
             log.msg('Storage resulted in: %r' % result)
 
-        if not hashes:
+        if not self.refreshingHashes:
             expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
-            hashes = self.db.expiredHashes(expireAfter)
-            if len(hashes.keys()) > 0:
-                log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
+            self.refreshingHashes = self.db.expiredHashes(expireAfter)
+            if len(self.refreshingHashes) > 0:
+                log.msg('Refreshing the keys of %d DHT values' % len(self.refreshingHashes))
 
         delay = 60
-        if hashes:
+        if self.refreshingHashes:
             delay = 3
-            raw_hash = hashes.keys()[0]
-            self.db.refreshHash(raw_hash)
-            hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
-            del hashes[raw_hash]
+            refresh = self.refreshingHashes.pop(0)
+            self.db.refreshHash(refresh['hash'])
+            hash = HashObject(refresh['hash'], pieces = refresh['pieces'])
             storeDefer = self.store(hash)
-            storeDefer.addBoth(self.refreshFiles, hashes)
+            storeDefer.addBoth(self.refreshFiles)
 
         if self.nextRefresh.active():
             self.nextRefresh.reset(delay)
         else:
-            self.nextRefresh = reactor.callLater(delay, self.refreshFiles, None, hashes)
+            self.nextRefresh = reactor.callLater(delay, self.refreshFiles)
     
     def getStats(self):
         """Retrieve the formatted statistics for the DHT.
index 2640987..4bb5a75 100644 (file)
@@ -240,26 +240,28 @@ class DB:
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
-        @return: dictionary with keys the hashes, values a list of FilePaths
+        @return: a list of dictionaries of each hash needing refreshing, sorted by age
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
         row = c.fetchone()
-        expired = {}
+        expired = []
         while row:
-            res = expired.setdefault(row['hash'], {})
-            res['hashID'] = row['hashID']
+            res = {}
             res['hash'] = row['hash']
+            res['hashID'] = row['hashID']
             res['pieces'] = row['pieces']
+            expired.append(res)
             row = c.fetchone()
 
         # Make sure there are still valid DHT files for each hash
-        for hash in expired.values():
+        for i in xrange(len(expired)-1, -1, -1):
             dht = False
             non_dht = False
+            hash = expired[i]
             c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
@@ -270,7 +272,7 @@ class DB:
                 row = c.fetchone()
             if not dht:
                 # Remove hashes for which no DHT files are still available
-                del expired[hash['hash']]
+                del expired[i]
                 if not non_dht:
                     # Remove hashes for which no files are still available
                     c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
@@ -427,14 +429,14 @@ class TestDB(unittest.TestCase):
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
@@ -449,7 +451,7 @@ class TestDB(unittest.TestCase):
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
@@ -458,11 +460,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""