Try to rejoin DHT periodically after failures using exponential backoff.
authorCameron Dale <camrdale@gmail.com>
Mon, 14 Apr 2008 19:23:33 +0000 (12:23 -0700)
committerCameron Dale <camrdale@gmail.com>
Mon, 14 Apr 2008 20:15:52 +0000 (13:15 -0700)
Also make the get and store functions call errbacks if not joined.
Also modify the callers of the get and store functions to respond well
to errbacks.
Also add a unittest for the rejoining functionality.

TODO
apt_p2p/PeerManager.py
apt_p2p/apt_p2p.py
apt_p2p_Khashmir/DHT.py

diff --git a/TODO b/TODO
index 9f1e2f7952721862c1d1bc3bc4c4b5bbf5e07a83..2477d2fc0b851ac8161055d33ebe8a4dc4384e84 100644 (file)
--- a/TODO
+++ b/TODO
@@ -1,10 +1,3 @@
-Retry when joining the DHT.
-
-If a join node can not be reached when the program is started, it will
-currently give up and quit. Instead, it should try and join
-periodically every few minutes until it is successful.
-
-
 Add statistics gathering to the peer downloading.
 
 Statistics are needed of how much has been uploaded, downloaded from
 Add statistics gathering to the peer downloading.
 
 Statistics are needed of how much has been uploaded, downloaded from
index e9d1ecb590702de992bec32c14f3803fa0d6b6b9..5df37fe4b11594c4eb272a287b73a090390d5d48 100644 (file)
@@ -356,22 +356,26 @@ class FileDownload:
 
         # Start the DHT lookup
         lookupDefer = self.manager.dht.getValue(key)
 
         # Start the DHT lookup
         lookupDefer = self.manager.dht.getValue(key)
-        lookupDefer.addCallback(self._getDHTPieces, key)
+        lookupDefer.addBoth(self._getDHTPieces, key)
         
     def _getDHTPieces(self, results, key):
         """Check the retrieved values."""
         
     def _getDHTPieces(self, results, key):
         """Check the retrieved values."""
-        for result in results:
-            # Make sure the hash matches the key
-            result_hash = sha.new(result.get('t', '')).digest()
-            if result_hash == key:
-                pieces = result['t']
-                self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
-                log.msg('Retrieved %d piece hashes from the DHT' % len(self.pieces))
-                self.startDownload()
-                return
+        if isinstance(results, list):
+            for result in results:
+                # Make sure the hash matches the key
+                result_hash = sha.new(result.get('t', '')).digest()
+                if result_hash == key:
+                    pieces = result['t']
+                    self.pieces = [pieces[x:x+20] for x in xrange(0, len(pieces), 20)]
+                    log.msg('Retrieved %d piece hashes from the DHT' % len(self.pieces))
+                    self.startDownload()
+                    return
+                
+            log.msg('Could not retrieve the piece hashes from the DHT')
+        else:
+            log.msg('Looking up piece hashes in the DHT resulted in an error: %r' % (result, ))
             
         # Continue without the piece hashes
             
         # Continue without the piece hashes
-        log.msg('Could not retrieve the piece hashes from the DHT')
         self.pieces = [None for x in xrange(0, self.hash.expSize, PIECE_SIZE)]
         self.startDownload()
 
         self.pieces = [None for x in xrange(0, self.hash.expSize, PIECE_SIZE)]
         self.startDownload()
 
index fb47468cb8c3b7d0865863725d224873133360ef..671e2269aea9d5713c020a4dc3a514aa20ddd894 100644 (file)
@@ -269,7 +269,7 @@ class AptP2P:
         log.msg('Looking up hash in DHT for file: %s' % url)
         key = hash.expected()
         lookupDefer = self.dht.getValue(key)
         log.msg('Looking up hash in DHT for file: %s' % url)
         key = hash.expected()
         lookupDefer = self.dht.getValue(key)
-        lookupDefer.addCallback(self.lookupHash_done, hash, url, d)
+        lookupDefer.addBoth(self.lookupHash_done, hash, url, d)
 
     def lookupHash_done(self, values, hash, url, d):
         """Start the download of the file.
 
     def lookupHash_done(self, values, hash, url, d):
         """Start the download of the file.
@@ -281,8 +281,11 @@ class AptP2P:
         @param values: the returned values from the DHT containing peer
             download information
         """
         @param values: the returned values from the DHT containing peer
             download information
         """
-        if not values:
-            log.msg('Peers for %s were not found' % url)
+        if not isinstance(values, list) or not values:
+            if not isinstance(values, list):
+                log.msg('DHT lookup for %s failed with error %r' % (url, values))
+            else:
+                log.msg('Peers for %s were not found' % url)
             getDefer = self.peers.get(hash, url)
             getDefer.addCallback(self.cache.save_file, hash, url)
             getDefer.addErrback(self.cache.save_error, url)
             getDefer = self.peers.get(hash, url)
             getDefer.addCallback(self.cache.save_file, hash, url)
             getDefer.addErrback(self.cache.save_error, url)
