Standardize the number of values retrieved from the DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
index 0196fd228a27342f2abf501a71d0eb74f1348c95..eeaab0acf638215dc2ec707d6c0d62daff48d04c 100644 (file)
 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
 # see LICENSE.txt for license information
 
-from time import time
-from random import randrange
-import sqlite  ## find this at http://pysqlite.sourceforge.net/
+import warnings
+warnings.simplefilter("ignore", DeprecationWarning)
+
+from datetime import datetime, timedelta
+from random import randrange, shuffle
+from sha import sha
+import os
 
 from twisted.internet.defer import Deferred
-from twisted.internet import protocol
-from twisted.internet import reactor
+from twisted.internet import protocol, reactor
+from twisted.trial import unittest
 
-import const
+from db import DB
 from ktable import KTable
-from knode import KNodeBase, KNodeRead, KNodeWrite
+from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
 from khash import newID, newIDInRange
-from actions import FindNode, GetValue, KeyExpirer, StoreValue
+from actions import FindNode, FindValue, GetValue, StoreValue
 import krpc
 
-class KhashmirDBExcept(Exception):
-    pass
-
 # this is the base class, has base functionality and find node, no key-value mappings
 class KhashmirBase(protocol.Factory):
     _Node = KNodeBase
-    def __init__(self, host, port, db='khashmir.db'):
-        self.setup(host, port, db)
+    def __init__(self, config, cache_dir='/tmp'):
+        self.config = None
+        self.setup(config, cache_dir)
         
-    def setup(self, host, port, db='khashmir.db'):
-        self._findDB(db)
-        self.port = port
-        self.node = self._loadSelfNode(host, port)
-        self.table = KTable(self.node)
+    def setup(self, config, cache_dir):
+        self.config = config
+        self.port = config['PORT']
+        self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
+        self.node = self._loadSelfNode('', self.port)
+        self.table = KTable(self.node, config)
+        self.token_secrets = [newID()]
         #self.app = service.Application("krpc")
-        self.udp = krpc.hostbroker(self)
+        self.udp = krpc.hostbroker(self, config)
         self.udp.protocol = krpc.KRPC
-        self.listenport = reactor.listenUDP(port, self.udp)
-        self.last = time()
+        self.listenport = reactor.listenUDP(self.port, self.udp)
         self._loadRoutingTable()
-        KeyExpirer(store=self.store)
         self.refreshTable(force=1)
-        reactor.callLater(60, self.checkpoint, (1,))
+        self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
 
-    def Node(self):
-        n = self._Node()
+    def Node(self, id, host = None, port = None):
+        """Create a new node."""
+        n = self._Node(id, host, port)
         n.table = self.table
+        n.conn = self.udp.connectionForAddr((n.host, n.port))
         return n
     
     def __del__(self):
         self.listenport.stopListening()
         
     def _loadSelfNode(self, host, port):
-        c = self.store.cursor()
-        c.execute('select id from self where num = 0;')
-        if c.rowcount > 0:
-            id = c.fetchone()[0]
-        else:
+        id = self.store.getSelfNode()
+        if not id:
             id = newID()
-        return self._Node().init(id, host, port)
-        
-    def _saveSelfNode(self):
-        c = self.store.cursor()
-        c.execute('delete from self where num = 0;')
-        c.execute("insert into self values (0, %s);", sqlite.encode(self.node.id))
-        self.store.commit()
+        return self._Node(id, host, port)
         
     def checkpoint(self, auto=0):
-        self._saveSelfNode()
-        self._dumpRoutingTable()
+        self.token_secrets.insert(0, newID())
+        if len(self.token_secrets) > 3:
+            self.token_secrets.pop()
+        self.store.saveSelfNode(self.node.id)
+        self.store.dumpRoutingTable(self.table.buckets)
+        self.store.expireValues(self.config['KEY_EXPIRE'])
         self.refreshTable()
         if auto:
