Fix a traceback generating bug in the refreshing of DHT entries.
authorCameron Dale <camrdale@gmail.com>
Wed, 7 May 2008 07:10:38 +0000 (00:10 -0700)
committerCameron Dale <camrdale@gmail.com>
Wed, 7 May 2008 07:10:38 +0000 (00:10 -0700)
Also change the DB's expiredHashes to return a sorted list.

apt_p2p/DHTManager.py
apt_p2p/db.py

index 42a23e736472a3f9e9627e4055c0998bfd4924db..14e88e47b51550ad83120d58f0a7329ab0f6eb3e 100644 (file)
@@ -32,6 +32,10 @@ class DHT:
     @type my_contact: C{string}
     @ivar my_contact: the 6-byte compact peer representation of this peer's
         download information (IP address and port)
     @type my_contact: C{string}
     @ivar my_contact: the 6-byte compact peer representation of this peer's
         download information (IP address and port)
+    @type nextRefresh: L{twisted.internet.interfaces.IDelayedCall}
+    @ivar nextRefresh: the next delayed call to refreshFiles
+    @type refreshingHashes: C{list} of C{dictionary}
+    @ivar refreshingHashes: the list of hashes that still need to be refreshed
     """
     
     def __init__(self, dhtClass, db):
     """
     
     def __init__(self, dhtClass, db):
@@ -43,6 +47,8 @@ class DHT:
         self.dhtClass = dhtClass
         self.db = db
         self.my_contact = None
         self.dhtClass = dhtClass
         self.db = db
         self.my_contact = None
+        self.nextRefresh = None
+        self.refreshingHashes = []
         
     def start(self):
         self.dht = self.dhtClass()
         
     def start(self):
         self.dht = self.dhtClass()
@@ -63,7 +69,8 @@ class DHT:
         if not my_addr:
             raise RuntimeError, "IP address for this machine could not be found"
         self.my_contact = compact(my_addr, config.getint('DEFAULT', 'PORT'))
         if not my_addr:
             raise RuntimeError, "IP address for this machine could not be found"
         self.my_contact = compact(my_addr, config.getint('DEFAULT', 'PORT'))
-        self.nextRefresh = reactor.callLater(60, self.refreshFiles)
+        if not self.nextRefresh or not self.nextRefresh.active():
+            self.nextRefresh = reactor.callLater(60, self.refreshFiles)
         return (my_addr, config.getint('DEFAULT', 'PORT'))
 
     def joinError(self, failure):
         return (my_addr, config.getint('DEFAULT', 'PORT'))
 
     def joinError(self, failure):
@@ -72,31 +79,30 @@ class DHT:
         log.err(failure)
         return failure
     
         log.err(failure)
         return failure
     
-    def refreshFiles(self, result = None, hashes = {}):
+    def refreshFiles(self, result = None):
         """Refresh any files in the DHT that are about to expire."""
         if result is not None:
             log.msg('Storage resulted in: %r' % result)
 
         """Refresh any files in the DHT that are about to expire."""
         if result is not None:
             log.msg('Storage resulted in: %r' % result)
 
-        if not hashes:
+        if not self.refreshingHashes:
             expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
             expireAfter = config.gettime('DEFAULT', 'KEY_REFRESH')
-            hashes = self.db.expiredHashes(expireAfter)
-            if len(hashes.keys()) > 0:
-                log.msg('Refreshing the keys of %d DHT values' % len(hashes.keys()))
+            self.refreshingHashes = self.db.expiredHashes(expireAfter)
+            if len(self.refreshingHashes) > 0:
+                log.msg('Refreshing the keys of %d DHT values' % len(self.refreshingHashes))
 
         delay = 60
 
         delay = 60
-        if hashes:
+        if self.refreshingHashes:
             delay = 3
             delay = 3
-            raw_hash = hashes.keys()[0]
-            self.db.refreshHash(raw_hash)
-            hash = HashObject(raw_hash, pieces = hashes[raw_hash]['pieces'])
-            del hashes[raw_hash]
+            refresh = self.refreshingHashes.pop(0)
+            self.db.refreshHash(refresh['hash'])
+            hash = HashObject(refresh['hash'], pieces = refresh['pieces'])
             storeDefer = self.store(hash)
             storeDefer = self.store(hash)