@@ -358,7 +361,8 @@ class AptP2P:
             value['l'] = sha.new(''.join(pieces)).digest()
 
         storeDefer = self.dht.storeValue(key, value)
             value['l'] = sha.new(''.join(pieces)).digest()
 
         storeDefer = self.dht.storeValue(key, value)
-        storeDefer.addCallback(self.store_done, hash)
+        storeDefer.addCallbacks(self.store_done, self.store_error,
+                                callbackArgs = (hash, ), errbackArgs = (hash.digest(), ))
         return storeDefer
 
     def store_done(self, result, hash):
         return storeDefer
 
     def store_done(self, result, hash):
@@ -371,7 +375,8 @@ class AptP2P:
             value = {'t': ''.join(pieces)}
 
             storeDefer = self.dht.storeValue(key, value)
             value = {'t': ''.join(pieces)}
 
             storeDefer = self.dht.storeValue(key, value)
-            storeDefer.addCallback(self.store_torrent_done, key)
+            storeDefer.addCallbacks(self.store_torrent_done, self.store_error,
+                                    callbackArgs = (key, ), errbackArgs = (key, ))
             return storeDefer
         return result
 
             return storeDefer
         return result
 
@@ -379,4 +384,9 @@ class AptP2P:
         """Adding the file to the DHT is complete, and so is the workflow."""
         log.msg('Added torrent string %s to the DHT: %r' % (b2a_hex(key), result))
         return result
         """Adding the file to the DHT is complete, and so is the workflow."""
         log.msg('Added torrent string %s to the DHT: %r' % (b2a_hex(key), result))
         return result
+
+    def store_error(self, err, key):
+        """Adding to the DHT failed."""
+        log.msg('An error occurred adding %s to the DHT: %r' % (b2a_hex(key), err))
+        return err
     
\ No newline at end of file
     
\ No newline at end of file
index edb626d7b338116bad8493bb9a267d6ec44eb7b1..f11eefc8305f87f12c10a85b5703d942ee3ed08e 100644 (file)
@@ -45,6 +45,8 @@ class DHT:
     @ivar joined: whether the DHT network has been successfully joined
     @type outstandingJoins: C{int}
     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
     @ivar joined: whether the DHT network has been successfully joined
     @type outstandingJoins: C{int}
     @ivar outstandingJoins: the number of bootstrap nodes that have yet to respond
+    @type next_rejoin: C{int}
+    @ivar next_rejoin: the number of seconds before retrying the next join
     @type foundAddrs: C{list} of (C{string}, C{int})
     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
     @type storing: C{dictionary}
     @type foundAddrs: C{list} of (C{string}, C{int})
     @ivar foundAddrs: the IP address an port that were returned by bootstrap nodes
     @type storing: C{dictionary}
@@ -79,8 +81,10 @@ class DHT:
         self.bootstrap = []
         self.bootstrap_node = False
         self.joining = None
         self.bootstrap = []
         self.bootstrap_node = False
         self.joining = None
+        self.khashmir = None
         self.joined = False
         self.outstandingJoins = 0
         self.joined = False
         self.outstandingJoins = 0
+        self.next_rejoin = 20
         self.foundAddrs = []
         self.storing = {}
         self.retrieving = {}
         self.foundAddrs = []
         self.storing = {}
         self.retrieving = {}
@@ -115,17 +119,33 @@ class DHT:
             else:
                 self.config[k] = self.config_parser.get(section, k)
     
             else:
                 self.config[k] = self.config_parser.get(section, k)
     
-    def join(self):
-        """See L{apt_p2p.interfaces.IDHT}."""
-        if self.config is None:
-            raise DHTError, "configuration not loaded"
+    def join(self, deferred = None):
+        """See L{apt_p2p.interfaces.IDHT}.
+        
+        @param deferred: the deferred to callback when the join is complete
+            (optional, defaults to creating a new deferred and returning it)
+        """
+        # Check for multiple simultaneous joins 
         if self.joining:
         if self.joining:
-            raise DHTError, "a join is already in progress"
+            if deferred:
+                deferred.errback(DHTError("a join is already in progress"))
+                return
+            else:
+                raise DHTError, "a join is already in progress"
+
+        if deferred:
+            self.joining = deferred
+        else:
+            self.joining = defer.Deferred()
+
+        if self.config is None:
+            self.joining.errback(DHTError("configuration not loaded"))
+            return self.joining
 
         # Create the new khashmir instance
 
         # Create the new khashmir instance
-        self.khashmir = Khashmir(self.config, self.cache_dir)
-        
-        self.joining = defer.Deferred()
+        if not self.khashmir:
+            self.khashmir = Khashmir(self.config, self.cache_dir)
+
         for node in self.bootstrap:
             host, port = node.rsplit(':', 1)
             port = int(port)
         for node in self.bootstrap:
             host, port = node.rsplit(':', 1)
             port = int(port)