-            reactor.callLater(randrange(int(const.CHECKPOINT_INTERVAL * .9), int(const.CHECKPOINT_INTERVAL * 1.1)), self.checkpoint, (1,))
-        
-    def _findDB(self, db):
-        import os
-        try:
-            os.stat(db)
-        except OSError:
-            self._createNewDB(db)
-        else:
-            self._loadDB(db)
-        
-    def _loadDB(self, db):
-        try:
-            self.store = sqlite.connect(db=db)
-            #self.store.autocommit = 0
-        except:
-            import traceback
-            raise KhashmirDBExcept, "Couldn't open DB", traceback.format_exc()
-        
-    def _createNewDB(self, db):
-        self.store = sqlite.connect(db=db)
-        s = """
-            create table kv (key binary, value binary, time timestamp, primary key (key, value));
-            create index kv_key on kv(key);
-            create index kv_timestamp on kv(time);
-            
-            create table nodes (id binary primary key, host text, port number);
-            
-            create table self (num number primary key, id binary);
-            """
-        c = self.store.cursor()
-        c.execute(s)
-        self.store.commit()
-
-    def _dumpRoutingTable(self):
-        """
-            save routing table nodes to the database
-        """
-        c = self.store.cursor()
-        c.execute("delete from nodes where id not NULL;")
-        for bucket in self.table.buckets:
-            for node in bucket.l:
-                c.execute("insert into nodes values (%s, %s, %s);", (sqlite.encode(node.id), node.host, node.port))
-        self.store.commit()
+            self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
+                                        int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
+                              self.checkpoint, (1,))
         
     def _loadRoutingTable(self):
         """
             load routing table nodes from database
             it's usually a good idea to call refreshTable(force=1) after loading the table
         """
-        c = self.store.cursor()
-        c.execute("select * from nodes;")
-        for rec in c.fetchall():
-            n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
-            n.conn = self.udp.connectionForAddr((n.host, n.port))
+        nodes = self.store.getRoutingTable()
+        for rec in nodes:
+            n = self.Node(rec[0], rec[1], int(rec[2]))
             self.table.insertNode(n, contacted=0)
             
 
     #######
     #######  LOCAL INTERFACE    - use these methods!
-    def addContact(self, host, port, callback=None):
+    def addContact(self, host, port, callback=None, errback=None):
         """
             ping this node and add the contact info to the table on pong!
         """
-        n =self.Node().init(const.NULL_ID, host, port) 
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
-        self.sendPing(n, callback=callback)
+        n = self.Node(NULL_ID, host, port)
+        self.sendJoin(n, callback=callback, errback=errback)
 
     ## this call is async!
     def findNode(self, id, callback, errback=None):
@@ -150,7 +105,7 @@ class KhashmirBase(protocol.Factory):
             d.callback(nodes)
         else:
             # create our search state
-            state = FindNode(self, id, d.callback)
+            state = FindNode(self, id, d.callback, self.config)
             reactor.callLater(0, state.goWithNodes, nodes)
     
     def insertNode(self, n, contacted=1):
