Return a token in find_node responses, use it in store_value requests.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, GetValue, KeyExpirer, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         self.token_secrets = [newID()]
37         #self.app = service.Application("krpc")
38         self.udp = krpc.hostbroker(self, config)
39         self.udp.protocol = krpc.KRPC
40         self.listenport = reactor.listenUDP(self.port, self.udp)
41         self._loadRoutingTable()
42         self.expirer = KeyExpirer(self.store, config)
43         self.refreshTable(force=1)
44         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
45
46     def Node(self, id, host = None, port = None):
47         """Create a new node."""
48         n = self._Node(id, host, port)
49         n.table = self.table
50         n.conn = self.udp.connectionForAddr((n.host, n.port))
51         return n
52     
53     def __del__(self):
54         self.listenport.stopListening()
55         
56     def _loadSelfNode(self, host, port):
57         id = self.store.getSelfNode()
58         if not id:
59             id = newID()
60         return self._Node(id, host, port)
61         
62     def checkpoint(self, auto=0):
63         self.token_secrets.insert(0, newID())
64         if len(self.token_secrets) > 3:
65             self.token_secrets.pop()
66         self.store.saveSelfNode(self.node.id)
67         self.store.dumpRoutingTable(self.table.buckets)
68         self.refreshTable()
69         if auto:
70             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
71                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
72                               self.checkpoint, (1,))
73         
74     def _loadRoutingTable(self):
75         """
76             load routing table nodes from database
77             it's usually a good idea to call refreshTable(force=1) after loading the table
78         """
79         nodes = self.store.getRoutingTable()
80         for rec in nodes:
81             n = self.Node(rec[0], rec[1], int(rec[2]))
82             self.table.insertNode(n, contacted=0)
83             
84
85     #######
86     #######  LOCAL INTERFACE    - use these methods!
87     def addContact(self, host, port, callback=None, errback=None):
88         """
89             ping this node and add the contact info to the table on pong!
90         """
91         n = self.Node(NULL_ID, host, port)
92         self.sendJoin(n, callback=callback, errback=errback)
93
94     ## this call is async!
95     def findNode(self, id, callback, errback=None):
96         """ returns the contact info for node, or the k closest nodes, from the global table """
97         # get K nodes out of local table/cache, or the node we want
98         nodes = self.table.findNodes(id)
99         d = Deferred()
100         if errback:
101             d.addCallbacks(callback, errback)
102         else:
103             d.addCallback(callback)
104         if len(nodes) == 1 and nodes[0].id == id :
105             d.callback(nodes)
106         else:
107             # create our search state
108             state = FindNode(self, id, d.callback, self.config)
109             reactor.callLater(0, state.goWithNodes, nodes)
110     
111     def insertNode(self, n, contacted=1):
112         """
113         insert a node in our local table, pinging oldest contact in bucket, if necessary
114         
115         If all you have is a host/port, then use addContact, which calls this method after
116         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
117         a node into the table without it's peer-ID.  That means of course the node passed into this
118         method needs to be a properly formed Node object with a valid ID.
119         """
120         old = self.table.insertNode(n, contacted=contacted)
121         if (old and old.id != self.node.id and
122             (datetime.now() - old.lastSeen) > 
123              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
124             # the bucket is full, check to see if old node is still around and if so, replace it
125             
126             ## these are the callbacks used when we ping the oldest node in a bucket
127             def _staleNodeHandler(oldnode=old, newnode = n):
128                 """ called if the pinged node never responds """
129                 self.table.replaceStaleNode(old, newnode)
130             
131             def _notStaleNodeHandler(dict, old=old):
132                 """ called when we get a pong from the old node """
133                 dict = dict['rsp']
134                 if dict['id'] == old.id:
135                     self.table.justSeenNode(old.id)
136             
137             df = old.ping(self.node.id)
138             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
139
140     def sendJoin(self, node, callback=None, errback=None):
141         """
142             ping a node
143         """
144         df = node.join(self.node.id)
145         ## these are the callbacks we use when we issue a PING
146         def _pongHandler(dict, node=node, self=self, callback=callback):
147             n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
148             self.insertNode(n)
149             if callback:
150                 callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
151         def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
152             table.nodeFailed(node)
153             if errback:
154                 errback()
155             else:
156                 callback(None)
157         
158         df.addCallbacks(_pongHandler,_defaultPong)
159
160     def findCloseNodes(self, callback=lambda a: None, errback = None):
161         """
162             This does a findNode on the ID one away from our own.  
163             This will allow us to populate our table with nodes on our network closest to our own.
164             This is called as soon as we start up with an empty table
165         """
166         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
167         self.findNode(id, callback, errback)
168
169     def refreshTable(self, force=0):
170         """
171             force=1 will refresh table regardless of last bucket access time
172         """
173         def callback(nodes):
174             pass
175     
176         for bucket in self.table.buckets:
177             if force or (datetime.now() - bucket.lastAccessed > 
178                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
179                 id = newIDInRange(bucket.min, bucket.max)
180                 self.findNode(id, callback)
181
182     def stats(self):
183         """
184         Returns (num_contacts, num_nodes)
185         num_contacts: number contacts in our routing table
186         num_nodes: number of nodes estimated in the entire dht
187         """
188         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
189         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
190         return (num_contacts, num_nodes)
191     
192     def shutdown(self):
193         """Closes the port and cancels pending later calls."""
194         self.listenport.stopListening()
195         try:
196             self.next_checkpoint.cancel()
197         except:
198             pass
199         self.expirer.shutdown()
200         self.store.close()
201
202     #### Remote Interface - called by remote nodes
203     def krpc_ping(self, id, _krpc_sender):
204         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
205         self.insertNode(n, contacted=0)
206         return {"id" : self.node.id}
207         
208     def krpc_join(self, id, _krpc_sender):
209         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
210         self.insertNode(n, contacted=0)
211         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
212         
213     def krpc_find_node(self, target, id, _krpc_sender):
214         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
215         self.insertNode(n, contacted=0)
216         nodes = self.table.findNodes(target)
217         nodes = map(lambda node: node.contactInfo(), nodes)
218         token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
219         return {"nodes" : nodes, "token" : token, "id" : self.node.id}
220
221
222 ## This class provides read-only access to the DHT, valueForKey
223 ## you probably want to use this mixin and provide your own write methods
224 class KhashmirRead(KhashmirBase):
225     _Node = KNodeRead
226
227     ## also async
228     def valueForKey(self, key, callback, searchlocal = 1):
229         """ returns the values found for key in global table
230             callback will be called with a list of values for each peer that returns unique values
231             final callback will be an empty list - probably should change to 'more coming' arg
232         """
233         nodes = self.table.findNodes(key)
234         
235         # get locals
236         if searchlocal:
237             l = self.store.retrieveValues(key)
238             if len(l) > 0:
239                 reactor.callLater(0, callback, key, l)
240         else:
241             l = []
242         
243         # create our search state
244         state = GetValue(self, key, callback, self.config)
245         reactor.callLater(0, state.goWithNodes, nodes, l)
246
247     #### Remote Interface - called by remote nodes
248     def krpc_find_value(self, key, id, _krpc_sender):
249         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
250         self.insertNode(n, contacted=0)
251     
252         l = self.store.retrieveValues(key)
253         if len(l) > 0:
254             return {'values' : l, "id": self.node.id}
255         else:
256             nodes = self.table.findNodes(key)
257             nodes = map(lambda node: node.contactInfo(), nodes)
258             return {'nodes' : nodes, "id": self.node.id}
259
260 ###  provides a generic write method, you probably don't want to deploy something that allows
261 ###  arbitrary value storage
262 class KhashmirWrite(KhashmirRead):
263     _Node = KNodeWrite
264     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
265     def storeValueForKey(self, key, value, callback=None):
266         """ stores the value and origination time for key in the global table, returns immediately, no status 
267             in this implementation, peers respond but don't indicate status to storing values
268             a key can have many values
269         """
270         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
271             if not response:
272                 # default callback
273                 def _storedValueHandler(key, value, sender):
274                     pass
275                 response=_storedValueHandler
276             action = StoreValue(self.table, key, value, response, self.config)
277             reactor.callLater(0, action.goWithNodes, nodes)
278             
279         # this call is asynch
280         self.findNode(key, _storeValueForKey)
281                     
282     #### Remote Interface - called by remote nodes
283     def krpc_store_value(self, key, value, token, id, _krpc_sender):
284         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
285         self.insertNode(n, contacted=0)
286         for secret in self.token_secrets:
287             this_token = sha(secret + _krpc_sender[0]).digest()
288             if token == this_token:
289                 self.store.storeValue(key, value)
290                 break;
291         return {"id" : self.node.id}
292
293 # the whole shebang, for testing
294 class Khashmir(KhashmirWrite):
295     _Node = KNodeWrite
296
297 class SimpleTests(unittest.TestCase):
298     
299     timeout = 10
300     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
301                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
302                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
303                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
304                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
305                     'KE_AGE': 3600, 'SPEW': False, }
306
307     def setUp(self):
308         krpc.KRPC.noisy = 0
309         d = self.DHT_DEFAULTS.copy()
310         d['PORT'] = 4044
311         self.a = Khashmir(d)
312         d = self.DHT_DEFAULTS.copy()
313         d['PORT'] = 4045
314         self.b = Khashmir(d)
315         
316     def tearDown(self):
317         self.a.shutdown()
318         self.b.shutdown()
319         os.unlink(self.a.store.db)
320         os.unlink(self.b.store.db)
321
322     def testAddContact(self):
323         self.failUnlessEqual(len(self.a.table.buckets), 1)
324         self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
325
326         self.failUnlessEqual(len(self.b.table.buckets), 1)
327         self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
328
329         self.a.addContact('127.0.0.1', 4045)
330         reactor.iterate()
331         reactor.iterate()
332         reactor.iterate()
333         reactor.iterate()
334
335         self.failUnlessEqual(len(self.a.table.buckets), 1)
336         self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
337         self.failUnlessEqual(len(self.b.table.buckets), 1)
338         self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
339
340     def testStoreRetrieve(self):
341         self.a.addContact('127.0.0.1', 4045)
342         reactor.iterate()
343         reactor.iterate()
344         reactor.iterate()
345         reactor.iterate()
346         self.got = 0
347         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
348         reactor.iterate()
349         reactor.iterate()
350         reactor.iterate()
351         reactor.iterate()
352         reactor.iterate()
353         reactor.iterate()
354         self.a.valueForKey(sha('foo').digest(), self._cb)
355         reactor.iterate()
356         reactor.iterate()
357         reactor.iterate()
358         reactor.iterate()
359         reactor.iterate()
360         reactor.iterate()
361         reactor.iterate()
362
363     def _cb(self, key, val):
364         if not val:
365             self.failUnlessEqual(self.got, 1)
366         elif 'foobar' in val:
367             self.got = 1
368
369
370 class MultiTest(unittest.TestCase):
371     
372     timeout = 30
373     num = 20
374     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
375                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
376                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
377                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
378                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
379                     'KE_AGE': 3600, 'SPEW': False, }
380
381     def _done(self, val):
382         self.done = 1
383         
384     def setUp(self):
385         self.l = []
386         self.startport = 4088
387         for i in range(self.num):
388             d = self.DHT_DEFAULTS.copy()
389             d['PORT'] = self.startport + i
390             self.l.append(Khashmir(d))
391         reactor.iterate()
392         reactor.iterate()
393         
394         for i in self.l:
395             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
396             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
397             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
398             reactor.iterate()
399             reactor.iterate()
400             reactor.iterate() 
401             
402         for i in self.l:
403             self.done = 0
404             i.findCloseNodes(self._done)
405             while not self.done:
406                 reactor.iterate()
407         for i in self.l:
408             self.done = 0
409             i.findCloseNodes(self._done)
410             while not self.done:
411                 reactor.iterate()
412
413     def tearDown(self):
414         for i in self.l:
415             i.shutdown()
416             os.unlink(i.store.db)
417             
418         reactor.iterate()
419         
420     def testStoreRetrieve(self):
421         for i in range(10):
422             K = newID()
423             V = newID()
424             
425             for a in range(3):
426                 self.done = 0
427                 def _scb(key, value, result):
428                     self.done = 1
429                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
430                 while not self.done:
431                     reactor.iterate()
432
433
434                 def _rcb(key, val):
435                     if not val:
436                         self.done = 1
437                         self.failUnlessEqual(self.got, 1)
438                     elif V in val:
439                         self.got = 1
440                 for x in range(3):
441                     self.got = 0
442                     self.done = 0
443                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
444                     while not self.done:
445                         reactor.iterate()