Increase the concurrency of DHT requests to 8.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / khashmir.py
index 126a30ea585b86b45cff8f867ddaa0daa9e99a5f..f083da4763dc25f5a96fea8b50be9ca7488a89a7 100644 (file)
@@ -1,7 +1,9 @@
-## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
-# see LICENSE.txt for license information
 
-"""The main Khashmir program."""
+"""The main Khashmir program.
+
+@var isLocal: a compiled regular expression suitable for testing if an
+    IP address is from a known local or private range
+"""
 
 import warnings
 warnings.simplefilter("ignore", DeprecationWarning)
@@ -9,10 +11,12 @@ warnings.simplefilter("ignore", DeprecationWarning)
 from datetime import datetime, timedelta
 from random import randrange, shuffle
 from sha import sha
-import os
+from copy import copy
+import os, re
 
 from twisted.internet.defer import Deferred
 from twisted.internet import protocol, reactor
+from twisted.python import log
 from twisted.trial import unittest
 
 from db import DB
@@ -20,8 +24,16 @@ from ktable import KTable
 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
 from khash import newID, newIDInRange
 from actions import FindNode, FindValue, GetValue, StoreValue
+from stats import StatsLogger
 import krpc
 
+isLocal = re.compile('^(192\.168\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(10\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?1[6-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?2[0-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?3[0-1]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(127\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})$')
+
 class KhashmirBase(protocol.Factory):
     """The base Khashmir class, with base functionality and find node, no key-value mappings.
     
@@ -39,6 +51,8 @@ class KhashmirBase(protocol.Factory):
     @ivar table: the routing table
     @type token_secrets: C{list} of C{string}
     @ivar token_secrets: the current secrets to use to create tokens
+    @type stats: L{stats.StatsLogger}
+    @ivar stats: the statistics gatherer
     @type udp: L{krpc.hostbroker}
     @ivar udp: the factory for the KRPC protocol
     @type listenport: L{twisted.internet.interfaces.IListeningPort}
@@ -75,9 +89,10 @@ class KhashmirBase(protocol.Factory):
         self.node = self._loadSelfNode('', self.port)
         self.table = KTable(self.node, config)
         self.token_secrets = [newID()]
+        self.stats = StatsLogger(self.table, self.store)
         
         # Start listening
-        self.udp = krpc.hostbroker(self, config)
+        self.udp = krpc.hostbroker(self, self.stats, config)
         self.udp.protocol = krpc.KRPC
         self.listenport = reactor.listenUDP(self.port, self.udp)
         
@@ -151,12 +166,12 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
         n = self.Node(NULL_ID, host, port)
         self.sendJoin(n, callback=callback, errback=errback)
 
-    def findNode(self, id, callback, errback=None):
+    def findNode(self, id, callback):
         """Find the contact info for the K closest nodes in the global table.
         
         @type id: C{string}
@@ -164,32 +179,20 @@ class KhashmirBase(protocol.Factory):
         @type callback: C{method}
         @param callback: the method to call with the results, it must take 1
             parameter, the list of K closest nodes
-        @type errback: C{method}
-        @param errback: the method to call if an error occurs
-            (optional, defaults to doing nothing when an error occurs)
         """
-        # Get K nodes out of local table/cache
-        nodes = self.table.findNodes(id)
-        d = Deferred()
-        if errback:
-            d.addCallbacks(callback, errback)
-        else:
-            d.addCallback(callback)
+        # Start with our node
+        nodes = [copy(self.node)]
 
-        # If the target ID was found
-        if len(nodes) == 1 and nodes[0].id == id:
-            d.callback(nodes)
-        else:
-            # Start the finding nodes action
-            state = FindNode(self, id, d.callback, self.config)
-            reactor.callLater(0, state.goWithNodes, nodes)
+        # Start the finding nodes action
+        state = FindNode(self, id, callback, self.config, self.stats)
+        reactor.callLater(0, state.goWithNodes, nodes)
     
     def insertNode(self, node, contacted = True):
         """Try to insert a node in our local table, pinging oldest contact if necessary.
         
         If all you have is a host/port, then use L{addContact}, which calls this
         method after receiving the PONG from the remote node. The reason for
-        the seperation is we can't insert a node into the table without its
+        the separation is we can't insert a node into the table without its
         node ID. That means of course the node passed into this method needs
         to be a properly formed Node object with a valid ID.
 