@@ -163,7 +118,9 @@ class KhashmirBase(protocol.Factory):
         method needs to be a properly formed Node object with a valid ID.
         """
         old = self.table.insertNode(n, contacted=contacted)
-        if old and (time() - old.lastSeen) > const.MIN_PING_INTERVAL and old.id != self.node.id:
+        if (old and old.id != self.node.id and
+            (datetime.now() - old.lastSeen) > 
+             timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             # the bucket is full, check to see if old node is still around and if so, replace it
             
             ## these are the callbacks used when we ping the oldest node in a bucket
@@ -180,38 +137,34 @@ class KhashmirBase(protocol.Factory):
             df = old.ping(self.node.id)
             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
 
-    def sendPing(self, node, callback=None):
+    def sendJoin(self, node, callback=None, errback=None):
         """
             ping a node
         """
-        df = node.ping(self.node.id)
+        df = node.join(self.node.id)
         ## these are the callbacks we use when we issue a PING
-        def _pongHandler(dict, node=node, table=self.table, callback=callback):
-            _krpc_sender = dict['_krpc_sender']
-            dict = dict['rsp']
-            sender = {'id' : dict['id']}
-            sender['host'] = _krpc_sender[0]
-            sender['port'] = _krpc_sender[1]
-            n = self.Node().initWithDict(sender)
-            n.conn = self.udp.connectionForAddr((n.host, n.port))
-            table.insertNode(n)
+        def _pongHandler(dict, node=node, self=self, callback=callback):
+            n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
+            self.insertNode(n)
             if callback:
-                callback()
-        def _defaultPong(err, node=node, table=self.table, callback=callback):
+                callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
+        def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
             table.nodeFailed(node)
-            if callback:
-                callback()
+            if errback:
+                errback()
+            else:
+                callback(None)
         
         df.addCallbacks(_pongHandler,_defaultPong)
 
-    def findCloseNodes(self, callback=lambda a: None):
+    def findCloseNodes(self, callback=lambda a: None, errback = None):
         """
             This does a findNode on the ID one away from our own.  
             This will allow us to populate our table with nodes on our network closest to our own.
             This is called as soon as we start up with an empty table
         """
         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
-        self.findNode(id, callback)
+        self.findNode(id, callback, errback)
 
     def refreshTable(self, force=0):
         """
@@ -221,7 +174,8 @@ class KhashmirBase(protocol.Factory):
             pass
     
         for bucket in self.table.buckets:
-            if force or (time() - bucket.lastAccessed >= const.BUCKET_STALENESS):
+            if force or (datetime.now() - bucket.lastAccessed > 
+                         timedelta(seconds=self.config['BUCKET_STALENESS'])):
                 id = newIDInRange(bucket.min, bucket.max)
                 self.findNode(id, callback)
 
