Use the version number in the Khashmir node ID.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / khashmir.py
index 580f0f0e16b5482ef0ee1149b41135c325a8a4b5..f711a6a62a315964a339b03085d809079b0a2b6d 100644 (file)
@@ -1,7 +1,9 @@
-## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
-# see LICENSE.txt for license information
 
-"""The main Khashmir program."""
+"""The main Khashmir program.
+
+@var isLocal: a compiled regular expression suitable for testing if an
+    IP address is from a known local or private range
+"""
 
 import warnings
 warnings.simplefilter("ignore", DeprecationWarning)
@@ -10,9 +12,10 @@ from datetime import datetime, timedelta
 from random import randrange, shuffle
 from sha import sha
 from copy import copy
-import os
+import os, re
 
 from twisted.internet.defer import Deferred
+from twisted.internet.base import DelayedCall
 from twisted.internet import protocol, reactor
 from twisted.python import log
 from twisted.trial import unittest
@@ -25,6 +28,13 @@ from actions import FindNode, FindValue, GetValue, StoreValue
 from stats import StatsLogger
 import krpc
 
+isLocal = re.compile('^(192\.168\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(10\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?1[6-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?2[0-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?3[0-1]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(127\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})$')
+
 class KhashmirBase(protocol.Factory):
     """The base Khashmir class, with base functionality and find node, no key-value mappings.
     
@@ -32,6 +42,9 @@ class KhashmirBase(protocol.Factory):
     @ivar _Node: the knode implementation to use for this class of DHT
     @type config: C{dictionary}
     @ivar config: the configuration parameters for the DHT
+    @type pinging: C{dictionary}
+    @ivar pinging: the node's that are currently being pinged, keys are the
+        node id's, values are the Deferred or DelayedCall objects
     @type port: C{int}
     @ivar port: the port to listen on
     @type store: L{db.DB}
@@ -64,6 +77,7 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to the /tmp directory)
         """
         self.config = None
+        self.pinging = {}
         self.setup(config, cache_dir)
         
     def setup(self, config, cache_dir):
@@ -109,8 +123,8 @@ class KhashmirBase(protocol.Factory):
     def _loadSelfNode(self, host, port):
         """Create this node, loading any previously saved one."""
         id = self.store.getSelfNode()
-        if not id:
-            id = newID()
+        if not id or not id.endswith(self.config['VERSION']):
+            id = newID(self.config['VERSION'])
         return self._Node(id, host, port)
         
     def checkpoint(self):
@@ -157,7 +171,7 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
         n = self.Node(NULL_ID, host, port)
         self.sendJoin(n, callback=callback, errback=errback)
@@ -171,6 +185,9 @@ class KhashmirBase(protocol.Factory):
         @param callback: the method to call with the results, it must take 1
             parameter, the list of K closest nodes
         """
+        # Mark the bucket as having been accessed
+        self.table.touch(id)
+        
         # Start with our node
         nodes = [copy(self.node)]
 
@@ -183,7 +200,7 @@ class KhashmirBase(protocol.Factory):
         
         If all you have is a host/port, then use L{addContact}, which calls this
         method after receiving the PONG from the remote node. The reason for
-        the seperation is we can't insert a node into the table without its
+        the separation is we can't insert a node into the table without its
         node ID. That means of course the node passed into this method needs
         to be a properly formed Node object with a valid ID.
 
@@ -193,29 +210,88 @@ class KhashmirBase(protocol.Factory):
         @param contacted: whether the new node is known to be good, i.e.
             responded to a request (optional, defaults to True)
         """
+        # Don't add any local nodes to the routing table
+        if not self.config['LOCAL_OK'] and isLocal.match(node.host):
+            log.msg('Not adding local node to table: %s/%s' % (node.host, node.port))
+            return
+        
         old = self.table.insertNode(node, contacted=contacted)