-            storeDefer.addBoth(self.refreshFiles, hashes)
+            storeDefer.addBoth(self.refreshFiles)
 
         if self.nextRefresh.active():
             self.nextRefresh.reset(delay)
         else:
 
         if self.nextRefresh.active():
             self.nextRefresh.reset(delay)
         else:
-            self.nextRefresh = reactor.callLater(delay, self.refreshFiles, None, hashes)
+            self.nextRefresh = reactor.callLater(delay, self.refreshFiles)
     
     def getStats(self):
         """Retrieve the formatted statistics for the DHT.
     
     def getStats(self):
         """Retrieve the formatted statistics for the DHT.
index 26409875f3339c37bb74fb086d7cc07fc8ae972f..4bb5a754b498c32e2cac9b53ae21c1558f7f8c82 100644 (file)
@@ -240,26 +240,28 @@ class DB:
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
         For each hash that needs refreshing, finds all the files with that hash.
         If the file has changed or is missing, it is removed from the table.
         
-        @return: dictionary with keys the hashes, values a list of FilePaths
+        @return: a list of dictionaries of each hash needing refreshing, sorted by age
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
         """
         t = datetime.now() - timedelta(seconds=expireAfter)
         
         # Find all the hashes that need refreshing
         c = self.conn.cursor()
-        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ?", (t, ))
+        c.execute("SELECT hashID, hash, pieces FROM hashes WHERE refreshed < ? ORDER BY refreshed", (t, ))
         row = c.fetchone()
         row = c.fetchone()
-        expired = {}
+        expired = []
         while row:
         while row:
-            res = expired.setdefault(row['hash'], {})
-            res['hashID'] = row['hashID']
+            res = {}
             res['hash'] = row['hash']
             res['hash'] = row['hash']
+            res['hashID'] = row['hashID']
             res['pieces'] = row['pieces']
             res['pieces'] = row['pieces']
+            expired.append(res)
             row = c.fetchone()
 
         # Make sure there are still valid DHT files for each hash
             row = c.fetchone()
 
         # Make sure there are still valid DHT files for each hash
-        for hash in expired.values():
+        for i in xrange(len(expired)-1, -1, -1):
             dht = False
             non_dht = False
             dht = False
             non_dht = False
+            hash = expired[i]
             c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
             c.execute("SELECT path, dht, size, mtime FROM files WHERE hashID = ?", (hash['hashID'], ))
             row = c.fetchone()
             while row:
@@ -270,7 +272,7 @@ class DB:
                 row = c.fetchone()
             if not dht:
                 # Remove hashes for which no DHT files are still available
                 row = c.fetchone()
             if not dht:
                 # Remove hashes for which no DHT files are still available
-                del expired[hash['hash']]
+                del expired[i]
                 if not non_dht:
                     # Remove hashes for which no files are still available
                     c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
                 if not non_dht:
                     # Remove hashes for which no files are still available
                     c.execute("DELETE FROM hashes WHERE hashID = ?", (hash['hashID'], ))
@@ -427,14 +429,14 @@ class TestDB(unittest.TestCase):
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
     def test_expiry(self):
         """Tests retrieving the files from the database that have expired."""
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         sleep(2)
         res = self.store.expiredHashes(1)
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         
     def build_dirs(self):
         for dir in self.dirs:
         
     def build_dirs(self):
         for dir in self.dirs:
@@ -449,7 +451,7 @@ class TestDB(unittest.TestCase):
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
         """Tests looking up a hash with multiple files in the database."""
         self.build_dirs()
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
         res = self.store.lookupHash(self.hash)
         self.failUnless(res)
         self.failUnlessEqual(len(res), 4)
@@ -458,11 +460,11 @@ class TestDB(unittest.TestCase):
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
         self.failUnlessEqual(res[0]['refreshed'], res[3]['refreshed'])
         sleep(2)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 1)
-        self.failUnlessEqual(res.keys()[0], self.hash)
+        self.failUnlessEqual(len(res), 1)
+        self.failUnlessEqual(res[0]['hash'], self.hash)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
         self.store.refreshHash(self.hash)
         res = self.store.expiredHashes(1)
-        self.failUnlessEqual(len(res.keys()), 0)
+        self.failUnlessEqual(len(res), 0)
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""
     
     def test_removeUntracked(self):
         """Tests removing untracked files from the database."""