ripped out xmlrpc, experimented with xmlrpc but with bencode, finally
[quix0rs-apt-p2p.git] / khashmir.py
1 ## Copyright 2002 Andrew Loewenstern, All Rights Reserved
2
3 from const import reactor
4 import const
5
6 import time
7
8 from sha import sha
9
10 from ktable import KTable, K
11 from knode import KNode as Node
12
13 from hash import newID, newIDInRange
14
15 from actions import FindNode, GetValue, KeyExpirer, StoreValue
16 import krpc
17 import airhook
18
19 from twisted.internet.defer import Deferred
20 from twisted.internet import protocol
21 from twisted.python import threadable
22 from twisted.internet.app import Application
23 from twisted.web import server
24 threadable.init()
25 import sys
26
27 import sqlite  ## find this at http://pysqlite.sourceforge.net/
28 import pysqlite_exceptions
29
30 class KhashmirDBExcept(Exception):
31     pass
32
33 # this is the main class!
34 class Khashmir(protocol.Factory):
35     __slots__ = ('listener', 'node', 'table', 'store', 'app', 'last', 'protocol')
36     protocol = krpc.KRPC
37     def __init__(self, host, port, db='khashmir.db'):
38         self.setup(host, port, db)
39         
40     def setup(self, host, port, db='khashmir.db'):
41         self._findDB(db)
42         self.node = self._loadSelfNode(host, port)
43         self.table = KTable(self.node)
44         self.app = Application("krpc")
45         self.airhook = airhook.listenAirhookStream(port, self)
46         self.last = time.time()
47         self._loadRoutingTable()
48         KeyExpirer(store=self.store)
49         #self.refreshTable(force=1)
50         reactor.callLater(60, self.checkpoint, (1,))
51         
52     def _loadSelfNode(self, host, port):
53         c = self.store.cursor()
54         c.execute('select id from self where num = 0;')
55         if c.rowcount > 0:
56             id = c.fetchone()[0].decode('hex')
57         else:
58             id = newID()
59         return Node().init(id, host, port)
60         
61     def _saveSelfNode(self):
62         self.store.autocommit = 0
63         c = self.store.cursor()
64         c.execute('delete from self where num = 0;')
65         c.execute("insert into self values (0, '%s');" % self.node.id.encode('hex'))
66         self.store.commit()
67         self.store.autocommit = 1
68         
69     def checkpoint(self, auto=0):
70         self._saveSelfNode()
71         self._dumpRoutingTable()
72         if auto:
73             reactor.callLater(const.CHECKPOINT_INTERVAL, self.checkpoint)
74         
75     def _findDB(self, db):
76         import os
77         try:
78             os.stat(db)
79         except OSError:
80             self._createNewDB(db)
81         else:
82             self._loadDB(db)
83         
84     def _loadDB(self, db):
85         try:
86             self.store = sqlite.connect(db=db)
87             self.store.autocommit = 1
88         except:
89             import traceback
90             raise KhashmirDBExcept, "Couldn't open DB", traceback.exc_traceback
91         
92     def _createNewDB(self, db):
93         self.store = sqlite.connect(db=db)
94         self.store.autocommit = 1
95         s = """
96             create table kv (key text, value text, time timestamp, primary key (key, value));
97             create index kv_key on kv(key);
98             create index kv_timestamp on kv(time);
99             
100             create table nodes (id text primary key, host text, port number);
101             
102             create table self (num number primary key, id text);
103             """
104         c = self.store.cursor()
105         c.execute(s)
106
107     def _dumpRoutingTable(self):
108         """
109             save routing table nodes to the database
110         """
111         self.store.autocommit = 0;
112         c = self.store.cursor()
113         c.execute("delete from nodes where id not NULL;")
114         for bucket in self.table.buckets:
115             for node in bucket.l:
116                 d = node.senderDict()
117                 c.execute("insert into nodes values ('%s', '%s', '%s');" % (d['id'].encode('hex'), d['host'], d['port']))
118         self.store.commit()
119         self.store.autocommit = 1;
120         
121     def _loadRoutingTable(self):
122         """
123             load routing table nodes from database
124             it's usually a good idea to call refreshTable(force=1) after loading the table
125         """
126         c = self.store.cursor()
127         c.execute("select * from nodes;")
128         for rec in c.fetchall():
129             n = Node().initWithDict({'id':rec[0].decode('hex'), 'host':rec[1], 'port':int(rec[2])})
130             n.conn = self.airhook.connectionForAddr((n.host, n.port))
131             self.table.insertNode(n, contacted=0)
132             
133
134     #######
135     #######  LOCAL INTERFACE    - use these methods!
136     def addContact(self, host, port, callback=None):
137         """
138             ping this node and add the contact info to the table on pong!
139         """
140         n =Node().init(const.NULL_ID, host, port) 
141         n.conn = self.airhook.connectionForAddr((n.host, n.port))
142         self.sendPing(n, callback=callback)
143
144     ## this call is async!
145     def findNode(self, id, callback, errback=None):
146         """ returns the contact info for node, or the k closest nodes, from the global table """
147         # get K nodes out of local table/cache, or the node we want
148         nodes = self.table.findNodes(id)
149         d = Deferred()
150         if errback:
151             d.addCallbacks(callback, errback)
152         else:
153             d.addCallback(callback)
154         if len(nodes) == 1 and nodes[0].id == id :
155             d.callback(nodes)
156         else:
157             # create our search state
158             state = FindNode(self, id, d.callback)
159             reactor.callFromThread(state.goWithNodes, nodes)
160     
161     
162     ## also async
163     def valueForKey(self, key, callback):
164         """ returns the values found for key in global table
165             callback will be called with a list of values for each peer that returns unique values
166             final callback will be an empty list - probably should change to 'more coming' arg
167         """
168         nodes = self.table.findNodes(key)
169         
170         # get locals
171         l = self.retrieveValues(key)
172         
173         # create our search state
174         state = GetValue(self, key, callback)
175         reactor.callFromThread(state.goWithNodes, nodes, l)
176
177     ## async, callback indicates nodes we got a response from (but no guarantee they didn't drop it on the floor)
178     def storeValueForKey(self, key, value, callback=None):
179         """ stores the value for key in the global table, returns immediately, no status 
180             in this implementation, peers respond but don't indicate status to storing values
181             a key can have many values
182         """
183         def _storeValueForKey(nodes, key=key, value=value, response=callback , table=self.table):
184             if not response:
185                 # default callback
186                 def _storedValueHandler(sender):
187                     pass
188                 response=_storedValueHandler
189             action = StoreValue(self.table, key, value, response)
190             reactor.callFromThread(action.goWithNodes, nodes)
191             
192         # this call is asynch
193         self.findNode(key, _storeValueForKey)
194         
195     
196     def insertNode(self, n, contacted=1):
197         """
198         insert a node in our local table, pinging oldest contact in bucket, if necessary
199         
200         If all you have is a host/port, then use addContact, which calls this method after
201         receiving the PONG from the remote node.  The reason for the seperation is we can't insert
202         a node into the table without it's peer-ID.  That means of course the node passed into this
203         method needs to be a properly formed Node object with a valid ID.
204         """
205         old = self.table.insertNode(n, contacted=contacted)
206         if old and (time.time() - old.lastSeen) > const.MIN_PING_INTERVAL and old.id != self.node.id:
207             # the bucket is full, check to see if old node is still around and if so, replace it
208             
209             ## these are the callbacks used when we ping the oldest node in a bucket
210             def _staleNodeHandler(oldnode=old, newnode = n):
211                 """ called if the pinged node never responds """
212                 self.table.replaceStaleNode(old, newnode)
213             
214             def _notStaleNodeHandler(dict, old=old):
215                 """ called when we get a pong from the old node """
216                 sender = dict['sender']
217                 if sender['id'] == old.id:
218                     self.table.justSeenNode(old.id)
219             
220             df = old.ping(self.node.senderDict())
221             df.addCallbacks(_notStaleNodeHandler, _staleNodeHandler)
222
223     def sendPing(self, node, callback=None):
224         """
225             ping a node
226         """
227         df = node.ping(self.node.senderDict())
228         ## these are the callbacks we use when we issue a PING
229         def _pongHandler(dict, node=node, table=self.table, callback=callback):
230             sender = dict['sender']
231             if node.id != const.NULL_ID and node.id != sender['id']:
232                 # whoah, got response from different peer than we were expecting
233                 self.table.invalidateNode(node)
234             else:
235                 sender['host'] = node.host
236                 sender['port'] = node.port
237                 n = Node().initWithDict(sender)
238                 n.conn = self.airhook.connectionForAddr((n.host, n.port))
239                 table.insertNode(n)
240                 if callback:
241                     callback()
242         def _defaultPong(err, node=node, table=self.table, callback=callback):
243             table.nodeFailed(node)
244             if callback:
245                 callback()
246         
247         df.addCallbacks(_pongHandler,_defaultPong)
248
249     def findCloseNodes(self, callback=lambda a: None):
250         """
251             This does a findNode on the ID one away from our own.  
252             This will allow us to populate our table with nodes on our network closest to our own.
253             This is called as soon as we start up with an empty table
254         """
255         id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256)
256         self.findNode(id, callback)
257
258     def refreshTable(self, force=0):
259         """
260             force=1 will refresh table regardless of last bucket access time
261         """
262         def callback(nodes):
263             pass
264     
265         for bucket in self.table.buckets:
266             if force or (time.time() - bucket.lastAccessed >= const.BUCKET_STALENESS):
267                 id = newIDInRange(bucket.min, bucket.max)
268                 self.findNode(id, callback)
269
270
271     def retrieveValues(self, key):
272         s = "select value from kv where key = '%s';" % key.encode('hex')
273         c = self.store.cursor()
274         c.execute(s)
275         t = c.fetchone()
276         l = []
277         while t:
278             l.append(t['value'].decode('base64'))
279             t = c.fetchone()
280         return l
281     
282     #####
283     ##### INCOMING MESSAGE HANDLERS
284     
285     def krpc_ping(self, sender, _krpc_sender):
286         """
287             takes sender dict = {'id', <id>, 'port', port} optional keys = 'ip'
288             returns sender dict
289         """
290         sender['host'] = _krpc_sender[0]
291         n = Node().initWithDict(sender)
292         n.conn = self.airhook.connectionForAddr((n.host, n.port))
293         self.insertNode(n, contacted=0)
294         return {"sender" : self.node.senderDict()}
295         
296     def krpc_find_node(self, target, sender, _krpc_sender):
297         nodes = self.table.findNodes(target)
298         nodes = map(lambda node: node.senderDict(), nodes)
299         sender['host'] = _krpc_sender[0]
300         n = Node().initWithDict(sender)
301         n.conn = self.airhook.connectionForAddr((n.host, n.port))
302         self.insertNode(n, contacted=0)
303         return {"nodes" : nodes, "sender" : self.node.senderDict()}
304             
305     def krpc_store_value(self, key, value, sender, _krpc_sender):
306         t = "%0.6f" % time.time()
307         s = "insert into kv values ('%s', '%s', '%s');" % (key.encode("hex"), value.encode("base64"), t)
308         c = self.store.cursor()
309         try:
310             c.execute(s)
311         except pysqlite_exceptions.IntegrityError, reason:
312             # update last insert time
313             s = "update kv set time = '%s' where key = '%s' and value = '%s';" % (t, key.encode("hex"), value.encode("base64"))
314             c.execute(s)
315         sender['host'] = _krpc_sender[0]
316         n = Node().initWithDict(sender)
317         n.conn = self.airhook.connectionForAddr((n.host, n.port))
318         self.insertNode(n, contacted=0)
319         return {"sender" : self.node.senderDict()}
320     
321     def krpc_find_value(self, key, sender, _krpc_sender):
322         sender['host'] = _krpc_sender[0]
323         n = Node().initWithDict(sender)
324         n.conn = self.airhook.connectionForAddr((n.host, n.port))
325         self.insertNode(n, contacted=0)
326     
327         l = self.retrieveValues(key)
328         if len(l) > 0:
329             return {'values' : l, "sender": self.node.senderDict()}
330         else:
331             nodes = self.table.findNodes(key)
332             nodes = map(lambda node: node.senderDict(), nodes)
333             return {'nodes' : nodes, "sender": self.node.senderDict()}
334
335 ### TESTING ###
336 from random import randrange
337 import threading, thread, sys, time
338 from sha import sha
339 from hash import newID
340
341
342 def test_net(peers=24, startport=2001, dbprefix='/tmp/test'):
343     import thread
344     l = []
345     for i in xrange(peers):
346         a = Khashmir('127.0.0.1', startport + i, db = dbprefix+`i`)
347         l.append(a)
348     thread.start_new_thread(l[0].app.run, ())
349     for peer in l[1:]:
350         peer.app.run()  
351     return l
352     
353 def test_build_net(quiet=0, peers=24, host='127.0.0.1',  pause=0, startport=2001, dbprefix='/tmp/test'):
354     from whrandom import randrange
355     import threading
356     import thread
357     import sys
358     port = startport
359     l = []
360     if not quiet:
361         print "Building %s peer table." % peers
362     
363     for i in xrange(peers):
364         a = Khashmir(host, port + i, db = dbprefix +`i`)
365         l.append(a)
366     
367     
368     thread.start_new_thread(l[0].app.run, ())
369     time.sleep(1)
370     for peer in l[1:]:
371         peer.app.run()
372     #time.sleep(3)
373     
374     def spewer(frame, s, ignored):
375         from twisted.python import reflect
376         if frame.f_locals.has_key('self'):
377             se = frame.f_locals['self']
378             print 'method %s of %s at %s' % (
379                 frame.f_code.co_name, reflect.qual(se.__class__), id(se)
380                 )
381     #sys.settrace(spewer)
382
383     print "adding contacts...."
384     def makecb(flag):
385         def cb(f=flag):
386             f.set()
387         return cb
388
389     for peer in l:
390         p = l[randrange(0, len(l))]
391         if p != peer:
392             n = p.node
393             flag = threading.Event()
394             peer.addContact(host, n.port, makecb(flag))
395             flag.wait()
396         p = l[randrange(0, len(l))]
397         if p != peer:
398             n = p.node
399             flag = threading.Event()
400             peer.addContact(host, n.port, makecb(flag))
401             flag.wait()
402         p = l[randrange(0, len(l))]
403         if p != peer:
404             n = p.node
405             flag = threading.Event()
406             peer.addContact(host, n.port, makecb(flag))
407             flag.wait()
408     
409     print "finding close nodes...."
410     
411     for peer in l:
412         flag = threading.Event()
413         def cb(nodes, f=flag):
414             f.set()
415         peer.findCloseNodes(cb)
416         flag.wait()
417     #    for peer in l:
418     #   peer.refreshTable()
419     return l
420         
421 def test_find_nodes(l, quiet=0):
422     flag = threading.Event()
423     
424     n = len(l)
425     
426     a = l[randrange(0,n)]
427     b = l[randrange(0,n)]
428     
429     def callback(nodes, flag=flag, id = b.node.id):
430         if (len(nodes) >0) and (nodes[0].id == id):
431             print "test_find_nodes      PASSED"
432         else:
433             print "test_find_nodes      FAILED"
434         flag.set()
435     a.findNode(b.node.id, callback)
436     flag.wait()
437     
438 def test_find_value(l, quiet=0):
439     
440     fa = threading.Event()
441     fb = threading.Event()
442     fc = threading.Event()
443     
444     n = len(l)
445     a = l[randrange(0,n)]
446     b = l[randrange(0,n)]
447     c = l[randrange(0,n)]
448     d = l[randrange(0,n)]
449     
450     key = newID()
451     value = newID()
452     if not quiet: print "inserting value..."
453     a.storeValueForKey(key, value)
454     time.sleep(3)
455     if not quiet:
456         print "finding..."
457     
458     class cb:
459         def __init__(self, flag, value=value):
460             self.flag = flag
461             self.val = value
462             self.found = 0
463         def callback(self, values):
464             try:
465                 if(len(values) == 0):
466                     if not self.found:
467                         print "find                NOT FOUND"
468                     else:
469                         print "find                FOUND"
470                 else:
471                     if self.val in values:
472                         self.found = 1
473             finally:
474                 self.flag.set()
475     
476     b.valueForKey(key, cb(fa).callback)
477     fa.wait()
478     c.valueForKey(key, cb(fb).callback)
479     fb.wait()
480     d.valueForKey(key, cb(fc).callback)    
481     fc.wait()
482     
483 def test_one(host, port, db='/tmp/test'):
484     import thread
485     k = Khashmir(host, port, db)
486     thread.start_new_thread(reactor.run, ())
487     return k
488     
489 if __name__ == "__main__":
490     import sys
491     n = 8
492     if len(sys.argv) > 1: n = int(sys.argv[1])
493     l = test_build_net(peers=n)
494     time.sleep(3)
495     print "finding nodes..."
496     for i in range(10):
497         test_find_nodes(l)
498     print "inserting and fetching values..."
499     for i in range(10):
500         test_find_value(l)