Check response packet lengths before sending.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / krpc.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from bencode import bencode, bdecode
5 from time import asctime
6 from math import ceil
7
8 from twisted.internet.defer import Deferred
9 from twisted.internet import protocol, reactor
10 from twisted.python import log
11 from twisted.trial import unittest
12
13 from khash import newID
14
15 KRPC_TIMEOUT = 20
16 UDP_PACKET_LIMIT = 1472
17
18 # Remote node errors
19 KRPC_ERROR = 200
20 KRPC_ERROR_SERVER_ERROR = 201
21 KRPC_ERROR_MALFORMED_PACKET = 202
22 KRPC_ERROR_METHOD_UNKNOWN = 203
23 KRPC_ERROR_MALFORMED_REQUEST = 204
24 KRPC_ERROR_INVALID_TOKEN = 205
25 KRPC_ERROR_RESPONSE_TOO_LONG = 206
26
27 # Local errors
28 KRPC_ERROR_INTERNAL = 100
29 KRPC_ERROR_RECEIVED_UNKNOWN = 101
30 KRPC_ERROR_TIMEOUT = 102
31 KRPC_ERROR_PROTOCOL_STOPPED = 103
32
33 # commands
34 TID = 't'
35 REQ = 'q'
36 RSP = 'r'
37 TYP = 'y'
38 ARG = 'a'
39 ERR = 'e'
40
41 class KrpcError(Exception):
42     pass
43
44 def verifyMessage(msg):
45     """Check received message for corruption and errors.
46     
47     @type msg: C{dictionary}
48     @param msg: the dictionary of information received on the connection
49     @raise KrpcError: if the message is corrupt
50     """
51     
52     if type(msg) != dict:
53         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "not a dictionary")
54     if TYP not in msg:
55         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no message type")
56     if msg[TYP] == REQ:
57         if REQ not in msg:
58             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type not specified")
59         if type(msg[REQ]) != str:
60             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type is not a string")
61         if ARG not in msg:
62             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no arguments for request")
63         if type(msg[ARG]) != dict:
64             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "arguments for request are not in a dictionary")
65     elif msg[TYP] == RSP:
66         if RSP not in msg:
67             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response not specified")
68         if type(msg[RSP]) != dict:
69             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response is not a dictionary")
70     elif msg[TYP] == ERR:
71         if ERR not in msg:
72             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error not specified")
73         if type(msg[ERR]) != list:
74             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a list")
75         if len(msg[ERR]) != 2:
76             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a 2-element list")
77         if type(msg[ERR][0]) not in (int, long):
78             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error number is not a number")
79         if type(msg[ERR][1]) != str:
80             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error string is not a string")
81 #    else:
82 #        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "unknown message type")
83     if TID not in msg:
84         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no transaction ID specified")
85     if type(msg[TID]) != str:
86         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "transaction id is not a string")
87
88 class hostbroker(protocol.DatagramProtocol):       
89     def __init__(self, server, config):
90         self.server = server
91         self.config = config
92         # this should be changed to storage that drops old entries
93         self.connections = {}
94         
95     def datagramReceived(self, datagram, addr):
96         #print `addr`, `datagram`
97         #if addr != self.addr:
98         c = self.connectionForAddr(addr)
99         c.datagramReceived(datagram, addr)
100         #if c.idle():
101         #    del self.connections[addr]
102
103     def connectionForAddr(self, addr):
104         if addr == self.addr:
105             raise Exception
106         if not self.connections.has_key(addr):
107             conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
108             self.connections[addr] = conn
109         else:
110             conn = self.connections[addr]
111         return conn
112
113     def makeConnection(self, transport):
114         protocol.DatagramProtocol.makeConnection(self, transport)
115         tup = transport.getHost()
116         self.addr = (tup.host, tup.port)
117         
118     def stopProtocol(self):
119         for conn in self.connections.values():
120             conn.stop()
121         protocol.DatagramProtocol.stopProtocol(self)
122
123 ## connection
124 class KRPC:
125     def __init__(self, addr, server, transport, spew = False):
126         self.transport = transport
127         self.factory = server
128         self.addr = addr
129         self.noisy = spew
130         self.tids = {}
131         self.stopped = False
132
133     def datagramReceived(self, data, addr):
134         if self.stopped:
135             if self.noisy:
136                 log.msg("stopped, dropping message from %r: %s" % (addr, data))
137         # bdecode
138         try:
139             msg = bdecode(data)
140         except Exception, e:
141             if self.noisy:
142                 log.msg("krpc bdecode error: ")
143                 log.err(e)
144             return
145
146         try:
147             verifyMessage(msg)
148         except Exception, e:
149             log.msg("krpc message verification error: ")
150             log.err(e)
151             return
152
153         if self.noisy:
154             log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
155         # look at msg type
156         if msg[TYP]  == REQ:
157             ilen = len(data)
158             # if request
159             #   tell factory to handle
160             f = getattr(self.factory ,"krpc_" + msg[REQ], None)
161             msg[ARG]['_krpc_sender'] =  self.addr
162             if f and callable(f):
163                 try:
164                     ret = f(*(), **msg[ARG])
165                 except KrpcError, e:
166                     log.msg('Got a Krpc error while running: krpc_%s' % msg[REQ])
167                     log.err(e)
168                     olen = self._sendResponse(addr, msg[TID], ERR, [e[0], e[1]])
169                 except TypeError, e:
170                     log.msg('Got a malformed request for: krpc_%s' % msg[REQ])
171                     log.err(e)
172                     olen = self._sendResponse(addr, msg[TID], ERR,
173                                               [KRPC_ERROR_MALFORMED_REQUEST, str(e)])
174                 except Exception, e:
175                     log.msg('Got an unknown error while running: krpc_%s' % msg[REQ])
176                     log.err(e)
177                     olen = self._sendResponse(addr, msg[TID], ERR,
178                                               [KRPC_ERROR_SERVER_ERROR, str(e)])
179                 else:
180                     olen = self._sendResponse(addr, msg[TID], RSP, ret)
181             else:
182                 # unknown method
183                 log.msg("ERROR: don't know about method %s" % msg[REQ])
184                 olen = self._sendResponse(addr, msg[TID], ERR,
185                                           [KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
186             if self.noisy:
187                 log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
188                                                   ilen, msg[REQ], olen))
189         elif msg[TYP] == RSP:
190             # if response
191             #   lookup tid
192             if self.tids.has_key(msg[TID]):
193                 df = self.tids[msg[TID]]
194                 #       callback
195                 del(self.tids[msg[TID]])
196                 df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
197             else:
198                 # no tid, this transaction timed out already...
199                 if self.noisy:
200                     log.msg('timeout: %r' % msg[RSP]['id'])
201         elif msg[TYP] == ERR:
202             # if error
203             #   lookup tid
204             if self.tids.has_key(msg[TID]):
205                 df = self.tids[msg[TID]]
206                 del(self.tids[msg[TID]])
207                 # callback
208                 df.errback(KrpcError(*msg[ERR]))
209             else:
210                 # day late and dollar short, just log it
211                 log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
212                 pass
213         else:
214             if self.noisy:
215                 log.msg("unknown message type: %r" % msg)
216             # unknown message type
217             if msg[TID] in self.tids:
218                 df = self.tids[msg[TID]]
219                 del(self.tids[msg[TID]])
220                 # callback
221                 df.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
222                                      "Received an unknown message type: %r" % msg[TYP]))
223                 
224     def _sendResponse(self, addr, tid, msgType, response):
225         if not response:
226             response = {}
227         
228         try:
229             msg = {TID : tid, TYP : msgType, msgType : response}
230     
231             if self.noisy:
232                 log.msg("%d responding to %r: %s" % (self.factory.port, addr, msg))
233     
234             out = bencode(msg)
235             
236             if len(out) > UDP_PACKET_LIMIT:
237                 if 'values' in response:
238                     # Save the original list of values
239                     orig_values = response['values']
240                     len_orig_values = len(bencode(orig_values))
241                     
242                     # Caclulate the maximum value length possible
243                     max_len_values = len_orig_values - (len(out) - UDP_PACKET_LIMIT)
244                     assert max_len_values > 0
245                     
246                     # Start with a calculation of how many values should be included
247                     # (assumes all values are the same length)
248                     per_value = (float(len_orig_values) - 2.0) / float(len(orig_values))
249                     num_values = len(orig_values) - int(ceil(float(len(out) - UDP_PACKET_LIMIT) / per_value))
250     
251                     # Do a linear search for the actual maximum number possible
252                     bencoded_values = len(bencode(orig_values[:num_values]))
253                     while bencoded_values < max_len_values and num_values + 1 < len(orig_values):
254                         bencoded_values += len(bencode(orig_values[num_values]))
255                         num_values += 1
256                     while bencoded_values > max_len_values and num_values > 0:
257                         num_values -= 1
258                         bencoded_values -= len(bencode(orig_values[num_values]))
259                     assert num_values > 0
260     
261                     # Encode the result
262                     response['values'] = orig_values[:num_values]
263                     out = bencode(msg)
264                     assert len(out) < UDP_PACKET_LIMIT
265                     log.msg('Shortened a long packet from %d to %d values, new packet length: %d' % 
266                             (len(orig_values), num_values, len(out)))
267                 else:
268                     # Too long a response, send an error
269                     log.msg('Could not send response, too long: %d bytes' % len(out))
270                     msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_RESPONSE_TOO_LONG, "response was %d bytes" % len(out)]}
271                     out = bencode(msg)
272
273         except Exception, e:
274             # Unknown error, send an error message
275             msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_SERVER_ERROR, "unknown error sending response: %s" % str(e)]}
276             out = bencode(msg)
277                     
278         self.transport.write(out, addr)
279         return len(out)
280     
281     def sendRequest(self, method, args):
282         if self.stopped:
283             raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
284         # make message
285         # send it
286         msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
287         if self.noisy:
288             log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
289         data = bencode(msg)
290         d = Deferred()
291         self.tids[msg[TID]] = d
292         def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
293             if tids.has_key(id):
294                 df = tids[id]
295                 del(tids[id])
296                 df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
297         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
298         def dropTimeOut(dict, later_call = later):
299             if later_call.active():
300                 later_call.cancel()
301             return dict
302         d.addBoth(dropTimeOut)
303         self.transport.write(data, self.addr)
304         return d
305     
306     def stop(self):
307         """Timeout all pending requests."""
308         for df in self.tids.values():
309             df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
310         self.tids = {}
311         self.stopped = True
312  
313 def connectionForAddr(host, port):
314     return host
315     
316 class Receiver(protocol.Factory):
317     protocol = KRPC
318     def __init__(self):
319         self.buf = []
320     def krpc_store(self, msg, _krpc_sender):
321         self.buf += [msg]
322         return {}
323     def krpc_echo(self, msg, _krpc_sender):
324         return {'msg': msg}
325     def krpc_values(self, length, num, _krpc_sender):
326         return {'values': ['1'*length]*num}
327
328 def make(port):
329     af = Receiver()
330     a = hostbroker(af, {'SPEW': False})
331     a.protocol = KRPC
332     p = reactor.listenUDP(port, a)
333     return af, a, p
334     
335 class KRPCTests(unittest.TestCase):
336     timeout = 2
337     
338     def setUp(self):
339         self.af, self.a, self.ap = make(1180)
340         self.bf, self.b, self.bp = make(1181)
341
342     def tearDown(self):
343         self.ap.stopListening()
344         self.bp.stopListening()
345
346     def bufEquals(self, result, value):
347         self.failUnlessEqual(self.bf.buf, value)
348
349     def testSimpleMessage(self):
350         d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
351         d.addCallback(self.bufEquals, ["This is a test."])
352         return d
353
354     def testMessageBlast(self):
355         for i in range(100):
356             d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
357         d.addCallback(self.bufEquals, ["This is a test."] * 100)
358         return d
359
360     def testEcho(self):
361         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
362         df.addCallback(self.gotMsg, "This is a test.")
363         return df
364
365     def gotMsg(self, dict, should_be):
366         _krpc_sender = dict['_krpc_sender']
367         msg = dict['rsp']
368         self.failUnlessEqual(msg['msg'], should_be)
369
370     def testManyEcho(self):
371         for i in xrange(100):
372             df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
373             df.addCallback(self.gotMsg, "This is a test.")
374         return df
375
376     def testMultiEcho(self):
377         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
378         df.addCallback(self.gotMsg, "This is a test.")
379
380         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
381         df.addCallback(self.gotMsg, "This is another test.")
382
383         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
384         df.addCallback(self.gotMsg, "This is yet another test.")
385         
386         return df
387
388     def testEchoReset(self):
389         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
390         df.addCallback(self.gotMsg, "This is a test.")
391
392         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
393         df.addCallback(self.gotMsg, "This is another test.")
394         df.addCallback(self.echoReset)
395         return df
396     
397     def echoReset(self, dict):
398         del(self.a.connections[('127.0.0.1', 1181)])
399         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
400         df.addCallback(self.gotMsg, "This is yet another test.")
401         return df
402
403     def testUnknownMeth(self):
404         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
405         df.addBoth(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
406         return df
407
408     def testMalformedRequest(self):
409         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test.", 'foo': 'bar'})
410         df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST)
411         return df
412
413     def gotErr(self, err, should_be):
414         self.failUnlessEqual(err.value[0], should_be)
415         
416     def testLongPackets(self):
417         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('values', {'length' : 1, 'num': 2000})
418         df.addCallback(self.gotLongRsp)
419         return df
420
421     def gotLongRsp(self, dict):
422         # Not quite accurate, but good enough
423         self.failUnless(len(bencode(dict))-10 < UDP_PACKET_LIMIT)
424
425