@@ -232,78 +186,99 @@ class KhashmirBase(protocol.Factory):
         num_nodes: number of nodes estimated in the entire dht
         """
         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
-        num_nodes = const.K * (2**(len(self.table.buckets) - 1))
+        num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
         return (num_contacts, num_nodes)
+    
+    def shutdown(self):
+        """Closes the port and cancels pending later calls."""
+        self.listenport.stopListening()
+        try:
+            self.next_checkpoint.cancel()
+        except:
+            pass
+        self.store.close()
 
+    #### Remote Interface - called by remote nodes
     def krpc_ping(self, id, _krpc_sender):
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
         self.insertNode(n, contacted=0)
         return {"id" : self.node.id}
         
+    def krpc_join(self, id, _krpc_sender):
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
+        return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
+        
     def krpc_find_node(self, target, id, _krpc_sender):
-        nodes = self.table.findNodes(target)
-        nodes = map(lambda node: node.senderDict(), nodes)
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
         self.insertNode(n, contacted=0)
-        return {"nodes" : nodes, "id" : self.node.id}
+        nodes = self.table.findNodes(target)
+        nodes = map(lambda node: node.contactInfo(), nodes)
+        token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
+        return {"nodes" : nodes, "token" : token, "id" : self.node.id}
 
 
 ## This class provides read-only access to the DHT, valueForKey
 ## you probably want to use this mixin and provide your own write methods
 class KhashmirRead(KhashmirBase):
     _Node = KNodeRead
-    def retrieveValues(self, key):
-        c = self.store.cursor()
-        c.execute("select value from kv where key = %s;", sqlite.encode(key))
-        t = c.fetchone()
-        l = []
-        while t:
-            l.append(t['value'])
-            t = c.fetchone()
-        return l
+
     ## also async
+    def findValue(self, key, callback, errback=None):
+        """ returns the contact info for nodes that have values for the key, from the global table """
+        # get K nodes out of local table/cache
+        nodes = self.table.findNodes(key)
+        d = Deferred()
+        if errback:
+            d.addCallbacks(callback, errback)
+        else:
+            d.addCallback(callback)
+
+        # create our search state
+        state = FindValue(self, key, d.callback, self.config)
+        reactor.callLater(0, state.goWithNodes, nodes)
+
     def valueForKey(self, key, callback, searchlocal = 1):
         """ returns the values found for key in global table
             callback will be called with a list of values for each peer that returns unique values
             final callback will be an empty list - probably should change to 'more coming' arg
         """
-        nodes = self.table.findNodes(key)
-        
         # get locals
         if searchlocal:
-            l = self.retrieveValues(key)
+            l = self.store.retrieveValues(key)
             if len(l) > 0:
-                reactor.callLater(0, callback, (l))
+                reactor.callLater(0, callback, key, l)
         else:
             l = []
-        
-        # create our search state
-        state = GetValue(self, key, callback)
-        reactor.callLater(0, state.goWithNodes, nodes, l)
 
+        def _getValueForKey(nodes, key=key, local_values=l, response=callback, self=self):
+            # create our search state
+            state = GetValue(self, key, local_values, self.config['RETRIEVE_VALUES'], response, self.config)
+            reactor.callLater(0, state.goWithNodes, nodes)
+            
+        # this call is asynch
+        self.findValue(key, _getValueForKey)
+
+    #### Remote Interface - called by remote nodes
     def krpc_find_value(self, key, id, _krpc_sender):
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
         self.insertNode(n, contacted=0)
     
-        l = self.retrieveValues(key)
-        if len(l) > 0:
+        nodes = self.table.findNodes(key)
+        nodes = map(lambda node: node.contactInfo(), nodes)
+        num_values = self.store.countValues(key)
+        return {'nodes' : nodes, 'num' : num_values, "id": self.node.id}
+
+    def krpc_get_value(self, key, num, id, _krpc_sender):
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
+        self.insertNode(n, contacted=0)
+    
+        l = self.store.retrieveValues(key)
+        if num == 0 or num >= len(l):
             return {'values' : l, "id": self.node.id}
         else:
-            nodes = self.table.findNodes(key)
-            nodes = map(lambda node: node.senderDict(), nodes)
-            return {'nodes' : nodes, "id": self.node.id}
+            shuffle(l)
+            return {'values' : l[:num], "id": self.node.id}
 
 ###  provides a generic write method, you probably don't want to deploy something that allows
 ###  arbitrary value storage
@@ -311,39 +286,183 @@ class KhashmirWrite(KhashmirRead):
     _Node = KNodeWrite
     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
     def storeValueForKey(self, key, value, callback=None):
-        """ stores the value for key in the global table, returns immediately, no status 
+        """ stores the value and origination time for key in the global table, returns immediately, no status 
             in this implementation, peers respond but don't indicate status to storing values
             a key can have many values
         """
-        def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
+        def _storeValueForKey(nodes, key=key, value=value, response=callback, self=self):
             if not response:
                 # default callback
-                def _storedValueHandler(sender):
+                def _storedValueHandler(key, value, sender):
                     pass
                 response=_storedValueHandler
-            action = StoreValue(self.table, key, value, response)
+            action = StoreValue(self, key, value, self.config['STORE_REDUNDANCY'], response, self.config)
             reactor.callLater(0, action.goWithNodes, nodes)
             
         # this call is asynch
         self.findNode(key, _storeValueForKey)
                     
-    def krpc_store_value(self, key, value, id, _krpc_sender):
-        t = "%0.6f" % time()
-        c = self.store.cursor()
-        try:
-            c.execute("insert into kv values (%s, %s, %s);", (sqlite.encode(key), sqlite.encode(value), t))
-        except sqlite.IntegrityError, reason:
-            # update last insert time
-            c.execute("update kv set time = %s where key = %s and value = %s;", (t, sqlite.encode(key), sqlite.encode(value)))
-        self.store.commit()
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
+    #### Remote Interface - called by remote nodes
+    def krpc_store_value(self, key, value, token, id, _krpc_sender):
+        n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
         self.insertNode(n, contacted=0)
