]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - messages.py
minor comment change
[quix0rs-apt-p2p.git] / messages.py
1 ## Copyright 2002 Andrew Loewenstern, All Rights Reserved
2
3 from bencode import bencode, bdecode
4 from btemplate import *
5 from node import Node
6
7
8 # template checker for hash id
9 def hashid(thing, verbose):
10     if type(thing) != type(''):
11         raise ValueError, 'must be a string'
12     if len(thing) != 20:
13         raise ValueError, 'must be 20 characters long'
14
15 ## our messages
16 BASE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : string_template})
17
18 PING = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'ping'})
19 PONG = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'pong'})
20
21 FIND_NODE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'find node', "target" : hashid})
22 GOT_NODES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got nodes', "nodes" : ListMarker({'id': hashid, 'host': string_template, 'port': 1})})
23
24 STORE_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'store value', "key" : hashid, "value" : string_template})
25 STORED_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'stored value'})
26
27 GET_VALUE = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'get value', "key" : hashid})
28 GOT_VALUES = compile_template({'id' : hashid, 'tid' : hashid, 'type' : 'got values', "values" : ListMarker({'key': hashid, 'value': string_template})})
29
30 GOT_NODES_OR_VALUES = compile_template([GOT_NODES, GOT_VALUES])
31
32
33 class MessageFactory:
34     def __init__(self, id):
35         self.id = id
36         
37     def encodePing(self, tid):
38         return bencode({'id' : self.id, 'tid' : tid, 'type' : 'ping'})
39     def decodePing(self, msg):
40         msg = bdecode(msg)
41         PING(msg)
42         return msg
43         
44     def encodePong(self, tid):
45         msg = {'id' : self.id, 'tid' : tid, 'type' : 'pong'}
46         PONG(msg)
47         return bencode(msg)
48     def decodePong(self, msg):
49         msg = bdecode(msg)
50         PONG(msg)
51         return msg
52
53     def encodeFindNode(self, tid, target):
54         return bencode({'id' : self.id, 'tid' : tid, 'type' : 'find node', 'target' : target})
55     def decodeFindNode(self, msg):
56         msg = bdecode(msg)
57         FIND_NODE(msg)
58         return msg
59
60     def encodeStoreValue(self, tid, key, value):
61         return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'store value', 'value' : value})
62     def decodeStoreValue(self, msg):
63         msg = bdecode(msg)
64         STORE_VALUE(msg)
65         return msg
66     
67     
68     def encodeStoredValue(self, tid):
69         return bencode({'id' : self.id, 'tid' : tid, 'type' : 'stored value'})
70     def decodeStoredValue(self, msg):
71         msg = bdecode(msg)
72         STORED_VALUE(msg)
73         return msg
74
75     
76     def encodeGetValue(self, tid, key):
77         return bencode({'id' : self.id, 'tid' : tid, 'key' : key, 'type' : 'get value'})
78     def decodeGetValue(self, msg):
79         msg = bdecode(msg)
80         GET_VALUE(msg)
81         return msg
82
83     def encodeGotNodes(self, tid, nodes):
84         n = []
85         for node in nodes:
86             n.append({'id' : node.id, 'host' : node.host, 'port' : node.port})
87         return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got nodes', 'nodes' : n})
88     def decodeGotNodes(self, msg):
89         msg = bdecode(msg)
90         GOT_NODES(msg)
91         return msg
92
93     def encodeGotValues(self, tid, values):
94         n = []
95         for value in values:
96             n.append({'key' : value[0], 'value' : value[1]})
97         return bencode({'id' : self.id, 'tid' : tid, 'type' : 'got values', 'values' : n})
98     def decodeGotValues(self, msg):
99         msg = bdecode(msg)
100         GOT_VALUES(msg)
101         return msg
102         
103         
104
105 ######
106 import unittest
107
108 class TestMessageEncoding(unittest.TestCase):
109     def setUp(self):
110         from sha import sha
111         self.a = sha('a').digest()
112         self.b = sha('b').digest()
113
114     
115     def test_ping(self):
116         m = MessageFactory(self.a)
117         s = m.encodePing(self.b)
118         msg = m.decodePing(s)
119         PING(msg)
120         assert(msg['id'] == self.a)
121         assert(msg['tid'] == self.b)
122         
123     def test_pong(self):
124         m = MessageFactory(self.a)
125         s = m.encodePong(self.b)
126         msg = m.decodePong(s)
127         PONG(msg)
128         assert(msg['id'] == self.a)
129         assert(msg['tid'] == self.b)
130         
131     def test_find_node(self):
132         m = MessageFactory(self.a)
133         s = m.encodeFindNode(self.a, self.b)
134         msg = m.decodeFindNode(s)
135         FIND_NODE(msg)
136         assert(msg['id'] == self.a)
137         assert(msg['tid'] == self.a)
138         assert(msg['target'] == self.b)
139
140     def test_store_value(self):
141         m = MessageFactory(self.a)
142         s = m.encodeStoreValue(self.a, self.b, 'foo')
143         msg = m.decodeStoreValue(s)
144         STORE_VALUE(msg)
145         assert(msg['id'] == self.a)
146         assert(msg['tid'] == self.a)
147         assert(msg['key'] == self.b)
148         assert(msg['value'] == 'foo')
149         
150     def test_stored_value(self):
151         m = MessageFactory(self.a)
152         s = m.encodeStoredValue(self.b)
153         msg = m.decodeStoredValue(s)
154         STORED_VALUE(msg)
155         assert(msg['id'] == self.a)
156         assert(msg['tid'] == self.b)
157     
158     def test_get_value(self):
159         m = MessageFactory(self.a)
160         s = m.encodeGetValue(self.a, self.b)
161         msg = m.decodeGetValue(s)
162         GET_VALUE(msg)
163         assert(msg['id'] == self.a)
164         assert(msg['tid'] == self.a)
165         assert(msg['key'] == self.b)
166     
167     def test_got_nodes(self):
168         m = MessageFactory(self.a)
169         s = m.encodeGotNodes(self.a, [Node(self.b, 'localhost', 2002), Node(self.a, 'localhost', 2003)])
170         msg = m.decodeGotNodes(s)
171         GOT_NODES(msg)
172         assert(msg['id'] == self.a)
173         assert(msg['tid'] == self.a)
174         assert(msg['nodes'][0]['id'] == self.b)
175     
176     def test_got_values(self):
177         m = MessageFactory(self.a)
178         s = m.encodeGotValues(self.a, [(self.b, 'localhost')])
179         msg = m.decodeGotValues(s)
180         GOT_VALUES(msg)
181         assert(msg['id'] == self.a)
182         assert(msg['tid'] == self.a)
183
184
185 if __name__ == "__main__":
186         unittest.main()