63086ed890d4641bcb038f42857f35f60f594309
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / krpc.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from bencode import bencode, bdecode
5 from time import asctime
6
7 from twisted.internet.defer import Deferred
8 from twisted.internet import protocol, reactor
9 from twisted.python import log
10 from twisted.trial import unittest
11
12 from khash import newID
13
14 KRPC_TIMEOUT = 20
15
16 # Remote node errors
17 KRPC_ERROR = 200
18 KRPC_ERROR_SERVER_ERROR = 201
19 KRPC_ERROR_MALFORMED_PACKET = 202
20 KRPC_ERROR_METHOD_UNKNOWN = 203
21 KRPC_ERROR_MALFORMED_REQUEST = 204
22 KRPC_ERROR_INVALID_TOKEN = 205
23
24 # Local errors
25 KRPC_ERROR_INTERNAL = 100
26 KRPC_ERROR_RECEIVED_UNKNOWN = 101
27 KRPC_ERROR_TIMEOUT = 102
28 KRPC_ERROR_PROTOCOL_STOPPED = 103
29
30 # commands
31 TID = 't'
32 REQ = 'q'
33 RSP = 'r'
34 TYP = 'y'
35 ARG = 'a'
36 ERR = 'e'
37
38 class KrpcError(Exception):
39     pass
40
41 def verifyMessage(msg):
42     """Check received message for corruption and errors.
43     
44     @type msg: C{dictionary}
45     @param msg: the dictionary of information received on the connection
46     @raise KrpcError: if the message is corrupt
47     """
48     
49     if type(msg) != dict:
50         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "not a dictionary")
51     if TYP not in msg:
52         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no message type")
53     if msg[TYP] == REQ:
54         if REQ not in msg:
55             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type not specified")
56         if type(msg[REQ]) != str:
57             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type is not a string")
58         if ARG not in msg:
59             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no arguments for request")
60         if type(msg[ARG]) != dict:
61             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "arguments for request are not in a dictionary")
62     elif msg[TYP] == RSP:
63         if RSP not in msg:
64             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response not specified")
65         if type(msg[RSP]) != dict:
66             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response is not a dictionary")
67     elif msg[TYP] == ERR:
68         if ERR not in msg:
69             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error not specified")
70         if type(msg[ERR]) != list:
71             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a list")
72         if len(msg[ERR]) != 2:
73             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a 2-element list")
74         if type(msg[ERR][0]) not in (int, long):
75             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error number is not a number")
76         if type(msg[ERR][1]) != str:
77             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error string is not a string")
78 #    else:
79 #        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "unknown message type")
80     if TID not in msg:
81         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no transaction ID specified")
82     if type(msg[TID]) != str:
83         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "transaction id is not a string")
84
85 class hostbroker(protocol.DatagramProtocol):       
86     def __init__(self, server, config):
87         self.server = server
88         self.config = config
89         # this should be changed to storage that drops old entries
90         self.connections = {}
91         
92     def datagramReceived(self, datagram, addr):
93         #print `addr`, `datagram`
94         #if addr != self.addr:
95         c = self.connectionForAddr(addr)
96         c.datagramReceived(datagram, addr)
97         #if c.idle():
98         #    del self.connections[addr]
99
100     def connectionForAddr(self, addr):
101         if addr == self.addr:
102             raise Exception
103         if not self.connections.has_key(addr):
104             conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
105             self.connections[addr] = conn
106         else:
107             conn = self.connections[addr]
108         return conn
109
110     def makeConnection(self, transport):
111         protocol.DatagramProtocol.makeConnection(self, transport)
112         tup = transport.getHost()
113         self.addr = (tup.host, tup.port)
114         
115     def stopProtocol(self):
116         for conn in self.connections.values():
117             conn.stop()
118         protocol.DatagramProtocol.stopProtocol(self)
119
120 ## connection
121 class KRPC:
122     def __init__(self, addr, server, transport, spew = False):
123         self.transport = transport
124         self.factory = server
125         self.addr = addr
126         self.noisy = spew
127         self.tids = {}
128         self.stopped = False
129
130     def datagramReceived(self, data, addr):
131         if self.stopped:
132             if self.noisy:
133                 log.msg("stopped, dropping message from %r: %s" % (addr, data))
134         # bdecode
135         try:
136             msg = bdecode(data)
137         except Exception, e:
138             if self.noisy:
139                 log.msg("krpc bdecode error: ")
140                 log.err(e)
141             return
142
143         try:
144             verifyMessage(msg)
145         except Exception, e:
146             log.msg("krpc message verification error: ")
147             log.err(e)
148             return
149
150         if self.noisy:
151             log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
152         # look at msg type
153         if msg[TYP]  == REQ:
154             ilen = len(data)
155             # if request
156             #   tell factory to handle
157             f = getattr(self.factory ,"krpc_" + msg[REQ], None)
158             msg[ARG]['_krpc_sender'] =  self.addr
159             if f and callable(f):
160                 try:
161                     ret = f(*(), **msg[ARG])
162                 except KrpcError, e:
163                     olen = self._sendResponse(addr, msg[TID], ERR, [e[0], e[1]])
164                 except TypeError, e:
165                     olen = self._sendResponse(addr, msg[TID], ERR,
166                                               [KRPC_ERROR_MALFORMED_REQUEST, str(e)])
167                 except Exception, e:
168                     olen = self._sendResponse(addr, msg[TID], ERR,
169                                               [KRPC_ERROR_SERVER_ERROR, str(e)])
170                 else:
171                     olen = self._sendResponse(addr, msg[TID], RSP, ret)
172             else:
173                 if self.noisy:
174                     log.msg("don't know about method %s" % msg[REQ])
175                 # unknown method
176                 olen = self._sendResponse(addr, msg[TID], ERR,
177                                           [KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
178             if self.noisy:
179                 log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
180                                                   ilen, msg[REQ], olen))
181         elif msg[TYP] == RSP:
182             # if response
183             #   lookup tid
184             if self.tids.has_key(msg[TID]):
185                 df = self.tids[msg[TID]]
186                 #       callback
187                 del(self.tids[msg[TID]])
188                 df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
189             else:
190                 # no tid, this transaction timed out already...
191                 if self.noisy:
192                     log.msg('timeout: %r' % msg[RSP]['id'])
193         elif msg[TYP] == ERR:
194             # if error
195             #   lookup tid
196             if self.tids.has_key(msg[TID]):
197                 df = self.tids[msg[TID]]
198                 del(self.tids[msg[TID]])
199                 # callback
200                 df.errback(KrpcError(*msg[ERR]))
201             else:
202                 # day late and dollar short, just log it
203                 log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
204                 pass
205         else:
206             if self.noisy:
207                 log.msg("unknown message type: %r" % msg)
208             # unknown message type
209             if msg[TID] in self.tids:
210                 df = self.tids[msg[TID]]
211                 del(self.tids[msg[TID]])
212                 # callback
213                 df.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
214                                      "Received an unknown message type: %r" % msg[TYP]))
215                 
216     def _sendResponse(self, addr, tid, msgType, response):
217         if not response:
218             response = {}
219             
220         msg = {TID : tid, TYP : msgType, msgType : response}
221
222         if self.noisy:
223             log.msg("%d responding to %r: %s" % (self.factory.port, addr, msg))
224
225         out = bencode(msg)
226         self.transport.write(out, addr)
227         return len(out)
228     
229     def sendRequest(self, method, args):
230         if self.stopped:
231             raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
232         # make message
233         # send it
234         msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
235         if self.noisy:
236             log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
237         data = bencode(msg)
238         d = Deferred()
239         self.tids[msg[TID]] = d
240         def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
241             if tids.has_key(id):
242                 df = tids[id]
243                 del(tids[id])
244                 df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
245         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
246         def dropTimeOut(dict, later_call = later):
247             if later_call.active():
248                 later_call.cancel()
249             return dict
250         d.addBoth(dropTimeOut)
251         self.transport.write(data, self.addr)
252         return d
253     
254     def stop(self):
255         """Timeout all pending requests."""
256         for df in self.tids.values():
257             df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
258         self.tids = {}
259         self.stopped = True
260  
261 def connectionForAddr(host, port):
262     return host
263     
264 class Receiver(protocol.Factory):
265     protocol = KRPC
266     def __init__(self):
267         self.buf = []
268     def krpc_store(self, msg, _krpc_sender):
269         self.buf += [msg]
270         return {}
271     def krpc_echo(self, msg, _krpc_sender):
272         return {'msg': msg}
273
274 def make(port):
275     af = Receiver()
276     a = hostbroker(af, {'SPEW': False})
277     a.protocol = KRPC
278     p = reactor.listenUDP(port, a)
279     return af, a, p
280     
281 class KRPCTests(unittest.TestCase):
282     def setUp(self):
283         self.af, self.a, self.ap = make(1180)
284         self.bf, self.b, self.bp = make(1181)
285
286     def tearDown(self):
287         self.ap.stopListening()
288         self.bp.stopListening()
289
290     def bufEquals(self, result, value):
291         self.failUnlessEqual(self.bf.buf, value)
292
293     def testSimpleMessage(self):
294         d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
295         d.addCallback(self.bufEquals, ["This is a test."])
296         return d
297
298     def testMessageBlast(self):
299         for i in range(100):
300             d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
301         d.addCallback(self.bufEquals, ["This is a test."] * 100)
302         return d
303
304     def testEcho(self):
305         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
306         df.addCallback(self.gotMsg, "This is a test.")
307         return df
308
309     def gotMsg(self, dict, should_be):
310         _krpc_sender = dict['_krpc_sender']
311         msg = dict['rsp']
312         self.failUnlessEqual(msg['msg'], should_be)
313
314     def testManyEcho(self):
315         for i in xrange(100):
316             df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
317             df.addCallback(self.gotMsg, "This is a test.")
318         return df
319
320     def testMultiEcho(self):
321         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
322         df.addCallback(self.gotMsg, "This is a test.")
323
324         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
325         df.addCallback(self.gotMsg, "This is another test.")
326
327         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
328         df.addCallback(self.gotMsg, "This is yet another test.")
329         
330         return df
331
332     def testEchoReset(self):
333         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
334         df.addCallback(self.gotMsg, "This is a test.")
335
336         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
337         df.addCallback(self.gotMsg, "This is another test.")
338         df.addCallback(self.echoReset)
339         return df
340     
341     def echoReset(self, dict):
342         del(self.a.connections[('127.0.0.1', 1181)])
343         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
344         df.addCallback(self.gotMsg, "This is yet another test.")
345         return df
346
347     def testUnknownMeth(self):
348         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
349         df.addBoth(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
350         return df
351
352     def testMalformedRequest(self):
353         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test.", 'foo': 'bar'})
354         df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST)
355         return df
356
357     def gotErr(self, err, should_be):
358         self.failUnlessEqual(err.value[0], should_be)