Change all unittests to use twisted's trial.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / khashmir.py
1 ## Copyright 2002-2004 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from time import time
5 from random import randrange
6 from sha import sha
7 import os
8 import sqlite  ## find this at http://pysqlite.sourceforge.net/
9
10 from twisted.internet.defer import Deferred
11 from twisted.internet import protocol, reactor
12 from twisted.trial import unittest
13
14 from ktable import KTable
15 from knode import KNodeBase, KNodeRead, KNodeWrite, NULL_ID
16 from khash import newID, newIDInRange
17 from actions import FindNode, GetValue, KeyExpirer, StoreValue
18 import krpc
19
20 class KhashmirDBExcept(Exception):
21     pass
22
23 # this is the base class, has base functionality and find node, no key-value mappings
24 class KhashmirBase(protocol.Factory):
25     _Node = KNodeBase
26     def __init__(self, config, cache_dir='/tmp'):
27         self.config = None
28         self.setup(config, cache_dir)
29         
30     def setup(self, config, cache_dir):
31         self.config = config
32         self.port = config['PORT']
33         self._findDB(os.path.join(cache_dir, 'khashmir.' + str(self.port) + '.db'))
34         self.node = self._loadSelfNode('', self.port)
35         self.table = KTable(self.node, config)
36         #self.app = service.Application("krpc")
37         self.udp = krpc.hostbroker(self)
38         self.udp.protocol = krpc.KRPC
39         self.listenport = reactor.listenUDP(self.port, self.udp)
40         self.last = time()
41         self._loadRoutingTable()
42         self.expirer = KeyExpirer(self.store, config)
43         self.refreshTable(force=1)
44         self.next_checkpoint = reactor.callLater(60, self.checkpoint, (1,))
45
46     def Node(self):
47         n = self._Node()
48         n.table = self.table
49         return n
50     
51     def __del__(self):
52         self.listenport.stopListening()
53         
54     def _loadSelfNode(self, host, port):
55         c = self.store.cursor()
56         c.execute('select id from self where num = 0;')
57         if c.rowcount > 0:
58             id = c.fetchone()[0]
59         else:
60             id = newID()
61         return self._Node().init(id, host, port)
62         
63     def _saveSelfNode(self):
64         c = self.store.cursor()
65         c.execute('delete from self where num = 0;')
66         c.execute("insert into self values (0, %s);", sqlite.encode(self.node.id))
67         self.store.commit()
68         
69     def checkpoint(self, auto=0):
70         self._saveSelfNode()
71         self._dumpRoutingTable()
72         self.refreshTable()
73         if auto:
74             self.next_checkpoint = reactor.callLater(randrange(int(self.config['CHECKPOINT_INTERVAL'] * .9), 
75                                         int(self.config['CHECKPOINT_INTERVAL'] * 1.1)), 
76                               self.checkpoint, (1,))
77         
78     def _findDB(self, db):
79         self.db = db
80         try:
81             os.stat(db)
82         except OSError:
83             self._createNewDB(db)
84         else:
85             self._loadDB(db)
86         
87     def _loadDB(self, db):
88         try:
89             self.store = sqlite.connect(db=db)
90             #self.store.autocommit = 0
91         except:
92             import traceback
93             raise KhashmirDBExcept, "Couldn't open DB", traceback.format_exc()
94         
95     def _createNewDB(self, db):
96         self.store = sqlite.connect(db=db)
97         s = """
98             create table kv (key binary, value binary, time timestamp, primary key (key, value));
99             create index kv_key on kv(key);
100             create index kv_timestamp on kv(time);
101             
102             create table nodes (id binary primary key, host text, port number);
103             
104             create table self (num number primary key, id binary);
105             """
106         c = self.store.cursor()
107         c.execute(s)
108         self.store.commit()
109
110     def _dumpRoutingTable(self):
111         """
112             save routing table nodes to the database
113         """
114         c = self.store.cursor()
115         c.execute("delete from nodes where id not NULL;")
116         for bucket in self.table.buckets:
117             for node in bucket.l:
118                 c.execute("insert into nodes values (%s, %s, %s);", (sqlite.encode(node.id), node.host, node.port))
119         self.store.commit()
120         
121     def _loadRoutingTable(self):
122         """
123             load routing table nodes from database
124             it's usually a good idea to call refreshTable(force=1) after loading the table
125         """
126         c = self.store.cursor()
127         c.execute("select * from nodes;")
128         for rec in c.fetchall():
129             n = self.Node().initWithDict({'id':rec[0], 'host':rec[1], 'port':int(rec[2])})
130             n.conn = self.udp.connectionForAddr((n.host, n.port))
131             self.table.insertNode(n, contacted=0)
132             
133
134     #######
135     #######  LOCAL INTERFACE    - use these methods!
136     def addContact(self, host, port, callback=None):
137         """
138             ping this node and add the contact info to the table on pong!
139         """
140         n =self.Node().init(NULL_ID, host, port) 
141         n.conn = self.udp.connectionForAddr((n.host, n.port))
142         self.sendPing(n, callback=callback)
143
144     ## this call is async!
145     def findNode(self, id, callback, errback=None):
146         """ returns the contact info for node, or the k closest nodes, from the global table """
147         # get K nodes out of local table/cache, or the node we want
148         nodes = self.table.findNodes(id)
149         d = Deferred()
150         if errback:
151             d.addCallbacks(callback, errback)
152         else:
153             d.addCallback(callback)
154         if len(nodes) == 1 and nodes[0].id == id :
155             d.callback(nodes)
156         else:
157             # create our search state
158             state = FindNode(self, id, d.callback, self.config)
159             reactor.callLater(0, state.goWithNodes, nodes)
160     
161     def insertNode(self, n, contacted=1):
162         """
163         insert a node in our local table, pinging oldest contact in bucket, if necessary
164         
165         If all you have is a host/port, then use addContact, which calls this method after
166         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
167         a node into the table without it's peer-ID.  That means of course the node passed into this
168         method needs to be a properly formed Node object with a valid ID.
169         """
170         old = self.table.insertNode(n, contacted=contacted)
171         if old and (time() - old.lastSeen) > self.config['MIN_PING_INTERVAL'] and old.id != self.node.id:
172             # the bucket is full, check to see if old node is still around and if so, replace it
173             
174             ## these are the callbacks used when we ping the oldest node in a bucket
175             def _staleNodeHandler(oldnode=old, newnode = n):
176                 """ called if the pinged node never responds """
177                 self.table.replaceStaleNode(old, newnode)
178             
179             def _notStaleNodeHandler(dict, old=old):
180                 """ called when we get a pong from the old node """
181                 dict = dict['rsp']
182                 if dict['id'] == old.id:
183                     self.table.justSeenNode(old.id)
184             
185             df = old.ping(self.node.id)
186             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
187
188     def sendPing(self, node, callback=None):
189         """
190             ping a node
191         """
192         df = node.ping(self.node.id)
193         ## these are the callbacks we use when we issue a PING
194         def _pongHandler(dict, node=node, table=self.table, callback=callback):
195             _krpc_sender = dict['_krpc_sender']
196             dict = dict['rsp']
197             sender = {'id' : dict['id']}
198             sender['host'] = _krpc_sender[0]
199             sender['port'] = _krpc_sender[1]
200             n = self.Node().initWithDict(sender)
201             n.conn = self.udp.connectionForAddr((n.host, n.port))
202             table.insertNode(n)
203             if callback:
204                 callback()
205         def _defaultPong(err, node=node, table=self.table, callback=callback):
206             table.nodeFailed(node)
207             if callback:
208                 callback()
209         
210         df.addCallbacks(_pongHandler,_defaultPong)
211
212     def findCloseNodes(self, callback=lambda a: None):
213         """
214             This does a findNode on the ID one away from our own.  
215             This will allow us to populate our table with nodes on our network closest to our own.
216             This is called as soon as we start up with an empty table
217         """
218         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
219         self.findNode(id, callback)
220
221     def refreshTable(self, force=0):
222         """
223             force=1 will refresh table regardless of last bucket access time
224         """
225         def callback(nodes):
226             pass
227     
228         for bucket in self.table.buckets:
229             if force or (time() - bucket.lastAccessed >= self.config['BUCKET_STALENESS']):
230                 id = newIDInRange(bucket.min, bucket.max)
231                 self.findNode(id, callback)
232
233     def stats(self):
234         """
235         Returns (num_contacts, num_nodes)
236         num_contacts: number contacts in our routing table
237         num_nodes: number of nodes estimated in the entire dht
238         """
239         num_contacts = reduce(lambda a, b: a + len(b.l), self.table.buckets, 0)
240         num_nodes = self.config['K'] * (2**(len(self.table.buckets) - 1))
241         return (num_contacts, num_nodes)
242
243     def krpc_ping(self, id, _krpc_sender):
244         sender = {'id' : id}
245         sender['host'] = _krpc_sender[0]
246         sender['port'] = _krpc_sender[1]        
247         n = self.Node().initWithDict(sender)
248         n.conn = self.udp.connectionForAddr((n.host, n.port))
249         self.insertNode(n, contacted=0)
250         return {"id" : self.node.id}
251         
252     def krpc_find_node(self, target, id, _krpc_sender):
253         nodes = self.table.findNodes(target)
254         nodes = map(lambda node: node.senderDict(), nodes)
255         sender = {'id' : id}
256         sender['host'] = _krpc_sender[0]
257         sender['port'] = _krpc_sender[1]        
258         n = self.Node().initWithDict(sender)
259         n.conn = self.udp.connectionForAddr((n.host, n.port))
260         self.insertNode(n, contacted=0)
261         return {"nodes" : nodes, "id" : self.node.id}
262
263
264 ## This class provides read-only access to the DHT, valueForKey
265 ## you probably want to use this mixin and provide your own write methods
266 class KhashmirRead(KhashmirBase):
267     _Node = KNodeRead
268     def retrieveValues(self, key):
269         c = self.store.cursor()
270         c.execute("select value from kv where key = %s;", sqlite.encode(key))
271         t = c.fetchone()
272         l = []
273         while t:
274             l.append(t['value'])
275             t = c.fetchone()
276         return l
277     ## also async
278     def valueForKey(self, key, callback, searchlocal = 1):
279         """ returns the values found for key in global table
280             callback will be called with a list of values for each peer that returns unique values
281             final callback will be an empty list - probably should change to 'more coming' arg
282         """
283         nodes = self.table.findNodes(key)
284         
285         # get locals
286         if searchlocal:
287             l = self.retrieveValues(key)
288             if len(l) > 0:
289                 reactor.callLater(0, callback, (l))
290         else:
291             l = []
292         
293         # create our search state
294         state = GetValue(self, key, callback, self.config)
295         reactor.callLater(0, state.goWithNodes, nodes, l)
296
297     def krpc_find_value(self, key, id, _krpc_sender):
298         sender = {'id' : id}
299         sender['host'] = _krpc_sender[0]
300         sender['port'] = _krpc_sender[1]        
301         n = self.Node().initWithDict(sender)
302         n.conn = self.udp.connectionForAddr((n.host, n.port))
303         self.insertNode(n, contacted=0)
304     
305         l = self.retrieveValues(key)
306         if len(l) > 0:
307             return {'values' : l, "id": self.node.id}
308         else:
309             nodes = self.table.findNodes(key)
310             nodes = map(lambda node: node.senderDict(), nodes)
311             return {'nodes' : nodes, "id": self.node.id}
312
313 ###  provides a generic write method, you probably don't want to deploy something that allows
314 ###  arbitrary value storage
315 class KhashmirWrite(KhashmirRead):
316     _Node = KNodeWrite
317     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
318     def storeValueForKey(self, key, value, callback=None):
319         """ stores the value for key in the global table, returns immediately, no status 
320             in this implementation, peers respond but don't indicate status to storing values
321             a key can have many values
322         """
323         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
324             if not response:
325                 # default callback
326                 def _storedValueHandler(sender):
327                     pass
328                 response=_storedValueHandler
329             action = StoreValue(self.table, key, value, response, self.config)
330             reactor.callLater(0, action.goWithNodes, nodes)
331             
332         # this call is asynch
333         self.findNode(key, _storeValueForKey)
334                     
335     def krpc_store_value(self, key, value, id, _krpc_sender):
336         t = "%0.6f" % time()
337         c = self.store.cursor()
338         try:
339             c.execute("insert into kv values (%s, %s, %s);", (sqlite.encode(key), sqlite.encode(value), t))
340         except sqlite.IntegrityError, reason:
341             # update last insert time
342             c.execute("update kv set time = %s where key = %s and value = %s;", (t, sqlite.encode(key), sqlite.encode(value)))
343         self.store.commit()
344         sender = {'id' : id}
345         sender['host'] = _krpc_sender[0]
346         sender['port'] = _krpc_sender[1]        
347         n = self.Node().initWithDict(sender)
348         n.conn = self.udp.connectionForAddr((n.host, n.port))
349         self.insertNode(n, contacted=0)
350         return {"id" : self.node.id}
351
352 # the whole shebang, for testing
353 class Khashmir(KhashmirWrite):
354     _Node = KNodeWrite
355
356 class SimpleTests(unittest.TestCase):
357     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
358                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
359                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
360                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
361                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
362                     'KE_AGE': 3600, }
363
364     def setUp(self):
365         d = self.DHT_DEFAULTS.copy()
366         d['PORT'] = 4044
367         self.a = Khashmir(d)
368         d = self.DHT_DEFAULTS.copy()
369         d['PORT'] = 4045
370         self.b = Khashmir(d)
371         
372     def tearDown(self):
373         self.a.listenport.stopListening()
374         self.b.listenport.stopListening()
375         try:
376             self.a.next_checkpoint.cancel()
377         except:
378             pass
379         try:
380             self.b.next_checkpoint.cancel()
381         except:
382             pass
383         try:
384             self.a.expirer.next_expire.cancel()
385         except:
386             pass
387         try:
388             self.b.expirer.next_expire.cancel()
389         except:
390             pass
391         self.a.store.close()
392         self.b.store.close()
393         os.unlink(self.a.db)
394         os.unlink(self.b.db)
395
396     def testAddContact(self):
397         self.assertEqual(len(self.a.table.buckets), 1)
398         self.assertEqual(len(self.a.table.buckets[0].l), 0)
399
400         self.assertEqual(len(self.b.table.buckets), 1)
401         self.assertEqual(len(self.b.table.buckets[0].l), 0)
402
403         self.a.addContact('127.0.0.1', 4045)
404         reactor.iterate()
405         reactor.iterate()
406         reactor.iterate()
407         reactor.iterate()
408
409         self.assertEqual(len(self.a.table.buckets), 1)
410         self.assertEqual(len(self.a.table.buckets[0].l), 1)
411         self.assertEqual(len(self.b.table.buckets), 1)
412         self.assertEqual(len(self.b.table.buckets[0].l), 1)
413
414     def testStoreRetrieve(self):
415         self.a.addContact('127.0.0.1', 4045)
416         reactor.iterate()
417         reactor.iterate()
418         reactor.iterate()
419         reactor.iterate()
420         self.got = 0
421         self.a.storeValueForKey(sha('foo').digest(), 'foobar')
422         reactor.iterate()
423         reactor.iterate()
424         reactor.iterate()
425         reactor.iterate()
426         reactor.iterate()
427         reactor.iterate()
428         self.a.valueForKey(sha('foo').digest(), self._cb)
429         reactor.iterate()
430         reactor.iterate()
431         reactor.iterate()
432         reactor.iterate()
433         reactor.iterate()
434         reactor.iterate()
435         reactor.iterate()
436
437     def _cb(self, val):
438         if not val:
439             self.assertEqual(self.got, 1)
440         elif 'foobar' in val:
441             self.got = 1
442
443
444 class MultiTest(unittest.TestCase):
445     num = 20
446     DHT_DEFAULTS = {'PORT': 9977, 'K': 8, 'HASH_LENGTH': 160,
447                     'CHECKPOINT_INTERVAL': 900, 'CONCURRENT_REQS': 4,
448                     'STORE_REDUNDANCY': 3, 'MAX_FAILURES': 3,
449                     'MIN_PING_INTERVAL': 900,'BUCKET_STALENESS': 3600,
450                     'KEINITIAL_DELAY': 15, 'KE_DELAY': 1200,
451                     'KE_AGE': 3600, }
452
453     def _done(self, val):
454         self.done = 1
455         
456     def setUp(self):
457         self.l = []
458         self.startport = 4088
459         for i in range(self.num):
460             d = self.DHT_DEFAULTS.copy()
461             d['PORT'] = self.startport + i
462             self.l.append(Khashmir(d))
463         reactor.iterate()
464         reactor.iterate()
465         
466         for i in self.l:
467             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
468             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
469             i.addContact('127.0.0.1', self.l[randrange(0,self.num)].port)
470             reactor.iterate()
471             reactor.iterate()
472             reactor.iterate() 
473             
474         for i in self.l:
475             self.done = 0
476             i.findCloseNodes(self._done)
477             while not self.done:
478                 reactor.iterate()
479         for i in self.l:
480             self.done = 0
481             i.findCloseNodes(self._done)
482             while not self.done:
483                 reactor.iterate()
484
485     def tearDown(self):
486         for i in self.l:
487             i.listenport.stopListening()
488             try:
489                 i.next_checkpoint.cancel()
490             except:
491                 pass
492             try:
493                 i.expirer.next_expire.cancel()
494             except:
495                 pass
496             i.store.close()
497             os.unlink(i.db)
498             
499         reactor.iterate()
500         
501     def testStoreRetrieve(self):
502         for i in range(10):
503             K = newID()
504             V = newID()
505             
506             for a in range(3):
507                 self.done = 0
508                 def _scb(val):
509                     self.done = 1
510                 self.l[randrange(0, self.num)].storeValueForKey(K, V, _scb)
511                 while not self.done:
512                     reactor.iterate()
513
514
515                 def _rcb(val):
516                     if not val:
517                         self.done = 1
518                         self.assertEqual(self.got, 1)
519                     elif V in val:
520                         self.got = 1
521                 for x in range(3):
522                     self.got = 0
523                     self.done = 0
524                     self.l[randrange(0, self.num)].valueForKey(K, _rcb)
525                     while not self.done:
526                         reactor.iterate()