@@ -199,25 +202,59 @@ class KhashmirBase(protocol.Factory):
         @param contacted: whether the new node is known to be good, i.e.
             responded to a request (optional, defaults to True)
         """
+        # Don't add any local nodes to the routing table
+        if not self.config['LOCAL_OK'] and isLocal.match(node.host):
+            log.msg('Not adding local node to table: %s/%s' % (node.host, node.port))
+            return
+        
         old = self.table.insertNode(node, contacted=contacted)
-        if (old and old.id != self.node.id and
+
+        if (isinstance(old, self._Node) and old.id != self.node.id and
             (datetime.now() - old.lastSeen) > 
              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             
-            def _staleNodeHandler(oldnode = old, newnode = node):
-                """The pinged node never responded, so replace it."""
-                self.table.replaceStaleNode(oldnode, newnode)
-            
-            def _notStaleNodeHandler(dict, old=old):
-                """Got a pong from the old node, so update it."""
-                dict = dict['rsp']
-                if dict['id'] == old.id:
-                    self.table.justSeenNode(old.id)
-            
             # Bucket is full, check to see if old node is still available
+            self.stats.startedAction('ping')
             df = old.ping(self.node.id)
-            df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
+            df.addCallbacks(self._freshNodeHandler, self._staleNodeHandler,
+                            callbackArgs = (old, datetime.now()),
+                            errbackArgs = (old, datetime.now(), node, contacted))
+        elif not old and not contacted:
+            # There's room, we just need to contact the node first
+            self.stats.startedAction('ping')
+            df = node.ping(self.node.id)
+            # Convert the returned contact info into a node
+            df.addCallback(self._pongHandler, datetime.now())
+            # Try adding the contacted node
+            df.addCallbacks(self.insertNode, self._pongError,
+                            errbackArgs = (node, datetime.now()))
+
+    def _freshNodeHandler(self, dict, old, start):
+        """Got a pong from the old node, so update it."""
+        self.stats.completedAction('ping', start)
+        if dict['id'] == old.id:
+            self.table.justSeenNode(old.id)
+    
+    def _staleNodeHandler(self, err, old, start, node, contacted):
+        """The pinged node never responded, so replace it."""
+        log.msg("action ping failed on %s/%s: %s" % (old.host, old.port, err.getErrorMessage()))
+        self.stats.completedAction('ping', start)
+        self.table.invalidateNode(old)
+        self.insertNode(node, contacted)
+    
+    def _pongHandler(self, dict, start):
+        """Node responded properly, change response into a node to insert."""
+        self.stats.completedAction('ping', start)
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        return n
 
+    def _pongError(self, err, node, start):
+        """Error occurred, fail node and errback or callback with error."""
+        log.msg("action ping failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('ping', start)
+        self.table.nodeFailed(node)
+    
     def sendJoin(self, node, callback=None, errback=None):
         """Join the DHT by pinging a bootstrap node.
         
@@ -229,28 +266,34 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
-
-        def _pongHandler(dict, node=node, self=self, callback=callback):
-            """Node responded properly, callback with response."""
-            n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
-            self.insertNode(n)
-            if callback:
-                callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
-
-        def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
-            """Error occurred, fail node and errback or callback with error."""
-            table.nodeFailed(node)
-            if errback:
-                errback()
-            elif callback:
-                callback(None)
-        
+        if errback is None:
+            errback = callback
+        self.stats.startedAction('join')
         df = node.join(self.node.id)
