6162a6e5e8ac88fe97bfa3ce9fdebe1f6d4b866a
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange, shuffle
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, FindValue, GetValue, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         self.token_secrets = [newID()]
37         #self.app = service.Application("krpc")
38         self.udp = krpc.hostbroker(self, config)
39         self.udp.protocol = krpc.KRPC
40         self.listenport = reactor.listenUDP(self.port, self.udp)
41         self._loadRoutingTable()
42         self.refreshTable(force=1)
43         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
44
45     def Node(self, id, host = None, port = None):
46         """Create a new node."""
47         n = self._Node(id, host, port)
48         n.table = self.table
49         n.conn = self.udp.connectionForAddr((n.host, n.port))
50         return n
51     
52     def __del__(self):
53         self.listenport.stopListening()
54         
55     def _loadSelfNode(self, host, port):
56         id = self.store.getSelfNode()
57         if not id:
58             id = newID()
59         return self._Node(id, host, port)
60         
61     def checkpoint(self, auto=0):
62         self.token_secrets.insert(0, newID())
63         if len(self.token_secrets) > 3:
64             self.token_secrets.pop()
65         self.store.saveSelfNode(self.node.id)
66         self.store.dumpRoutingTable(self.table.buckets)
67         self.store.expireValues(self.config['KEY_EXPIRE'])
68         self.refreshTable()
69         if auto:
70             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
71                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
72                               self.checkpoint, (1,))
73         
74     def _loadRoutingTable(self):
75         """
76             load routing table nodes from database
77             it's usually a good idea to call refreshTable(force=1) after loading the table
78         """
79         nodes = self.store.getRoutingTable()
80         for rec in nodes:
81             n = self.Node(rec[0], rec[1], int(rec[2]))
82             self.table.insertNode(n, contacted=0)
83             
84
85     #######
86     #######  LOCAL INTERFACE    - use these methods!
87     def addContact(self, host, port, callback=None, errback=None):
88         """
89             ping this node and add the contact info to the table on pong!
90         """
91         n = self.Node(NULL_ID, host, port)
92         self.sendJoin(n, callback=callback, errback=errback)
93
94     ## this call is async!
95     def findNode(self, id, callback, errback=None):
96         """ returns the contact info for node, or the k closest nodes, from the global table """
97         # get K nodes out of local table/cache, or the node we want
98         nodes = self.table.findNodes(id)
99         d = Deferred()
100         if errback:
101             d.addCallbacks(callback, errback)
102         else:
103             d.addCallback(callback)
104         if len(nodes) == 1 and nodes[0].id == id :
105             d.callback(nodes)
106         else:
107             # create our search state
108             state = FindNode(self, id, d.callback, self.config)
109             reactor.callLater(0, state.goWithNodes, nodes)
110     
111     def insertNode(self, n, contacted=1):
112         """
113         insert a node in our local table, pinging oldest contact in bucket, if necessary
114         
115         If all you have is a host/port, then use addContact, which calls this method after
116         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
117         a node into the table without it's peer-ID.  That means of course the node passed into this
118         method needs to be a properly formed Node object with a valid ID.
119         """
120         old = self.table.insertNode(n, contacted=contacted)
121         if (old and old.id != self.node.id and
122             (datetime.now() - old.lastSeen) > 
123              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
124             # the bucket is full, check to see if old node is still around and if so, replace it
125             
126             ## these are the callbacks used when we ping the oldest node in a bucket
127             def _staleNodeHandler(oldnode=old, newnode = n):
128                 """ called if the pinged node never responds """
129                 self.table.replaceStaleNode(old, newnode)
130             
131             def _notStaleNodeHandler(dict, old=old):
132                 """ called when we get a pong from the old node """
133                 dict = dict['rsp']
134                 if dict['id'] == old.id:
135                     self.table.justSeenNode(old.id)
136             
137             df = old.ping(self.node.id)
138             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
139
140     def sendJoin(self, node, callback=None, errback=None):
141         """
142             ping a node
143         """
144         df = node.join(self.node.id)
145         ## these are the callbacks we use when we issue a PING
146         def _pongHandler(dict, node=node, self=self, callback=callback):
147             n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
148             self.insertNode(n)
149             if callback:
150                 callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
151         def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
152             table.nodeFailed(node)
153             if errback:
154                 errback()
155             else:
156                 callback(None)
157         
158         df.addCallbacks(_pongHandler,_defaultPong)
159
160     def findCloseNodes(self, callback=lambda a: None, errback = None):
161         """
162             This does a findNode on the ID one away from our own.  
163             This will allow us to populate our table with nodes on our network closest to our own.
164             This is called as soon as we start up with an empty table
165         """
166         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
167         self.findNode(id, callback, errback)
168
169     def refreshTable(self, force=0):
170         """
171             force=1 will refresh table regardless of last bucket access time
172         """
173         def callback(nodes):
174             pass
175     
176         for bucket in self.table.buckets:
177             if force or (datetime.now() - bucket.lastAccessed > 
178                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
179                 id = newIDInRange(bucket.min, bucket.max)
180                 self.findNode(id, callback)
181
182     def stats(self):
183         """
184         Returns (num_contacts, num_nodes)
185         num_contacts: number contacts in our routing table
186         num_nodes: number of nodes estimated in the entire dht
187         """
188         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
189         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
190         return (num_contacts, num_nodes)
191     
192     def shutdown(self):
193         """Closes the port and cancels pending later calls."""
194         self.listenport.stopListening()
195         try:
196             self.next_checkpoint.cancel()
197         except:
198             pass
199         self.store.close()
200
201     #### Remote Interface - called by remote nodes
202     def krpc_ping(self, id, _krpc_sender):
203         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
204         self.insertNode(n, contacted=0)
205         return {"id" : self.node.id}
206         
207     def krpc_join(self, id, _krpc_sender):
208         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
209         self.insertNode(n, contacted=0)
210         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
211         
212     def krpc_find_node(self, target, id, _krpc_sender):
213         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
214         self.insertNode(n, contacted=0)
215         nodes = self.table.findNodes(target)
216         nodes = map(lambda node: node.contactInfo(), nodes)
217         token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
218         return {"nodes" : nodes, "token" : token, "id" : self.node.id}
219
220
221 ## This class provides read-only access to the DHT, valueForKey
222 ## you probably want to use this mixin and provide your own write methods
223 class KhashmirRead(KhashmirBase):
224     _Node = KNodeRead
225
226     ## also async
227     def findValue(self, key, callback, errback=None):
228         """ returns the contact info for nodes that have values for the key, from the global table """
229         # get K nodes out of local table/cache
230         nodes = self.table.findNodes(key)
231         d = Deferred()
232         if errback:
233             d.addCallbacks(callback, errback)
234         else:
235             d.addCallback(callback)
236
237         # create our search state
238         state = FindValue(self, key, d.callback, self.config)
239         reactor.callLater(0, state.goWithNodes, nodes)
240
241     def valueForKey(self, key, callback, searchlocal = 1):
242         """ returns the values found for key in global table
243             callback will be called with a list of values for each peer that returns unique values
244             final callback will be an empty list - probably should change to 'more coming' arg
245         """
246         # get locals
247         if searchlocal:
248             l = self.store.retrieveValues(key)
249             if len(l) > 0:
250                 reactor.callLater(0, callback, key, l)
251         else:
252             l = []
253
254         def _getValueForKey(nodes, key=key, local_values=l, response=callback, table=self.table, config=self.config):
255             # create our search state
256             state = GetValue(table, key, 50 - len(local_values), response, config)
257             reactor.callLater(0, state.goWithNodes, nodes, local_values)
258             
259         # this call is asynch
260         self.findValue(key, _getValueForKey)
261
262     #### Remote Interface - called by remote nodes
263     def krpc_find_value(self, key, id, _krpc_sender):
264         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
265         self.insertNode(n, contacted=0)
266     
267         nodes = self.table.findNodes(key)
268         nodes = map(lambda node: node.contactInfo(), nodes)
269         num_values = self.store.countValues(key)
270         return {'nodes' : nodes, 'num' : num_values, "id": self.node.id}
271
272     def krpc_get_value(self, key, num, id, _krpc_sender):
273         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
274         self.insertNode(n, contacted=0)
275     
276         l = self.store.retrieveValues(key)
277         if num == 0 or num >= len(l):
278             return {'values' : l, "id": self.node.id}
279         else:
280             shuffle(l)
281             return {'values' : l[:num], "id": self.node.id}
282
283 ###  provides a generic write method, you probably don't want to deploy something that allows
284 ###  arbitrary value storage
285 class KhashmirWrite(KhashmirRead):
286     _Node = KNodeWrite
287     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
288     def storeValueForKey(self, key, value, callback=None):
289         """ stores the value and origination time for key in the global table, returns immediately, no status 
290             in this implementation, peers respond but don't indicate status to storing values
291             a key can have many values
292         """
293         def _storeValueForKey(nodes, key=key, value=value, response=callback, table=self.table, config=self.config):
294             if not response:
295                 # default callback
296                 def _storedValueHandler(key, value, sender):
297                     pass
298                 response=_storedValueHandler
299             action = StoreValue(table, key, value, response, config)
300             reactor.callLater(0, action.goWithNodes, nodes)
301             
302         # this call is asynch
303         self.findNode(key, _storeValueForKey)
304                     
305     #### Remote Interface - called by remote nodes
306     def krpc_store_value(self, key, value, token, id, _krpc_sender):
307         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
308         self.insertNode(n, contacted=0)
309         for secret in self.token_secrets:
310             this_token = sha(secret + _krpc_sender[0]).digest()
311             if token == this_token:
312                 self.store.storeValue(key, value)
313                 return {"id" : self.node.id}
314         raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
315
316 # the whole shebang, for testing
317 class Khashmir(KhashmirWrite):
318     _Node = KNodeWrite
319
320 class SimpleTests(unittest.TestCase):
321     
322     timeout = 10
323     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
324                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
325                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
326                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
327                     'KEY_EXPIRE': 3600, 'SPEW': False, }
328
329     def setUp(self):
330         krpc.KRPC.noisy = 0
331         d = self.DHT_DEFAULTS.copy()
332         d['PORT'] = 4044
333         self.a = Khashmir(d)
334         d = self.DHT_DEFAULTS.copy()
335         d['PORT'] = 4045
336         self.b = Khashmir(d)
337         
338     def tearDown(self):
339         self.a.shutdown()
340         self.b.shutdown()
341         os.unlink(self.a.store.db)
342         os.unlink(self.b.store.db)
343
344     def testAddContact(self):
345         self.failUnlessEqual(len(self.a.table.buckets), 1)
346         self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
347
348         self.failUnlessEqual(len(self.b.table.buckets), 1)
349         self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
350
351         self.a.addContact('127.0.0.1', 4045)
352         reactor.iterate()
353         reactor.iterate()
354         reactor.iterate()
355         reactor.iterate()
356
357         self.failUnlessEqual(len(self.a.table.buckets), 1)
358         self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
359         self.failUnlessEqual(len(self.b.table.buckets), 1)
360         self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
361
362     def testStoreRetrieve(self):
363         self.a.addContact('127.0.0.1', 4045)
364         reactor.iterate()
365         reactor.iterate()
366         reactor.iterate()
367         reactor.iterate()
368         self.got = 0
369         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
370         reactor.iterate()
371         reactor.iterate()
372         reactor.iterate()
373         reactor.iterate()
374         reactor.iterate()
375         reactor.iterate()
376         self.a.valueForKey(sha('foo').digest(), self._cb)
377         reactor.iterate()
378         reactor.iterate()
379         reactor.iterate()
380         reactor.iterate()
381         reactor.iterate()
382         reactor.iterate()
383         reactor.iterate()
384
385     def _cb(self, key, val):
386         if not val:
387             self.failUnlessEqual(self.got, 1)
388         elif 'foobar' in val:
389             self.got = 1
390
391
392 class MultiTest(unittest.TestCase):
393     
394     timeout = 30
395     num = 20
396     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
397                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
398                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
399                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
400                     'KEY_EXPIRE': 3600, 'SPEW': False, }
401
402     def _done(self, val):
403         self.done = 1
404         
405     def setUp(self):
406         self.l = []
407         self.startport = 4088
408         for i in range(self.num):
409             d = self.DHT_DEFAULTS.copy()
410             d['PORT'] = self.startport + i
411             self.l.append(Khashmir(d))
412         reactor.iterate()
413         reactor.iterate()
414         
415         for i in self.l:
416             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
417             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
418             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
419             reactor.iterate()
420             reactor.iterate()
421             reactor.iterate() 
422             
423         for i in self.l:
424             self.done = 0
425             i.findCloseNodes(self._done)
426             while not self.done:
427                 reactor.iterate()
428         for i in self.l:
429             self.done = 0
430             i.findCloseNodes(self._done)
431             while not self.done:
432                 reactor.iterate()
433
434     def tearDown(self):
435         for i in self.l:
436             i.shutdown()
437             os.unlink(i.store.db)
438             
439         reactor.iterate()
440         
441     def testStoreRetrieve(self):
442         for i in range(10):
443             K = newID()
444             V = newID()
445             
446             for a in range(3):
447                 self.done = 0
448                 def _scb(key, value, result):
449                     self.done = 1
450                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
451                 while not self.done:
452                     reactor.iterate()
453
454
455                 def _rcb(key, val):
456                     if not val:
457                         self.done = 1
458                         self.failUnlessEqual(self.got, 1)
459                     elif V in val:
460                         self.got = 1
461                 for x in range(3):
462                     self.got = 0
463                     self.done = 0
464                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
465                     while not self.done:
466                         reactor.iterate()