More work on the TODO.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, GetValue, KeyExpirer, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         #self.app = service.Application("krpc")
37         self.udp = krpc.hostbroker(self, config)
38         self.udp.protocol = krpc.KRPC
39         self.listenport = reactor.listenUDP(self.port, self.udp)
40         self._loadRoutingTable()
41         self.expirer = KeyExpirer(self.store, config)
42         self.refreshTable(force=1)
43         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
44
45     def Node(self, id, host = None, port = None):
46         """Create a new node."""
47         n = self._Node(id, host, port)
48         n.table = self.table
49         n.conn = self.udp.connectionForAddr((n.host, n.port))
50         return n
51     
52     def __del__(self):
53         self.listenport.stopListening()
54         
55     def _loadSelfNode(self, host, port):
56         id = self.store.getSelfNode()
57         if not id:
58             id = newID()
59         return self._Node(id, host, port)
60         
61     def checkpoint(self, auto=0):
62         self.store.saveSelfNode(self.node.id)
63         self.store.dumpRoutingTable(self.table.buckets)
64         self.refreshTable()
65         if auto:
66             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
67                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
68                               self.checkpoint, (1,))
69         
70     def _loadRoutingTable(self):
71         """
72             load routing table nodes from database
73             it's usually a good idea to call refreshTable(force=1) after loading the table
74         """
75         nodes = self.store.getRoutingTable()
76         for rec in nodes:
77             n = self.Node(rec[0], rec[1], int(rec[2]))
78             self.table.insertNode(n, contacted=0)
79             
80
81     #######
82     #######  LOCAL INTERFACE    - use these methods!
83     def addContact(self, host, port, callback=None, errback=None):
84         """
85             ping this node and add the contact info to the table on pong!
86         """
87         n = self.Node(NULL_ID, host, port)
88         self.sendJoin(n, callback=callback, errback=errback)
89
90     ## this call is async!
91     def findNode(self, id, callback, errback=None):
92         """ returns the contact info for node, or the k closest nodes, from the global table """
93         # get K nodes out of local table/cache, or the node we want
94         nodes = self.table.findNodes(id)
95         d = Deferred()
96         if errback:
97             d.addCallbacks(callback, errback)
98         else:
99             d.addCallback(callback)
100         if len(nodes) == 1 and nodes[0].id == id :
101             d.callback(nodes)
102         else:
103             # create our search state
104             state = FindNode(self, id, d.callback, self.config)
105             reactor.callLater(0, state.goWithNodes, nodes)
106     
107     def insertNode(self, n, contacted=1):
108         """
109         insert a node in our local table, pinging oldest contact in bucket, if necessary
110         
111         If all you have is a host/port, then use addContact, which calls this method after
112         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
113         a node into the table without it's peer-ID.  That means of course the node passed into this
114         method needs to be a properly formed Node object with a valid ID.
115         """
116         old = self.table.insertNode(n, contacted=contacted)
117         if (old and old.id != self.node.id and
118             (datetime.now() - old.lastSeen) > 
119              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
120             # the bucket is full, check to see if old node is still around and if so, replace it
121             
122             ## these are the callbacks used when we ping the oldest node in a bucket
123             def _staleNodeHandler(oldnode=old, newnode = n):
124                 """ called if the pinged node never responds """
125                 self.table.replaceStaleNode(old, newnode)
126             
127             def _notStaleNodeHandler(dict, old=old):
128                 """ called when we get a pong from the old node """
129                 dict = dict['rsp']
130                 if dict['id'] == old.id:
131                     self.table.justSeenNode(old.id)
132             
133             df = old.ping(self.node.id)
134             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
135
136     def sendJoin(self, node, callback=None, errback=None):
137         """
138             ping a node
139         """
140         df = node.join(self.node.id)
141         ## these are the callbacks we use when we issue a PING
142         def _pongHandler(dict, node=node, self=self, callback=callback):
143             n = self.Node(dict['rsp']['id'], dict['_krpc_sender'][0], dict['_krpc_sender'][1])
144             self.insertNode(n)
145             if callback:
146                 callback((dict['rsp']['ip_addr'], dict['rsp']['port']))
147         def _defaultPong(err, node=node, table=self.table, callback=callback, errback=errback):
148             table.nodeFailed(node)
149             if errback:
150                 errback()
151             else:
152                 callback(None)
153         
154         df.addCallbacks(_pongHandler,_defaultPong)
155
156     def findCloseNodes(self, callback=lambda a: None, errback = None):
157         """
158             This does a findNode on the ID one away from our own.  
159             This will allow us to populate our table with nodes on our network closest to our own.
160             This is called as soon as we start up with an empty table
161         """
162         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
163         self.findNode(id, callback, errback)
164
165     def refreshTable(self, force=0):
166         """
167             force=1 will refresh table regardless of last bucket access time
168         """
169         def callback(nodes):
170             pass
171     
172         for bucket in self.table.buckets:
173             if force or (datetime.now() - bucket.lastAccessed > 
174                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
175                 id = newIDInRange(bucket.min, bucket.max)
176                 self.findNode(id, callback)
177
178     def stats(self):
179         """
180         Returns (num_contacts, num_nodes)
181         num_contacts: number contacts in our routing table
182         num_nodes: number of nodes estimated in the entire dht
183         """
184         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
185         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
186         return (num_contacts, num_nodes)
187     
188     def shutdown(self):
189         """Closes the port and cancels pending later calls."""
190         self.listenport.stopListening()
191         try:
192             self.next_checkpoint.cancel()
193         except:
194             pass
195         self.expirer.shutdown()
196         self.store.close()
197
198     #### Remote Interface - called by remote nodes
199     def krpc_ping(self, id, _krpc_sender):
200         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
201         self.insertNode(n, contacted=0)
202         return {"id" : self.node.id}
203         
204     def krpc_join(self, id, _krpc_sender):
205         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
206         self.insertNode(n, contacted=0)
207         return {"ip_addr" : _krpc_sender[0], "port" : _krpc_sender[1], "id" : self.node.id}
208         
209     def krpc_find_node(self, target, id, _krpc_sender):
210         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
211         self.insertNode(n, contacted=0)
212         nodes = self.table.findNodes(target)
213         nodes = map(lambda node: node.senderDict(), nodes)
214         return {"nodes" : nodes, "id" : self.node.id}
215
216
217 ## This class provides read-only access to the DHT, valueForKey
218 ## you probably want to use this mixin and provide your own write methods
219 class KhashmirRead(KhashmirBase):
220     _Node = KNodeRead
221
222     ## also async
223     def valueForKey(self, key, callback, searchlocal = 1):
224         """ returns the values found for key in global table
225             callback will be called with a list of values for each peer that returns unique values
226             final callback will be an empty list - probably should change to 'more coming' arg
227         """
228         nodes = self.table.findNodes(key)
229         
230         # get locals
231         if searchlocal:
232             l = self.store.retrieveValues(key)
233             if len(l) > 0:
234                 reactor.callLater(0, callback, key, l)
235         else:
236             l = []
237         
238         # create our search state
239         state = GetValue(self, key, callback, self.config)
240         reactor.callLater(0, state.goWithNodes, nodes, l)
241
242     #### Remote Interface - called by remote nodes
243     def krpc_find_value(self, key, id, _krpc_sender):
244         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
245         self.insertNode(n, contacted=0)
246     
247         l = self.store.retrieveValues(key)
248         if len(l) > 0:
249             return {'values' : l, "id": self.node.id}
250         else:
251             nodes = self.table.findNodes(key)
252             nodes = map(lambda node: node.senderDict(), nodes)
253             return {'nodes' : nodes, "id": self.node.id}
254
255 ###  provides a generic write method, you probably don't want to deploy something that allows
256 ###  arbitrary value storage
257 class KhashmirWrite(KhashmirRead):
258     _Node = KNodeWrite
259     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
260     def storeValueForKey(self, key, value, originated, callback=None):
261         """ stores the value and origination time for key in the global table, returns immediately, no status 
262             in this implementation, peers respond but don't indicate status to storing values
263             a key can have many values
264         """
265         def _storeValueForKey(nodes, key=key, value=value, originated=originated, response=callback , table=self.table):
266             if not response:
267                 # default callback
268                 def _storedValueHandler(key, value, sender):
269                     pass
270                 response=_storedValueHandler
271             action = StoreValue(self.table, key, value, originated, response, self.config)
272             reactor.callLater(0, action.goWithNodes, nodes)
273             
274         # this call is asynch
275         self.findNode(key, _storeValueForKey)
276                     
277     #### Remote Interface - called by remote nodes
278     def krpc_store_value(self, key, value, originated, id, _krpc_sender):
279         n = self.Node(id, _krpc_sender[0], _krpc_sender[1])
280         self.insertNode(n, contacted=0)
281         self.store.storeValue(key, value, originated)
282         return {"id" : self.node.id}
283
284 # the whole shebang, for testing
285 class Khashmir(KhashmirWrite):
286     _Node = KNodeWrite
287
288 class SimpleTests(unittest.TestCase):
289     
290     timeout = 10
291     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
292                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
293                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
294                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
295                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
296                     'KE_AGE': 3600, 'SPEW': False, }
297
298     def setUp(self):
299         krpc.KRPC.noisy = 0
300         d = self.DHT_DEFAULTS.copy()
301         d['PORT'] = 4044
302         self.a = Khashmir(d)
303         d = self.DHT_DEFAULTS.copy()
304         d['PORT'] = 4045
305         self.b = Khashmir(d)
306         
307     def tearDown(self):
308         self.a.shutdown()
309         self.b.shutdown()
310         os.unlink(self.a.store.db)
311         os.unlink(self.b.store.db)
312
313     def testAddContact(self):
314         self.failUnlessEqual(len(self.a.table.buckets), 1)
315         self.failUnlessEqual(len(self.a.table.buckets[0].l), 0)
316
317         self.failUnlessEqual(len(self.b.table.buckets), 1)
318         self.failUnlessEqual(len(self.b.table.buckets[0].l), 0)
319
320         self.a.addContact('127.0.0.1', 4045)
321         reactor.iterate()
322         reactor.iterate()
323         reactor.iterate()
324         reactor.iterate()
325
326         self.failUnlessEqual(len(self.a.table.buckets), 1)
327         self.failUnlessEqual(len(self.a.table.buckets[0].l), 1)
328         self.failUnlessEqual(len(self.b.table.buckets), 1)
329         self.failUnlessEqual(len(self.b.table.buckets[0].l), 1)
330
331     def testStoreRetrieve(self):
332         self.a.addContact('127.0.0.1', 4045)
333         reactor.iterate()
334         reactor.iterate()
335         reactor.iterate()
336         reactor.iterate()
337         self.got = 0
338         self.a.storeValueForKey(sha('foo').digest(), 'foobar', datetime.utcnow())
339         reactor.iterate()
340         reactor.iterate()
341         reactor.iterate()
342         reactor.iterate()
343         reactor.iterate()
344         reactor.iterate()
345         self.a.valueForKey(sha('foo').digest(), self._cb)
346         reactor.iterate()
347         reactor.iterate()
348         reactor.iterate()
349         reactor.iterate()
350         reactor.iterate()
351         reactor.iterate()
352         reactor.iterate()
353
354     def _cb(self, key, val):
355         if not val:
356             self.failUnlessEqual(self.got, 1)
357         elif 'foobar' in val:
358             self.got = 1
359
360
361 class MultiTest(unittest.TestCase):
362     
363     timeout = 30
364     num = 20
365     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
366                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
367                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
368                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
369                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
370                     'KE_AGE': 3600, 'SPEW': False, }
371
372     def _done(self, val):
373         self.done = 1
374         
375     def setUp(self):
376         self.l = []
377         self.startport = 4088
378         for i in range(self.num):
379             d = self.DHT_DEFAULTS.copy()
380             d['PORT'] = self.startport + i
381             self.l.append(Khashmir(d))
382         reactor.iterate()
383         reactor.iterate()
384         
385         for i in self.l:
386             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
387             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
388             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
389             reactor.iterate()
390             reactor.iterate()
391             reactor.iterate() 
392             
393         for i in self.l:
394             self.done = 0
395             i.findCloseNodes(self._done)
396             while not self.done:
397                 reactor.iterate()
398         for i in self.l:
399             self.done = 0
400             i.findCloseNodes(self._done)
401             while not self.done:
402                 reactor.iterate()
403
404     def tearDown(self):
405         for i in self.l:
406             i.shutdown()
407             os.unlink(i.store.db)
408             
409         reactor.iterate()
410         
411     def testStoreRetrieve(self):
412         for i in range(10):
413             K = newID()
414             V = newID()
415             
416             for a in range(3):
417                 self.done = 0
418                 def _scb(key, value, result):
419                     self.done = 1
420                 self.l[randrange(0, self.num)].storeValueForKey(K, V, datetime.utcnow(), _scb)
421                 while not self.done:
422                     reactor.iterate()
423
424
425                 def _rcb(key, val):
426                     if not val:
427                         self.done = 1
428                         self.failUnlessEqual(self.got, 1)
429                     elif V in val:
430                         self.got = 1
431                 for x in range(3):
432                     self.got = 0
433                     self.done = 0
434                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
435                     while not self.done:
436                         reactor.iterate()