-        return {"id" : self.node.id}
+        for secret in self.token_secrets:
+            this_token = sha(secret + _krpc_sender[0]).digest()
+            if token == this_token:
+                self.store.storeValue(key, value)
+                return {"id" : self.node.id}
+        raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
 
 # the whole shebang, for testing
 class Khashmir(KhashmirWrite):
     _Node = KNodeWrite
+
+class SimpleTests(unittest.TestCase):
+    
+    timeout = 10
+    DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3,
+                    'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
+                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+
+    def setUp(self):
+        krpc.KRPC.noisy = 0
+        d = self.DHT_DEFAULTS.copy()
+        d['PORT'] = 4044
+        self.a = Khashmir(d)
+        d = self.DHT_DEFAULTS.copy()
+        d['PORT'] = 4045
+        self.b = Khashmir(d)
+        
+    def tearDown(self):
+        self.a.shutdown()
+        self.b.shutdown()
+        os.unlink(self.a.store.db)
+        os.unlink(self.b.store.db)
+
+    def testAddContact(self):
+        self.failUnlessEqual(len(self.a.table.buckets), 1)
+        self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
+
+        self.failUnlessEqual(len(self.b.table.buckets), 1)
+        self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
+
+        self.a.addContact('127.0.0.1', 4045)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+
+        self.failUnlessEqual(len(self.a.table.buckets), 1)
+        self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
+        self.failUnlessEqual(len(self.b.table.buckets), 1)
+        self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
+
+    def testStoreRetrieve(self):
+        self.a.addContact('127.0.0.1', 4045)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.got = 0
+        self.a.storeValueForKey(sha('foo').digest(), 'foobar')
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.a.valueForKey(sha('foo').digest(), self._cb)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+
+    def _cb(self, key, val):
+        if not val:
+            self.failUnlessEqual(self.got, 1)
+        elif 'foobar' in val:
+            self.got = 1
+
+
+class MultiTest(unittest.TestCase):
+    
+    timeout = 30
+    num = 20
+    DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
+                    'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
+                    'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
+                    'MAX_FAILURES': 3,
+                    'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
+                    'KEY_EXPIRE': 3600, 'SPEW': False, }
+
+    def _done(self, val):
+        self.done = 1
+        
+    def setUp(self):
+        self.l = []
+        self.startport = 4088
+        for i in range(self.num):
+            d = self.DHT_DEFAULTS.copy()
+            d['PORT'] = self.startport + i
+            self.l.append(Khashmir(d))
+        reactor.iterate()
+        reactor.iterate()
+        
+        for i in self.l:
+            i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
+            i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
+            i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
+            reactor.iterate()
+            reactor.iterate()
+            reactor.iterate() 
+            
+        for i in self.l:
+            self.done = 0
+            i.findCloseNodes(self._done)
+            while not self.done:
+                reactor.iterate()
+        for i in self.l:
+            self.done = 0
+            i.findCloseNodes(self._done)
+            while not self.done:
+                reactor.iterate()
+
+    def tearDown(self):
+        for i in self.l:
+            i.shutdown()
+            os.unlink(i.store.db)
+            
+        reactor.iterate()
+        
+    def testStoreRetrieve(self):
+        for i in range(10):
+            K = newID()
+            V = newID()
+            
+            for a in range(3):
+                self.done = 0
+                def _scb(key, value, result):
+                    self.done = 1
+                self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
+                while not self.done:
+                    reactor.iterate()
+
+
+                def _rcb(key, val):
+                    if not val:
+                        self.done = 1
+                        self.failUnlessEqual(self.got, 1)
+                    elif V in val:
+                        self.got = 1
+                for x in range(3):
+                    self.got = 0
+                    self.done = 0
+                    self.l[randrange(0, self.num)].valueForKey(K, _rcb)
+                    while not self.done:
+                        reactor.iterate()