-        df.addCallbacks(_pongHandler, _defaultPong)
-
-    def findCloseNodes(self, callback=lambda a: None, errback = None):
+        df.addCallbacks(self._joinHandler, self._joinError,
+                        callbackArgs = (node, datetime.now()),
+                        errbackArgs = (node, datetime.now()))
+        if callback:
+            df.addCallbacks(callback, errback)
+
+    def _joinHandler(self, dict, node, start):
+        """Node responded properly, extract the response."""
+        self.stats.completedAction('join', start)
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        reactor.callLater(0, self.insertNode, n)
+        return (dict['ip_addr'], dict['port'])
+
+    def _joinError(self, err, node, start):
+        """Error occurred, fail node."""
+        log.msg("action join failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('join', start)
+        self.table.nodeFailed(node)
+        return err
+        
+    def findCloseNodes(self, callback=lambda a: None):
         """Perform a findNode on the ID one away from our own.
 
         This will allow us to populate our table with nodes on our network
@@ -261,12 +304,9 @@ class KhashmirBase(protocol.Factory):
         @param callback: the method to call with the results, it must take 1
             parameter, the list of K closest nodes
             (optional, defaults to doing nothing with the results)
-        @type errback: C{method}
-        @param errback: the method to call if an error occurs
-            (optional, defaults to doing nothing when an error occurs)
         """
         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
-        self.findNode(id, callback, errback)
+        self.findNode(id, callback)
 
     def refreshTable(self, force = False):
         """Check all the buckets for those that need refreshing.
@@ -284,17 +324,6 @@ class KhashmirBase(protocol.Factory):
                 id = newIDInRange(bucket.min, bucket.max)
                 self.findNode(id, callback)
 
-    def stats(self):
-        """Collect some statistics about the DHT.
-        
-        @rtype: (C{int}, C{int})
-        @return: the number contacts in our routing table, and the estimated
-            number of nodes in the entire DHT
-        """
-        num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
-        num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
-        return (num_contacts, num_nodes)
-    
     def shutdown(self):
         """Closes the port and cancels pending later calls."""
         self.listenport.stopListening()
@@ -303,9 +332,13 @@ class KhashmirBase(protocol.Factory):
         except:
             pass
         self.store.close()
+    
+    def getStats(self):
+        """Gather the statistics for the DHT."""
+        return self.stats.formatHTML()
 
     #{ Remote interface
-    def krpc_ping(self, id, _krpc_sender):
+    def krpc_ping(self, id, _krpc_sender = None):
         """Pong with our ID.
         
         @type id: C{string}
@@ -313,12 +346,13 @@ class KhashmirBase(protocol.Factory):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
 
         return {"id" : self.node.id}
         
-    def krpc_join(self, id, _krpc_sender):
+    def krpc_join(self, id, _krpc_sender = None):
         """Add the node by responding with its address and port.
         
         @type id: C{string}
@@ -326,12 +360,15 @@ class KhashmirBase(protocol.Factory):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
+        else:
+            _krpc_sender = ('127.0.0.1', self.port)
 
         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
         
-    def krpc_find_node(self, target, id, _krpc_sender):
+    def krpc_find_node(self, id, target, _krpc_sender = None):
         """Find the K closest nodes to the target in the local routing table.
         
         @type target: C{string}
@@ -341,8 +378,11 @@ class KhashmirBase(protocol.Factory):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
+        else:
+            _krpc_sender = ('127.0.0.1', self.port)
 
         nodes = self.table.findNodes(target)
         nodes = map(lambda node: node.contactInfo(), nodes)
@@ -356,7 +396,7 @@ class KhashmirRead(KhashmirBase):
     _Node = KNodeRead
 
     #{ Local interface
-    def findValue(self, key, callback, errback=None):
+    def findValue(self, key, callback):
         """Get the nodes that have values for the key from the global table.
         
         @type key: C{string}
@@ -364,20 +404,12 @@ class KhashmirRead(KhashmirBase):
         @type callback: C{method}
         @param callback: the method to call with the results, it must take 1
             parameter, the list of nodes with values
-        @type errback: C{method}
-        @param errback: the method to call if an error occurs
-            (optional, defaults to doing nothing when an error occurs)
         """
-        # Get K nodes out of local table/cache
-        nodes = self.table.findNodes(key)
-        d = Deferred()
-        if errback:
-            d.addCallbacks(callback, errback)
-        else:
-            d.addCallback(callback)
-
+        # Start with ourself
+        nodes = [copy(self.node)]
+        
         # Search for others starting with the locally found ones
-        state = FindValue(self, key, d.callback, self.config)
+        state = FindValue(self, key, callback, self.config, self.stats)
         reactor.callLater(0, state.goWithNodes, nodes)
 
     def valueForKey(self, key, callback, searchlocal = True):
@@ -394,24 +426,25 @@ class KhashmirRead(KhashmirBase):
         @type searchlocal: C{boolean}
         @param searchlocal: whether to also look for any local values
         """
-        # Get any local values
-        if searchlocal:
-            l = self.store.retrieveValues(key)
-            if len(l) > 0:
-                reactor.callLater(0, callback, key, l)
-        else:
-            l = []
 
-        def _getValueForKey(nodes, key=key, local_values=l, response=callback, self=self):
+        def _getValueForKey(nodes, key=key, response=callback, self=self, searchlocal=searchlocal):
             """Use the found nodes to send requests for values to."""
-            state = GetValue(self, key, local_values, self.config['RETRIEVE_VALUES'], response, self.config)
+            # Get any local values
+            if searchlocal:
+                l = self.store.retrieveValues(key)
+                if len(l) > 0:
+                    node = copy(self.node)
+                    node.updateNumValues(len(l))
+                    nodes = nodes + [node]
+
+            state = GetValue(self, key, self.config['RETRIEVE_VALUES'], response, self.config, self.stats)
             reactor.callLater(0, state.goWithNodes, nodes)
             
         # First lookup nodes that have values for the key
         self.findValue(key, _getValueForKey)
 
     #{ Remote interface
-    def krpc_find_value(self, key, id, _krpc_sender):
+    def krpc_find_value(self, id, key, _krpc_sender = None):
         """Find the number of values stored locally for the key, and the K closest nodes.
         
         @type key: C{string}
@@ -421,15 +454,16 @@ class KhashmirRead(KhashmirBase):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
     
         nodes = self.table.findNodes(key)
         nodes = map(lambda node: node.contactInfo(), nodes)
         num_values = self.store.countValues(key)
         return {'nodes' : nodes, 'num' : num_values, "id": self.node.id}
 
-    def krpc_get_value(self, key, num, id, _krpc_sender):
+    def krpc_get_value(self, id, key, num, _krpc_sender = None):
         """Retrieve the values stored locally for the key.
         
         @type key: C{string}
@@ -442,8 +476,9 @@ class KhashmirRead(KhashmirBase):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
     
         l = self.store.retrieveValues(key)
         if num == 0 or num >= len(l):
@@ -481,14 +516,14 @@ class KhashmirWrite(KhashmirRead):
                     """Default callback that does nothing."""
                     pass
                 response = _storedValueHandler
-            action = StoreValue(self, key, value, self.config['STORE_REDUNDANCY'], response, self.config)
+            action = StoreValue(self, key, value, self.config['STORE_REDUNDANCY'], response, self.config, self.stats)
             reactor.callLater(0, action.goWithNodes, nodes)
             
         # First find the K closest nodes to operate on.
         self.findNode(key, _storeValueForKey)
                     
     #{ Remote interface
-    def krpc_store_value(self, key, value, token, id, _krpc_sender):
+    def krpc_store_value(self, id, key, value, token, _krpc_sender = None):
         """Store the value locally with the key.
         
         @type key: C{string}
@@ -501,8 +536,12 @@ class KhashmirWrite(KhashmirRead):
         @type _krpc_sender: (C{string}, C{int})
         @param _krpc_sender: the sender node's IP address and port
         """
-        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-        self.insertNode(n, contacted = False)
+        if _krpc_sender is not None:
+            n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+            reactor.callLater(0, self.insertNode, n, False)
+        else:
+            _krpc_sender = ('127.0.0.1', self.port)
+
         for secret in self.token_secrets:
             this_token = sha(secret + _krpc_sender[0]).digest()
             if token == this_token:
@@ -519,12 +558,13 @@ class Khashmir(KhashmirWrite):
 class SimpleTests(unittest.TestCase):
     
     timeout = 10
-    DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+    DHT_DEFAULTS = {'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def setUp(self):
         d = self.DHT_DEFAULTS.copy()
@@ -542,10 +582,10 @@ class SimpleTests(unittest.TestCase):
 
     def testAddContact(self):
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 0)
 
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 0)
 
         self.a.addContact('127.0.0.1', 4045)
         reactor.iterate()
@@ -554,9 +594,9 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
 
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 1)
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 1)
 
     def testStoreRetrieve(self):
         self.a.addContact('127.0.0.1', 4045)
@@ -592,12 +632,13 @@ class MultiTest(unittest.TestCase):
     
     timeout = 30
     num = 20
-    DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+    DHT_DEFAULTS = {'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def _done(self, val):
         self.done = 1