Added a shutdown method to the khasmir DHT.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from time import time
8 from random import randrange
9 from sha import sha
10 import os
11 import sqlite  ## find this at http://pysqlite.sourceforge.net/
12
13 from twisted.internet.defer import Deferred
14 from twisted.internet import protocol, reactor
15 from twisted.trial import unittest
16
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, GetValue, KeyExpirer, StoreValue
21 import krpc
22
23 class KhashmirDBExcept(Exception):
24     pass
25
26 # this is the base class, has base functionality and find node, no key-value mappings
27 class KhashmirBase(protocol.Factory):
28     _Node = KNodeBase
29     def __init__(self, config, cache_dir='/tmp'):
30         self.config = None
31         self.setup(config, cache_dir)
32         
33     def setup(self, config, cache_dir):
34         self.config = config
35         self.port = config['PORT']
36         self._findDB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
37         self.node = self._loadSelfNode('', self.port)
38         self.table = KTable(self.node, config)
39         #self.app = service.Application("krpc")
40         self.udp = krpc.hostbroker(self)
41         self.udp.protocol = krpc.KRPC
42         self.listenport = reactor.listenUDP(self.port, self.udp)
43         self.last = time()
44         self._loadRoutingTable()
45         self.expirer = KeyExpirer(self.store, config)
46         self.refreshTable(force=1)
47         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
48
49     def Node(self):
50         n = self._Node()
51         n.table = self.table
52         return n
53     
54     def __del__(self):
55         self.listenport.stopListening()
56         
57     def _loadSelfNode(self, host, port):
58         c = self.store.cursor()
59         c.execute('select id from self where num = 0;')
60         if c.rowcount > 0:
61             id = c.fetchone()[0]
62         else:
63             id = newID()
64         return self._Node().init(id, host, port)
65         
66     def _saveSelfNode(self):
67         c = self.store.cursor()
68         c.execute('delete from self where num = 0;')
69         c.execute("insert into self values (0, %s);", sqlite.encode(self.node.id))
70         self.store.commit()
71         
72     def checkpoint(self, auto=0):
73         self._saveSelfNode()
74         self._dumpRoutingTable()
75         self.refreshTable()
76         if auto:
77             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
78                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
79                               self.checkpoint, (1,))
80         
81     def _findDB(self, db):
82         self.db = db
83         try:
84             os.stat(db)
85         except OSError:
86             self._createNewDB(db)
87         else:
88             self._loadDB(db)
89         
90     def _loadDB(self, db):
91         try:
92             self.store = sqlite.connect(db=db)
93             #self.store.autocommit = 0
94         except:
95             import traceback
96             raise KhashmirDBExcept, "Couldn't open DB", traceback.format_exc()
97         
98     def _createNewDB(self, db):
99         self.store = sqlite.connect(db=db)
100         s = """
101             create table kv (key binary, value binary, time timestamp, primary key (key, value));
102             create index kv_key on kv(key);
103             create index kv_timestamp on kv(time);
104             
105             create table nodes (id binary primary key, host text, port number);
106             
107             create table self (num number primary key, id binary);
108             """
109         c = self.store.cursor()
110         c.execute(s)
111         self.store.commit()
112
113     def _dumpRoutingTable(self):
114         """
115             save routing table nodes to the database
116         """
117         c = self.store.cursor()
118         c.execute("delete from nodes where id not NULL;")
119         for bucket in self.table.buckets:
120             for node in bucket.l:
121                 c.execute("insert into nodes values (%s, %s, %s);", (sqlite.encode(node.id), node.host, node.port))
122         self.store.commit()
123         
124     def _loadRoutingTable(self):
125         """
126             load routing table nodes from database
127             it's usually a good idea to call refreshTable(force=1) after loading the table
128         """
129         c = self.store.cursor()
130         c.execute("select * from nodes;")
131         for rec in c.fetchall():
132             n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
133             n.conn = self.udp.connectionForAddr((n.host, n.port))
134             self.table.insertNode(n, contacted=0)
135             
136
137     #######
138     #######  LOCAL INTERFACE    - use these methods!
139     def addContact(self, host, port, callback=None):
140         """
141             ping this node and add the contact info to the table on pong!
142         """
143         n =self.Node().init(NULL_ID, host, port) 
144         n.conn = self.udp.connectionForAddr((n.host, n.port))
145         self.sendPing(n, callback=callback)
146
147     ## this call is async!
148     def findNode(self, id, callback, errback=None):
149         """ returns the contact info for node, or the k closest nodes, from the global table """
150         # get K nodes out of local table/cache, or the node we want
151         nodes = self.table.findNodes(id)
152         d = Deferred()
153         if errback:
154             d.addCallbacks(callback, errback)
155         else:
156             d.addCallback(callback)
157         if len(nodes) == 1 and nodes[0].id == id :
158             d.callback(nodes)
159         else:
160             # create our search state
161             state = FindNode(self, id, d.callback, self.config)
162             reactor.callLater(0, state.goWithNodes, nodes)
163     
164     def insertNode(self, n, contacted=1):
165         """
166         insert a node in our local table, pinging oldest contact in bucket, if necessary
167         
168         If all you have is a host/port, then use addContact, which calls this method after
169         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
170         a node into the table without it's peer-ID.  That means of course the node passed into this
171         method needs to be a properly formed Node object with a valid ID.
172         """
173         old = self.table.insertNode(n, contacted=contacted)
174         if old and (time() - old.lastSeen) > self.config['MIN_PING_INTERVAL'] and old.id != self.node.id:
175             # the bucket is full, check to see if old node is still around and if so, replace it
176             
177             ## these are the callbacks used when we ping the oldest node in a bucket
178             def _staleNodeHandler(oldnode=old, newnode = n):
179                 """ called if the pinged node never responds """
180                 self.table.replaceStaleNode(old, newnode)
181             
182             def _notStaleNodeHandler(dict, old=old):
183                 """ called when we get a pong from the old node """
184                 dict = dict['rsp']
185                 if dict['id'] == old.id:
186                     self.table.justSeenNode(old.id)
187             
188             df = old.ping(self.node.id)
189             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
190
191     def sendPing(self, node, callback=None):
192         """
193             ping a node
194         """
195         df = node.ping(self.node.id)
196         ## these are the callbacks we use when we issue a PING
197         def _pongHandler(dict, node=node, table=self.table, callback=callback):
198             _krpc_sender = dict['_krpc_sender']
199             dict = dict['rsp']
200             sender = {'id' : dict['id']}
201             sender['host'] = _krpc_sender[0]
202             sender['port'] = _krpc_sender[1]
203             n = self.Node().initWithDict(sender)
204             n.conn = self.udp.connectionForAddr((n.host, n.port))
205             table.insertNode(n)
206             if callback:
207                 callback()
208         def _defaultPong(err, node=node, table=self.table, callback=callback):
209             table.nodeFailed(node)
210             if callback:
211                 callback()
212         
213         df.addCallbacks(_pongHandler,_defaultPong)
214
215     def findCloseNodes(self, callback=lambda a: None):
216         """
217             This does a findNode on the ID one away from our own.  
218             This will allow us to populate our table with nodes on our network closest to our own.
219             This is called as soon as we start up with an empty table
220         """
221         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
222         self.findNode(id, callback)
223
224     def refreshTable(self, force=0):
225         """
226             force=1 will refresh table regardless of last bucket access time
227         """
228         def callback(nodes):
229             pass
230     
231         for bucket in self.table.buckets:
232             if force or (time() - bucket.lastAccessed >= self.config['BUCKET_STALENESS']):
233                 id = newIDInRange(bucket.min, bucket.max)
234                 self.findNode(id, callback)
235
236     def stats(self):
237         """
238         Returns (num_contacts, num_nodes)
239         num_contacts: number contacts in our routing table
240         num_nodes: number of nodes estimated in the entire dht
241         """
242         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
243         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
244         return (num_contacts, num_nodes)
245     
246     def shutdown(self):
247         """Closes the port and cancels pending later calls."""
248         self.listenport.stopListening()
249         try:
250             self.next_checkpoint.cancel()
251         except:
252             pass
253         self.expirer.shutdown()
254         self.store.close()
255
256     def krpc_ping(self, id, _krpc_sender):
257         sender = {'id' : id}
258         sender['host'] = _krpc_sender[0]
259         sender['port'] = _krpc_sender[1]        
260         n = self.Node().initWithDict(sender)
261         n.conn = self.udp.connectionForAddr((n.host, n.port))
262         self.insertNode(n, contacted=0)
263         return {"id" : self.node.id}
264         
265     def krpc_find_node(self, target, id, _krpc_sender):
266         nodes = self.table.findNodes(target)
267         nodes = map(lambda node: node.senderDict(), nodes)
268         sender = {'id' : id}
269         sender['host'] = _krpc_sender[0]
270         sender['port'] = _krpc_sender[1]        
271         n = self.Node().initWithDict(sender)
272         n.conn = self.udp.connectionForAddr((n.host, n.port))
273         self.insertNode(n, contacted=0)
274         return {"nodes" : nodes, "id" : self.node.id}
275
276
277 ## This class provides read-only access to the DHT, valueForKey
278 ## you probably want to use this mixin and provide your own write methods
279 class KhashmirRead(KhashmirBase):
280     _Node = KNodeRead
281     def retrieveValues(self, key):
282         c = self.store.cursor()
283         c.execute("select value from kv where key = %s;", sqlite.encode(key))
284         t = c.fetchone()
285         l = []
286         while t:
287             l.append(t['value'])
288             t = c.fetchone()
289         return l
290     ## also async
291     def valueForKey(self, key, callback, searchlocal = 1):
292         """ returns the values found for key in global table
293             callback will be called with a list of values for each peer that returns unique values
294             final callback will be an empty list - probably should change to 'more coming' arg
295         """
296         nodes = self.table.findNodes(key)
297         
298         # get locals
299         if searchlocal:
300             l = self.retrieveValues(key)
301             if len(l) > 0:
302                 reactor.callLater(0, callback, (l))
303         else:
304             l = []
305         
306         # create our search state
307         state = GetValue(self, key, callback, self.config)
308         reactor.callLater(0, state.goWithNodes, nodes, l)
309
310     def krpc_find_value(self, key, id, _krpc_sender):
311         sender = {'id' : id}
312         sender['host'] = _krpc_sender[0]
313         sender['port'] = _krpc_sender[1]        
314         n = self.Node().initWithDict(sender)
315         n.conn = self.udp.connectionForAddr((n.host, n.port))
316         self.insertNode(n, contacted=0)
317     
318         l = self.retrieveValues(key)
319         if len(l) > 0:
320             return {'values' : l, "id": self.node.id}
321         else:
322             nodes = self.table.findNodes(key)
323             nodes = map(lambda node: node.senderDict(), nodes)
324             return {'nodes' : nodes, "id": self.node.id}
325
326 ###  provides a generic write method, you probably don't want to deploy something that allows
327 ###  arbitrary value storage
328 class KhashmirWrite(KhashmirRead):
329     _Node = KNodeWrite
330     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
331     def storeValueForKey(self, key, value, callback=None):
332         """ stores the value for key in the global table, returns immediately, no status 
333             in this implementation, peers respond but don't indicate status to storing values
334             a key can have many values
335         """
336         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
337             if not response:
338                 # default callback
339                 def _storedValueHandler(sender):
340                     pass
341                 response=_storedValueHandler
342             action = StoreValue(self.table, key, value, response, self.config)
343             reactor.callLater(0, action.goWithNodes, nodes)
344             
345         # this call is asynch
346         self.findNode(key, _storeValueForKey)
347                     
348     def krpc_store_value(self, key, value, id, _krpc_sender):
349         t = "%0.6f" % time()
350         c = self.store.cursor()
351         try:
352             c.execute("insert into kv values (%s, %s, %s);", (sqlite.encode(key), sqlite.encode(value), t))
353         except sqlite.IntegrityError, reason:
354             # update last insert time
355             c.execute("update kv set time = %s where key = %s and value = %s;", (t, sqlite.encode(key), sqlite.encode(value)))
356         self.store.commit()
357         sender = {'id' : id}
358         sender['host'] = _krpc_sender[0]
359         sender['port'] = _krpc_sender[1]        
360         n = self.Node().initWithDict(sender)
361         n.conn = self.udp.connectionForAddr((n.host, n.port))
362         self.insertNode(n, contacted=0)
363         return {"id" : self.node.id}
364
365 # the whole shebang, for testing
366 class Khashmir(KhashmirWrite):
367     _Node = KNodeWrite
368
369 class SimpleTests(unittest.TestCase):
370     
371     timeout = 10
372     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
373                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
374                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
375                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
376                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
377                     'KE_AGE': 3600, }
378
379     def setUp(self):
380         krpc.KRPC.noisy = 0
381         d = self.DHT_DEFAULTS.copy()
382         d['PORT'] = 4044
383         self.a = Khashmir(d)
384         d = self.DHT_DEFAULTS.copy()
385         d['PORT'] = 4045
386         self.b = Khashmir(d)
387         
388     def tearDown(self):
389         self.a.shutdown()
390         self.b.shutdown()
391         os.unlink(self.a.db)
392         os.unlink(self.b.db)
393
394     def testAddContact(self):
395         self.assertEqual(len(self.a.table.buckets), 1)
396         self.assertEqual(len(self.a.table.buckets[0].l), 0)
397
398         self.assertEqual(len(self.b.table.buckets), 1)
399         self.assertEqual(len(self.b.table.buckets[0].l), 0)
400
401         self.a.addContact('127.0.0.1', 4045)
402         reactor.iterate()
403         reactor.iterate()
404         reactor.iterate()
405         reactor.iterate()
406
407         self.assertEqual(len(self.a.table.buckets), 1)
408         self.assertEqual(len(self.a.table.buckets[0].l), 1)
409         self.assertEqual(len(self.b.table.buckets), 1)
410         self.assertEqual(len(self.b.table.buckets[0].l), 1)
411
412     def testStoreRetrieve(self):
413         self.a.addContact('127.0.0.1', 4045)
414         reactor.iterate()
415         reactor.iterate()
416         reactor.iterate()
417         reactor.iterate()
418         self.got = 0
419         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
420         reactor.iterate()
421         reactor.iterate()
422         reactor.iterate()
423         reactor.iterate()
424         reactor.iterate()
425         reactor.iterate()
426         self.a.valueForKey(sha('foo').digest(), self._cb)
427         reactor.iterate()
428         reactor.iterate()
429         reactor.iterate()
430         reactor.iterate()
431         reactor.iterate()
432         reactor.iterate()
433         reactor.iterate()
434
435     def _cb(self, val):
436         if not val:
437             self.assertEqual(self.got, 1)
438         elif 'foobar' in val:
439             self.got = 1
440
441
442 class MultiTest(unittest.TestCase):
443     
444     timeout = 30
445     num = 20
446     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
447                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
448                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
449                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
450                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
451                     'KE_AGE': 3600, }
452
453     def _done(self, val):
454         self.done = 1
455         
456     def setUp(self):
457         self.l = []
458         self.startport = 4088
459         for i in range(self.num):
460             d = self.DHT_DEFAULTS.copy()
461             d['PORT'] = self.startport + i
462             self.l.append(Khashmir(d))
463         reactor.iterate()
464         reactor.iterate()
465         
466         for i in self.l:
467             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
468             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
469             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
470             reactor.iterate()
471             reactor.iterate()
472             reactor.iterate() 
473             
474         for i in self.l:
475             self.done = 0
476             i.findCloseNodes(self._done)
477             while not self.done:
478                 reactor.iterate()
479         for i in self.l:
480             self.done = 0
481             i.findCloseNodes(self._done)
482             while not self.done:
483                 reactor.iterate()
484
485     def tearDown(self):
486         for i in self.l:
487             i.shutdown()
488             os.unlink(i.db)
489             
490         reactor.iterate()
491         
492     def testStoreRetrieve(self):
493         for i in range(10):
494             K = newID()
495             V = newID()
496             
497             for a in range(3):
498                 self.done = 0
499                 def _scb(val):
500                     self.done = 1
501                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
502                 while not self.done:
503                     reactor.iterate()
504
505
506                 def _rcb(val):
507                     if not val:
508                         self.done = 1
509                         self.assertEqual(self.got, 1)
510                     elif V in val:
511                         self.got = 1
512                 for x in range(3):
513                     self.got = 0
514                     self.done = 0
515                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
516                     while not self.done:
517                         reactor.iterate()