c4d018c1a1980a5c5c10f0a4ad70c6dedaba9298
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 import warnings
5 warnings.simplefilter("ignore", DeprecationWarning)
6
7 from datetime import datetime, timedelta
8 from random import randrange
9 from sha import sha
10 import os
11
12 from twisted.internet.defer import Deferred
13 from twisted.internet import protocol, reactor
14 from twisted.trial import unittest
15
16 from db import DB
17 from ktable import KTable
18 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
19 from khash import newID, newIDInRange
20 from actions import FindNode, GetValue, KeyExpirer, StoreValue
21 import krpc
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self.store = DB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         #self.app = service.Application("krpc")
37         self.udp = krpc.hostbroker(self)
38         self.udp.protocol = krpc.KRPC
39         self.listenport = reactor.listenUDP(self.port, self.udp)
40         self._loadRoutingTable()
41         self.expirer = KeyExpirer(self.store, config)
42         self.refreshTable(force=1)
43         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
44
45     def Node(self):
46         n = self._Node()
47         n.table = self.table
48         return n
49     
50     def __del__(self):
51         self.listenport.stopListening()
52         
53     def _loadSelfNode(self, host, port):
54         id = self.store.getSelfNode()
55         if not id:
56             id = newID()
57         return self._Node().init(id, host, port)
58         
59     def checkpoint(self, auto=0):
60         self.store.saveSelfNode(self.node.id)
61         self.store.dumpRoutingTable(self.table.buckets)
62         self.refreshTable()
63         if auto:
64             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
65                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
66                               self.checkpoint, (1,))
67         
68     def _loadRoutingTable(self):
69         """
70             load routing table nodes from database
71             it's usually a good idea to call refreshTable(force=1) after loading the table
72         """
73         nodes = self.store.getRoutingTable()
74         for rec in nodes:
75             n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
76             n.conn = self.udp.connectionForAddr((n.host, n.port))
77             self.table.insertNode(n, contacted=0)
78             
79     def _update_node(self, id, host, port):
80         n = self.Node().init(id, host, port)
81         n.conn = self.udp.connectionForAddr((host, port))
82         self.insertNode(n, contacted=0)
83     
84
85     #######
86     #######  LOCAL INTERFACE    - use these methods!
87     def addContact(self, host, port, callback=None):
88         """
89             ping this node and add the contact info to the table on pong!
90         """
91         n =self.Node().init(NULL_ID, host, port) 
92         n.conn = self.udp.connectionForAddr((n.host, n.port))
93         self.sendPing(n, callback=callback)
94
95     ## this call is async!
96     def findNode(self, id, callback, errback=None):
97         """ returns the contact info for node, or the k closest nodes, from the global table """
98         # get K nodes out of local table/cache, or the node we want
99         nodes = self.table.findNodes(id)
100         d = Deferred()
101         if errback:
102             d.addCallbacks(callback, errback)
103         else:
104             d.addCallback(callback)
105         if len(nodes) == 1 and nodes[0].id == id :
106             d.callback(nodes)
107         else:
108             # create our search state
109             state = FindNode(self, id, d.callback, self.config)
110             reactor.callLater(0, state.goWithNodes, nodes)
111     
112     def insertNode(self, n, contacted=1):
113         """
114         insert a node in our local table, pinging oldest contact in bucket, if necessary
115         
116         If all you have is a host/port, then use addContact, which calls this method after
117         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
118         a node into the table without it's peer-ID.  That means of course the node passed into this
119         method needs to be a properly formed Node object with a valid ID.
120         """
121         old = self.table.insertNode(n, contacted=contacted)
122         if (old and old.id != self.node.id and
123             (datetime.now() - old.lastSeen) > 
124              timedelta(seconds=self.config['MIN_PING_INTERVAL'])):
125             # the bucket is full, check to see if old node is still around and if so, replace it
126             
127             ## these are the callbacks used when we ping the oldest node in a bucket
128             def _staleNodeHandler(oldnode=old, newnode = n):
129                 """ called if the pinged node never responds """
130                 self.table.replaceStaleNode(old, newnode)
131             
132             def _notStaleNodeHandler(dict, old=old):
133                 """ called when we get a pong from the old node """
134                 dict = dict['rsp']
135                 if dict['id'] == old.id:
136                     self.table.justSeenNode(old.id)
137             
138             df = old.ping(self.node.id)
139             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
140
141     def sendPing(self, node, callback=None):
142         """
143             ping a node
144         """
145         df = node.ping(self.node.id)
146         ## these are the callbacks we use when we issue a PING
147         def _pongHandler(dict, node=node, table=self.table, callback=callback):
148             _krpc_sender = dict['_krpc_sender']
149             dict = dict['rsp']
150             sender = {'id' : dict['id']}
151             sender['host'] = _krpc_sender[0]
152             sender['port'] = _krpc_sender[1]
153             n = self.Node().initWithDict(sender)
154             n.conn = self.udp.connectionForAddr((n.host, n.port))
155             table.insertNode(n)
156             if callback:
157                 callback()
158         def _defaultPong(err, node=node, table=self.table, callback=callback):
159             table.nodeFailed(node)
160             if callback:
161                 callback()
162         
163         df.addCallbacks(_pongHandler,_defaultPong)
164
165     def findCloseNodes(self, callback=lambda a: None):
166         """
167             This does a findNode on the ID one away from our own.  
168             This will allow us to populate our table with nodes on our network closest to our own.
169             This is called as soon as we start up with an empty table
170         """
171         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
172         self.findNode(id, callback)
173
174     def refreshTable(self, force=0):
175         """
176             force=1 will refresh table regardless of last bucket access time
177         """
178         def callback(nodes):
179             pass
180     
181         for bucket in self.table.buckets:
182             if force or (datetime.now() - bucket.lastAccessed > 
183                          timedelta(seconds=self.config['BUCKET_STALENESS'])):
184                 id = newIDInRange(bucket.min, bucket.max)
185                 self.findNode(id, callback)
186
187     def stats(self):
188         """
189         Returns (num_contacts, num_nodes)
190         num_contacts: number contacts in our routing table
191         num_nodes: number of nodes estimated in the entire dht
192         """
193         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
194         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
195         return (num_contacts, num_nodes)
196     
197     def shutdown(self):
198         """Closes the port and cancels pending later calls."""
199         self.listenport.stopListening()
200         try:
201             self.next_checkpoint.cancel()
202         except:
203             pass
204         self.expirer.shutdown()
205         self.store.close()
206
207     #### Remote Interface - called by remote nodes
208     def krpc_ping(self, id, _krpc_sender):
209         self._update_node(id, _krpc_sender[0], _krpc_sender[1])
210         return {"id" : self.node.id}
211         
212     def krpc_find_node(self, target, id, _krpc_sender):
213         self._update_node(id, _krpc_sender[0], _krpc_sender[1])
214         nodes = self.table.findNodes(target)
215         nodes = map(lambda node: node.senderDict(), nodes)
216         return {"nodes" : nodes, "id" : self.node.id}
217
218
219 ## This class provides read-only access to the DHT, valueForKey
220 ## you probably want to use this mixin and provide your own write methods
221 class KhashmirRead(KhashmirBase):
222     _Node = KNodeRead
223
224     ## also async
225     def valueForKey(self, key, callback, searchlocal = 1):
226         """ returns the values found for key in global table
227             callback will be called with a list of values for each peer that returns unique values
228             final callback will be an empty list - probably should change to 'more coming' arg
229         """
230         nodes = self.table.findNodes(key)
231         
232         # get locals
233         if searchlocal:
234             l = self.store.retrieveValues(key)
235             if len(l) > 0:
236                 reactor.callLater(0, callback, key, l)
237         else:
238             l = []
239         
240         # create our search state
241         state = GetValue(self, key, callback, self.config)
242         reactor.callLater(0, state.goWithNodes, nodes, l)
243
244     #### Remote Interface - called by remote nodes
245     def krpc_find_value(self, key, id, _krpc_sender):
246         self._update_node(id, _krpc_sender[0], _krpc_sender[1])
247     
248         l = self.store.retrieveValues(key)
249         if len(l) > 0:
250             return {'values' : l, "id": self.node.id}
251         else:
252             nodes = self.table.findNodes(key)
253             nodes = map(lambda node: node.senderDict(), nodes)
254             return {'nodes' : nodes, "id": self.node.id}
255
256 ###  provides a generic write method, you probably don't want to deploy something that allows
257 ###  arbitrary value storage
258 class KhashmirWrite(KhashmirRead):
259     _Node = KNodeWrite
260     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
261     def storeValueForKey(self, key, value, callback=None):
262         """ stores the value for key in the global table, returns immediately, no status 
263             in this implementation, peers respond but don't indicate status to storing values
264             a key can have many values
265         """
266         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
267             if not response:
268                 # default callback
269                 def _storedValueHandler(key, value, sender):
270                     pass
271                 response=_storedValueHandler
272             action = StoreValue(self.table, key, value, response, self.config)
273             reactor.callLater(0, action.goWithNodes, nodes)
274             
275         # this call is asynch
276         self.findNode(key, _storeValueForKey)
277                     
278     #### Remote Interface - called by remote nodes
279     def krpc_store_value(self, key, value, id, _krpc_sender):
280         self._update_node(id, _krpc_sender[0], _krpc_sender[1])
281         self.store.storeValue(key, value)
282         return {"id" : self.node.id}
283
284 # the whole shebang, for testing
285 class Khashmir(KhashmirWrite):
286     _Node = KNodeWrite
287
288 class SimpleTests(unittest.TestCase):
289     
290     timeout = 10
291     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
292                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
293                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
294                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
295                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
296                     'KE_AGE': 3600, }
297
298     def setUp(self):
299         krpc.KRPC.noisy = 0
300         d = self.DHT_DEFAULTS.copy()
301         d['PORT'] = 4044
302         self.a = Khashmir(d)
303         d = self.DHT_DEFAULTS.copy()
304         d['PORT'] = 4045
305         self.b = Khashmir(d)
306         
307     def tearDown(self):
308         self.a.shutdown()
309         self.b.shutdown()
310         os.unlink(self.a.store.db)
311         os.unlink(self.b.store.db)
312
313     def testAddContact(self):
314         self.assertEqual(len(self.a.table.buckets), 1)
315         self.assertEqual(len(self.a.table.buckets[0].l), 0)
316
317         self.assertEqual(len(self.b.table.buckets), 1)
318         self.assertEqual(len(self.b.table.buckets[0].l), 0)
319
320         self.a.addContact('127.0.0.1', 4045)
321         reactor.iterate()
322         reactor.iterate()
323         reactor.iterate()
324         reactor.iterate()
325
326         self.assertEqual(len(self.a.table.buckets), 1)
327         self.assertEqual(len(self.a.table.buckets[0].l), 1)
328         self.assertEqual(len(self.b.table.buckets), 1)
329         self.assertEqual(len(self.b.table.buckets[0].l), 1)
330
331     def testStoreRetrieve(self):
332         self.a.addContact('127.0.0.1', 4045)
333         reactor.iterate()
334         reactor.iterate()
335         reactor.iterate()
336         reactor.iterate()
337         self.got = 0
338         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
339         reactor.iterate()
340         reactor.iterate()
341         reactor.iterate()
342         reactor.iterate()
343         reactor.iterate()
344         reactor.iterate()
345         self.a.valueForKey(sha('foo').digest(), self._cb)
346         reactor.iterate()
347         reactor.iterate()
348         reactor.iterate()
349         reactor.iterate()
350         reactor.iterate()
351         reactor.iterate()
352         reactor.iterate()
353
354     def _cb(self, key, val):
355         if not val:
356             self.assertEqual(self.got, 1)
357         elif 'foobar' in val:
358             self.got = 1
359
360
361 class MultiTest(unittest.TestCase):
362     
363     timeout = 30
364     num = 20
365     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
366                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
367                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
368                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
369                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
370                     'KE_AGE': 3600, }
371
372     def _done(self, val):
373         self.done = 1
374         
375     def setUp(self):
376         self.l = []
377         self.startport = 4088
378         for i in range(self.num):
379             d = self.DHT_DEFAULTS.copy()
380             d['PORT'] = self.startport + i
381             self.l.append(Khashmir(d))
382         reactor.iterate()
383         reactor.iterate()
384         
385         for i in self.l:
386             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
387             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
388             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
389             reactor.iterate()
390             reactor.iterate()
391             reactor.iterate() 
392             
393         for i in self.l:
394             self.done = 0
395             i.findCloseNodes(self._done)
396             while not self.done:
397                 reactor.iterate()
398         for i in self.l:
399             self.done = 0
400             i.findCloseNodes(self._done)
401             while not self.done:
402                 reactor.iterate()
403
404     def tearDown(self):
405         for i in self.l:
406             i.shutdown()
407             os.unlink(i.store.db)
408             
409         reactor.iterate()
410         
411     def testStoreRetrieve(self):
412         for i in range(10):
413             K = newID()
414             V = newID()
415             
416             for a in range(3):
417                 self.done = 0
418                 def _scb(key, value, result):
419                     self.done = 1
420                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
421                 while not self.done:
422                     reactor.iterate()
423
424
425                 def _rcb(key, val):
426                     if not val:
427                         self.done = 1
428                         self.assertEqual(self.got, 1)
429                     elif V in val:
430                         self.got = 1
431                 for x in range(3):
432                     self.got = 0
433                     self.done = 0
434                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
435                     while not self.done:
436                         reactor.iterate()