-        if (old and old.id != self.node.id and
+
+        if (isinstance(old, self._Node) and old.id != self.node.id and
             (datetime.now() - old.lastSeen) > 
              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             
-            def _staleNodeHandler(err, oldnode = old, newnode = node, self = self, start = datetime.now()):
-                """The pinged node never responded, so replace it."""
-                log.msg("ping failed (%s) %s/%s" % (self.config['PORT'], oldnode.host, oldnode.port))
-                log.err(err)
-                self.stats.completedAction('ping', start)
-                self.table.replaceStaleNode(oldnode, newnode)
-            
-            def _notStaleNodeHandler(dict, old = old, self = self, start = datetime.now()):
-                """Got a pong from the old node, so update it."""
-                self.stats.completedAction('ping', start)
-                if dict['id'] == old.id:
-                    self.table.justSeenNode(old.id)
-            
             # Bucket is full, check to see if old node is still available
-            self.stats.startedAction('ping')
-            df = old.ping(self.node.id)
-            df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
+            df = self.sendPing(old)
+            df.addErrback(self._staleNodeHandler, old, node, contacted)
+        elif not old and not contacted:
+            # There's room, we just need to contact the node first
+            df = self.sendPing(node)
+            # Also schedule a future ping to make sure the node works
+            def rePing(newnode, self = self):
+                if newnode.id not in self.pinging:
+                    self.pinging[newnode.id] = reactor.callLater(self.config['MIN_PING_INTERVAL'],
+                                                                 self.sendPing, newnode)
+                return newnode
+            df.addCallback(rePing)
+
+    def _staleNodeHandler(self, err, old, node, contacted):
+        """The pinged node never responded, so replace it."""
+        self.table.invalidateNode(old)
+        self.insertNode(node, contacted)
+        return err
+    
+    def nodeFailed(self, node):
+        """Mark a node as having failed a request and schedule a future check.
+        
+        @type node: L{node.Node}
+        @param node: the new node to try and insert
+        """
+        exists = self.table.nodeFailed(node)
+        
+        # If in the table, schedule a ping, if one isn't already sent/scheduled
+        if exists and node.id not in self.pinging:
+            self.pinging[node.id] = reactor.callLater(self.config['MIN_PING_INTERVAL'],
+                                                      self.sendPing, node)
+    
+    def sendPing(self, node):
+        """Ping the node to see if it's still alive.
+        
+        @type node: L{node.Node}
+        @param node: the node to send the join to
+        """
+        # Check for a ping already underway
+        if (isinstance(self.pinging.get(node.id, None), DelayedCall) and
+            self.pinging[node.id].active()):
+            self.pinging[node.id].cancel()
+        elif isinstance(self.pinging.get(node.id, None), Deferred):
+            return self.pinging[node.id]
+
+        self.stats.startedAction('ping')
+        df = node.ping(self.node.id)
+        self.pinging[node.id] = df
+        df.addCallbacks(self._pingHandler, self._pingError,
+                        callbackArgs = (node, datetime.now()),
+                        errbackArgs = (node, datetime.now()))
+        return df
+
+    def _pingHandler(self, dict, node, start):
+        """Node responded properly, update it and return the node object."""
+        self.stats.completedAction('ping', start)
+        del self.pinging[node.id]
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        reactor.callLater(0, self.insertNode, n)
+        return n
 
+    def _pingError(self, err, node, start):
+        """Error occurred, fail node."""
+        log.msg("action ping failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('ping', start)
+        del self.pinging[node.id]
+        self.nodeFailed(node)
+        return err
+        
     def sendJoin(self, node, callback=None, errback=None):
         """Join the DHT by pinging a bootstrap node.
         
@@ -227,32 +303,33 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
-
-        def _pongHandler(dict, node=node, self=self, callback=callback, start = datetime.now()):
-            """Node responded properly, callback with response."""
-            n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
-            self.stats.completedAction('join', start)
-            self.insertNode(n)
-            if callback:
-                callback((dict['ip_addr'], dict['port']))
-
-        def _defaultPong(err, node=node, self=self, callback=callback, errback=errback, start = datetime.now()):
-            """Error occurred, fail node and errback or callback with error."""
-            log.msg("join failed (%s) %s/%s" % (self.config['PORT'], node.host, node.port))
-            log.err(err)
-            self.stats.completedAction('join', start)
-            self.table.nodeFailed(node)
-            if errback:
-                errback()
-            elif callback:
-                callback(None)
-        
+        if errback is None:
+            errback = callback
         self.stats.startedAction('join')
         df = node.join(self.node.id)
-        df.addCallbacks(_pongHandler, _defaultPong)
-
+        df.addCallbacks(self._joinHandler, self._joinError,
+                        callbackArgs = (node, datetime.now()),
+                        errbackArgs = (node, datetime.now()))
+        if callback:
+            df.addCallbacks(callback, errback)
+
+    def _joinHandler(self, dict, node, start):
+        """Node responded properly, extract the response."""
+        self.stats.completedAction('join', start)
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        reactor.callLater(0, self.insertNode, n)
+        return (dict['ip_addr'], dict['port'])
+
+    def _joinError(self, err, node, start):
+        """Error occurred, fail node."""
+        log.msg("action join failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('join', start)
+        self.nodeFailed(node)
+        return err
+        
     def findCloseNodes(self, callback=lambda a: None):
         """Perform a findNode on the ID one away from our own.
 
@@ -291,6 +368,10 @@ class KhashmirBase(protocol.Factory):
             self.next_checkpoint.cancel()
         except:
             pass
+        for nodeid in self.pinging.keys():
+            if isinstance(self.pinging[nodeid], DelayedCall) and self.pinging[nodeid].active():
+                self.pinging[nodeid].cancel()
+                del self.pinging[nodeid]
         self.store.close()
     
     def getStats(self):
@@ -308,7 +389,7 @@ class KhashmirBase(protocol.Factory):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
 
         return {"id" : self.node.id}
         
@@ -322,7 +403,7 @@ class KhashmirBase(protocol.Factory):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
         else:
             _krpc_sender = ('127.0.0.1', self.port)
 
@@ -340,7 +421,7 @@ class KhashmirBase(protocol.Factory):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
         else:
             _krpc_sender = ('127.0.0.1', self.port)
 
@@ -365,6 +446,9 @@ class KhashmirRead(KhashmirBase):
         @param callback: the method to call with the results, it must take 1
             parameter, the list of nodes with values
         """
+        # Mark the bucket as having been accessed
+        self.table.touch(key)
+        
         # Start with ourself
         nodes = [copy(self.node)]
         
@@ -416,7 +500,7 @@ class KhashmirRead(KhashmirBase):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
     
         nodes = self.table.findNodes(key)
         nodes = map(lambda node: node.contactInfo(), nodes)
@@ -438,7 +522,7 @@ class KhashmirRead(KhashmirBase):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
     
         l = self.store.retrieveValues(key)
         if num == 0 or num >= len(l):
@@ -498,7 +582,7 @@ class KhashmirWrite(KhashmirRead):
         """
         if _krpc_sender is not None:
             n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
-            self.insertNode(n, contacted = False)
+            reactor.callLater(0, self.insertNode, n, False)
         else:
             _krpc_sender = ('127.0.0.1', self.port)
 
@@ -518,12 +602,13 @@ class Khashmir(KhashmirWrite):
 class SimpleTests(unittest.TestCase):
     
     timeout = 10
-    DHT_DEFAULTS = {'PORT': 9977,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
-                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+    DHT_DEFAULTS = {'VERSION': 'A000', 'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
+                    'STORE_REDUNDANCY': 6, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def setUp(self):
         d = self.DHT_DEFAULTS.copy()
@@ -541,21 +626,25 @@ class SimpleTests(unittest.TestCase):
 
     def testAddContact(self):
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 0)
 
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 0)
 
         self.a.addContact('127.0.0.1', 4045)
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
 
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 1)
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 1)
 
     def testStoreRetrieve(self):
         self.a.addContact('127.0.0.1', 4045)
@@ -579,6 +668,20 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
 
     def _cb(self, key, val):
         if not val:
@@ -591,12 +694,13 @@ class MultiTest(unittest.TestCase):
     
     timeout = 30
     num = 20
-    DHT_DEFAULTS = {'PORT': 9977,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
-                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+    DHT_DEFAULTS = {'VERSION': 'A000', 'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
+                    'STORE_REDUNDANCY': 6, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def _done(self, val):
         self.done = 1
@@ -618,6 +722,9 @@ class MultiTest(unittest.TestCase):
             reactor.iterate()
             reactor.iterate()
             reactor.iterate() 
+            reactor.iterate()
+            reactor.iterate()
+            reactor.iterate() 
             
         for i in self.l:
             self.done = 0