@@ -168,7 +188,7 @@ class DHT:
 
     def _join_complete(self, result):
         """End the joining process and return the addresses found for this node."""
 
     def _join_complete(self, result):
         """End the joining process and return the addresses found for this node."""
-        if not self.joined and len(result) > 0:
+        if not self.joined and len(result) > 1:
             self.joined = True
         if self.joining and self.outstandingJoins <= 0:
             df = self.joining
             self.joined = True
         if self.joining and self.outstandingJoins <= 0:
             df = self.joining
@@ -177,7 +197,10 @@ class DHT:
                 self.joined = True
                 df.callback(self.foundAddrs)
             else:
                 self.joined = True
                 df.callback(self.foundAddrs)
             else:
-                df.errback(DHTError('could not find any nodes to bootstrap to'))
+                # Try to join later using exponential backoff delays
+                log.msg('Join failed, retrying in %d seconds' % self.next_rejoin)
+                reactor.callLater(self.next_rejoin, self.join, df)
+                self.next_rejoin *= 2
         
     def getAddrs(self):
         """Get the list of addresses returned by bootstrap nodes for this node."""
         
     def getAddrs(self):
         """Get the list of addresses returned by bootstrap nodes for this node."""
@@ -214,14 +237,17 @@ class DHT:
 
     def getValue(self, key):
         """See L{apt_p2p.interfaces.IDHT}."""
 
     def getValue(self, key):
         """See L{apt_p2p.interfaces.IDHT}."""
+        d = defer.Deferred()
+
         if self.config is None:
         if self.config is None:
-            raise DHTError, "configuration not loaded"
+            d.errback(DHTError("configuration not loaded"))
+            return d
         if not self.joined:
         if not self.joined:
-            raise DHTError, "have not joined a network yet"
+            d.errback(DHTError("have not joined a network yet"))
+            return d
         
         key = self._normKey(key)
 
         
         key = self._normKey(key)
 
-        d = defer.Deferred()
         if key not in self.retrieving:
             self.khashmir.valueForKey(key, self._getValue)
         self.retrieving.setdefault(key, []).append(d)
         if key not in self.retrieving:
             self.khashmir.valueForKey(key, self._getValue)
         self.retrieving.setdefault(key, []).append(d)
@@ -245,10 +271,14 @@ class DHT:
 
     def storeValue(self, key, value):
         """See L{apt_p2p.interfaces.IDHT}."""
 
     def storeValue(self, key, value):
         """See L{apt_p2p.interfaces.IDHT}."""
+        d = defer.Deferred()
+
         if self.config is None:
         if self.config is None:
-            raise DHTError, "configuration not loaded"
+            d.errback(DHTError("configuration not loaded"))
+            return d
         if not self.joined:
         if not self.joined:
-            raise DHTError, "have not joined a network yet"
+            d.errback(DHTError("have not joined a network yet"))
+            return d
 
         key = self._normKey(key)
         bvalue = bencode(value)
 
         key = self._normKey(key)
         bvalue = bencode(value)
@@ -256,7 +286,6 @@ class DHT:
         if key in self.storing and bvalue in self.storing[key]:
             raise DHTError, "already storing that key with the same value"
 
         if key in self.storing and bvalue in self.storing[key]:
             raise DHTError, "already storing that key with the same value"
 
-        d = defer.Deferred()
         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
         self.storing.setdefault(key, {})[bvalue] = d
         return d
         self.khashmir.storeValueForKey(key, bvalue, self._storeValue)
         self.storing.setdefault(key, {})[bvalue] = d
         return d
@@ -301,7 +330,7 @@ class DHT:
 class TestSimpleDHT(unittest.TestCase):
     """Simple 2-node unit tests for the DHT."""
     
 class TestSimpleDHT(unittest.TestCase):
     """Simple 2-node unit tests for the DHT."""
     
-    timeout = 2
+    timeout = 50
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
@@ -325,6 +354,16 @@ class TestSimpleDHT(unittest.TestCase):
     def test_bootstrap_join(self):
         d = self.a.join()
         return d
     def test_bootstrap_join(self):
         d = self.a.join()
         return d
+
+    def test_failed_join(self):
+        from krpc import KrpcError
+        d = self.b.join()
+        reactor.callLater(30, self.a.join)
+        def no_errors(result, self = self):
+            self.flushLoggedErrors(KrpcError)
+            return result
+        d.addCallback(no_errors)
+        return d
         
     def node_join(self, result):
         d = self.b.join()
         
     def node_join(self, result):
         d = self.b.join()
@@ -406,7 +445,7 @@ class TestSimpleDHT(unittest.TestCase):
 class TestMultiDHT(unittest.TestCase):
     """More complicated 20-node tests for the DHT."""
     
 class TestMultiDHT(unittest.TestCase):
     """More complicated 20-node tests for the DHT."""
     
-    timeout = 60
+    timeout = 80
     num = 20
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
     num = 20
     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,