]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - apt_dht_Khashmir/khashmir.py
Improve the stopping of the krpc protocol so no timeouts are left.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
index 1baffa17dc25072e47198f903de8e8a9fe26783d..c4d018c1a1980a5c5c10f0a4ad70c6dedaba9298 100644 (file)
@@ -4,25 +4,22 @@
 import warnings
 warnings.simplefilter("ignore", DeprecationWarning)
 
-from time import time
+from datetime import datetime, timedelta
 from random import randrange
 from sha import sha
 import os
-import sqlite  ## find this at http://pysqlite.sourceforge.net/
 
 from twisted.internet.defer import Deferred
 from twisted.internet import protocol, reactor
 from twisted.trial import unittest
 
+from db import DB
 from ktable import KTable
 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
 from khash import newID, newIDInRange
 from actions import FindNode, GetValue, KeyExpirer, StoreValue
 import krpc
 
-class KhashmirDBExcept(Exception):
-    pass
-
 # this is the base class, has base functionality and find node, no key-value mappings
 class KhashmirBase(protocol.Factory):
     _Node = KNodeBase
@@ -33,14 +30,13 @@ class KhashmirBase(protocol.Factory):
     def setup(self, config, cache_dir):
         self.config = config
         self.port = config['PORT']
-        self._findDB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
+        self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
         self.node = self._loadSelfNode('', self.port)
         self.table = KTable(self.node, config)
         #self.app = service.Application("krpc")
         self.udp = krpc.hostbroker(self)
         self.udp.protocol = krpc.KRPC
         self.listenport = reactor.listenUDP(self.port, self.udp)
-        self.last = time()
         self._loadRoutingTable()
         self.expirer = KeyExpirer(self.store, config)
         self.refreshTable(force=1)
@@ -55,84 +51,36 @@ class KhashmirBase(protocol.Factory):
         self.listenport.stopListening()
         
     def _loadSelfNode(self, host, port):
-        c = self.store.cursor()
-        c.execute('select id from self where num = 0;')
-        if c.rowcount > 0:
-            id = c.fetchone()[0]
-        else:
+        id = self.store.getSelfNode()
+        if not id:
             id = newID()
         return self._Node().init(id, host, port)
         
-    def _saveSelfNode(self):
-        c = self.store.cursor()
-        c.execute('delete from self where num = 0;')
-        c.execute("insert into self values (0, %s);", sqlite.encode(self.node.id))
-        self.store.commit()
-        
     def checkpoint(self, auto=0):
-        self._saveSelfNode()
-        self._dumpRoutingTable()
+        self.store.saveSelfNode(self.node.id)
+        self.store.dumpRoutingTable(self.table.buckets)
         self.refreshTable()
         if auto:
             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
                               self.checkpoint, (1,))
         
-    def _findDB(self, db):
-        self.db = db
-        try:
-            os.stat(db)
-        except OSError:
-            self._createNewDB(db)
-        else:
-            self._loadDB(db)
-        
-    def _loadDB(self, db):
-        try:
-            self.store = sqlite.connect(db=db)
-            #self.store.autocommit = 0
-        except:
-            import traceback
-            raise KhashmirDBExcept, "Couldn't open DB", traceback.format_exc()
-        
-    def _createNewDB(self, db):
-        self.store = sqlite.connect(db=db)
-        s = """
-            create table kv (key binary, value binary, time timestamp, primary key (key, value));
-            create index kv_key on kv(key);
-            create index kv_timestamp on kv(time);
-            
-            create table nodes (id binary primary key, host text, port number);
-            
-            create table self (num number primary key, id binary);
-            """
-        c = self.store.cursor()
-        c.execute(s)
-        self.store.commit()
-
-    def _dumpRoutingTable(self):
-        """
-            save routing table nodes to the database
-        """
-        c = self.store.cursor()
-        c.execute("delete from nodes where id not NULL;")
-        for bucket in self.table.buckets:
-            for node in bucket.l:
-                c.execute("insert into nodes values (%s, %s, %s);", (sqlite.encode(node.id), node.host, node.port))
-        self.store.commit()
-        
     def _loadRoutingTable(self):
         """
             load routing table nodes from database
             it's usually a good idea to call refreshTable(force=1) after loading the table
         """
-        c = self.store.cursor()
-        c.execute("select * from nodes;")
-        for rec in c.fetchall():
+        nodes = self.store.getRoutingTable()
+        for rec in nodes:
             n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
             n.conn = self.udp.connectionForAddr((n.host, n.port))
             self.table.insertNode(n, contacted=0)
             
+    def _update_node(self, id, host, port):
+        n = self.Node().init(id, host, port)
+        n.conn = self.udp.connectionForAddr((host, port))
+        self.insertNode(n, contacted=0)
+    
 
     #######
     #######  LOCAL INTERFACE    - use these methods!
