Move the key expiring to the checkpoint function.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, GetValue, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         self.token_secrets = [newID()]
37         #self.app = service.Application("krpc")
38         self.udp = krpc.hostbroker(self, config)
39         self.udp.protocol = krpc.KRPC
40         self.listenport = reactor.listenUDP(self.port, self.udp)
41         self._loadRoutingTable()
42         self.refreshTable(force=1)
43         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
44
45     def Node(self, id, host = None, port = None):
46         """Create a new node."""
47         n = self._Node(id, host, port)
48         n.table = self.table
49         n.conn = self.udp.connectionForAddr((n.host, n.port))
50         return n
51     
52     def __del__(self):
53         self.listenport.stopListening()
54         
55     def _loadSelfNode(self, host, port):
56         id = self.store.getSelfNode()
57         if not id:
58             id = newID()
59         return self._Node(id, host, port)
60         
61     def checkpoint(self, auto=0):
62         self.token_secrets.insert(0, newID())
63         if len(self.token_secrets) > 3:
64             self.token_secrets.pop()
65         self.store.saveSelfNode(self.node.id)
66         self.store.dumpRoutingTable(self.table.buckets)
67         self.store.expireValues(self.config['KEY_EXPIRE'])
68         self.refreshTable()
69         if auto:
70             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
71                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
72                               self.checkpoint, (1,))
73         
74     def _loadRoutingTable(self):
75         """
76             load routing table nodes from database
77             it's usually a good idea to call refreshTable(force=1) after loading the table
78         """
79         nodes = self.store.getRoutingTable()
80         for rec in nodes:
81             n = self.Node(rec[0], rec[1], int(rec[2]))
82             self.table.insertNode(n, contacted=0)
83             
84
85     #######
86     #######  LOCAL INTERFACE    - use these methods!
87     def addContact(self, host, port, callback=None, errback=None):
88         """
89             ping this node and add the contact info to the table on pong!
90         """
91         n = self.Node(NULL_ID, host, port)
92         self.sendJoin(n, callback=callback, errback=errback)
93
94     ## this call is async!
95     def findNode(self, id, callback, errback=None):
96         """ returns the contact info for node, or the k closest nodes, from the global table """
97         # get K nodes out of local table/cache, or the node we want
98         nodes = self.table.findNodes(id)
99         d = Deferred()
100         if errback:
101             d.addCallbacks(callback, errback)
102         else:
103             d.addCallback(callback)
104         if len(nodes) == 1 and nodes[0].id == id :
105             d.callback(nodes)
106         else:
107             # create our search state
108             state = FindNode(self, id, d.callback, self.config)
109             reactor.callLater(0, state.goWithNodes, nodes)
110     
111     def insertNode(self, n, contacted=1):
112         """
113         insert a node in our local table, pinging oldest contact in bucket, if necessary
114         
115         If all you have is a host/port, then use addContact, which calls this method after
116         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
117         a node into the table without it's peer-ID.  That means of course the node passed into this
118         method needs to be a properly formed Node object with a valid ID.
119         """
120         old = self.table.insertNode(n, contacted=contacted)
121         if (old and old.id != self.node.id and
122             (datetime.now() - old.lastSeen) > 
123              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
124             # the bucket is full, check to see if old node is still around and if so, replace it
125             
126             ## these are the callbacks used when we ping the oldest node in a bucket
127             def _staleNodeHandler(oldnode=old, newnode = n):
128                 """ called if the pinged node never responds """
129                 self.table.replaceStaleNode(old, newnode)
130             
131             def _notStaleNodeHandler(dict, old=old):
132                 """ called when we get a pong from the old node """
133                 dict = dict['rsp']
134                 if dict['id'] == old.id:
135                     self.table.justSeenNode(old.id)
136             
137             df = old.ping(self.node.id)
138             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
139
140     def sendJoin(self, node, callback=None, errback=None):
141         """
142             ping a node
143         """
144         df = node.join(self.node.id)
145         ## these are the callbacks we use when we issue a PING
146         def _pongHandler(dict, node=node, self=self, callback=callback):
147             n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
148             self.insertNode(n)
149             if callback:
150                 callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
151         def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
152             table.nodeFailed(node)
153             if errback:
154                 errback()
155             else:
156                 callback(None)
157         
158         df.addCallbacks(_pongHandler,_defaultPong)
159
160     def findCloseNodes(self, callback=lambda a: None, errback = None):
161         """
162             This does a findNode on the ID one away from our own.  
163             This will allow us to populate our table with nodes on our network closest to our own.
164             This is called as soon as we start up with an empty table
165         """
166         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
167         self.findNode(id, callback, errback)
168
169     def refreshTable(self, force=0):
170         """
171             force=1 will refresh table regardless of last bucket access time
172         """
173         def callback(nodes):
174             pass
175     
176         for bucket in self.table.buckets:
177             if force or (datetime.now() - bucket.lastAccessed > 
178                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
179                 id = newIDInRange(bucket.min, bucket.max)
180                 self.findNode(id, callback)
181
182     def stats(self):
183         """
184         Returns (num_contacts, num_nodes)
185         num_contacts: number contacts in our routing table
186         num_nodes: number of nodes estimated in the entire dht
187         """
188         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
189         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
190         return (num_contacts, num_nodes)
191     
192     def shutdown(self):
193         """Closes the port and cancels pending later calls."""
194         self.listenport.stopListening()
195         try:
196             self.next_checkpoint.cancel()
197         except:
198             pass
199         self.store.close()
200
201     #### Remote Interface - called by remote nodes
202     def krpc_ping(self, id, _krpc_sender):
203         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
204         self.insertNode(n, contacted=0)
205         return {"id" : self.node.id}
206         
207     def krpc_join(self, id, _krpc_sender):
208         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
209         self.insertNode(n, contacted=0)
210         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
211         
212     def krpc_find_node(self, target, id, _krpc_sender):
213         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
214         self.insertNode(n, contacted=0)
215         nodes = self.table.findNodes(target)
216         nodes = map(lambda node: node.contactInfo(), nodes)
217         token = sha(self.token_secrets[0] + _krpc_sender[0]).digest()
218         return {"nodes" : nodes, "token" : token, "id" : self.node.id}
219
220
221 ## This class provides read-only access to the DHT, valueForKey
222 ## you probably want to use this mixin and provide your own write methods
223 class KhashmirRead(KhashmirBase):
224     _Node = KNodeRead
225
226     ## also async
227     def valueForKey(self, key, callback, searchlocal = 1):
228         """ returns the values found for key in global table
229             callback will be called with a list of values for each peer that returns unique values
230             final callback will be an empty list - probably should change to 'more coming' arg
231         """
232         nodes = self.table.findNodes(key)
233         
234         # get locals
235         if searchlocal:
236             l = self.store.retrieveValues(key)
237             if len(l) > 0:
238                 reactor.callLater(0, callback, key, l)
239         else:
240             l = []
241         
242         # create our search state
243         state = GetValue(self, key, callback, self.config)
244         reactor.callLater(0, state.goWithNodes, nodes, l)
245
246     #### Remote Interface - called by remote nodes
247     def krpc_find_value(self, key, id, _krpc_sender):
248         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
249         self.insertNode(n, contacted=0)
250     
251         l = self.store.retrieveValues(key)
252         if len(l) > 0:
253             return {'values' : l, "id": self.node.id}
254         else:
255             nodes = self.table.findNodes(key)
256             nodes = map(lambda node: node.contactInfo(), nodes)
257             return {'nodes' : nodes, "id": self.node.id}
258
259 ###  provides a generic write method, you probably don't want to deploy something that allows
260 ###  arbitrary value storage
261 class KhashmirWrite(KhashmirRead):
262     _Node = KNodeWrite
263     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
264     def storeValueForKey(self, key, value, callback=None):
265         """ stores the value and origination time for key in the global table, returns immediately, no status 
266             in this implementation, peers respond but don't indicate status to storing values
267             a key can have many values
268         """
269         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
270             if not response:
271                 # default callback
272                 def _storedValueHandler(key, value, sender):
273                     pass
274                 response=_storedValueHandler
275             action = StoreValue(self.table, key, value, response, self.config)
276             reactor.callLater(0, action.goWithNodes, nodes)
277             
278         # this call is asynch
279         self.findNode(key, _storeValueForKey)
280                     
281     #### Remote Interface - called by remote nodes
282     def krpc_store_value(self, key, value, token, id, _krpc_sender):
283         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
284         self.insertNode(n, contacted=0)
285         for secret in self.token_secrets:
286             this_token = sha(secret + _krpc_sender[0]).digest()
287             if token == this_token:
288                 self.store.storeValue(key, value)
289                 return {"id" : self.node.id}
290         raise krpc.KrpcError, (krpc.KRPC_ERROR_INVALID_TOKEN, 'token is invalid, do a find_nodes to get a fresh one')
291
292 # the whole shebang, for testing
293 class Khashmir(KhashmirWrite):
294     _Node = KNodeWrite
295
296 class SimpleTests(unittest.TestCase):
297     
298     timeout = 10
299     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
300                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
301                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
302                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
303                     'KEY_EXPIRE': 3600, 'SPEW': False, }
304
305     def setUp(self):
306         krpc.KRPC.noisy = 0
307         d = self.DHT_DEFAULTS.copy()
308         d['PORT'] = 4044
309         self.a = Khashmir(d)
310         d = self.DHT_DEFAULTS.copy()
311         d['PORT'] = 4045
312         self.b = Khashmir(d)
313         
314     def tearDown(self):
315         self.a.shutdown()
316         self.b.shutdown()
317         os.unlink(self.a.store.db)
318         os.unlink(self.b.store.db)
319
320     def testAddContact(self):
321         self.failUnlessEqual(len(self.a.table.buckets), 1)
322         self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
323
324         self.failUnlessEqual(len(self.b.table.buckets), 1)
325         self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
326
327         self.a.addContact('127.0.0.1', 4045)
328         reactor.iterate()
329         reactor.iterate()
330         reactor.iterate()
331         reactor.iterate()
332
333         self.failUnlessEqual(len(self.a.table.buckets), 1)
334         self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
335         self.failUnlessEqual(len(self.b.table.buckets), 1)
336         self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
337
338     def testStoreRetrieve(self):
339         self.a.addContact('127.0.0.1', 4045)
340         reactor.iterate()
341         reactor.iterate()
342         reactor.iterate()
343         reactor.iterate()
344         self.got = 0
345         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
346         reactor.iterate()
347         reactor.iterate()
348         reactor.iterate()
349         reactor.iterate()
350         reactor.iterate()
351         reactor.iterate()
352         self.a.valueForKey(sha('foo').digest(), self._cb)
353         reactor.iterate()
354         reactor.iterate()
355         reactor.iterate()
356         reactor.iterate()
357         reactor.iterate()
358         reactor.iterate()
359         reactor.iterate()
360
361     def _cb(self, key, val):
362         if not val:
363             self.failUnlessEqual(self.got, 1)
364         elif 'foobar' in val:
365             self.got = 1
366
367
368 class MultiTest(unittest.TestCase):
369     
370     timeout = 30
371     num = 20
372     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
373                     'CHECKPOINT_INTERVAL': 300, 'CONCURRENT_REQS': 4,
374                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
375                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
376                     'KEY_EXPIRE': 3600, 'SPEW': False, }
377
378     def _done(self, val):
379         self.done = 1
380         
381     def setUp(self):
382         self.l = []
383         self.startport = 4088
384         for i in range(self.num):
385             d = self.DHT_DEFAULTS.copy()
386             d['PORT'] = self.startport + i
387             self.l.append(Khashmir(d))
388         reactor.iterate()
389         reactor.iterate()
390         
391         for i in self.l:
392             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
393             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
394             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
395             reactor.iterate()
396             reactor.iterate()
397             reactor.iterate() 
398             
399         for i in self.l:
400             self.done = 0
401             i.findCloseNodes(self._done)
402             while not self.done:
403                 reactor.iterate()
404         for i in self.l:
405             self.done = 0
406             i.findCloseNodes(self._done)
407             while not self.done:
408                 reactor.iterate()
409
410     def tearDown(self):
411         for i in self.l:
412             i.shutdown()
413             os.unlink(i.store.db)
414             
415         reactor.iterate()
416         
417     def testStoreRetrieve(self):
418         for i in range(10):
419             K = newID()
420             V = newID()
421             
422             for a in range(3):
423                 self.done = 0
424                 def _scb(key, value, result):
425                     self.done = 1
426                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
427                 while not self.done:
428                     reactor.iterate()
429
430
431                 def _rcb(key, val):
432                     if not val:
433                         self.done = 1
434                         self.failUnlessEqual(self.got, 1)
435                     elif V in val:
436                         self.got = 1
437                 for x in range(3):
438                     self.got = 0
439                     self.done = 0
440                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
441                     while not self.done:
442                         reactor.iterate()