]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_p2p_Khashmir/khashmir.py
Use python-debian's new debian package instead of debian_bundle
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / khashmir.py
index 5db5ed8ab48015c9c73073f9c8ceab8eac652701..13bfb5a23352f91d603168945d54d614c351fe3f 100644 (file)
@@ -1,7 +1,9 @@
-## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
-# see LICENSE.txt for license information
 
 
-"""The main Khashmir program."""
+"""The main Khashmir program.
+
+@var isLocal: a compiled regular expression suitable for testing if an
+    IP address is from a known local or private range
+"""
 
 import warnings
 warnings.simplefilter("ignore", DeprecationWarning)
 
 import warnings
 warnings.simplefilter("ignore", DeprecationWarning)
@@ -10,9 +12,10 @@ from datetime import datetime, timedelta
 from random import randrange, shuffle
 from sha import sha
 from copy import copy
 from random import randrange, shuffle
 from sha import sha
 from copy import copy
-import os
+import os, re
 
 from twisted.internet.defer import Deferred
 
 from twisted.internet.defer import Deferred
+from twisted.internet.base import DelayedCall
 from twisted.internet import protocol, reactor
 from twisted.python import log
 from twisted.trial import unittest
 from twisted.internet import protocol, reactor
 from twisted.python import log
 from twisted.trial import unittest
@@ -25,6 +28,13 @@ from actions import FindNode, FindValue, GetValue, StoreValue
 from stats import StatsLogger
 import krpc
 
 from stats import StatsLogger
 import krpc
 
+isLocal = re.compile('^(192\.168\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(10\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?1[6-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?2[0-9]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(172\.0?3[0-1]\.[0-9]{1,3}\.[0-9]{1,3})|'+
+                     '(127\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3})$')
+
 class KhashmirBase(protocol.Factory):
     """The base Khashmir class, with base functionality and find node, no key-value mappings.
     
 class KhashmirBase(protocol.Factory):
     """The base Khashmir class, with base functionality and find node, no key-value mappings.
     
@@ -32,6 +42,9 @@ class KhashmirBase(protocol.Factory):
     @ivar _Node: the knode implementation to use for this class of DHT
     @type config: C{dictionary}
     @ivar config: the configuration parameters for the DHT
     @ivar _Node: the knode implementation to use for this class of DHT
     @type config: C{dictionary}
     @ivar config: the configuration parameters for the DHT
+    @type pinging: C{dictionary}
+    @ivar pinging: the node's that are currently being pinged, keys are the
+        node id's, values are the Deferred or DelayedCall objects
     @type port: C{int}
     @ivar port: the port to listen on
     @type store: L{db.DB}
     @type port: C{int}
     @ivar port: the port to listen on
     @type store: L{db.DB}
@@ -64,6 +77,7 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to the /tmp directory)
         """
         self.config = None
             (optional, defaults to the /tmp directory)
         """
         self.config = None
+        self.pinging = {}
         self.setup(config, cache_dir)
         
     def setup(self, config, cache_dir):
         self.setup(config, cache_dir)
         
     def setup(self, config, cache_dir):
@@ -109,8 +123,8 @@ class KhashmirBase(protocol.Factory):
     def _loadSelfNode(self, host, port):
         """Create this node, loading any previously saved one."""
         id = self.store.getSelfNode()
     def _loadSelfNode(self, host, port):
         """Create this node, loading any previously saved one."""
         id = self.store.getSelfNode()
-        if not id:
-            id = newID()
+        if not id or not id.endswith(self.config['VERSION']):
+            id = newID(self.config['VERSION'])
         return self._Node(id, host, port)
         
     def checkpoint(self):
         return self._Node(id, host, port)
         
     def checkpoint(self):