@@ -171,7 +119,9 @@ class KhashmirBase(protocol.Factory):
         method needs to be a properly formed Node object with a valid ID.
         """
         old = self.table.insertNode(n, contacted=contacted)
-        if old and (time() - old.lastSeen) > self.config['MIN_PING_INTERVAL'] and old.id != self.node.id:
+        if (old and old.id != self.node.id and
+            (datetime.now() - old.lastSeen) > 
+             timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
             # the bucket is full, check to see if old node is still around and if so, replace it
             
             ## these are the callbacks used when we ping the oldest node in a bucket
@@ -229,7 +179,8 @@ class KhashmirBase(protocol.Factory):
             pass
     
         for bucket in self.table.buckets:
-            if force or (time() - bucket.lastAccessed >= self.config['BUCKET_STALENESS']):
+            if force or (datetime.now() - bucket.lastAccessed > 
+                         timedelta(seconds=self.config['BUCKET_STALENESS'])):
                 id = newIDInRange(bucket.min, bucket.max)
                 self.findNode(id, callback)
 
@@ -253,24 +204,15 @@ class KhashmirBase(protocol.Factory):
         self.expirer.shutdown()
         self.store.close()
 
+    #### Remote Interface - called by remote nodes
     def krpc_ping(self, id, _krpc_sender):
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
-        self.insertNode(n, contacted=0)
+        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
         return {"id" : self.node.id}
         
     def krpc_find_node(self, target, id, _krpc_sender):
+        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
         nodes = self.table.findNodes(target)
         nodes = map(lambda node: node.senderDict(), nodes)
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
-        self.insertNode(n, contacted=0)
         return {"nodes" : nodes, "id" : self.node.id}
 
 
@@ -278,15 +220,7 @@ class KhashmirBase(protocol.Factory):
 ## you probably want to use this mixin and provide your own write methods
 class KhashmirRead(KhashmirBase):
     _Node = KNodeRead
-    def retrieveValues(self, key):
-        c = self.store.cursor()
-        c.execute("select value from kv where key = %s;", sqlite.encode(key))
-        t = c.fetchone()
-        l = []
-        while t:
-            l.append(t['value'])
-            t = c.fetchone()
-        return l
+
     ## also async
     def valueForKey(self, key, callback, searchlocal = 1):
         """ returns the values found for key in global table
@@ -297,9 +231,9 @@ class KhashmirRead(KhashmirBase):
         
         # get locals
         if searchlocal:
-            l = self.retrieveValues(key)
+            l = self.store.retrieveValues(key)
             if len(l) > 0:
-                reactor.callLater(0, callback, (l))
+                reactor.callLater(0, callback, key, l)
         else:
             l = []
         
@@ -307,15 +241,11 @@ class KhashmirRead(KhashmirBase):
         state = GetValue(self, key, callback, self.config)
         reactor.callLater(0, state.goWithNodes, nodes, l)
 
+    #### Remote Interface - called by remote nodes
     def krpc_find_value(self, key, id, _krpc_sender):
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
-        self.insertNode(n, contacted=0)
+        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
     
-        l = self.retrieveValues(key)
+        l = self.store.retrieveValues(key)
         if len(l) > 0:
             return {'values' : l, "id": self.node.id}
         else:
@@ -336,7 +266,7 @@ class KhashmirWrite(KhashmirRead):
         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
             if not response:
                 # default callback
-                def _storedValueHandler(sender):
+                def _storedValueHandler(key, value, sender):
                     pass
                 response=_storedValueHandler
             action = StoreValue(self.table, key, value, response, self.config)
@@ -345,21 +275,10 @@ class KhashmirWrite(KhashmirRead):
         # this call is asynch
         self.findNode(key, _storeValueForKey)
                     
+    #### Remote Interface - called by remote nodes
     def krpc_store_value(self, key, value, id, _krpc_sender):
-        t = "%0.6f" % time()
-        c = self.store.cursor()
-        try:
-            c.execute("insert into kv values (%s, %s, %s);", (sqlite.encode(key), sqlite.encode(value), t))
-        except sqlite.IntegrityError, reason:
-            # update last insert time
-            c.execute("update kv set time = %s where key = %s and value = %s;", (t, sqlite.encode(key), sqlite.encode(value)))
-        self.store.commit()
-        sender = {'id' : id}
-        sender['host'] = _krpc_sender[0]
-        sender['port'] = _krpc_sender[1]        
-        n = self.Node().initWithDict(sender)
-        n.conn = self.udp.connectionForAddr((n.host, n.port))
-        self.insertNode(n, contacted=0)
+        self._update_node(id, _krpc_sender[0], _krpc_sender[1])
+        self.store.storeValue(key, value)
         return {"id" : self.node.id}
 
 # the whole shebang, for testing
@@ -388,8 +307,8 @@ class SimpleTests(unittest.TestCase):
     def tearDown(self):
         self.a.shutdown()
         self.b.shutdown()
-        os.unlink(self.a.db)
-        os.unlink(self.b.db)
+        os.unlink(self.a.store.db)
+        os.unlink(self.b.store.db)
 
     def testAddContact(self):
         self.assertEqual(len(self.a.table.buckets), 1)
@@ -432,7 +351,7 @@ class SimpleTests(unittest.TestCase):
         reactor.iterate()
         reactor.iterate()
 
-    def _cb(self, val):
+    def _cb(self, key, val):
         if not val:
             self.assertEqual(self.got, 1)
         elif 'foobar' in val:
@@ -485,7 +404,7 @@ class MultiTest(unittest.TestCase):
     def tearDown(self):
         for i in self.l:
             i.shutdown()
-            os.unlink(i.db)
+            os.unlink(i.store.db)
             
         reactor.iterate()
         
@@ -496,14 +415,14 @@ class MultiTest(unittest.TestCase):
             
             for a in range(3):
                 self.done = 0
-                def _scb(val):
+                def _scb(key, value, result):
                     self.done = 1
                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
                 while not self.done:
                     reactor.iterate()
 
 
-                def _rcb(val):
+                def _rcb(key, val):
                     if not val:
                         self.done = 1
                         self.assertEqual(self.got, 1)