From 1abfc43747d252f5d0bf11f119e0daabecc4f3a6 Mon Sep 17 00:00:00 2001 From: burris Date: Fri, 19 Jul 2002 20:23:12 +0000 Subject: [PATCH] Initial revision --- README.txt | 10 + bencode.py | 335 ++++++++++++++++++++++++++++++ btemplate.py | 528 ++++++++++++++++++++++++++++++++++++++++++++++++ dispatcher.py | 216 ++++++++++++++++++++ hash.py | 108 ++++++++++ khashmir.py | 488 ++++++++++++++++++++++++++++++++++++++++++++ ktable.py | 231 +++++++++++++++++++++ listener.py | 100 +++++++++ messages.py | 186 +++++++++++++++++ node.py | 56 +++++ test.py | 10 + transactions.py | 71 +++++++ 12 files changed, 2339 insertions(+) create mode 100644 README.txt create mode 100644 bencode.py create mode 100644 btemplate.py create mode 100644 dispatcher.py create mode 100644 hash.py create mode 100644 khashmir.py create mode 100644 ktable.py create mode 100644 listener.py create mode 100644 messages.py create mode 100644 node.py create mode 100644 test.py create mode 100644 transactions.py diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..141e5fa --- /dev/null +++ b/README.txt @@ -0,0 +1,10 @@ +quick example: + +import khashmir, threading +k = khashmir.Khashmir('127.0.0.1', 4444) +start_new_thread(k.dispatcher.run, ()) +k.addContact('127.0.0.1', 8080) # right now we don't do gethostbyname +k.populateTable() + + +alternatively, you can call k.dispatcher.runOnce() periodically from whatever thread you choose \ No newline at end of file diff --git a/bencode.py b/bencode.py new file mode 100644 index 0000000..77f65da --- /dev/null +++ b/bencode.py @@ -0,0 +1,335 @@ +""" +A library for streaming and unstreaming of simple objects, designed +for speed, compactness, and ease of implementation. + +The basic functions are bencode and bdecode. bencode takes an object +and returns a string, bdecode takes a string and returns an object. +bdecode raises a ValueError if you give it an invalid string. + +The objects passed in may be nested dicts, lists, ints, strings, +and None. For example, all of the following may be bencoded - + +{'a': [0, 1], 'b': None} + +[None, ['a', 2, ['c', None]]] + +{'spam': (2,3,4)} + +{'name': 'Cronus', 'spouse': 'Rhea', 'children': ['Hades', 'Poseidon']} + +In general bdecode(bencode(spam)) == spam, but tuples and lists are +encoded the same, so bdecode(bencode((0, 1))) is [0, 1] rather +than (0, 1). Longs and ints are also encoded the same way, so +bdecode(bencode(4)) is a long. + +dict keys are required to be strings, to avoid a mess of potential +implementation incompatibilities. bencode is intended to be used +for protocols which are going to be re-implemented many times, so +it's very conservative in that regard. + +Which type is encoded is determined by the first character, 'i', 'n', +'d', 'l' and any digit. They indicate integer, null, dict, list, and +string, respectively. + +Strings are length-prefixed in base 10, followed by a colon. + +bencode('spam') == '4:spam' + +Nulls are indicated by a single 'n'. + +bencode(None) == 'n' + +integers are encoded base 10 and terminated with an 'e'. + +bencode(3) == 'i3e' +bencode(-20) == 'i-20e' + +Lists are encoded in list order, terminated by an 'e' - + +bencode(['abc', 'd']) == 'l3:abc1:de' +bencode([2, 'f']) == 'li2e1:fe' + +Dicts are encoded by containing alternating keys and values, +with the keys in sorted order, terminated by an 'e'. For example - + +bencode({'spam': 'eggs'}) == 'd4:spam4:eggse' +bencode({'ab': 2, 'a': None}) == 'd1:an2:abi2ee' + +Truncated strings come first, so in sort order 'a' comes before 'abc'. + +If a function is passed to bencode, it's called and it's return value +is included as a raw string, for example - + +bdecode(bencode(lambda: None)) == None +""" + +# This file is licensed under the GNU Lesser General Public License v2.1. +# originally written for Mojo Nation by Bryce Wilcox, Bram Cohen, and Greg P. Smith +# since then, almost completely rewritten by Bram Cohen + +from types import * +from cStringIO import StringIO +import re + +def bencode(data): + """ + encodes objects as strings, see module documentation for more info + """ + result = StringIO() + bwrite(data, result) + return result.getvalue() + +def bwrite(data, result): + encoder = encoders.get(type(data)) + assert encoder is not None, 'unsupported data type: ' + `type(data)` + encoder(data, result) + +encoders = {} + +def encode_int(data, result): + result.write('i' + str(data) + 'e') + +encoders[IntType] = encode_int +encoders[LongType] = encode_int + +def encode_list(data, result): + result.write('l') + for i in data: + bwrite(i, result) + result.write('e') + +encoders[TupleType] = encode_list +encoders[ListType] = encode_list + +def encode_string(data, result): + result.write(str(len(data)) + ':' + data) + +encoders[StringType] = encode_string + +def encode_dict(data, result): + result.write('d') + keys = data.keys() + keys.sort() + for key in keys: + assert type(key) is StringType, 'bencoded dictionary key must be a string' + bwrite(key, result) + bwrite(data[key], result) + result.write('e') + +encoders[DictType] = encode_dict + +encoders[NoneType] = lambda data, result: result.write('n') + +encoders[FunctionType] = lambda data, result: result.write(data()) +encoders[MethodType] = encoders[FunctionType] + +def bdecode(s): + """ + Does the opposite of bencode. Raises a ValueError if there's a problem. + """ + try: + result, index = bread(s, 0) + if index != len(s): + raise ValueError('left over stuff at end') + return result + except IndexError, e: + raise ValueError(str(e)) + except KeyError, e: + raise ValueError(str(e)) + +def bread(s, index): + return decoders[s[index]](s, index) + +decoders = {} + +_bre = re.compile(r'(0|[1-9][0-9]*):') + +def decode_raw_string(s, index): + x = _bre.match(s, index) + if x is None: + raise ValueError('invalid integer encoding') + endindex = x.end() + long(s[index:x.end() - 1]) + if endindex > len(s): + raise ValueError('length encoding indicated premature end of string') + return s[x.end(): endindex], endindex + +for c in '0123456789': + decoders[c] = decode_raw_string + +_int_re = re.compile(r'i(0|-?[1-9][0-9]*)e') + +def decode_int(s, index): + x = _int_re.match(s, index) + if x is None: + raise ValueError('invalid integer encoding') + return long(s[index + 1:x.end() - 1]), x.end() + +decoders['i'] = decode_int + +decoders['n'] = lambda s, index: (None, index + 1) + +def decode_list(s, index): + result = [] + index += 1 + while s[index] != 'e': + next, index = bread(s, index) + result.append(next) + return result, index + 1 + +decoders['l'] = decode_list + +def decode_dict(s, index): + result = {} + index += 1 + prevkey = None + while s[index] != 'e': + key, index = decode_raw_string(s, index) + if key <= prevkey: + raise ValueError("out of order keys") + prevkey = key + value, index = bread(s, index) + result[key] = value + return result, index + 1 + +decoders['d'] = decode_dict + +def test_decode_raw_string(): + assert decode_raw_string('1:a', 0) == ('a', 3) + assert decode_raw_string('0:', 0) == ('', 2) + assert decode_raw_string('10:aaaaaaaaaaaaaaaaaaaaaaaaa', 0) == ('aaaaaaaaaa', 13) + assert decode_raw_string('10:', 1) == ('', 3) + try: + decode_raw_string('01:a', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('--1:a', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('h', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('h:', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('1', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('', 0) + assert 0, 'failed' + except ValueError: + pass + try: + decode_raw_string('5:a', 0) + assert 0, 'failed' + except ValueError: + pass + +def test_dict_enforces_order(): + bdecode('d1:an1:bne') + try: + bdecode('d1:bn1:ane') + assert 0, 'failed' + except ValueError: + pass + +def test_dict_forbids_non_string_key(): + try: + bdecode('di3ene') + assert 0, 'failed' + except ValueError: + pass + +def test_dict_forbids_key_repeat(): + try: + bdecode('d1:an1:ane') + assert 0, 'failed' + except ValueError: + pass + +def test_empty_dict(): + assert bdecode('de') == {} + +def test_ValueError_in_decode_unknown(): + try: + bdecode('x') + assert 0, 'flunked' + except ValueError: + pass + +def test_encode_and_decode_none(): + assert bdecode(bencode(None)) == None + +def test_encode_and_decode_long(): + assert bdecode(bencode(-23452422452342L)) == -23452422452342L + +def test_encode_and_decode_int(): + assert bdecode(bencode(2)) == 2 + +def test_decode_noncanonical_int(): + try: + bdecode('i03e') + assert 0 + except ValueError: + pass + try: + bdecode('i3 e') + assert 0 + except ValueError: + pass + try: + bdecode('i 3e') + assert 0 + except ValueError: + pass + try: + bdecode('i-0e') + assert 0 + except ValueError: + pass + +def test_encode_and_decode_dict(): + x = {'42': 3} + assert bdecode(bencode(x)) == x + +def test_encode_and_decode_list(): + assert bdecode(bencode([])) == [] + +def test_encode_and_decode_tuple(): + assert bdecode(bencode(())) == [] + +def test_encode_and_decode_empty_dict(): + assert bdecode(bencode({})) == {} + +def test_encode_and_decode_complex_object(): + spam = [[], 0, -3, -345234523543245234523L, {}, 'spam', None, {'a': [3]}, {}] + assert bencode(bdecode(bencode(spam))) == bencode(spam) + assert bdecode(bencode(spam)) == spam + +def test_unfinished_list(): + try: + bdecode('ln') + assert 0 + except ValueError: + pass + +def test_unfinished_dict(): + try: + bdecode('d') + assert 0 + except ValueError: + pass + try: + bdecode('d1:a') + assert 0 + except ValueError: + pass diff --git a/btemplate.py b/btemplate.py new file mode 100644 index 0000000..038a1ac --- /dev/null +++ b/btemplate.py @@ -0,0 +1,528 @@ +# This file is licensed under the GNU Lesser General Public License v2.1. +# originally written for Mojo Nation by Bram Cohen, based on an earlier +# version by Bryce Wilcox +# The authors disclaim all liability for any damages resulting from +# any use of this software. + +import types + +def string_template(thing, verbose): + if type(thing) != types.StringType: + raise ValueError, "not a string" + +st = string_template + +def exact_length(l): + def func(s, verbose, l = l): + if type(s) != types.StringType: + raise ValueError, 'should have been string' + if len(s) != l: + raise ValueError, 'wrong length, should have been ' + str(l) + ' was ' + str(len(s)) + return func + +class MaxDepth: + def __init__(self, max_depth, template = None): + assert max_depth >= 0 + self.max_depth = max_depth + self.template = template + + def get_real_template(self): + assert self.template is not None, 'You forgot to set the template!' + if self.max_depth == 0: + return fail_too_deep + self.max_depth -= 1 + try: + return compile_inner(self.template) + finally: + self.max_depth += 1 + + def __repr__(self): + if hasattr(self, 'p'): + return '...' + try: + self.p = 1 + return 'MaxDepth(' + str(self.max_depth) + ', ' + `self.template` + ')' + finally: + del self.p + +def fail_too_deep(thing, verbose): + raise ValueError, 'recursed too deep' + +class ListMarker: + def __init__(self, template): + self.template = template + + def get_real_template(self): + return compile_list_template(self.template) + + def __repr__(self): + return 'ListMarker(' + `self.template` + ')' + +def compile_list_template(template): + def func(thing, verbose, template = compile_inner(template)): + if type(thing) not in (types.ListType, types.TupleType): + raise ValueError, 'not a list' + if verbose: + try: + for i in xrange(0, len(thing)): + template(thing[i], 1) + except ValueError, e: + reason = 'mismatch at index ' + str(i) + ': ' + str(e) + raise ValueError, reason + else: + for i in thing: + template(i, 0) + return func + +class ValuesMarker: + def __init__(self, template, t2 = string_template): + self.template = template + self.t2 = t2 + + def get_real_template(self): + return compile_values_template(self.template, self.t2) + + def __repr__(self): + return 'ValuesMarker(' + `self.template` + ')' + +def compile_values_template(template, t2): + def func(thing, verbose, template = compile_inner(template), + t2 = compile_inner(t2)): + if type(thing) != types.DictType: + raise ValueError, 'not a dict' + if verbose: + try: + for key, val in thing.items(): + template(val, 1) + t2(key, 1) + except ValueError, e: + raise ValueError, 'mismatch in key ' + `key` + ': ' + str(e) + else: + for key, val in thing.items(): + template(val, 0) + t2(key, 0) + return func + +compilers = {} + +def compile_string_template(template): + assert type(template) is types.StringType + def func(thing, verbose, template = template): + if thing != template: + raise ValueError, "didn't match string" + return func + +compilers[types.StringType] = compile_string_template + +def int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + +def nonnegative_int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + if thing < 0: + raise ValueError, 'thing less than zero' + +def positive_int_template(thing, verbose): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'thing not of integer type' + if thing <= 0: + raise ValueError, 'thing less than or equal to zero' + +def compile_int_template(s): + assert s in (-1, 0, 1) + if s == -1: + return int_template + elif s == 0: + return nonnegative_int_template + else: + return positive_int_template + +compilers[types.IntType] = compile_int_template +compilers[types.LongType] = compile_int_template + +def compile_slice(template): + assert type(template) is types.SliceType + assert template.step is None + assert template.stop is not None + start = template.start + if start is None: + start = 0 + def func(thing, verbose, start = start, stop = template.stop): + if type(thing) not in (types.IntType, types.LongType): + raise ValueError, 'not an int' + if thing < start: + raise ValueError, 'thing too small' + if thing >= stop: + raise ValueError, 'thing too large' + return func + +compilers[types.SliceType] = compile_slice + +class OptionMarker: + def __init__(self, template): + self.option_template = template + + def __repr__(self): + return 'OptionMarker(' + `self.option_template` + ')' + +def compile_dict_template(template): + assert type(template) is types.DictType + agroup = [] + bgroup = [] + cgroup = [] + optiongroup = [] + for key, value in template.items(): + if hasattr(value, 'option_template'): + optiongroup.append((key, compile_inner(value.option_template))) + elif type(value) is types.StringType: + agroup.append((key, compile_inner(value))) + elif type(value) in (types.IntType, types.LongType, types.SliceType): + bgroup.append((key, compile_inner(value))) + else: + cgroup.append((key, compile_inner(value))) + def func(thing, verbose, required = agroup + bgroup + cgroup, optional = optiongroup): + if type(thing) is not types.DictType: + raise ValueError, 'not a dict' + try: + for key, template in required: + if not thing.has_key(key): + raise ValueError, 'key not present' + template(thing[key], verbose) + for key, template in optional: + if thing.has_key(key): + template(thing[key], verbose) + except ValueError, e: + if verbose: + reason = 'mismatch in key ' + `key` + ': ' + str(e) + raise ValueError, reason + else: + raise + return func + +compilers[types.DictType] = compile_dict_template + +def none_template(thing, verbose): + if thing is not None: + raise ValueError, 'thing was not None' + +compilers[types.NoneType] = lambda template: none_template + +def compile_or_template(template): + assert type(template) in (types.ListType, types.TupleType) + def func(thing, verbose, templ = [compile_inner(x) for x in template]): + if verbose: + failure_reason = ('did not match any of the ' + + str(len(templ)) + ' possible templates;') + for i in xrange(len(templ)): + try: + templ[i](thing, 1) + return + except ValueError, reason: + failure_reason += (' failed template at index ' + + str(i) + ' because (' + str(reason) + ')') + raise ValueError, failure_reason + else: + for i in templ: + try: + i(thing, 0) + return + except ValueError: + pass + raise ValueError, "did not match any possible templates" + return func + +compilers[types.ListType] = compile_or_template +compilers[types.TupleType] = compile_or_template + +def compile_inner(template): + while hasattr(template, 'get_real_template'): + template = template.get_real_template() + if callable(template): + return template + return compilers[type(template)](template) + +def compile_template(template): + def func(thing, verbose = None, t = compile_inner(template), s = `template`): + if verbose is not None: + t(thing, verbose) + return + try: + t(thing, 0) + except ValueError: + try: + t(thing, 1) + assert 0 + except ValueError, reason: + raise ValueError, 'failed template check because: (' + str(reason) + ') target was: (' + `thing` + ') template was: (' + s + ')' + return func + + + +###### +import unittest + +class TestBTemplate(unittest.TestCase): + + def test_slice(self): + f = compile_template(slice(4)) + f(0) + f(3L) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f(4L) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(slice(-2, 3)) + f(-2L) + f(2) + try: + f(-3L) + assert 0 + except ValueError: + pass + try: + f(3) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_int(self): + f = compile_template(0) + f(0) + f(1L) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(-1) + f(0) + f(1) + f(-1L) + try: + f('a') + assert 0 + except ValueError: + pass + + f = compile_template(1) + try: + f(0) + assert 0 + except ValueError: + pass + f(1) + try: + f(-1) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_none(self): + f = compile_template(None) + f(None) + try: + f(0) + assert 0 + except ValueError: + pass + + def test_string(self): + f = compile_template('a') + f('a') + try: + f('b') + assert 0 + except ValueError: + pass + try: + f(0) + assert 0 + except ValueError: + pass + + def test_generic_string(self): + f = compile_template(st) + f('a') + try: + f(0) + assert 0 + except ValueError: + pass + + def test_values(self): + vt = compile_template(ValuesMarker('a', exact_length(1))) + vt({}) + vt({'x': 'a'}) + try: + vt(3) + assert 0 + except ValueError: + pass + try: + vt({'x': 'b'}) + assert 0 + except ValueError: + pass + try: + vt({'xx': 'a'}) + assert 0 + except ValueError: + pass + + def test_list(self): + f = compile_template(ListMarker('a')) + f(['a']) + f(('a', 'a')) + try: + f(('a', 'b')) + assert 0 + except ValueError: + pass + try: + f(('b', 'a')) + assert 0 + except ValueError: + pass + try: + f('a') + assert 0 + except ValueError: + pass + + def test_or(self): + f = compile_template(['a', 'b']) + f('a') + f('b') + try: + f('c') + assert 0 + except ValueError: + pass + + f = compile_template(('a', 'b')) + f('a') + f('b') + try: + f('c') + assert 0 + except ValueError: + pass + + def test_dict(self): + f = compile_template({'a': 'b', 'c': OptionMarker('d')}) + try: + f({}) + assert 0 + except ValueError: + pass + f({'a': 'b'}) + try: + f({'a': 'e'}) + assert 0 + except ValueError: + pass + try: + f({'c': 'd'}) + assert 0 + except ValueError: + pass + f({'a': 'b', 'c': 'd'}) + try: + f({'a': 'e', 'c': 'd'}) + assert 0 + except ValueError: + pass + try: + f({'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f({'a': 'b', 'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f({'a': 'e', 'c': 'f'}) + assert 0 + except ValueError: + pass + try: + f(None) + assert 0 + except ValueError: + pass + + def test_other_func(self): + def check3(thing, verbose): + if thing != 3: + raise ValueError + f = compile_template(check3) + f(3) + try: + f(4) + assert 0 + except ValueError: + pass + + def test_max_depth(self): + md = MaxDepth(2) + t = {'a': OptionMarker(ListMarker(md))} + md.template = t + f = compile_template(md) + f({'a': [{'a': []}]}) + f({'a': [{'a': []}]}) + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + f({'a': [{'a': []}]}) + try: + f({'a': [{'a': [{}]}]}) + assert 0 + except ValueError: + pass + + def test_use_compiled(self): + x = compile_template('a') + y = compile_template(ListMarker(x)) + y(['a']) + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher.py b/dispatcher.py new file mode 100644 index 0000000..08c62db --- /dev/null +++ b/dispatcher.py @@ -0,0 +1,216 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from bsddb3 import db ## find this at http://pybsddb.sf.net/ +from bsddb3._db import DBNotFoundError +import time +import hash +from node import Node +from bencode import bencode, bdecode +#from threading import RLock + +# max number of incoming or outgoing messages to process at a time +NUM_EVENTS = 5 + +class Transaction: + __slots__ = ['responseTemplate', 'id', 'dispatcher', 'target', 'payload', 'response', 'default', 'timeout'] + def __init__(self, dispatcher, node, response_handler, default_handler, id = None, payload = None, timeout=60): + if id == None: + id = hash.newID() + self.id = id + self.dispatcher = dispatcher + self.target = node + self.payload = payload + self.response = response_handler + self.default = default_handler + self.timeout = time.time() + timeout + + def setPayload(self, payload): + self.payload = payload + + def setResponseTemplate(self, t): + self.responseTemplate = t + + def responseHandler(self, msg): + if self.responseTemplate and callable(self.responseTemplate): + try: + self.responseTemplate(msg) + except ValueError, reason: + print "response %s" % (reason) + print `msg['id'], self.target.id` + return + self.response(self, msg) + + def defaultHandler(self): + self.default(self) + + def dispatch(self): + if callable(self.response) and callable(self.default): + self.dispatcher.initiate(self) + else: + self.dispatchNoResponse() + def dispatchNoResponse(self): + self.dispatcher.initiateNoResponse(self) + + + +class Dispatcher: + def __init__(self, listener, base_template, id): + self.id = id + self.listener = listener + self.transactions = {} + self.handlers = {} + self.timeout = db.DB() + self.timeout.set_flags(db.DB_DUP) + self.timeout.open(None, None, db.DB_BTREE) + self.BASE = base_template + self.stopped = 0 + #self.tlock = RLock() + + def registerHandler(self, key, handler, template): + assert(callable(handler)) + assert(callable(template)) + self.handlers[key] = (handler, template) + + def initiate(self, transaction): + #self.tlock.acquire() + #ignore messages to ourself + if transaction.target.id == self.id: + return + self.transactions[transaction.id] = transaction + self.timeout.put(`transaction.timeout`, transaction.id) + ## queue the message! + self.listener.qMsg(transaction.payload, transaction.target.host, transaction.target.port) + #self.tlock.release() + + def initiateNoResponse(self, transaction): + #ignore messages to ourself + if transaction.target.id == self.id: + return + #self.tlock.acquire() + self.listener.qMsg(transaction.payload, transaction.target.host, transaction.target.port) + #self.tlock.release() + + def postEvent(self, callback, delay, extras=None): + #self.tlock.acquire() + t = Transaction(self, None, None, callback, timeout=delay) + t.extras = extras + self.transactions[t.id] = t + self.timeout.put(`t.timeout`, t.id) + #self.tlock.release() + + def flushExpiredEvents(self): + events = 0 + tstamp = `time.time()` + #self.tlock.acquire() + c = self.timeout.cursor() + e = c.first() + while e and e[0] < tstamp: + events = events + 1 + try: + t = self.transactions[e[1]] + del(self.transactions[e[1]]) + except KeyError: + # transaction must have completed or was otherwise cancelled + pass + ## default callback! + else: + t.defaultHandler() + tmp = c.next() + # handle duplicates in a silly way + if tmp and e != tmp: + self.timeout.delete(e[0]) + e = tmp + #self.tlock.release() + return events + + def flushOutgoing(self): + events = 0 + n = self.listener.qLen() + if n > NUM_EVENTS: + n = NUM_EVENTS + for i in range(n): + self.listener.dispatchMsg() + events = events + 1 + return events + + def handleIncoming(self): + events = 0 + #self.tlock.acquire() + for i in range(NUM_EVENTS): + try: + msg, addr = self.listener.receiveMsg() + except ValueError: + break + + ## decode message, handle message! + try: + msg = bdecode(msg) + except ValueError: + # wrongly encoded message? + print "Bogus message received: %s" % msg + continue + try: + # check base template for correctness + self.BASE(msg) + except ValueError, reason: + # bad message! + print "Incoming message: %s" % reason + continue + try: + # check to see if we already know about this transaction + t = self.transactions[msg['tid']] + if msg['id'] != t.target.id and t.target.id != " "*20: + # we're expecting a response from someone else + if msg['id'] == self.id: + print "received our own response! " + `self.id` + else: + print "response from wrong peer! "+ `msg['id'],t.target.id` + else: + del(self.transactions[msg['tid']]) + self.timeout.delete(`t.timeout`) + t.addr = addr + # call transaction response handler + t.responseHandler(msg) + except KeyError: + # we don't know about it, must be unsolicited + n = Node(msg['id'], addr[0], addr[1]) + t = Transaction(self, n, None, None, msg['tid']) + if self.handlers.has_key(msg['type']): + ## handle this transaction + try: + # check handler template + self.handlers[msg['type']][1](msg) + except ValueError, reason: + print "BAD MESSAGE: %s" % reason + else: + self.handlers[msg['type']][0](t, msg) + else: + ## no transaction, no handler, drop it on the floor! + pass + events = events + 1 + #self.tlock.release() + return events + + def stop(self): + self.stopped = 1 + + def run(self): + self.stopped = 0 + while(not self.stopped): + events = self.runOnce() + ## sleep + if events == 0: + time.sleep(0.1) + + def runOnce(self): + events = 0 + ## handle some incoming messages + events = events + self.handleIncoming() + ## process some outstanding events + events = events + self.flushExpiredEvents() + ## send outgoing messages + events = events + self.flushOutgoing() + return events + + + diff --git a/hash.py b/hash.py new file mode 100644 index 0000000..8fa3e8e --- /dev/null +++ b/hash.py @@ -0,0 +1,108 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from sha import sha +from whrandom import randrange + +## takes a 20 bit hash, big-endian, and returns it expressed a python integer +## ha ha ha ha if this were a C module I wouldn't resort to such sillyness +def intify(hstr): + assert(len(hstr) == 20) + i = 0L + i = i + ord(hstr[19]) + i = i + ord(hstr[18]) * 256L + i = i + ord(hstr[17]) * 65536L + i = i + ord(hstr[16]) * 16777216L + i = i + ord(hstr[15]) * 4294967296L + i = i + ord(hstr[14]) * 1099511627776L + i = i + ord(hstr[13]) * 281474976710656L + i = i + ord(hstr[12]) * 72057594037927936L + i = i + ord(hstr[11]) * 18446744073709551616L + i = i + ord(hstr[10]) * 4722366482869645213696L + i = i + ord(hstr[9]) * 1208925819614629174706176L + i = i + ord(hstr[8]) * 309485009821345068724781056L + i = i + ord(hstr[7]) * 79228162514264337593543950336L + i = i + ord(hstr[6]) * 20282409603651670423947251286016L + i = i + ord(hstr[5]) * 5192296858534827628530496329220096L + i = i + ord(hstr[4]) * 1329227995784915872903807060280344576L + i = i + ord(hstr[3]) * 340282366920938463463374607431768211456L + i = i + ord(hstr[2]) * 87112285931760246646623899502532662132736L + i = i + ord(hstr[1]) * 22300745198530623141535718272648361505980416L + i = i + ord(hstr[0]) * 5708990770823839524233143877797980545530986496L + return i + +## returns the distance between two 160-bit hashes expressed as 20-character strings +def distance(a, b): + return intify(a) ^ intify(b) + + +## returns a new pseudorandom globally unique ID string +def newID(): + h = sha() + for i in range(20): + h.update(chr(randrange(0,256))) + return h.digest() + +def randRange(min, max): + return min + intify(newID()) % (max - min) + +import unittest + +class NewID(unittest.TestCase): + def testLength(self): + self.assertEqual(len(newID()), 20) + def testHundreds(self): + for x in xrange(100): + self.testLength + +class Intify(unittest.TestCase): + known = [('\0' * 20, 0), + ('\xff' * 20, 2**160 - 1), + ] + def testKnown(self): + for str, value in self.known: + self.assertEqual(intify(str), value) + def testEndianessOnce(self): + h = newID() + while h[-1] == '\xff': + h = newID() + k = h[:-1] + chr(ord(h[-1]) + 1) + self.assertEqual(intify(k) - intify(h), 1) + def testEndianessLots(self): + for x in xrange(100): + self.testEndianessOnce() + +class Disantance(unittest.TestCase): + known = [ + (("\0" * 20, "\xff" * 20), 2**160 -1), + ((sha("foo").digest(), sha("foo").digest()), 0), + ((sha("bar").digest(), sha("bar").digest()), 0) + ] + def testKnown(self): + for pair, dist in self.known: + self.assertEqual(distance(pair[0], pair[1]), dist) + def testCommutitive(self): + for i in xrange(100): + x, y, z = newID(), newID(), newID() + self.assertEqual(distance(x,y) ^ distance(y, z), distance(x, z)) + +class RandRange(unittest.TestCase): + def testOnce(self): + a = intify(newID()) + b = intify(newID()) + if a < b: + c = randRange(a, b) + self.assertEqual(a <= c < b, 1, "output out of range %d %d %d" % (b, c, a)) + else: + c = randRange(b, a) + assert b <= c < a, "output out of range %d %d %d" % (b, c, a) + + def testOneHundredTimes(self): + for i in xrange(100): + self.testOnce() + + + +if __name__ == '__main__': + unittest.main() + + \ No newline at end of file diff --git a/khashmir.py b/khashmir.py new file mode 100644 index 0000000..363d0a0 --- /dev/null +++ b/khashmir.py @@ -0,0 +1,488 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from listener import Listener +from ktable import KTable, K +from node import Node +from dispatcher import Dispatcher +from hash import newID, intify +import messages +import transactions + +import time + +from bsddb3 import db ## find this at http://pybsddb.sf.net/ +from bsddb3._db import DBNotFoundError + +# don't ping unless it's been at least this many seconds since we've heard from a peer +MAX_PING_INTERVAL = 60 * 15 # fifteen minutes + +# concurrent FIND_NODE/VALUE requests! +N = 3 + + +# this is the main class! +class Khashmir: + __slots__ = ['listener', 'node', 'table', 'dispatcher', 'tf', 'store'] + def __init__(self, host, port): + self.listener = Listener(host, port) + self.node = Node(newID(), host, port) + self.table = KTable(self.node) + self.dispatcher = Dispatcher(self.listener, messages.BASE, self.node.id) + self.tf = transactions.TransactionFactory(self.node.id, self.dispatcher) + + self.store = db.DB() + self.store.open(None, None, db.DB_BTREE) + + #### register unsolicited incoming message handlers + self.dispatcher.registerHandler('ping', self._pingHandler, messages.PING) + + self.dispatcher.registerHandler('find node', self._findNodeHandler, messages.FIND_NODE) + + self.dispatcher.registerHandler('get value', self._findValueHandler, messages.GET_VALUE) + + self.dispatcher.registerHandler('store value', self._storeValueHandler, messages.STORE_VALUE) + + + ####### + ####### LOCAL INTERFACE - use these methods! + def addContact(self, host, port): + """ + ping this node and add the contact info to the table on pong! + """ + n =Node(" "*20, host, port) # note, we + self.sendPing(n) + + + ## this call is async! + def findNode(self, id, callback): + """ returns the contact info for node, or the k closest nodes, from the global table """ + # get K nodes out of local table/cache, or the node we want + nodes = self.table.findNodes(id) + if len(nodes) == 1 and nodes[0].id == id : + # we got it in our table! + def tcall(t, callback=callback): + callback(t.extras) + self.dispatcher.postEvent(tcall, 0, extras=nodes) + else: + # create our search state + state = FindNode(self, self.dispatcher, id, callback) + # handle this in our own thread + self.dispatcher.postEvent(state.goWithNodes, 0, extras=nodes) + + + ## also async + def valueForKey(self, key, callback): + """ returns the values found for key in global table """ + nodes = self.table.findNodes(key) + # create our search state + state = GetValue(self, self.dispatcher, key, callback) + # handle this in our own thread + self.dispatcher.postEvent(state.goWithNodes, 0, extras=nodes) + + + ## async, but in the current implementation there is no guarantee a store does anything so there is no callback right now + def storeValueForKey(self, key, value): + """ stores the value for key in the global table, returns immediately, no status + in this implementation, peers respond but don't indicate status to storing values + values are stored in peers on a first-come first-served basis + this will probably change so more than one value can be stored under a key + """ + def _storeValueForKey(nodes, tf=self.tf, key=key, value=value, response= self._storedValueHandler, default= lambda t: "didn't respond"): + for node in nodes: + if node.id != self.node.id: + t = tf.StoreValue(node, key, value, response, default) + t.dispatch() + # this call is asynch + self.findNode(key, _storeValueForKey) + + + def insertNode(self, n): + """ + insert a node in our local table, pinging oldest contact in bucket, if necessary + + If all you have is a host/port, then use addContact, which calls this function after + receiving the PONG from the remote node. The reason for the seperation is we can't insert + a node into the table without it's peer-ID. That means of course the node passed into this + method needs to be a properly formed Node object with a valid ID. + """ + old = self.table.insertNode(n) + if old and (time.time() - old.lastSeen) > MAX_PING_INTERVAL and old.id != self.node.id: + # the bucket is full, check to see if old node is still around and if so, replace it + t = self.tf.Ping(old, self._notStaleNodeHandler, self._staleNodeHandler) + t.newnode = n + t.dispatch() + + + def sendPing(self, node): + """ + ping a node + """ + t = self.tf.Ping(node, self._pongHandler, self._defaultPong) + t.dispatch() + + + def findCloseNodes(self): + """ + This does a findNode on the ID one away from our own. + This will allow us to populate our table with nodes on our network closest to our own. + This is called as soon as we start up with an empty table + """ + id = self.node.id[:-1] + chr((ord(self.node.id[-1]) + 1) % 256) + def callback(nodes): + pass + self.findNode(id, callback) + + def refreshTable(self): + """ + + """ + def callback(nodes): + pass + + for bucket in self.table.buckets: + if time.time() - bucket.lastAccessed >= 60 * 60: + id = randRange(bucket.min, bucket.max) + self.findNode(id, callback) + + + ##### + ##### UNSOLICITED INCOMING MESSAGE HANDLERS + + def _pingHandler(self, t, msg): + #print "Got PING from %s at %s:%s" % (`t.target.id`, t.target.host, t.target.port) + self.insertNode(t.target) + # respond, no callbacks, we don't care if they get it or not + nt = self.tf.Pong(t) + nt.dispatch() + + def _findNodeHandler(self, t, msg): + #print "Got FIND_NODES from %s:%s at %s:%s" % (t.target.host, t.target.port, self.node.host, self.node.port) + nodes = self.table.findNodes(msg['target']) + # respond, no callbacks, we don't care if they get it or not + nt = self.tf.GotNodes(t, nodes) + nt.dispatch() + + def _storeValueHandler(self, t, msg): + if not self.store.has_key(msg['key']): + self.store.put(msg['key'], msg['value']) + nt = self.tf.StoredValue(t) + nt.dispatch() + + def _findValueHandler(self, t, msg): + if self.store.has_key(msg['key']): + t = self.tf.GotValues(t, [(msg['key'], self.store[msg['key']])]) + else: + nodes = self.table.findNodes(msg['key']) + t = self.tf.GotNodes(t, nodes) + t.dispatch() + + + ### + ### message response callbacks + # called when we get a response to store value + def _storedValueHandler(self, t, msg): + self.table.insertNode(t.target) + + + ## these are the callbacks used when we ping the oldest node in a bucket + def _staleNodeHandler(self, t): + """ called if the pinged node never responds """ + self.table.replaceStaleNode(t.target, t.newnode) + + def _notStaleNodeHandler(self, t, msg): + """ called when we get a ping from the remote node """ + self.table.insertNode(t.target) + + + ## these are the callbacks we use when we issue a PING + def _pongHandler(self, t, msg): + #print "Got PONG from %s at %s:%s" % (`msg['id']`, t.target.host, t.target.port) + n = Node(msg['id'], t.addr[0], t.addr[1]) + self.table.insertNode(n) + + def _defaultPong(self, t): + # this should probably increment a failed message counter and dump the node if it gets over a threshold + print "Never got PONG from %s at %s:%s" % (`t.target.id`, t.target.host, t.target.port) + + + +class ActionBase: + """ base class for some long running asynchronous proccesses like finding nodes or values """ + def __init__(self, table, dispatcher, target, callback): + self.table = table + self.dispatcher = dispatcher + self.target = target + self.int = intify(target) + self.found = {} + self.queried = {} + self.answered = {} + self.callback = callback + self.outstanding = 0 + self.finished = 0 + + def sort(a, b, int=self.int): + """ this function is for sorting nodes relative to the ID we are looking for """ + x, y = int ^ a.int, int ^ b.int + if x > y: + return 1 + elif x < y: + return -1 + return 0 + self.sort = sort + + def goWithNodes(self, t): + pass + +class FindNode(ActionBase): + """ find node action merits it's own class as it is a long running stateful process """ + def handleGotNodes(self, t, msg): + if self.finished or self.answered.has_key(t.id): + # a day late and a dollar short + return + self.outstanding = self.outstanding - 1 + self.answered[t.id] = 1 + for node in msg['nodes']: + if not self.found.has_key(node['id']): + n = Node(node['id'], node['host'], node['port']) + self.found[n.id] = n + self.table.insertNode(n) + self.schedule() + + def schedule(self): + """ + send messages to new peers, if necessary + """ + if self.finished: + return + l = self.found.values() + l.sort(self.sort) + + for node in l[:K]: + if node.id == self.target: + self.finished=1 + return self.callback([node]) + if not self.queried.has_key(node.id) and node.id != self.table.node.id: + t = self.table.tf.FindNode(node, self.target, self.handleGotNodes, self.defaultGotNodes) + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + t.timeout = time.time() + 15 + t.dispatch() + if self.outstanding >= N: + break + assert(self.outstanding) >=0 + if self.outstanding == 0: + ## all done!! + self.finished=1 + self.callback(l[:K]) + + def defaultGotNodes(self, t): + if self.finished: + return + self.outstanding = self.outstanding - 1 + self.schedule() + + + def goWithNodes(self, t): + """ + this starts the process, our argument is a transaction with t.extras being our list of nodes + it's a transaction since we got called from the dispatcher + """ + nodes = t.extras + for node in nodes: + if node.id == self.table.node.id: + continue + self.found[node.id] = node + t = self.table.tf.FindNode(node, self.target, self.handleGotNodes, self.defaultGotNodes) + t.timeout = time.time() + 15 + t.dispatch() + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + if self.outstanding == 0: + self.callback(nodes) + + + +class GetValue(FindNode): + """ get value task """ + def handleGotNodes(self, t, msg): + if self.finished or self.answered.has_key(t.id): + # a day late and a dollar short + return + self.outstanding = self.outstanding - 1 + self.answered[t.id] = 1 + # go through nodes + # if we have any closer than what we already got, query them + if msg['type'] == 'got nodes': + for node in msg['nodes']: + if not self.found.has_key(node['id']): + n = Node(node['id'], node['host'], node['port']) + self.found[n.id] = n + self.table.insertNode(n) + elif msg['type'] == 'got values': + ## done + self.finished = 1 + return self.callback(msg['values']) + self.schedule() + + ## get value + def schedule(self): + if self.finished: + return + l = self.found.values() + l.sort(self.sort) + + for node in l[:K]: + if not self.queried.has_key(node.id) and node.id != self.table.node.id: + t = self.table.tf.GetValue(node, self.target, self.handleGotNodes, self.defaultGotNodes) + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + t.timeout = time.time() + 15 + t.dispatch() + if self.outstanding >= N: + break + assert(self.outstanding) >=0 + if self.outstanding == 0: + ## all done, didn't find it!! + self.finished=1 + self.callback([]) + + ## get value + def goWithNodes(self, t): + nodes = t.extras + for node in nodes: + if node.id == self.table.node.id: + continue + self.found[node.id] = node + t = self.table.tf.GetValue(node, self.target, self.handleGotNodes, self.defaultGotNodes) + t.timeout = time.time() + 15 + t.dispatch() + self.outstanding = self.outstanding + 1 + self.queried[node.id] = 1 + if self.outstanding == 0: + self.callback([]) + + +#------ +def test_build_net(quiet=0): + from whrandom import randrange + import thread + port = 2001 + l = [] + peers = 100 + + if not quiet: + print "Building %s peer table." % peers + + for i in xrange(peers): + a = Khashmir('localhost', port + i) + l.append(a) + + def run(l=l): + while(1): + events = 0 + for peer in l: + events = events + peer.dispatcher.runOnce() + if events == 0: + time.sleep(.25) + + for i in range(10): + thread.start_new_thread(run, (l[i*10:(i+1)*10],)) + #thread.start_new_thread(l[i].dispatcher.run, ()) + + for peer in l[1:]: + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + n = l[randrange(0, len(l))].node + peer.addContact(n.host, n.port) + + time.sleep(5) + + for peer in l: + peer.findCloseNodes() + time.sleep(5) + for peer in l: + peer.refreshTable() + return l + +def test_find_nodes(l, quiet=0): + import threading, sys + from whrandom import randrange + flag = threading.Event() + + n = len(l) + + a = l[randrange(0,n)] + b = l[randrange(0,n)] + + def callback(nodes, l=l, flag=flag): + if (len(nodes) >0) and (nodes[0].id == b.node.id): + print "test_find_nodes PASSED" + else: + print "test_find_nodes FAILED" + flag.set() + a.findNode(b.node.id, callback) + flag.wait() + +def test_find_value(l, quiet=0): + from whrandom import randrange + from sha import sha + import time, threading, sys + + fa = threading.Event() + fb = threading.Event() + fc = threading.Event() + + n = len(l) + a = l[randrange(0,n)] + b = l[randrange(0,n)] + c = l[randrange(0,n)] + d = l[randrange(0,n)] + + key = sha(`randrange(0,100000)`).digest() + value = sha(`randrange(0,100000)`).digest() + if not quiet: + print "inserting value...", + sys.stdout.flush() + a.storeValueForKey(key, value) + time.sleep(3) + print "finding..." + + def mc(flag, value=value): + def callback(values, f=flag, val=value): + try: + if(len(values) == 0): + print "find FAILED" + else: + if values[0]['value'] != val: + print "find FAILED" + else: + print "find FOUND" + finally: + f.set() + return callback + b.valueForKey(key, mc(fa)) + c.valueForKey(key, mc(fb)) + d.valueForKey(key, mc(fc)) + + fa.wait() + fb.wait() + fc.wait() + +if __name__ == "__main__": + l = test_build_net() + time.sleep(3) + print "finding nodes..." + test_find_nodes(l) + test_find_nodes(l) + test_find_nodes(l) + print "inserting and fetching values..." + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + test_find_value(l) + for i in l: + i.dispatcher.stop() diff --git a/ktable.py b/ktable.py new file mode 100644 index 0000000..9cd3732 --- /dev/null +++ b/ktable.py @@ -0,0 +1,231 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +import hash +from bisect import * +import time +from types import * + +from node import Node + +# The all-powerful, magical Kademlia "k" constant, bucket depth +K = 20 + +# how many bits wide is our hash? +HASH_LENGTH = 160 + + +# the local routing table for a kademlia like distributed hash table +class KTable: + def __init__(self, node): + # this is the root node, a.k.a. US! + self.node = node + self.buckets = [KBucket([], 0L, 2L**HASH_LENGTH)] + self.insertNode(node) + + def _bucketIndexForInt(self, int): + """returns the index of the bucket that should hold int""" + return bisect_left(self.buckets, int) + + def findNodes(self, id): + """ return k nodes in our own local table closest to the ID + ignoreSelf means we will return K closest nodes to ourself if we search for our own ID + note, K closest nodes may actually include ourself, it's the callers responsibilty to + not send messages to itself if it matters + """ + if type(id) == StringType: + int = hash.intify(id) + elif type(id) == InstanceType: + int = id.int + elif type(id) == IntType or type(id) == LongType: + int = id + else: + raise TypeError, "findLocalNodes requires an int, string, or Node instance type" + + nodes = [] + + def sort(a, b, int=int): + """ this function is for sorting nodes relative to the ID we are looking for """ + x, y = int ^ a.int, int ^ b.int + if x > y: + return 1 + elif x < y: + return -1 + return 0 + + i = self._bucketIndexForInt(int) + + ## see if this node is already in our table and return it + try: + index = self.buckets[i].l.index(int) + except ValueError: + pass + else: + self.buckets[i].touch() + return [self.buckets[i].l[index]] + + nodes = nodes + self.buckets[i].l + if len(nodes) == K: + nodes.sort(sort) + return nodes + else: + # need more nodes + min = i - 1 + max = i + 1 + while (len(nodes) < K and (min >= 0 and max < len(self.buckets))): + if min >= 0: + nodes = nodes + self.buckets[min].l + self.buckets[min].touch() + if max < len(self.buckets): + nodes = nodes + self.buckets[max].l + self.buckets[max].touch() + + nodes.sort(sort) + return nodes[:K-1] + + def _splitBucket(self, a): + diff = (a.max - a.min) / 2 + b = KBucket([], a.max - diff, a.max) + self.buckets.insert(self.buckets.index(a.min) + 1, b) + a.max = a.max - diff + # transfer nodes to new bucket + for anode in a.l[:]: + if anode.int >= a.max: + a.l.remove(anode) + b.l.append(anode) + + def replaceStaleNode(self, stale, new): + """ this is used by clients to replace a node returned by insertNode after + it fails to respond to a Pong message + """ + i = self._bucketIndexForInt(stale.int) + try: + it = self.buckets[i].l.index(stale.int) + except ValueError: + return + + del(self.buckets[i].l[it]) + self.buckets[i].l.append(new) + + def insertNode(self, node): + """ + this insert the node, returning None if successful, returns the oldest node in the bucket if it's full + the caller responsible for pinging the returned node and calling replaceStaleNode if it is found to be stale!! + """ + # get the bucket for this node + i = self. _bucketIndexForInt(node.int) + ## check to see if node is in the bucket already + try: + it = self.buckets[i].l.index(node.int) + except ValueError: + ## no + pass + else: + node.updateLastSeen() + # move node to end of bucket + del(self.buckets[i].l[it]) + self.buckets[i].l.append(node) + self.buckets[i].touch() + return + + # we don't have this node, check to see if the bucket is full + if len(self.buckets[i].l) < K: + # no, append this node and return + self.buckets[i].l.append(node) + self.buckets[i].touch() + return + + # bucket is full, check to see if self.node is in the bucket + try: + me = self.buckets[i].l.index(self.node) + except ValueError: + return self.buckets[i].l[0] + + ## this bucket is full and contains our node, split the bucket + if len(self.buckets) >= HASH_LENGTH: + # our table is FULL + print "Hash Table is FULL! Increase K!" + return + + self._splitBucket(self.buckets[i]) + + ## now that the bucket is split and balanced, try to insert the node again + return self.insertNode(node) + + def justSeenNode(self, node): + """ call this any time you get a message from a node, to update it in the table if it's there """ + try: + n = self.findNodes(node.int)[0] + except IndexError: + return None + else: + tstamp = n.lastSeen + n.updateLastSeen() + return tstamp + + +class KBucket: + __slots = ['min', 'max', 'lastAccessed'] + def __init__(self, contents, min, max): + self.l = contents + self.min = min + self.max = max + self.lastAccessed = time.time() + + def touch(self): + self.lastAccessed = time.time() + + def getNodeWithInt(self, int): + try: + return self.l[self.l.index(int)] + self.touch() + except IndexError: + raise ValueError + + def __repr__(self): + return "" % (len(self.l), self.min, self.max) + + ## comparators, necessary for bisecting list of buckets with a hash expressed as an integer or a distance + ## compares integer or node object with the bucket's range + def __lt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.max <= a + def __le__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min < a + def __gt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min > a + def __ge__(self, a): + if type(a) == InstanceType: + a = a.int + return self.max >= a + def __eq__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min <= a and self.max > a + def __ne__(self, a): + if type(a) == InstanceType: + a = a.int + return self.min >= a or self.max < a + + + +############## +import unittest + +class TestKTable(unittest.TestCase): + def setUp(self): + self.a = Node(hash.newID(), 'localhost', 2002) + self.t = KTable(self.a) + + def test_replace_stale_node(self): + self.b = Node(hash.newID(), 'localhost', 2003) + self.t.replaceStaleNode(self.a, self.b) + assert(len(self.t.buckets[0].l) == 1) + assert(self.t.buckets[0].l[0].id == self.b.id) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/listener.py b/listener.py new file mode 100644 index 0000000..3999b32 --- /dev/null +++ b/listener.py @@ -0,0 +1,100 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from socket import * + +# simple UDP communicator + +class Listener: + def __init__(self, host, port): + self.msgq = [] + self.sock = socket(AF_INET, SOCK_DGRAM) + self.sock.setblocking(0) + self.sock.bind((host, port)) + + def qMsg(self, msg, host, port): + self.msgq.append((msg, host, port)) + + def qLen(self): + return len(self.msgq) + + def dispatchMsg(self): + if self.qLen() > 0: + msg, host, port = self.msgq[0] + del self.msgq[0] + self.sock.sendto(msg, 0, (host, port)) + + def receiveMsg(self): + msg = () + try: + msg = self.sock.recvfrom(65536) + except error, tup: + if tup[1] == "Resource temporarily unavailable": + # no message + return msg + print error, tup + else: + return msg + + def __del__(self): + self.sock.close() + + + +########################### +import unittest + +class ListenerTest(unittest.TestCase): + def setUp(self): + self.a = Listener('localhost', 8080) + self.b = Listener('localhost', 8081) + def tearDown(self): + del(self.a) + del(self.b) + + def testQueue(self): + assert self.a.qLen() == 0, "expected queue to be empty" + self.a.qMsg('hello', 'localhost', 8081) + assert self.a.qLen() == 1, "expected one message to be in queue" + self.a.qMsg('hello', 'localhost', 8081) + assert self.a.qLen() == 2, "expected two messages to be in queue" + self.a.dispatchMsg() + assert self.a.qLen() == 1, "expected one message to be in queue" + self.a.dispatchMsg() + assert self.a.qLen() == 0, "expected all messages to be flushed from queue" + + def testSendReceiveOne(self): + self.a.qMsg('hello', 'localhost', 8081) + self.a.dispatchMsg() + + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg() == (), "received unexpected message" + + self.b.qMsg('hello', 'localhost', 8080) + self.b.dispatchMsg() + + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + + assert self.a.receiveMsg() == (), "received unexpected message" + + def testSendReceiveInterleaved(self): + self.a.qMsg('hello', 'localhost', 8081) + self.a.qMsg('hello', 'localhost', 8081) + self.a.dispatchMsg() + self.a.dispatchMsg() + + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.b.receiveMsg() == (), "received unexpected message" + + self.b.qMsg('hello', 'localhost', 8080) + self.b.qMsg('hello', 'localhost', 8080) + self.b.dispatchMsg() + self.b.dispatchMsg() + + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.a.receiveMsg()[0] == "hello", "did not receive expected message" + assert self.a.receiveMsg() == (), "received unexpected message" + + +if __name__ == '__main__': + unittest.main() diff --git a/messages.py b/messages.py new file mode 100644 index 0000000..542fe23 --- /dev/null +++ b/messages.py @@ -0,0 +1,186 @@ +## Copyright 2002 Andrew Loewenstern, All Rights Reserved + +from bencode import bencode, bdecode +from btemplate import * +from node import Node + + +# template checker for hash id +def hashid(thing, verbose): + if type(thing) != type(''): + raise ValueError, 'must be a string' + if len(thing) != 20: + raise ValueError, 'must be 20 characters long' + +## our messages +BASE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : string_template}) + +PING = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'ping'}) +PONG = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'pong'}) + +FIND_NODE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'find node', "target" : hashid}) +GOT_NODES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got nodes', "nodes" : ListMarker({'id': hashid, 'host': string_template, 'port': 1})}) + +STORE_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'store value', "key" : hashid, "value" : string_template}) +STORED_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'stored value'}) + +GET_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'get value', "key" : hashid}) +GOT_VALUES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got values', "values" : ListMarker({'key': hashid, 'value': string_template})}) + +GOT_NODES_OR_VALUES = compile_template([GOT_NODES, GOT_VALUES]) + + +class MessageFactory: + def __init__(self, id): + self.id = id + + def encodePing(self, tid): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'ping'}) + def decodePing(self, msg): + msg = bdecode(msg) + PING(msg) + return msg + + def encodePong(self, tid): + msg = {'id' : self.id, 'tid' : tid, 'type' : 'pong'} + PONG(msg) + return bencode(msg) + def decodePong(self, msg): + msg = bdecode(msg) + PONG(msg) + return msg + + def encodeFindNode(self, tid, target): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'find node', 'target' : target}) + def decodeFindNode(self, msg): + msg = bdecode(msg) + FIND_NODE(msg) + return msg + + def encodeStoreValue(self, tid, key, value): + return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'store value', 'value' : value}) + def decodeStoreValue(self, msg): + msg = bdecode(msg) + STORE_VALUE(msg) + return msg + + + def encodeStoredValue(self, tid): + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'stored value'}) + def decodeStoredValue(self, msg): + msg = bdecode(msg) + STORED_VALUE(msg) + return msg + + + def encodeGetValue(self, tid, key): + return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'get value'}) + def decodeGetValue(self, msg): + msg = bdecode(msg) + GET_VALUE(msg) + return msg + + def encodeGotNodes(self, tid, nodes): + n = [] + for node in nodes: + n.append({'id' : node.id, 'host' : node.host, 'port' : node.port}) + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got nodes', 'nodes' : n}) + def decodeGotNodes(self, msg): + msg = bdecode(msg) + GOT_NODES(msg) + return msg + + def encodeGotValues(self, tid, values): + n = [] + for value in values: + n.append({'key' : value[0], 'value' : value[1]}) + return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got values', 'values' : n}) + def decodeGotValues(self, msg): + msg = bdecode(msg) + GOT_VALUES(msg) + return msg + + + +###### +import unittest + +class TestMessageEncoding(unittest.TestCase): + def setUp(self): + from sha import sha + self.a = sha('a').digest() + self.b = sha('b').digest() + + + def test_ping(self): + m = MessageFactory(self.a) + s = m.encodePing(self.b) + msg = m.decodePing(s) + PING(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_pong(self): + m = MessageFactory(self.a) + s = m.encodePong(self.b) + msg = m.decodePong(s) + PONG(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_find_node(self): + m = MessageFactory(self.a) + s = m.encodeFindNode(self.a, self.b) + msg = m.decodeFindNode(s) + FIND_NODE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['target'] == self.b) + + def test_store_value(self): + m = MessageFactory(self.a) + s = m.encodeStoreValue(self.a, self.b, 'foo') + msg = m.decodeStoreValue(s) + STORE_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['key'] == self.b) + assert(msg['value'] == 'foo') + + def test_stored_value(self): + m = MessageFactory(self.a) + s = m.encodeStoredValue(self.b) + msg = m.decodeStoredValue(s) + STORED_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.b) + + def test_get_value(self): + m = MessageFactory(self.a) + s = m.encodeGetValue(self.a, self.b) + msg = m.decodeGetValue(s) + GET_VALUE(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['key'] == self.b) + + def test_got_nodes(self): + m = MessageFactory(self.a) + s = m.encodeGotNodes(self.a, [Node(self.b, 'localhost', 2002), Node(self.a, 'localhost', 2003)]) + msg = m.decodeGotNodes(s) + GOT_NODES(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + assert(msg['nodes'][0]['id'] == self.b) + + def test_got_values(self): + m = MessageFactory(self.a) + s = m.encodeGotValues(self.a, [(self.b, 'localhost')]) + msg = m.decodeGotValues(s) + GOT_VALUES(msg) + assert(msg['id'] == self.a) + assert(msg['tid'] == self.a) + + +if __name__ == "__main__": + unittest.main() diff --git a/node.py b/node.py new file mode 100644 index 0000000..cb1940e --- /dev/null +++ b/node.py @@ -0,0 +1,56 @@ +import hash +import time +from types import * + +class Node: + """encapsulate contact info""" + def __init__(self, id, host, port): + self.id = id + self.int = hash.intify(id) + self.host = host + self.port = port + self.lastSeen = time.time() + + def updateLastSeen(self): + self.lastSeen = time.time() + + def __repr__(self): + return `(self.id, self.host, self.port)` + + ## these comparators let us bisect/index a list full of nodes with either a node or an int/long + def __lt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int < a + def __le__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int <= a + def __gt__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int > a + def __ge__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int >= a + def __eq__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int == a + def __ne__(self, a): + if type(a) == InstanceType: + a = a.int + return self.int != a + + +import unittest + +class TestNode(unittest.TestCase): + def setUp(self): + self.node = Node(hash.newID(), 'localhost', 2002) + def testUpdateLastSeen(self): + t = self.node.lastSeen + self.node.updateLastSeen() + assert t < self.node.lastSeen + \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..044aeff --- /dev/null +++ b/test.py @@ -0,0 +1,10 @@ +import unittest + +import hash, node, messages +import listener, dispatcher +import ktable, transactions, khashmir + +import bencode, btemplate + +tests = unittest.defaultTestLoader.loadTestsFromNames(['hash', 'node', 'bencode', 'btemplate', 'listener', 'messages', 'dispatcher', 'transactions', 'ktable']) +result = unittest.TextTestRunner().run(tests) diff --git a/transactions.py b/transactions.py new file mode 100644 index 0000000..b862aaa --- /dev/null +++ b/transactions.py @@ -0,0 +1,71 @@ +import messages +from dispatcher import Transaction + +class TransactionFactory: + def __init__(self, id, dispatcher): + self.id = id + self.dispatcher = dispatcher + self.mf = messages.MessageFactory(self.id) + + def Ping(self, node, response, default): + """ create a ping transaction """ + t = Transaction(self.dispatcher, node, response, default) + str = self.mf.encodePing(t.id) + t.setPayload(str) + t.setResponseTemplate(messages.PONG) + return t + + def FindNode(self, target, key, response, default): + """ find node query """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeFindNode(t.id, key) + t.setPayload(str) + t.setResponseTemplate(messages.GOT_NODES) + return t + + def StoreValue(self, target, key, value, response, default): + """ find node query """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeStoreValue(t.id, key, value) + t.setPayload(str) + t.setResponseTemplate(messages.STORED_VALUE) + return t + + def GetValue(self, target, key, response, default): + """ find value query, response is GOT_VALUES or GOT_NODES! """ + t = Transaction(self.dispatcher, target, response, default) + str = self.mf.encodeGetValue(t.id, key) + t.setPayload(str) + t.setResponseTemplate(messages.GOT_NODES_OR_VALUES) + return t + + def Pong(self, ping_t): + """ create a pong response to ping transaction """ + t = Transaction(self.dispatcher, ping_t.target, None, None, ping_t.id) + str = self.mf.encodePong(t.id) + t.setPayload(str) + return t + + def GotNodes(self, findNode_t, nodes): + """ respond with gotNodes """ + t = Transaction(self.dispatcher, findNode_t.target, None, None, findNode_t.id) + str = self.mf.encodeGotNodes(t.id, nodes) + t.setPayload(str) + return t + + def GotValues(self, findNode_t, values): + """ respond with gotNodes """ + t = Transaction(self.dispatcher, findNode_t.target, None, None, findNode_t.id) + str = self.mf.encodeGotValues(t.id, values) + t.setPayload(str) + return t + + def StoredValue(self, tr): + """ store value response, really just a pong """ + t = Transaction(self.dispatcher, tr.target, None, None, id = tr.id) + str = self.mf.encodeStoredValue(t.id) + t.setPayload(str) + return t + + +########### -- 2.39.2