@@ -157,7 +171,7 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
         n = self.Node(NULL_ID, host, port)
         self.sendJoin(n, callback=callback, errback=errback)
         """
         n = self.Node(NULL_ID, host, port)
         self.sendJoin(n, callback=callback, errback=errback)
@@ -171,6 +185,9 @@ class KhashmirBase(protocol.Factory):
         @param callback: the method to call with the results, it must take 1
             parameter, the list of K closest nodes
         """
         @param callback: the method to call with the results, it must take 1
             parameter, the list of K closest nodes
         """
+        # Mark the bucket as having been accessed
+        self.table.touch(id)
+        
         # Start with our node
         nodes = [copy(self.node)]
 
         # Start with our node
         nodes = [copy(self.node)]
 
@@ -183,7 +200,7 @@ class KhashmirBase(protocol.Factory):
         
         If all you have is a host/port, then use L{addContact}, which calls this
         method after receiving the PONG from the remote node. The reason for
         
         If all you have is a host/port, then use L{addContact}, which calls this
         method after receiving the PONG from the remote node. The reason for
-        the seperation is we can't insert a node into the table without its
+        the separation is we can't insert a node into the table without its
         node ID. That means of course the node passed into this method needs
         to be a properly formed Node object with a valid ID.
 
         node ID. That means of course the node passed into this method needs
         to be a properly formed Node object with a valid ID.
 
@@ -193,28 +210,92 @@ class KhashmirBase(protocol.Factory):
         @param contacted: whether the new node is known to be good, i.e.
             responded to a request (optional, defaults to True)
         """
         @param contacted: whether the new node is known to be good, i.e.
             responded to a request (optional, defaults to True)
         """
+        # Don't add any local nodes to the routing table
+        if not self.config['LOCAL_OK'] and isLocal.match(node.host):
+            log.msg('Not adding local node to table: %s/%s' % (node.host, node.port))
+            return
+        
         old = self.table.insertNode(node, contacted=contacted)
         old = self.table.insertNode(node, contacted=contacted)
-        if (old and old.id != self.node.id and
+
+        if (isinstance(old, self._Node) and old.id != self.node.id and
             (datetime.now() - old.lastSeen) > 
              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             
             (datetime.now() - old.lastSeen) > 
              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             
-            def _staleNodeHandler(err, oldnode = old, newnode = node, self = self, start = datetime.now()):
-                """The pinged node never responded, so replace it."""
-                log.msg("action ping failed on %s/%s: %s" % (oldnode.host, oldnode.port, err.getErrorMessage()))
-                self.stats.completedAction('ping', start)
-                self.table.replaceStaleNode(oldnode, newnode)
-            
-            def _notStaleNodeHandler(dict, old = old, self = self, start = datetime.now()):
-                """Got a pong from the old node, so update it."""
-                self.stats.completedAction('ping', start)
-                if dict['id'] == old.id:
-                    self.table.justSeenNode(old.id)
-            
             # Bucket is full, check to see if old node is still available
             # Bucket is full, check to see if old node is still available
-            self.stats.startedAction('ping')
-            df = old.ping(self.node.id)
-            df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
+            df = self.sendPing(old)
+            df.addErrback(self._staleNodeHandler, old, node, contacted)
+        elif not old and not contacted:
+            # There's room, we just need to contact the node first
+            df = self.sendPing(node)
+            # Also schedule a future ping to make sure the node works
+            def rePing(newnode, self = self):
+                if newnode.id not in self.pinging:
+                    self.pinging[newnode.id] = reactor.callLater(self.config['MIN_PING_INTERVAL'],
+                                                                 self.sendPing, newnode)
+                return newnode
+            df.addCallback(rePing)
+
+    def _staleNodeHandler(self, err, old, node, contacted):
+        """The pinged node never responded, so replace it."""
+        self.table.invalidateNode(old)
+        self.insertNode(node, contacted)
+        return err
+    
+    def nodeFailed(self, node):
+        """Mark a node as having failed a request and schedule a future check.
+        
+        @type node: L{node.Node}
+        @param node: the new node to try and insert
+        """
+        exists = self.table.nodeFailed(node)
+        
+        # If in the table, schedule a ping, if one isn't already sent/scheduled
+        if exists and node.id not in self.pinging:
+            self.pinging[node.id] = reactor.callLater(self.config['MIN_PING_INTERVAL'],
+                                                      self.sendPing, node)
+    
+    def sendPing(self, node):
+        """Ping the node to see if it's still alive.
+        
+        @type node: L{node.Node}
+        @param node: the node to send the join to
+        """
+        # Check for a ping already underway
+        if (isinstance(self.pinging.get(node.id, None), DelayedCall) and
+            self.pinging[node.id].active()):
+            self.pinging[node.id].cancel()
+        elif isinstance(self.pinging.get(node.id, None), Deferred):
+            return self.pinging[node.id]
+
+        self.stats.startedAction('ping')
+        df = node.ping(self.node.id)
+        self.pinging[node.id] = df
+        df.addCallbacks(self._pingHandler, self._pingError,
+                        callbackArgs = (node, datetime.now()),
+                        errbackArgs = (node, datetime.now()))
+        return df
+
+    def _pingHandler(self, dict, node, start):
+        """Node responded properly, update it and return the node object."""
+        self.stats.completedAction('ping', start)
+        del self.pinging[node.id]
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        reactor.callLater(0, self.insertNode, n)
+        return n
 
 
+    def _pingError(self, err, node, start):
+        """Error occurred, fail node."""
+        log.msg("action ping failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('ping', start)
+        
+        # Consume unhandled errors
+        self.pinging[node.id].addErrback(lambda ping_err: None)
+        del self.pinging[node.id]
+        
+        self.nodeFailed(node)
+        return err
+        
     def sendJoin(self, node, callback=None, errback=None):
         """Join the DHT by pinging a bootstrap node.
         
     def sendJoin(self, node, callback=None, errback=None):
         """Join the DHT by pinging a bootstrap node.
         
@@ -226,31 +307,33 @@ class KhashmirBase(protocol.Factory):
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
             (optional, defaults to doing nothing with the results)
         @type errback: C{method}
         @param errback: the method to call if an error occurs
-            (optional, defaults to calling the callback with None)
+            (optional, defaults to calling the callback with the error)
         """
         """
-
-        def _pongHandler(dict, node=node, self=self, callback=callback, start = datetime.now()):
-            """Node responded properly, callback with response."""
-            n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
-            self.stats.completedAction('join', start)
-            reactor.callLater(0, self.insertNode, n)
-            if callback:
-                callback((dict['ip_addr'], dict['port']))
-
-        def _defaultPong(err, node=node, self=self, callback=callback, errback=errback, start = datetime.now()):
-            """Error occurred, fail node and errback or callback with error."""
-            log.msg("action join failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
-            self.stats.completedAction('join', start)
-            self.table.nodeFailed(node)
-            if errback:
-                errback()
-            elif callback:
-                callback(None)
-        
+        if errback is None:
+            errback = callback
         self.stats.startedAction('join')
         df = node.join(self.node.id)
         self.stats.startedAction('join')
         df = node.join(self.node.id)
-        df.addCallbacks(_pongHandler, _defaultPong)
-
+        df.addCallbacks(self._joinHandler, self._joinError,
+                        callbackArgs = (node, datetime.now()),
+                        errbackArgs = (node, datetime.now()))
+        if callback:
+            df.addCallbacks(callback, errback)
+
+    def _joinHandler(self, dict, node, start):
+        """Node responded properly, extract the response."""
+        self.stats.completedAction('join', start)
+        # Create the node using the returned contact info
+        n = self.Node(dict['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+        reactor.callLater(0, self.insertNode, n)
+        return (dict['ip_addr'], dict['port'])
+
+    def _joinError(self, err, node, start):
+        """Error occurred, fail node."""
+        log.msg("action join failed on %s/%s: %s" % (node.host, node.port, err.getErrorMessage()))
+        self.stats.completedAction('join', start)
+        self.nodeFailed(node)
+        return err
+        
     def findCloseNodes(self, callback=lambda a: None):
         """Perform a findNode on the ID one away from our own.
 
     def findCloseNodes(self, callback=lambda a: None):
         """Perform a findNode on the ID one away from our own.
 
@@ -289,6 +372,10 @@ class KhashmirBase(protocol.Factory):
             self.next_checkpoint.cancel()
         except:
             pass
             self.next_checkpoint.cancel()
         except:
             pass
+        for nodeid in self.pinging.keys():
+            if isinstance(self.pinging[nodeid], DelayedCall) and self.pinging[nodeid].active():
+                self.pinging[nodeid].cancel()
+                del self.pinging[nodeid]
         self.store.close()
     
     def getStats(self):
         self.store.close()
     
     def getStats(self):
@@ -363,6 +450,9 @@ class KhashmirRead(KhashmirBase):
         @param callback: the method to call with the results, it must take 1
             parameter, the list of nodes with values
         """
         @param callback: the method to call with the results, it must take 1
             parameter, the list of nodes with values
         """
+        # Mark the bucket as having been accessed
+        self.table.touch(key)
+        
         # Start with ourself
         nodes = [copy(self.node)]
         
         # Start with ourself
         nodes = [copy(self.node)]
         
@@ -516,13 +606,13 @@ class Khashmir(KhashmirWrite):
 class SimpleTests(unittest.TestCase):
     
     timeout = 10
 class SimpleTests(unittest.TestCase):
     
     timeout = 10
-    DHT_DEFAULTS = {'PORT': 9977,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
-                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+    DHT_DEFAULTS = {'VERSION': 'A000', 'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
+                    'STORE_REDUNDANCY': 6, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KRPC_TIMEOUT': 14, 'KRPC_INITIAL_DELAY': 2,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def setUp(self):
         d = self.DHT_DEFAULTS.copy()
 
     def setUp(self):
         d = self.DHT_DEFAULTS.copy()
@@ -540,21 +630,25 @@ class SimpleTests(unittest.TestCase):
 
     def testAddContact(self):
         self.failUnlessEqual(len(self.a.table.buckets), 1)
 
     def testAddContact(self):
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 0)
 
         self.failUnlessEqual(len(self.b.table.buckets), 1)
 
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 0)
 
         self.a.addContact('127.0.0.1', 4045)
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
 
         self.a.addContact('127.0.0.1', 4045)
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
 
         self.failUnlessEqual(len(self.a.table.buckets), 1)
 
         self.failUnlessEqual(len(self.a.table.buckets), 1)
-        self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.a.table.buckets[0].nodes), 1)
         self.failUnlessEqual(len(self.b.table.buckets), 1)
         self.failUnlessEqual(len(self.b.table.buckets), 1)
-        self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.b.table.buckets[0].nodes), 1)
 
     def testStoreRetrieve(self):
         self.a.addContact('127.0.0.1', 4045)
 
     def testStoreRetrieve(self):
         self.a.addContact('127.0.0.1', 4045)
@@ -578,6 +672,20 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
 
     def _cb(self, key, val):
         if not val:
 
     def _cb(self, key, val):
         if not val:
@@ -590,13 +698,13 @@ class MultiTest(unittest.TestCase):
     
     timeout = 30
     num = 20
     
     timeout = 30
     num = 20
-    DHT_DEFAULTS = {'PORT': 9977,
-                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
-                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
-                    'MAX_FAILURES': 3,
+    DHT_DEFAULTS = {'VERSION': 'A000', 'PORT': 9977,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 8,
+                    'STORE_REDUNDANCY': 6, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3, 'LOCAL_OK': True,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
-                    'KRPC_TIMEOUT': 14, 'KRPC_INITIAL_DELAY': 2,
-                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+                    'KRPC_TIMEOUT': 9, 'KRPC_INITIAL_DELAY': 2,
+                    'KEY_EXPIRE': 3600, 'SPEW': True, }
 
     def _done(self, val):
         self.done = 1
 
     def _done(self, val):
         self.done = 1
@@ -618,6 +726,9 @@ class MultiTest(unittest.TestCase):
             reactor.iterate()
             reactor.iterate()
             reactor.iterate() 
             reactor.iterate()
             reactor.iterate()
             reactor.iterate() 
+            reactor.iterate()
+            reactor.iterate()
+            reactor.iterate() 
             
         for i in self.l:
             self.done = 0
             
         for i in self.l:
             self.done = 0