From: burris Date: Fri, 19 Jul 2002 20:23:12 +0000 (+0000) Subject: Initial revision X-Git-Url: https://git.mxchange.org/?a=commitdiff_plain;h=1abfc43747d252f5d0bf11f119e0daabecc4f3a6;p=quix0rs-apt-p2p.git Initial revision --- 1abfc43747d252f5d0bf11f119e0daabecc4f3a6 diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..141e5fa --- /dev/null +++ b/README.txt @@ -0,0 +1,10 @@ +quick example: + +import khashmir, threading +k = khashmir.Khashmir('127.0.0.1', 4444) +start_new_thread(k.dispatcher.run, ()) +k.addContact('127.0.0.1', 8080) # right now we don't do gethostbyname +k.populateTable() + + +alternatively, you can call k.dispatcher.runOnce() periodically from whatever thread you choose \ No newline at end of file diff --git a/bencode.py b/bencode.py new file mode 100644 index 0000000..77f65da --- /dev/null +++ b/bencode.py @@ -0,0 +1,335 @@ +""" +A library for streaming and unstreaming of simple objects, designed +for speed, compactness, and ease of implementation. + +The basic functions are bencode and bdecode. bencode takes an object +and returns a string, bdecode takes a string and returns an object. +bdecode raises a ValueError if you give it an invalid string. + +The objects passed in may be nested dicts, lists, ints, strings, +and None. For example, all of the following may be bencoded - + +{'a': [0, 1], 'b': None} + +[None, ['a', 2, ['c', None]]] + +{'spam': (2,3,4)} + +{'name': 'Cronus', 'spouse': 'Rhea', 'children': ['Hades', 'Poseidon']} + +In general bdecode(bencode(spam)) == spam, but tuples and lists are +encoded the same, so bdecode(bencode((0, 1))) is [0, 1] rather +than (0, 1). Longs and ints are also encoded the same way, so +bdecode(bencode(4)) is a long. + +dict keys are required to be strings, to avoid a mess of potential +implementation incompatibilities. bencode is intended to be used +for protocols which are going to be re-implemented many times, so +it's very conservative in that regard. + +Which type is encoded is determined by the first character, 'i', 'n', +'d', 'l' and any digit. They indicate integer, null, dict, list, and +string, respectively. + +Strings are length-prefixed in base 10, followed by a colon. + +bencode('spam') == '4:spam' + +Nulls are indicated by a single 'n'. + +bencode(None) == 'n' + +integers are encoded base 10 and terminated with an 'e'. + +bencode(3) == 'i3e' +bencode(-20) == 'i-20e' + +Lists are encoded in list order, terminated by an 'e' - + +bencode(['abc', 'd']) == 'l3:abc1:de' +bencode([2, 'f']) == 'li2e1:fe' + +Dicts are encoded by containing alternating keys and values, +with the keys in sorted order, terminated by an 'e'. For example - + +bencode({'spam': 'eggs'}) == 'd4:spam4:eggse' +bencode({'ab': 2, 'a': None}) == 'd1:an2:abi2ee' + +Truncated strings come first, so in sort order 'a' comes before 'abc'. + +If a function is passed to bencode, it's called and it's return value +is included as a raw string, for example - + +bdecode(bencode(lambda: None)) == None +""" + +# This file is licensed under the GNU Lesser General Public License v2.1. +# originally written for Mojo Nation by Bryce Wilcox, Bram Cohen, and Greg P. Smith +# since then, almost completely rewritten by Bram Cohen + +from types import * +from cStringIO import StringIO +import re + +def bencode(data): + """ + encodes objects as strings, see module documentation for more info + """ + result = StringIO() + bwrite(data, result) + return result.getvalue() + +def bwrite(data, result): + encoder = encoders.get(type(data)) + assert encoder is not None, 'unsupported data type: ' + `type(data)` + encoder(data, result) + +encoders = {} + +def encode_int(data, result): + result.write('i' + str(data) + 'e') + +encoders[IntType] = encode_int +encoders[LongType] = encode_int + +def encode_list(data, result): + result.write('l') + for i in data: + bwrite(i, result) + result.write('e') + +encoders[TupleType] = encode_list +encoders[ListType] = encode_list + +def encode_string(data, result): + result.write(str(len(data)) + ':' + data) + +encoders[StringType] = encode_string + +def encode_dict(data, result): + result.write('d') + keys = data.keys() + keys.sort() + for key in keys: + assert type(key) is StringType, 'bencoded dictionary key must be a string' + bwrite(key, result) + bwrite(data[key], result) + result.write('e') + +encoders[DictType] = encode_dict + +encoders[NoneType] = lambda data, result: result.write('n') + +encoders[FunctionType] = lambda data, result: result.write(data()) +encoders[MethodType] = encoders[FunctionType] + +def bdecode(s): + """ + Does the opposite of bencode. Raises a ValueError if there's a problem. + """ + try: + result, index = bread(s, 0) + if index != len(s): + raise ValueError('left over stuff at end') + return result + except IndexError, e: + raise ValueError(str(e)) + except KeyError, e: + raise ValueError(str(e)) + +def bread(s, index): + return decoders[s[index]](s, index) + +decoders = {} + +_bre = re.compile(r'(0|[1-9][0-9]*):') + +def decode_raw_string(s, index): + x = _bre.match(s, index) + if x is None: + raise ValueError('invalid integer encoding') + endindex = x.end() + long(s[index:x.end() - 1]) + if endindex > len(s): + raise ValueError('length encoding indicated premature end of string') + return s[x.end(): endindex], endindex + +for c in '0123456789': + decoders[c] = decode_raw_string + +_int_re = re.compile(r'i(0|-?[1-9][0-9]*)e') + +def decode_int(s, index): + x = _int_re.match(s, index) + if x is None: + raise ValueError('invalid integer encoding') + return long(s[index + 1:x.end() - 1]), x.end() + +decoders['i'] = decode_int + +decoders['n'] = lambda s, index: (None, index + 1) + +def decode_list(s, index): + result = [] + index += 1 + while s[index] != 'e': + next, index = bread(s, index) + result.append(next) + return result, index + 1 + +decoders['l'] = decode_list + +def decode_dict(s, index): + result = {} + index += 1 + prevkey = None + while s[index] != 'e': + key, index = decode_raw_string(s, index) + if key <= prevkey: + raise ValueError("out of order keys") + prevkey = key + value, index = bread(s, index) + result[key] = value + return result, index + 1 + +decoders['d'] = decode_dict + +def test_decode_raw_string(): + assert decode_raw_string('1:a', 0) == ('a', 3) + assert decode_raw_string('0:', 0) == ('', 2) + assert decode_raw_string('10:aaaaaaaaaaaaaaaaaaaaaaaaa', 0) == ('aaaaaaaaaa', 13) + assert decode_raw_string('10:', 1) == ('', 3) + try: + decode_raw_string('01:a', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('--1:a', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('h', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('h:', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('1', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('5:a', 0) + assert 0, 'failed' + except ValueError: + pass + +def test_dict_enforces_order(): + bdecode('d1:an1:bne') + try: + bdecode('d1:bn1:ane') + assert 0, 'failed' + except ValueError: + pass + +def test_dict_forbids_non_string_key(): + try: + bdecode('di3ene') + assert 0, 'failed' + except ValueError: + pass + +def test_dict_forbids_key_repeat(): + try: + bdecode('d1:an1:ane') + assert 0, 'failed' + except ValueError: + pass + +def test_empty_dict(): + assert bdecode('de') == {} + +def test_ValueError_in_decode_unknown(): + try: + bdecode('x') + assert 0, 'flunked' + except ValueError: + pass + +def test_encode_and_decode_none(): + assert bdecode(bencode(None)) == None + +def test_encode_and_decode_long(): + assert bdecode(bencode(-23452422452342L)) == -23452422452342L + +def test_encode_and_decode_int(): + assert bdecode(bencode(2)) == 2 + +def test_decode_noncanonical_int(): + try: + bdecode('i03e') + assert 0 + except ValueError: + pass + try: + bdecode('i3 e') + assert 0 + except ValueError: + pass + try: + bdecode('i 3e') + assert 0 + except ValueError: + pass + try: + bdecode('i-0e') + assert 0 + except ValueError: + pass + +def test_encode_and_decode_dict(): + x = {'42': 3} + assert bdecode(bencode(x)) == x + +def test_encode_and_decode_list(): + assert bdecode(bencode([])) == [] + +def test_encode_and_decode_tuple(): + assert bdecode(bencode(())) == [] + +def test_encode_and_decode_empty_dict(): + assert bdecode(bencode({})) == {} + +def test_encode_and_decode_complex_object(): + spam = [[], 0, -3, -345234523543245234523L, {}, 'spam', None, {'a': [3]}, {}] + assert bencode(bdecode(bencode(spam))) == bencode(spam) + assert bdecode(bencode(spam)) == spam + +def test_unfinished_list(): + try: + bdecode('ln') + assert 0 + except ValueError: + pass + +def test_unfinished_dict(): + try: + bdecode('d') + assert 0 + except ValueError: + pass + try: + bdecode('d1:a') + assert 0 + except ValueError: + pass diff --git a/btemplate.py b/btemplate.py new file mode 100644 index 0000000..038a1ac --- /dev/null +++ b/btemplate.py @@ -0,0 +1,528 @@ +# This file is licensed under the GNU Lesser General Public License v2.1. +# originally written for Mojo Nation by Bram Cohen, based on an earlier +# version by Bryce Wilcox +# The authors disclaim all liability for any damages resulting from +# any use of this software. + +import types + +def string_template(thing, verbose): + if type(thing) != types.StringType: + raise ValueError, "not a string" + +st = string_template + +def exact_length(l): + def func(s, verbose, l = l): + if type(s) != types.StringType: + raise ValueError, 'should have been string' + if len(s) != l: + raise ValueError, 'wrong length, should have been ' + str(l) + ' was ' + str(len(s)) + return func + +class MaxDepth: + def __init__(self, max_depth, template = None): + assert max_depth >= 0 + self.max_depth = max_depth + self.template = template + + def get_real_template(self): + assert self.template is not None, 'You forgot to set the template!' + if self.max_depth == 0: + return fail_too_deep + self.max_depth -= 1 + try: + return compile_inner(self.template) + finally: + self.max_depth += 1 + + def __repr__(self): + if hasattr(self, 'p'): + return '...' + try: + self.p = 1 + return 'MaxDepth(' + str(self.max_depth) + ', ' + `self.template` + ')' + finally: + del self.p + +def fail_too_deep(thing, verbose): + raise ValueError, 'recursed too deep' + +class ListMarker: + def __init__(self, template): + self.template = template + + def get_real_template(self): + return compile_list_template(self.template) + + def __repr__(self): + return 'ListMarker(' + `self.template` + ')' + +def compile_list_template(template): + def func(thing, verbose, template = compile_inner(template)): + if type(thing) not in (types.ListType, types.TupleType): + raise ValueError, 'not a list' + if verbose: + try: + for i in xrange(0, len(thing)): + template(thing[i], 1) + except ValueError, e: + reason = 'mismatch at index ' + str(i) + ': ' + str(e) + raise ValueError, reason + else: + for i in thing: + template(i, 0) + return func + +class ValuesMarker: + def __init__(self, template, t2 = string_template): + self.template = template + self.t2 = t2 + + def get_real_template(self): + return compile_values_template(self.template, self.t2) + + def __repr__(self): + return 'ValuesMarker(' + `self.template` + ')' + +def compile_values_template(template, t2): + def func(thing, verbose, template = compile_inner(template), + t2 = compile_inner(t2)): + if type(thing) != types.DictType: + raise ValueError, 'not a dict' + if verbose: + try: + for key, val in thing.items(): + template(val, 1) + t2(key, 1) + except ValueError, e: + raise ValueError, 'mismatch in key ' + `key` + ': ' + str(e) + else: + for key, val in thing.items(): + template(val, 0) + t2(key, 0) + return func + +compilers = {} + +def compile_string_template(template): + assert type(template) is types.StringType + def func(thing, verbose, template = template): + if thing != template: + raise ValueError, "didn't match string" + return func + +compilers[types.StringType] = compile_string_template + +def int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + +def nonnegative_int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + if thing < 0: + raise ValueError, 'thing less than zero' + +def positive_int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + if thing <= 0: + raise ValueError, 'thing less than or equal to zero' + +def compile_int_template(s): + assert s in (-1, 0, 1) + if s == -1: + return int_template + elif s == 0: + return nonnegative_int_template + else: + return positive_int_template + +compilers[types.IntType] = compile_int_template +compilers[types.LongType] = compile_int_template + +def compile_slice(template): + assert type(template) is types.SliceType + assert template.step is None + assert template.stop is not None + start = template.start + if start is None: + start = 0 + def func(thing, verbose, start = start, stop = template.stop): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'not an int' + if thing < start: + raise ValueError, 'thing too small' + if thing >= stop: + raise ValueError, 'thing too large' + return func + +compilers[types.SliceType] = compile_slice + +class OptionMarker: + def __init__(self, template): + self.option_template = template + + def __repr__(self): + return 'OptionMarker(' + `self.option_template` + ')' + +def compile_dict_template(template): + assert type(template) is types.DictType + agroup = [] + bgroup = [] + cgroup = [] + optiongroup = [] + for key, value in template.items(): + if hasattr(value, 'option_template'): + optiongroup.append((key, compile_inner(value.option_template))) + elif type(value) is types.StringType: + agroup.append((key, compile_inner(value))) + elif type(value) in (types.IntType, types.LongType, types.SliceType): + bgroup.append((key, compile_inner(value))) + else: + cgroup.append((key, compile_inner(value))) + def func(thing, verbose, required = agroup + bgroup + cgroup, optional = optiongroup): + if type(thing) is not types.DictType: + raise ValueError, 'not a dict' + try: + for key, template in required: + if not thing.has_key(key): + raise ValueError, 'key not present' + template(thing[key], verbose) + for key, template in optional: + if thing.has_key(key): + template(thing[key], verbose) + except ValueError, e: + if verbose: + reason = 'mismatch in key ' + `key` + ': ' + str(e) + raise ValueError, reason + else: + raise + return func + +compilers[types.DictType] = compile_dict_template + +def none_template(thing, verbose): + if thing is not None: + raise ValueError, 'thing was not None' + +compilers[types.NoneType] = lambda template: none_template + +def compile_or_template(template): + assert type(template) in (types.ListType, types.TupleType) + def func(thing, verbose, templ = [compile_inner(x) for x in template]): + if verbose: + failure_reason = ('did not match any of the ' + + str(len(templ)) + ' possible templates;') + for i in xrange(len(templ)): + try: + templ[i](thing, 1) + return + except ValueError, reason: + failure_reason += (' failed template at index ' + + str(i) + ' because (' + str(reason) + ')') + raise ValueError, failure_reason + else: + for i in templ: + try: + i(thing, 0) + return + except ValueError: + pass + raise ValueError, "did not match any possible templates" + return func + +compilers[types.ListType] = compile_or_template +compilers[types.TupleType] = compile_or_template + +def compile_inner(template): + while hasattr(template, 'get_real_template'): + template = template.get_real_template() + if callable(template): + return template + return compilers[type(template)](template) + +def compile_template(template): + def func(thing, verbose = None, t = compile_inner(template), s = `template`): + if verbose is not None: + t(thing, verbose) + return + try: + t(thing, 0) + except ValueError: + try: + t(thing, 1) + assert 0 + except ValueError, reason: + raise ValueError, 'failed template check because: (' + str(reason) + ') target was: (' + `thing` + ') template was: (' + s + ')' + return func + + + +###### +import unittest + +class TestBTemplate(unittest.TestCase): + + def test_slice(self): + f = compile_template(slice(4)) + f(0) + f(3L) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f(4L) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(slice(-2, 3)) + f(-2L) + f(2) + try: + f(-3L) + assert 0 + except ValueError: + pass + try: + f(3) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_int(self): + f = compile_template(0) + f(0) + f(1L) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(-1) + f(0) + f(1) + f(-1L) + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(1) + try: + f(0) + assert 0 + except ValueError: + pass + f(1) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_none(self): + f = compile_template(None) + f(None) + try: + f(0) + assert 0 + except ValueError: + pass + + def test_string(self): + f = compile_template('a') + f('a') + try: + f('b') + assert 0 + except ValueError: + pass + try: + f(0) + assert 0 + except ValueError: + pass + + def test_generic_string(self): + f = compile_template(st) + f('a') + try: + f(0) + assert 0 + except ValueError: + pass + + def test_values(self): + vt = compile_template(ValuesMarker('a', exact_length(1))) + vt({}) + vt({'x': 'a'}) + try: + vt(3) + assert 0 + except ValueError: + pass + try: + vt({'x': 'b'}) + assert 0 + except ValueError: + pass + try: + vt({'xx': 'a'}) + assert 0 + except ValueError: + pass + + def test_list(self): + f = compile_template(ListMarker('a')) + f(['a']) + f(('a', 'a')) + try: + f(('a', 'b')) + assert 0 + except ValueError: + pass + try: + f(('b', 'a')) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_or(self): + f = compile_template(['a', 'b']) + f('a') + f('b') + try: + f('c') + assert 0 + except ValueError: + pass + + f = compile_template(('a', 'b')) + f('a') + f('b') + try: + f('c') + assert 0 + except ValueError: + pass + + def test_dict(self): + f = compile_template({'a': 'b', 'c': OptionMarker('d')}) + try: + f({}) + assert 0 + except ValueError: + pass + f({'a': 'b'}) + try: + f({'a': 'e'}) + assert 0 + except ValueError: + pass + try: + f({'c': 'd'}) + assert 0 + except ValueError: + pass + f({'a': 'b', 'c': 'd'}) + try: + f({'a': 'e', 'c': 'd'}) + assert 0 + except ValueError: + pass + try: + f({'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f({'a': 'b', 'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f({'a': 'e', 'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f(None) + assert 0 + except ValueError: + pass + + def test_other_func(self): + def check3(thing, verbose): + if thing != 3: + raise ValueError + f = compile_template(check3) + f(3) + try: + f(4) + assert 0 + except ValueError: + pass + + def test_max_depth(self): + md = MaxDepth(2) + t = {'a': OptionMarker(ListMarker(md))} + md.template = t + f = compile_template(md) + f({'a': [{'a': []}]}) + f({'a': [{'a': []}]}) + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + f({'a': [{'a': []}]}) + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + + def test_use_compiled(self): + x = compile_template('a') + y = compile_template(ListMarker(x)) + y(['a']) + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher.py b/dispatcher.py new file mode 100644 index 0000000..08c62db --- /dev/null +++ b/dispatcher.py @@ -0,0 +1,216 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from bsddb3 import db ## find this at http://pybsddb.sf.net/ +from bsddb3._db import DBNotFoundError +import time +import hash +from node import Node +from bencode import bencode, bdecode +#from threading import RLock + +# max number of incoming or outgoing messages to process at a time +NUM_EVENTS = 5 + +class Transaction: + __slots__ = ['responseTemplate', 'id', 'dispatcher', 'target', 'payload', 'response', 'default', 'timeout'] + def __init__(self, dispatcher, node, response_handler, default_handler, id = None, payload = None, timeout=60): + if id == None: + id = hash.newID() + self.id = id + self.dispatcher = dispatcher + self.target = node + self.payload = payload + self.response = response_handler + self.default = default_handler + self.timeout = time.time() + timeout + + def setPayload(self, payload): + self.payload = payload + + def setResponseTemplate(self, t): + self.responseTemplate = t + + def responseHandler(self, msg): + if self.responseTemplate and callable(self.responseTemplate): + try: + self.responseTemplate(msg) + except ValueError, reason: + print "response %s" % (reason) + print `msg['id'], self.target.id` + return + self.response(self, msg) + + def defaultHandler(self): + self.default(self) + + def dispatch(self): + if callable(self.response) and callable(self.default): + self.dispatcher.initiate(self) + else: + self.dispatchNoResponse() + def dispatchNoResponse(self): + self.dispatcher.initiateNoResponse(self) + + + +class Dispatcher: + def __init__(self, listener, base_template, id): + self.id = id + self.listener = listener + self.transactions = {} + self.handlers = {} + self.timeout = db.DB() + self.timeout.set_flags(db.DB_DUP) + self.timeout.open(None, None, db.DB_BTREE) + self.BASE = base_template + self.stopped = 0 + #self.tlock = RLock() + + def registerHandler(self, key, handler, template): + assert(callable(handler)) + assert(callable(template)) + self.handlers[key] = (handler, template) + + def initiate(self, transaction): + #self.tlock.acquire() + #ignore messages to ourself + if transaction.target.id == self.id: + return + self.transactions[transaction.id] = transaction + self.timeout.put(`transaction.timeout`, transaction.id) + ## queue the message! + self.listener.qMsg(transaction.payload, transaction.target.host, transaction.target.port) + #self.tlock.release() + + def initiateNoResponse(self, transaction): + #ignore messages to ourself + if transaction.target.id == self.id: + return + #self.tlock.acquire() + self.listener.qMsg(transaction.payload, transaction.target.host, transaction.target.port) + #self.tlock.release() + + def postEvent(self, callback, delay, extras=None): + #self.tlock.acquire() + t = Transaction(self, None, None, callback, timeout=delay) + t.extras = extras + self.transactions[t.id] = t + self.timeout.put(`t.timeout`, t.id) + #self.tlock.release() + + def flushExpiredEvents(self): + events = 0 + tstamp = `time.time()` + #self.tlock.acquire() + c = self.timeout.cursor() + e = c.first() + while e and e[0] < tstamp: + events = events + 1 + try: + t = self.transactions[e[1]] + del(self.transactions[e[1]]) + except KeyError: + # transaction must have completed or was otherwise cancelled + pass + ## default callback! + else: + t.defaultHandler() + tmp = c.next() + # handle duplicates in a silly way + if tmp and e != tmp: + self.timeout.delete(e[0]) + e = tmp + #self.tlock.release() + return events + + def flushOutgoing(self): + events = 0 + n = self.listener.qLen() + if n > NUM_EVENTS: + n = NUM_EVENTS + for i in range(n): + self.listener.dispatchMsg() + events = events + 1 + return events + + def handleIncoming(self): + events = 0 + #self.tlock.acquire() + for i in range(NUM_EVENTS): + try: + msg, addr = self.listener.receiveMsg() + except ValueError: + break + + ## decode message, handle message! + try: + msg = bdecode(msg) + except ValueError: + # wrongly encoded message? + print "Bogus message received: %s" % msg + continue + try: + # check base template for correctness + self.BASE(msg) + except ValueError, reason: + # bad message! + print "Incoming message: %s" % reason + continue + try: + # check to see if we already know about this transaction + t = self.transactions[msg['tid']] + if msg['id'] != t.target.id and t.target.id != " "*20: + # we're expecting a response from someone else + if msg['id'] == self.id: + print "received our own response! " + `self.id` + else: + print "response from wrong peer! "+ `msg['id'],t.target.id` + else: + del(self.transactions[msg['tid']]) + self.timeout.delete(`t.timeout`) + t.addr = addr + # call transaction response handler + t.responseHandler(msg) + except KeyError: + # we don't know about it, must be unsolicited + n = Node(msg['id'], addr[0], addr[1]) + t = Transaction(self, n, None, None, msg['tid']) + if self.handlers.has_key(msg['type']): + ## handle this transaction + try: + # check handler template + self.handlers[msg['type']][1](msg) + except ValueError, reason: + print "BAD MESSAGE: %s" % reason + else: + self.handlers[msg['type']][0](t, msg) + else: + ## no transaction, no handler, drop it on the floor! + pass + events = events + 1 + #self.tlock.release() + return events + + def stop(self): + self.stopped = 1 + + def run(self): + self.stopped = 0 + while(not self.stopped): + events = self.runOnce() + ## sleep + if events == 0: + time.sleep(0.1) + + def runOnce(self): + events = 0 + ## handle some incoming messages + events = events + self.handleIncoming() + ## process some outstanding events + events = events + self.flushExpiredEvents() + ## send outgoing messages + events = events + self.flushOutgoing() + return events + + + diff --git a/hash.py b/hash.py new file mode 100644 index 0000000..8fa3e8e --- /dev/null +++ b/hash.py @@ -0,0 +1,108 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from sha import sha +from whrandom import randrange + +## takes a 20 bit hash, big-endian, and returns it expressed a python integer +## ha ha ha ha if this were a C module I wouldn't resort to such sillyness +def intify(hstr): + assert(len(hstr) == 20) + i = 0L + i = i + ord(hstr[19]) + i = i + ord(hstr[18]) * 256L + i = i + ord(hstr[17]) * 65536L + i = i + ord(hstr[16]) * 16777216L + i = i + ord(hstr[15]) * 4294967296L + i = i + ord(hstr[14]) * 1099511627776L + i = i + ord(hstr[13]) * 281474976710656L + i = i + ord(hstr[12]) * 72057594037927936L + i = i + ord(hstr[11]) * 18446744073709551616L + i = i + ord(hstr[10]) * 4722366482869645213696L + i = i + ord(hstr[9]) * 1208925819614629174706176L + i = i + ord(hstr[8]) * 309485009821345068724781056L + i = i + ord(hstr[7]) * 79228162514264337593543950336L + i = i + ord(hstr[6]) * 20282409603651670423947251286016L + i = i + ord(hstr[5]) * 5192296858534827628530496329220096L + i = i + ord(hstr[4]) * 1329227995784915872903807060280344576L + i = i + ord(hstr[3]) * 340282366920938463463374607431768211456L + i = i + ord(hstr[2]) * 87112285931760246646623899502532662132736L + i = i + ord(hstr[1]) * 22300745198530623141535718272648361505980416L + i = i + ord(hstr[0]) * 5708990770823839524233143877797980545530986496L + return i + +## returns the distance between two 160-bit hashes expressed as 20-character strings +def distance(a, b): + return intify(a) ^ intify(b) + + +## returns a new pseudorandom globally unique ID string +def newID(): + h = sha() + for i in range(20): + h.update(chr(randrange(0,256))) + return h.digest() + +def randRange(min, max): + return min + intify(newID()) % (max - min) + +import unittest + +class NewID(unittest.TestCase): + def testLength(self): + self.assertEqual(len(newID()), 20) + def testHundreds(self): + for x in xrange(100): + self.testLength + +class Intify(unittest.TestCase): + known = [('\0' * 20, 0), + ('\xff' * 20, 2**160 - 1), + ] + def testKnown(self): + for str, value in self.known: + self.assertEqual(intify(str), value) + def testEndianessOnce(self): + h = newID() + while h[-1] == '\xff': + h = newID() + k = h[:-1] + chr(ord(h[-1]) + 1) + self.assertEqual(intify(k) - intify(h), 1) + def testEndianessLots(self): + for x in xrange(100): + self.testEndianessOnce() + +class Disantance(unittest.TestCase): + known = [ + (("\0" * 20, "\xff" * 20), 2**160 -1), + ((sha("foo").digest(), sha("foo").digest()), 0), + ((sha("bar").digest(), sha("bar").digest()), 0) + ] + def testKnown(self): + for pair, dist in self.known: + self.assertEqual(distance(pair[0], pair[1]), dist) + def testCommutitive(self): + for i in xrange(100): + x, y, z = newID(), newID(), newID() + self.assertEqual(distance(x,y) ^ distance(y, z), distance(x, z)) + +class RandRange(unittest.TestCase): + def testOnce(self): + a = intify(newID()) + b = intify(newID()) + if a < b: + c = randRange(a, b) + self.assertEqual(a <= c < b, 1, "output out of range %d %d %d" % (b, c, a)) + else: + c = randRange(b, a) + assert b <= c < a, "output out of range %d %d %d" % (b, c, a) + + def testOneHundredTimes(self): + for i in xrange(100): + self.testOnce() + + + +if __name__ == '__main__': + unittest.main() + + \ No newline at end of file diff --git a/khashmir.py b/khashmir.py new file mode 100644 index 0000000..363d0a0 --- /dev/null +++ b/khashmir.py @@ -0,0 +1,488 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from listener import Listener +from ktable import KTable, K +from node import Node +from dispatcher import Dispatcher +from hash import newID, intify +import messages +import transactions + +import time + +from bsddb3 import db ## find this at http://pybsddb.sf.net/ +from bsddb3._db import DBNotFoundError + +# don't ping unless it's been at least this many seconds since we've heard from a peer +MAX_PING_INTERVAL = 60 * 15 # fifteen minutes + +# concurrent FIND_NODE/VALUE requests! +N = 3 + + +# this is the main class! +class Khashmir: + __slots__ = ['listener', 'node', 'table', 'dispatcher', 'tf', 'store'] + def __init__(self, host, port): + self.listener = Listener(host, port) + self.node = Node(newID(), host, port) + self.table = KTable(self.node) + self.dispatcher = Dispatcher(self.listener, messages.BASE, self.node.id) + self.tf = transactions.TransactionFactory(self.node.id, self.dispatcher) + + self.store = db.DB() + self.store.open(None, None, db.DB_BTREE) + + #### register unsolicited incoming message handlers + self.dispatcher.registerHandler('ping', self._pingHandler, messages.PING) + + self.dispatcher.registerHandler('find node', self._findNodeHandler, messages.FIND_NODE) + + self.dispatcher.registerHandler('get value', self._findValueHandler, messages.GET_VALUE) + + self.dispatcher.registerHandler('store value', self._storeValueHandler, messages.STORE_VALUE) + + + ####### + ####### LOCAL INTERFACE - use these methods! + def addContact(self, host, port): + """ + ping this node and add the contact info to the table on pong! + """ + n =Node(" "*20, host, port) # note, we + self.sendPing(n) + + + ## this call is async! + def findNode(self, id, callback): + """ returns the contact info for node, or the k closest nodes, from the global table """ + # get K nodes out of local table/cache, or the node we want + nodes = self.table.findNodes(id) + if len(nodes) == 1 and nodes[0].id == id : + # we got it in our table! + def tcall(t, callback=callback): + callback(t.extras) + self.dispatcher.postEvent(tcall, 0, extras=nodes) + else: + # create our search state + state = FindNode(self, self.dispatcher, id, callback) + # handle this in our own thread + self.dispatcher.postEvent(state.goWithNodes, 0, extras=nodes) + + + ## also async + def valueForKey(self, key, callback): + """ returns the values found for key in global table """ + nodes = self.table.findNodes(key) + # create our search state + state = GetValue(self, self.dispatcher, key, callback) + # handle this in our own thread + self.dispatcher.postEvent(state.goWithNodes, 0, extras=nodes) + + + ## async, but in the current implementation there is no guarantee a store does anything so there is no callback right now + def storeValueForKey(self, key, value): + """ stores the value for key in the global table, returns immediately, no status + in this implementation, peers respond but don't indicate status to storing values + values are stored in peers on a first-come first-served basis + this will probably change so more than one value can be stored under a key + """ + def _storeValueForKey(nodes, tf=self.tf, key=key, value=value, response= self._storedValueHandler, default= lambda t: "didn't respond"): + for node in nodes: + if node.id != self.node.id: + t = tf.StoreValue(node, key, value, response, default) + t.dispatch() + # this call is asynch + self.findNode(key, _storeValueForKey) + + + def insertNode(self, n): + """ + insert a node in our local table, pinging oldest contact in bucket, if necessary + + If all you have is a host/port, then use addContact, which calls this function after + receiving the PONG from the remote node. The reason for the seperation is we can't insert + a node into the table without it's peer-ID. That means of course the node passed into this + method needs to be a properly formed Node object with a valid ID. + """ + old = self.table.insertNode(n) + if old and (time.time() - old.lastSeen) > MAX_PING_INTERVAL and old.id != self.node.id: + # the bucket is full, check to see if old node is still around and if so, replace it + t = self.tf.Ping(old, self._notStaleNodeHandler, self._staleNodeHandler) + t.newnode = n + t.dispatch() + + + def sendPing(self, node): + """ + ping a node + """ + t = self.tf.Ping(node, self._pongHandler, self._defaultPong) + t.dispatch() + + + def findCloseNodes(self): + """ + This does a findNode on the ID one away from our own. + This will allow us to populate our table with nodes on our network closest to our own. + This is called as soon as we start up with an empty table + """ + id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256) + def callback(nodes): + pass + self.findNode(id, callback) + + def refreshTable(self): + """ + + """ + def callback(nodes): + pass + + for bucket in self.table.buckets: + if time.time() - bucket.lastAccessed >= 60 * 60: + id = randRange(bucket.min, bucket.max) + self.findNode(id, callback) + + + ##### + ##### UNSOLICITED INCOMING MESSAGE HANDLERS + + def _pingHandler(self, t, msg): + #print "Got PING from %s at %s:%s" % (`t.target.id`, t.target.host, t.target.port) + self.insertNode(t.target) + # respond, no callbacks, we don't care if they get it or not + nt = self.tf.Pong(t) + nt.dispatch() + + def _findNodeHandler(self, t, msg): + #print "Got FIND_NODES from %s:%s at %s:%s" % (t.target.host, t.target.port, self.node.host, self.node.port) + nodes = self.table.findNodes(msg['target']) + # respond, no callbacks, we don't care if they get it or not + nt = self.tf.GotNodes(t, nodes) + nt.dispatch() + + def _storeValueHandler(self, t, msg): + if not self.store.has_key(msg['key']): + self.store.put(msg['key'], msg['value']) + nt = self.tf.StoredValue(t) + nt.dispatch() + + def _findValueHandler(self, t, msg): + if self.store.has_key(msg['key']): + t = self.tf.GotValues(t, [(msg['key'], self.store[msg['key']])]) + else: + nodes = self.table.findNodes(msg['key']) + t = self.tf.GotNodes(t, nodes) + t.dispatch() + + + ### + ### message response callbacks + # called when we get a response to store value + def _storedValueHandler(self, t, msg): + self.table.insertNode(t.target) + + + ## these are the callbacks used when we ping the oldest node in a bucket + def _staleNodeHandler(self, t): + """ called if the pinged node never responds """ + self.table.replaceStaleNode(t.target, t.newnode) + + def _notStaleNodeHandler(self, t, msg): + """ called when we get a ping from the remote node """ + self.table.insertNode(t.target) + + + ## these are the callbacks we use when we issue a PING + def _pongHandler(self, t, msg): + #print "Got PONG from %s at %s:%s" % (`msg['id']`, t.target.host, t.target.port) + n = Node(msg['id'], t.addr[0], t.addr[1]) + self.table.insertNode(n) + + def _defaultPong(self, t): + # this should probably increment a failed message counter and dump the node if it gets over a threshold + print "Never got PONG from %s at %s:%s" % (`t.target.id`, t.target.host, t.target.port) + + + +class ActionBase: + """ base class for some long running asynchronous proccesses like finding nodes or values """ + def __init__(self, table, dispatcher, target, callback): + self.table = table + self.dispatcher = dispatcher + self.target = target + self.int = intify(target) + self.found = {} + self.queried = {} + self.answered = {} + self.callback = callback + self.outstanding = 0 + self.finished = 0 + + def sort(a, b, int=self.int): + """ this function is for sorting nodes relative to the ID we are looking for """ + x, y = int ^ a.int, int ^ b.int + if x > y: + return 1 + elif x < y: + return -1 + return 0 + self.sort = sort + + def goWithNodes(self, t): + pass + +class FindNode(ActionBase): + """ find node action merits it's own class as it is a long running stateful process """ + def handleGotNodes(self, t, msg): + if self.finished or self.answered.has_key(t.id): + # a day late and a dollar short + return + self.outstanding = self.outstanding - 1 + self.answered[t.id] = 1 + for node in msg['nodes']: + if not self.found.has_key(node['id']): + n = Node(node['id'], node['host'], node['port']) + self.found[n.id] = n + self.table.insertNode(n) + self.schedule() + + def schedule(self): + """ + send messages to new peers, if necessary + """ + if self.finished: + return + l = self.found.values() + l.sort(self.sort) + + for node in l[:K]: + if node.id == self.target: + self.finished=1 + return self.callback([node]) + if not self.queried.has_key(node.id) and node.id != self.table.node.id: + t = self.table.tf.FindNode(node, self.target, self.handleGotNodes, self.defaultGotNodes) + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + t.timeout = time.time() + 15 + t.dispatch() + if self.outstanding >= N: + break + assert(self.outstanding) >=0 + if self.outstanding == 0: + ## all done!! + self.finished=1 + self.callback(l[:K]) + + def defaultGotNodes(self, t): + if self.finished: + return + self.outstanding = self.outstanding - 1 + self.schedule() + + + def goWithNodes(self, t): + """ + this starts the process, our argument is a transaction with t.extras being our list of nodes + it's a transaction since we got called from the dispatcher + """ + nodes = t.extras + for node in nodes: + if node.id == self.table.node.id: + continue + self.found[node.id] = node + t = self.table.tf.FindNode(node, self.target, self.handleGotNodes, self.defaultGotNodes) + t.timeout = time.time() + 15 + t.dispatch() + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + if self.outstanding == 0: + self.callback(nodes) + + + +class GetValue(FindNode): + """ get value task """ + def handleGotNodes(self, t, msg): + if self.finished or self.answered.has_key(t.id): + # a day late and a dollar short + return + self.outstanding = self.outstanding - 1 + self.answered[t.id] = 1 + # go through nodes + # if we have any closer than what we already got, query them + if msg['type'] == 'got nodes': + for node in msg['nodes']: + if not self.found.has_key(node['id']): + n = Node(node['id'], node['host'], node['port']) + self.found[n.id] = n + self.table.insertNode(n) + elif msg['type'] == 'got values': + ## done + self.finished = 1 + return self.callback(msg['values']) + self.schedule() + + ## get value + def schedule(self): + if self.finished: + return + l = self.found.values() + l.sort(self.sort) + + for node in l[:K]: + if not self.queried.has_key(node.id) and node.id != self.table.node.id: + t = self.table.tf.GetValue(node, self.target, self.handleGotNodes, self.defaultGotNodes) + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + t.timeout = time.time() + 15 + t.dispatch() + if self.outstanding >= N: + break + assert(self.outstanding) >=0 + if self.outstanding == 0: + ## all done, didn't find it!! + self.finished=1 + self.callback([]) + + ## get value + def goWithNodes(self, t): + nodes = t.extras + for node in nodes: + if node.id == self.table.node.id: + continue + self.found[node.id] = node + t = self.table.tf.GetValue(node, self.target, self.handleGotNodes, self.defaultGotNodes) + t.timeout = time.time() + 15 + t.dispatch() + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + if self.outstanding == 0: + self.callback([]) + + +#------ +def test_build_net(quiet=0): + from whrandom import randrange + import thread + port = 2001 + l = [] + peers = 100 + + if not quiet: + print "Building %s peer table." % peers + + for i in xrange(peers): + a = Khashmir('localhost', port + i) + l.append(a) + + def run(l=l): + while(1): + events = 0 + for peer in l: + events = events + peer.dispatcher.runOnce() + if events == 0: + time.sleep(.25) + + for i in range(10): + thread.start_new_thread(run, (l[i*10:(i+1)*10],)) + #thread.start_new_thread(l[i].dispatcher.run, ()) + + for peer in l[1:]: + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + + time.sleep(5) + + for peer in l: + peer.findCloseNodes() + time.sleep(5) + for peer in l: + peer.refreshTable() + return l + +def test_find_nodes(l, quiet=0): + import threading, sys + from whrandom import randrange + flag = threading.Event() + + n = len(l) + + a = l[randrange(0,n)] + b = l[randrange(0,n)] + + def callback(nodes, l=l, flag=flag): + if (len(nodes) >0) and (nodes[0].id == b.node.id): + print "test_find_nodes PASSED" + else: + print "test_find_nodes FAILED" + flag.set() + a.findNode(b.node.id, callback) + flag.wait() + +def test_find_value(l, quiet=0): + from whrandom import randrange + from sha import sha + import time, threading, sys + + fa = threading.Event() + fb = threading.Event() + fc = threading.Event() + + n = len(l) + a = l[randrange(0,n)] + b = l[randrange(0,n)] + c = l[randrange(0,n)] + d = l[randrange(0,n)] + + key = sha(`randrange(0,100000)`).digest() + value = sha(`randrange(0,100000)`).digest() + if not quiet: + print "inserting value...", + sys.stdout.flush() + a.storeValueForKey(key, value) + time.sleep(3) + print "finding..." + + def mc(flag, value=value): + def callback(values, f=flag, val=value): + try: + if(len(values) == 0): + print "find FAILED" + else: + if values[0]['value'] != val: + print "find FAILED" + else: + print "find FOUND" + finally: + f.set() + return callback + b.valueForKey(key, mc(fa)) + c.valueForKey(key, mc(fb)) + d.valueForKey(key, mc(fc)) + + fa.wait() + fb.wait() + fc.wait() + +if __name__ == "__main__": + l = test_build_net() + time.sleep(3) + print "finding nodes..." + test_find_nodes(l) + test_find_nodes(l) + test_find_nodes(l) + print "inserting and fetching values..." + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + for i in l: + i.dispatcher.stop() diff --git a/ktable.py b/ktable.py new file mode 100644 index 0000000..9cd3732 --- /dev/null +++ b/ktable.py @@ -0,0 +1,231 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +import hash +from bisect import * +import time +from types import * + +from node import Node + +# The all-powerful, magical Kademlia "k" constant, bucket depth +K = 20 + +# how many bits wide is our hash? +HASH_LENGTH = 160 + + +# the local routing table for a kademlia like distributed hash table +class KTable: + def __init__(self, node): + # this is the root node, a.k.a. US! + self.node = node + self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)] + self.insertNode(node) + + def _bucketIndexForInt(self, int): + """returns the index of the bucket that should hold int""" + return bisect_left(self.buckets, int) + + def findNodes(self, id): + """ return k nodes in our own local table closest to the ID + ignoreSelf means we will return K closest nodes to ourself if we search for our own ID + note, K closest nodes may actually include ourself, it's the callers responsibilty to + not send messages to itself if it matters + """ + if type(id) == StringType: + int = hash.intify(id) + elif type(id) == InstanceType: + int = id.int + elif type(id) == IntType or type(id) == LongType: + int = id + else: + raise TypeError, "findLocalNodes requires an int, string, or Node instance type" + + nodes = [] + + def sort(a, b, int=int): + """ this function is for sorting nodes relative to the ID we are looking for """ + x, y = int ^ a.int, int ^ b.int + if x > y: + return 1 + elif x < y: + return -1 + return 0 + + i = self._bucketIndexForInt(int) + + ## see if this node is already in our table and return it + try: + index = self.buckets[i].l.index(int) + except ValueError: + pass + else: + self.buckets[i].touch() + return [self.buckets[i].l[index]] + + nodes = nodes + self.buckets[i].l + if len(nodes) == K: + nodes.sort(sort) + return nodes + else: + # need more nodes + min = i - 1 + max = i + 1 + while (len(nodes) < K and (min >= 0 and max < len(self.buckets))): + if min >= 0: + nodes = nodes + self.buckets[min].l + self.buckets[min].touch() + if max < len(self.buckets): + nodes = nodes + self.buckets[max].l + self.buckets[max].touch() + + nodes.sort(sort) + return nodes[:K-1] + + def _splitBucket(self, a): + diff = (a.max - a.min) / 2 + b = KBucket([], a.max - diff, a.max) + self.buckets.insert(self.buckets.index(a.min) + 1, b) + a.max = a.max - diff + # transfer nodes to new bucket + for anode in a.l[:]: + if anode.int >= a.max: + a.l.remove(anode) + b.l.append(anode) + + def replaceStaleNode(self, stale, new): + """ this is used by clients to replace a node returned by insertNode after + it fails to respond to a Pong message + """ + i = self._bucketIndexForInt(stale.int) + try: + it = self.buckets[i].l.index(stale.int) + except ValueError: + return + + del(self.buckets[i].l[it]) + self.buckets[i].l.append(new) + + def insertNode(self, node): + """ + this insert the node, returning None if successful, returns the oldest node in the bucket if it's full + the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!! + """ + # get the bucket for this node + i = self. _bucketIndexForInt(node.int) + ## check to see if node is in the bucket already + try: + it = self.buckets[i].l.index(node.int) + except ValueError: + ## no + pass + else: + node.updateLastSeen() + # move node to end of bucket + del(self.buckets[i].l[it]) + self.buckets[i].l.append(node) + self.buckets[i].touch() + return + + # we don't have this node, check to see if the bucket is full + if len(self.buckets[i].l) < K: + # no, append this node and return + self.buckets[i].l.append(node) + self.buckets[i].touch() + return + + # bucket is full, check to see if self.node is in the bucket + try: + me = self.buckets[i].l.index(self.node) + except ValueError: + return self.buckets[i].l[0] + + ## this bucket is full and contains our node, split the bucket + if len(self.buckets) >= HASH_LENGTH: + # our table is FULL + print "Hash Table is FULL! Increase K!" + return + + self._splitBucket(self.buckets[i]) + + ## now that the bucket is split and balanced, try to insert the node again + return self.insertNode(node) + + def justSeenNode(self, node): + """ call this any time you get a message from a node, to update it in the table if it's there """ + try: + n = self.findNodes(node.int)[0] + except IndexError: + return None + else: + tstamp = n.lastSeen + n.updateLastSeen() + return tstamp + + +class KBucket: + __slots = ['min', 'max', 'lastAccessed'] + def __init__(self, contents, min, max): + self.l = contents + self.min = min + self.max = max + self.lastAccessed = time.time() + + def touch(self): + self.lastAccessed = time.time() + + def getNodeWithInt(self, int): + try: + return self.l[self.l.index(int)] + self.touch() + except IndexError: + raise ValueError + + def __repr__(self): + return "" % (len(self.l), self.min, self.max) + + ## comparators, necessary for bisecting list of buckets with a hash expressed as an integer or a distance + ## compares integer or node object with the bucket's range + def __lt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.max <= a + def __le__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min < a + def __gt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min > a + def __ge__(self, a): + if type(a) == InstanceType: + a = a.int + return self.max >= a + def __eq__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min <= a and self.max > a + def __ne__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min >= a or self.max < a + + + +############## +import unittest + +class TestKTable(unittest.TestCase): + def setUp(self): + self.a = Node(hash.newID(), 'localhost', 2002) + self.t = KTable(self.a) + + def test_replace_stale_node(self): + self.b = Node(hash.newID(), 'localhost', 2003) + self.t.replaceStaleNode(self.a, self.b) + assert(len(self.t.buckets[0].l) == 1) + assert(self.t.buckets[0].l[0].id == self.b.id) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/listener.py b/listener.py new file mode 100644 index 0000000..3999b32 --- /dev/null +++ b/listener.py @@ -0,0 +1,100 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from socket import * + +# simple UDP communicator + +class Listener: + def __init__(self, host, port): + self.msgq = [] + self.sock = socket(AF_INET, SOCK_DGRAM) + self.sock.setblocking(0) + self.sock.bind((host, port)) + + def qMsg(self, msg, host, port): + self.msgq.append((msg, host, port)) + + def qLen(self): + return len(self.msgq) + + def dispatchMsg(self): + if self.qLen() > 0: + msg, host, port = self.msgq[0] + del self.msgq[0] + self.sock.sendto(msg, 0, (host, port)) + + def receiveMsg(self): + msg = () + try: + msg = self.sock.recvfrom(65536) + except error, tup: + if tup[1] == "Resource temporarily unavailable": + # no message + return msg + print error, tup + else: + return msg + + def __del__(self): + self.sock.close() + + + +########################### +import unittest + +class ListenerTest(unittest.TestCase): + def setUp(self): + self.a = Listener('localhost', 8080) + self.b = Listener('localhost', 8081) + def tearDown(self): + del(self.a) + del(self.b) + + def testQueue(self): + assert self.a.qLen() == 0, "expected queue to be empty" + self.a.qMsg('hello', 'localhost', 8081) + assert self.a.qLen() == 1, "expected one message to be in queue" + self.a.qMsg('hello', 'localhost', 8081) + assert self.a.qLen() == 2, "expected two messages to be in queue" + self.a.dispatchMsg() + assert self.a.qLen() == 1, "expected one message to be in queue" + self.a.dispatchMsg() + assert self.a.qLen() == 0, "expected all messages to be flushed from queue" + + def testSendReceiveOne(self): + self.a.qMsg('hello', 'localhost', 8081) + self.a.dispatchMsg() + + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg() == (), "received unexpected message" + + self.b.qMsg('hello', 'localhost', 8080) + self.b.dispatchMsg() + + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + + assert self.a.receiveMsg() == (), "received unexpected message" + + def testSendReceiveInterleaved(self): + self.a.qMsg('hello', 'localhost', 8081) + self.a.qMsg('hello', 'localhost', 8081) + self.a.dispatchMsg() + self.a.dispatchMsg() + + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg() == (), "received unexpected message" + + self.b.qMsg('hello', 'localhost', 8080) + self.b.qMsg('hello', 'localhost', 8080) + self.b.dispatchMsg() + self.b.dispatchMsg() + + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.a.receiveMsg() == (), "received unexpected message" + + +if __name__ == '__main__': + unittest.main() diff --git a/messages.py b/messages.py new file mode 100644 index 0000000..542fe23 --- /dev/null +++ b/messages.py @@ -0,0 +1,186 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from bencode import bencode, bdecode +from btemplate import * +from node import Node + + +# template checker for hash id +def hashid(thing, verbose): + if type(thing) != type(''): + raise ValueError, 'must be a string' + if len(thing) != 20: + raise ValueError, 'must be 20 characters long' + +## our messages +BASE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : string_template}) + +PING = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'ping'}) +PONG = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'pong'}) + +FIND_NODE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'find node', "target" : hashid}) +GOT_NODES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got nodes', "nodes" : ListMarker({'id': hashid, 'host': string_template, 'port': 1})}) + +STORE_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'store value', "key" : hashid, "value" : string_template}) +STORED_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'stored value'}) + +GET_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'get value', "key" : hashid}) +GOT_VALUES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got values', "values" : ListMarker({'key': hashid, 'value': string_template})}) + +GOT_NODES_OR_VALUES = compile_template([GOT_NODES, GOT_VALUES]) + + +class MessageFactory: + def __init__(self, id): + self.id = id + + def encodePing(self, tid): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'ping'}) + def decodePing(self, msg): + msg = bdecode(msg) + PING(msg) + return msg + + def encodePong(self, tid): + msg = {'id' : self.id, 'tid' : tid, 'type' : 'pong'} + PONG(msg) + return bencode(msg) + def decodePong(self, msg): + msg = bdecode(msg) + PONG(msg) + return msg + + def encodeFindNode(self, tid, target): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'find node', 'target' : target}) + def decodeFindNode(self, msg): + msg = bdecode(msg) + FIND_NODE(msg) + return msg + + def encodeStoreValue(self, tid, key, value): + return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'store value', 'value' : value}) + def decodeStoreValue(self, msg): + msg = bdecode(msg) + STORE_VALUE(msg) + return msg + + + def encodeStoredValue(self, tid): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'stored value'}) + def decodeStoredValue(self, msg): + msg = bdecode(msg) + STORED_VALUE(msg) + return msg + + + def encodeGetValue(self, tid, key): + return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'get value'}) + def decodeGetValue(self, msg): + msg = bdecode(msg) + GET_VALUE(msg) + return msg + + def encodeGotNodes(self, tid, nodes): + n = [] + for node in nodes: + n.append({'id' : node.id, 'host' : node.host, 'port' : node.port}) + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got nodes', 'nodes' : n}) + def decodeGotNodes(self, msg): + msg = bdecode(msg) + GOT_NODES(msg) + return msg + + def encodeGotValues(self, tid, values): + n = [] + for value in values: + n.append({'key' : value[0], 'value' : value[1]}) + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got values', 'values' : n}) + def decodeGotValues(self, msg): + msg = bdecode(msg) + GOT_VALUES(msg) + return msg + + + +###### +import unittest + +class TestMessageEncoding(unittest.TestCase): + def setUp(self): + from sha import sha + self.a = sha('a').digest() + self.b = sha('b').digest() + + + def test_ping(self): + m = MessageFactory(self.a) + s = m.encodePing(self.b) + msg = m.decodePing(s) + PING(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_pong(self): + m = MessageFactory(self.a) + s = m.encodePong(self.b) + msg = m.decodePong(s) + PONG(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_find_node(self): + m = MessageFactory(self.a) + s = m.encodeFindNode(self.a, self.b) + msg = m.decodeFindNode(s) + FIND_NODE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['target'] == self.b) + + def test_store_value(self): + m = MessageFactory(self.a) + s = m.encodeStoreValue(self.a, self.b, 'foo') + msg = m.decodeStoreValue(s) + STORE_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['key'] == self.b) + assert(msg['value'] == 'foo') + + def test_stored_value(self): + m = MessageFactory(self.a) + s = m.encodeStoredValue(self.b) + msg = m.decodeStoredValue(s) + STORED_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_get_value(self): + m = MessageFactory(self.a) + s = m.encodeGetValue(self.a, self.b) + msg = m.decodeGetValue(s) + GET_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['key'] == self.b) + + def test_got_nodes(self): + m = MessageFactory(self.a) + s = m.encodeGotNodes(self.a, [Node(self.b, 'localhost', 2002), Node(self.a, 'localhost', 2003)]) + msg = m.decodeGotNodes(s) + GOT_NODES(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['nodes'][0]['id'] == self.b) + + def test_got_values(self): + m = MessageFactory(self.a) + s = m.encodeGotValues(self.a, [(self.b, 'localhost')]) + msg = m.decodeGotValues(s) + GOT_VALUES(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + + +if __name__ == "__main__": + unittest.main() diff --git a/node.py b/node.py new file mode 100644 index 0000000..cb1940e --- /dev/null +++ b/node.py @@ -0,0 +1,56 @@ +import hash +import time +from types import * + +class Node: + """encapsulate contact info""" + def __init__(self, id, host, port): + self.id = id + self.int = hash.intify(id) + self.host = host + self.port = port + self.lastSeen = time.time() + + def updateLastSeen(self): + self.lastSeen = time.time() + + def __repr__(self): + return `(self.id, self.host, self.port)` + + ## these comparators let us bisect/index a list full of nodes with either a node or an int/long + def __lt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int < a + def __le__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int <= a + def __gt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int > a + def __ge__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int >= a + def __eq__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int == a + def __ne__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int != a + + +import unittest + +class TestNode(unittest.TestCase): + def setUp(self): + self.node = Node(hash.newID(), 'localhost', 2002) + def testUpdateLastSeen(self): + t = self.node.lastSeen + self.node.updateLastSeen() + assert t < self.node.lastSeen + \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..044aeff --- /dev/null +++ b/test.py @@ -0,0 +1,10 @@ +import unittest + +import hash, node, messages +import listener, dispatcher +import ktable, transactions, khashmir + +import bencode, btemplate + +tests = unittest.defaultTestLoader.loadTestsFromNames(['hash', 'node', 'bencode', 'btemplate', 'listener', 'messages', 'dispatcher', 'transactions', 'ktable']) +result = unittest.TextTestRunner().run(tests) diff --git a/transactions.py b/transactions.py new file mode 100644 index 0000000..b862aaa --- /dev/null +++ b/transactions.py @@ -0,0 +1,71 @@ +import messages +from dispatcher import Transaction + +class TransactionFactory: + def __init__(self, id, dispatcher): + self.id = id + self.dispatcher = dispatcher + self.mf = messages.MessageFactory(self.id) + + def Ping(self, node, response, default): + """ create a ping transaction """ + t = Transaction(self.dispatcher, node, response, default) + str = self.mf.encodePing(t.id) + t.setPayload(str) + t.setResponseTemplate(messages.PONG) + return t + + def FindNode(self, target, key, response, default): + """ find node query """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeFindNode(t.id, key) + t.setPayload(str) + t.setResponseTemplate(messages.GOT_NODES) + return t + + def StoreValue(self, target, key, value, response, default): + """ find node query """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeStoreValue(t.id, key, value) + t.setPayload(str) + t.setResponseTemplate(messages.STORED_VALUE) + return t + + def GetValue(self, target, key, response, default): + """ find value query, response is GOT_VALUES or GOT_NODES! """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeGetValue(t.id, key) + t.setPayload(str) + t.setResponseTemplate(messages.GOT_NODES_OR_VALUES) + return t + + def Pong(self, ping_t): + """ create a pong response to ping transaction """ + t = Transaction(self.dispatcher, ping_t.target, None, None, ping_t.id) + str = self.mf.encodePong(t.id) + t.setPayload(str) + return t + + def GotNodes(self, findNode_t, nodes): + """ respond with gotNodes """ + t = Transaction(self.dispatcher, findNode_t.target, None, None, findNode_t.id) + str = self.mf.encodeGotNodes(t.id, nodes) + t.setPayload(str) + return t + + def GotValues(self, findNode_t, values): + """ respond with gotNodes """ + t = Transaction(self.dispatcher, findNode_t.target, None, None, findNode_t.id) + str = self.mf.encodeGotValues(t.id, values) + t.setPayload(str) + return t + + def StoredValue(self, tr): + """ store value response, really just a pong """ + t = Transaction(self.dispatcher, tr.target, None, None, id = tr.id) + str = self.mf.encodeStoredValue(t.id) + t.setPayload(str) + return t + + +###########