1 ## Copyright 2002 Andrew Loewenstern, All Rights Reserved
3 from bencode import bencode, bdecode
4 from btemplate import *
8 # template checker for hash id
9 def hashid(thing, verbose):
10 if type(thing) != type(''):
11 raise ValueError, 'must be a string'
13 raise ValueError, 'must be 20 characters long'
16 BASE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : string_template})
18 PING = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'ping'})
19 PONG = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'pong'})
21 FIND_NODE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'find node', "target" : hashid})
22 GOT_NODES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got nodes', "nodes" : ListMarker({'id': hashid, 'host': string_template, 'port': 1})})
24 STORE_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'store value', "key" : hashid, "value" : string_template})
25 STORED_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'stored value'})
27 GET_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'get value', "key" : hashid})
28 GOT_VALUES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got values', "values" : ListMarker({'key': hashid, 'value': string_template})})
30 GOT_NODES_OR_VALUES = compile_template([GOT_NODES, GOT_VALUES])
34 def __init__(self, id):
37 def encodePing(self, tid):
38 return bencode({'id' : self.id, 'tid' : tid, 'type' : 'ping'})
39 def decodePing(self, msg):
44 def encodePong(self, tid):
45 msg = {'id' : self.id, 'tid' : tid, 'type' : 'pong'}
48 def decodePong(self, msg):
53 def encodeFindNode(self, tid, target):
54 return bencode({'id' : self.id, 'tid' : tid, 'type' : 'find node', 'target' : target})
55 def decodeFindNode(self, msg):
60 def encodeStoreValue(self, tid, key, value):
61 return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'store value', 'value' : value})
62 def decodeStoreValue(self, msg):
68 def encodeStoredValue(self, tid):
69 return bencode({'id' : self.id, 'tid' : tid, 'type' : 'stored value'})
70 def decodeStoredValue(self, msg):
76 def encodeGetValue(self, tid, key):
77 return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'get value'})
78 def decodeGetValue(self, msg):
83 def encodeGotNodes(self, tid, nodes):
86 n.append({'id' : node.id, 'host' : node.host, 'port' : node.port})
87 return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got nodes', 'nodes' : n})
88 def decodeGotNodes(self, msg):
93 def encodeGotValues(self, tid, values):
96 n.append({'key' : value[0], 'value' : value[1]})
97 return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got values', 'values' : n})
98 def decodeGotValues(self, msg):
108 class TestMessageEncoding(unittest.TestCase):
111 self.a = sha('a').digest()
112 self.b = sha('b').digest()
116 m = MessageFactory(self.a)
117 s = m.encodePing(self.b)
118 msg = m.decodePing(s)
120 assert(msg['id'] == self.a)
121 assert(msg['tid'] == self.b)
124 m = MessageFactory(self.a)
125 s = m.encodePong(self.b)
126 msg = m.decodePong(s)
128 assert(msg['id'] == self.a)
129 assert(msg['tid'] == self.b)
131 def test_find_node(self):
132 m = MessageFactory(self.a)
133 s = m.encodeFindNode(self.a, self.b)
134 msg = m.decodeFindNode(s)
136 assert(msg['id'] == self.a)
137 assert(msg['tid'] == self.a)
138 assert(msg['target'] == self.b)
140 def test_store_value(self):
141 m = MessageFactory(self.a)
142 s = m.encodeStoreValue(self.a, self.b, 'foo')
143 msg = m.decodeStoreValue(s)
145 assert(msg['id'] == self.a)
146 assert(msg['tid'] == self.a)
147 assert(msg['key'] == self.b)
148 assert(msg['value'] == 'foo')
150 def test_stored_value(self):
151 m = MessageFactory(self.a)
152 s = m.encodeStoredValue(self.b)
153 msg = m.decodeStoredValue(s)
155 assert(msg['id'] == self.a)
156 assert(msg['tid'] == self.b)
158 def test_get_value(self):
159 m = MessageFactory(self.a)
160 s = m.encodeGetValue(self.a, self.b)
161 msg = m.decodeGetValue(s)
163 assert(msg['id'] == self.a)
164 assert(msg['tid'] == self.a)
165 assert(msg['key'] == self.b)
167 def test_got_nodes(self):
168 m = MessageFactory(self.a)
169 s = m.encodeGotNodes(self.a, [Node(self.b, 'localhost', 2002), Node(self.a, 'localhost', 2003)])
170 msg = m.decodeGotNodes(s)
172 assert(msg['id'] == self.a)
173 assert(msg['tid'] == self.a)
174 assert(msg['nodes'][0]['id'] == self.b)
176 def test_got_values(self):
177 m = MessageFactory(self.a)
178 s = m.encodeGotValues(self.a, [(self.b, 'localhost')])
179 msg = m.decodeGotValues(s)
181 assert(msg['id'] == self.a)
182 assert(msg['tid'] == self.a)
185 if __name__ == "__main__":