eeaab0acf638215dc2ec707d6c0d62daff48d04c
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange, shuffle
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, FindValue, GetValue, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         self.token_secrets = [newID()]
37         #self.app = service.Application("krpc")
38         self.udp = krpc.hostbroker(self, config)
39         self.udp.protocol = krpc.KRPC
40         self.listenport = reactor.listenUDP(self.port, self.udp)
41         self._loadRoutingTable()
42         self.refreshTable(force=1)
43         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
44
45     def Node(self, id, host = None, port = None):
46         """Create a new node."""
47         n = self._Node(id, host, port)
48         n.table = self.table
49         n.conn = self.udp.connectionForAddr((n.host, n.port))
50         return n
51     
52     def __del__(self):
53         self.listenport.stopListening()
54         
55     def _loadSelfNode(self, host, port):
56         id = self.store.getSelfNode()
57         if not id:
58             id = newID()
59         return self._Node(id, host, port)
60         
61     def checkpoint(self, auto=0):
62         self.token_secrets.insert(0, newID())
63         if len(self.token_secrets) > 3:
64             self.token_secrets.pop()
65         self.store.saveSelfNode(self.node.id)
66         self.store.dumpRoutingTable(self.table.buckets)
67         self.store.expireValues(self.config['KEY_EXPIRE'])
68         self.refreshTable()
69         if auto:
70             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
71                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
72                               self.checkpoint, (1,))
73         
74     def _loadRoutingTable(self):
75         """
76             load routing table nodes from database
77             it's usually a good idea to call refreshTable(force=1) after loading the table
78         """
79         nodes = self.store.getRoutingTable()
80         for rec in nodes:
81             n = self.Node(rec[0], rec[1], int(rec[2]))
82             self.table.insertNode(n, contacted=0)
83             
84
85     #######
86     #######  LOCAL INTERFACE    - use these methods!
87     def addContact(self, host, port, callback=None, errback=None):
88         """
89             ping this node and add the contact info to the table on pong!
90         """
91         n = self.Node(NULL_ID, host, port)
92         self.sendJoin(n, callback=callback, errback=errback)
93
94     ## this call is async!
95     def findNode(self, id, callback, errback=None):
96         """ returns the contact info for node, or the k closest nodes, from the global table """
97         # get K nodes out of local table/cache, or the node we want
98         nodes = self.table.findNodes(id)
99         d = Deferred()
100         if errback:
101             d.addCallbacks(callback, errback)
102         else:
103             d.addCallback(callback)
104         if len(nodes) == 1 and nodes[0].id == id :
105             d.callback(nodes)
106         else:
107             # create our search state
108             state = FindNode(self, id, d.callback, self.config)
109             reactor.callLater(0, state.goWithNodes, nodes)
110     
111     def insertNode(self, n, contacted=1):
112         """
113         insert a node in our local table, pinging oldest contact in bucket, if necessary
114         
115         If all you have is a host/port, then use addContact, which calls this method after
116         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
117         a node into the table without it's peer-ID.  That means of course the node passed into this
118         method needs to be a properly formed Node object with a valid ID.
119         """
120         old = self.table.insertNode(n, contacted=contacted)
121         if (old and old.id != self.node.id and
122             (datetime.now() - old.lastSeen) > 
123              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
124             # the bucket is full, check to see if old node is still around and if so, replace it
125             
126             ## these are the callbacks used when we ping the oldest node in a bucket
127             def _staleNodeHandler(oldnode=old, newnode = n):
128                 """ called if the pinged node never responds """
129                 self.table.replaceStaleNode(old, newnode)
130             
131             def _notStaleNodeHandler(dict, old=old):
132                 """ called when we get a pong from the old node """
133                 dict = dict['rsp']
134                 if dict['id'] == old.id:
135                     self.table.justSeenNode(old.id)
136             
137             df = old.ping(self.node.id)
138             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
139
140     def sendJoin(self, node, callback=None, errback=None):
141         """
142             ping a node
143         """
144         df = node.join(self.node.id)
145         ## these are the callbacks we use when we issue a PING
146         def _pongHandler(dict, node=node, self=self, callback=callback):
147             n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
148             self.insertNode(n)
149             if callback:
150                 callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
151         def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
152             table.nodeFailed(node)
153             if errback:
154                 errback()
155             else:
156                 callback(None)
157         
158         df.addCallbacks(_pongHandler,_defaultPong)
159
160     def findCloseNodes(self, callback=lambda a: None, errback = None):
161         """
162             This does a findNode on the ID one away from our own.  
163             This will allow us to populate our table with nodes on our network closest to our own.
164             This is called as soon as we start up with an empty table
165         """
166         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
167         self.findNode(id, callback, errback)
168
169     def refreshTable(self, force=0):
170         """
171             force=1 will refresh table regardless of last bucket access time
172         """
173         def callback(nodes):
174             pass
175     
176         for bucket in self.table.buckets:
177             if force or (datetime.now() - bucket.lastAccessed > 
178                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
179                 id = newIDInRange(bucket.min, bucket.max)
180                 self.findNode(id, callback)
181
182     def stats(self):
183         """
184         Returns (num_contacts, num_nodes)
185         num_contacts: number contacts in our routing table
186         num_nodes: number of nodes estimated in the entire dht
187         """
188         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
189         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
190         return (num_contacts, num_nodes)
191     
192     def shutdown(self):
193         """Closes the port and cancels pending later calls."""
194         self.listenport.stopListening()
195         try:
196             self.next_checkpoint.cancel()
197         except:
198             pass
199         self.store.close()
200
201     #### Remote Interface - called by remote nodes
202     def krpc_ping(self, id, _krpc_sender):
203         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
204         self.insertNode(n, contacted=0)
205         return {"id" : self.node.id}
206         
207     def krpc_join(self, id, _krpc_sender):
208         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
209         self.insertNode(n, contacted=0)
210         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
211         
212     def krpc_find_node(self, target, id, _krpc_sender):
213         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
214         self.insertNode(n, contacted=0)
215         nodes = self.table.findNodes(target)
216         nodes = map(lambda node: node.contactInfo(), nodes)
217         token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
218         return {"nodes" : nodes, "token" : token, "id" : self.node.id}
219
220
221 ## This class provides read-only access to the DHT, valueForKey
222 ## you probably want to use this mixin and provide your own write methods
223 class KhashmirRead(KhashmirBase):
224     _Node = KNodeRead
225
226     ## also async
227     def findValue(self, key, callback, errback=None):
228         """ returns the contact info for nodes that have values for the key, from the global table """
229         # get K nodes out of local table/cache
230         nodes = self.table.findNodes(key)
231         d = Deferred()
232         if errback:
233             d.addCallbacks(callback, errback)
234         else:
235             d.addCallback(callback)
236
237         # create our search state
238         state = FindValue(self, key, d.callback, self.config)
239         reactor.callLater(0, state.goWithNodes, nodes)
240
241     def valueForKey(self, key, callback, searchlocal = 1):
242         """ returns the values found for key in global table
243             callback will be called with a list of values for each peer that returns unique values
244             final callback will be an empty list - probably should change to 'more coming' arg
245         """
246         # get locals
247         if searchlocal:
248             l = self.store.retrieveValues(key)
249             if len(l) > 0:
250                 reactor.callLater(0, callback, key, l)
251         else:
252             l = []
253
254         def _getValueForKey(nodes, key=key, local_values=l, response=callback, self=self):
255             # create our search state
256             state = GetValue(self, key, local_values, self.config['RETRIEVE_VALUES'], response, self.config)
257             reactor.callLater(0, state.goWithNodes, nodes)
258             
259         # this call is asynch
260         self.findValue(key, _getValueForKey)
261
262     #### Remote Interface - called by remote nodes
263     def krpc_find_value(self, key, id, _krpc_sender):
264         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
265         self.insertNode(n, contacted=0)
266     
267         nodes = self.table.findNodes(key)
268         nodes = map(lambda node: node.contactInfo(), nodes)
269         num_values = self.store.countValues(key)
270         return {'nodes' : nodes, 'num' : num_values, "id": self.node.id}
271
272     def krpc_get_value(self, key, num, id, _krpc_sender):
273         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
274         self.insertNode(n, contacted=0)
275     
276         l = self.store.retrieveValues(key)
277         if num == 0 or num >= len(l):
278             return {'values' : l, "id": self.node.id}
279         else:
280             shuffle(l)
281             return {'values' : l[:num], "id": self.node.id}
282
283 ###  provides a generic write method, you probably don't want to deploy something that allows
284 ###  arbitrary value storage
285 class KhashmirWrite(KhashmirRead):
286     _Node = KNodeWrite
287     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
288     def storeValueForKey(self, key, value, callback=None):
289         """ stores the value and origination time for key in the global table, returns immediately, no status 
290             in this implementation, peers respond but don't indicate status to storing values
291             a key can have many values
292         """
293         def _storeValueForKey(nodes, key=key, value=value, response=callback, self=self):
294             if not response:
295                 # default callback
296                 def _storedValueHandler(key, value, sender):
297                     pass
298                 response=_storedValueHandler
299             action = StoreValue(self, key, value, self.config['STORE_REDUNDANCY'], response, self.config)
300             reactor.callLater(0, action.goWithNodes, nodes)
301             
302         # this call is asynch
303         self.findNode(key, _storeValueForKey)
304                     
305     #### Remote Interface - called by remote nodes
306     def krpc_store_value(self, key, value, token, id, _krpc_sender):
307         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
308         self.insertNode(n, contacted=0)
309         for secret in self.token_secrets:
310             this_token = sha(secret + _krpc_sender[0]).digest()
311             if token == this_token:
312                 self.store.storeValue(key, value)
313                 return {"id" : self.node.id}
314         raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
315
316 # the whole shebang, for testing
317 class Khashmir(KhashmirWrite):
318     _Node = KNodeWrite
319
320 class SimpleTests(unittest.TestCase):
321     
322     timeout = 10
323     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
324                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
325                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
326                     'MAX_FAILURES': 3,
327                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
328                     'KEY_EXPIRE': 3600, 'SPEW': False, }
329
330     def setUp(self):
331         krpc.KRPC.noisy = 0
332         d = self.DHT_DEFAULTS.copy()
333         d['PORT'] = 4044
334         self.a = Khashmir(d)
335         d = self.DHT_DEFAULTS.copy()
336         d['PORT'] = 4045
337         self.b = Khashmir(d)
338         
339     def tearDown(self):
340         self.a.shutdown()
341         self.b.shutdown()
342         os.unlink(self.a.store.db)
343         os.unlink(self.b.store.db)
344
345     def testAddContact(self):
346         self.failUnlessEqual(len(self.a.table.buckets), 1)
347         self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
348
349         self.failUnlessEqual(len(self.b.table.buckets), 1)
350         self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
351
352         self.a.addContact('127.0.0.1', 4045)
353         reactor.iterate()
354         reactor.iterate()
355         reactor.iterate()
356         reactor.iterate()
357
358         self.failUnlessEqual(len(self.a.table.buckets), 1)
359         self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
360         self.failUnlessEqual(len(self.b.table.buckets), 1)
361         self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
362
363     def testStoreRetrieve(self):
364         self.a.addContact('127.0.0.1', 4045)
365         reactor.iterate()
366         reactor.iterate()
367         reactor.iterate()
368         reactor.iterate()
369         self.got = 0
370         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
371         reactor.iterate()
372         reactor.iterate()
373         reactor.iterate()
374         reactor.iterate()
375         reactor.iterate()
376         reactor.iterate()
377         self.a.valueForKey(sha('foo').digest(), self._cb)
378         reactor.iterate()
379         reactor.iterate()
380         reactor.iterate()
381         reactor.iterate()
382         reactor.iterate()
383         reactor.iterate()
384         reactor.iterate()
385
386     def _cb(self, key, val):
387         if not val:
388             self.failUnlessEqual(self.got, 1)
389         elif 'foobar' in val:
390             self.got = 1
391
392
393 class MultiTest(unittest.TestCase):
394     
395     timeout = 30
396     num = 20
397     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
398                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
399                     'STORE_REDUNDANCY': 3, 'RETRIEVE_VALUES': -10000,
400                     'MAX_FAILURES': 3,
401                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
402                     'KEY_EXPIRE': 3600, 'SPEW': False, }
403
404     def _done(self, val):
405         self.done = 1
406         
407     def setUp(self):
408         self.l = []
409         self.startport = 4088
410         for i in range(self.num):
411             d = self.DHT_DEFAULTS.copy()
412             d['PORT'] = self.startport + i
413             self.l.append(Khashmir(d))
414         reactor.iterate()
415         reactor.iterate()
416         
417         for i in self.l:
418             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
419             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
420             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
421             reactor.iterate()
422             reactor.iterate()
423             reactor.iterate() 
424             
425         for i in self.l:
426             self.done = 0
427             i.findCloseNodes(self._done)
428             while not self.done:
429                 reactor.iterate()
430         for i in self.l:
431             self.done = 0
432             i.findCloseNodes(self._done)
433             while not self.done:
434                 reactor.iterate()
435
436     def tearDown(self):
437         for i in self.l:
438             i.shutdown()
439             os.unlink(i.store.db)
440             
441         reactor.iterate()
442         
443     def testStoreRetrieve(self):
444         for i in range(10):
445             K = newID()
446             V = newID()
447             
448             for a in range(3):
449                 self.done = 0
450                 def _scb(key, value, result):
451                     self.done = 1
452                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
453                 while not self.done:
454                     reactor.iterate()
455
456
457                 def _rcb(key, val):
458                     if not val:
459                         self.done = 1
460                         self.failUnlessEqual(self.got, 1)
461                     elif V in val:
462                         self.got = 1
463                 for x in range(3):
464                     self.got = 0
465                     self.done = 0
466                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
467                     while not self.done:
468                         reactor.iterate()