Rename all apt-dht files to apt-p2p.
[quix0rs-apt-p2p.git] / apt_p2p_Khashmir / krpc.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 """The KRPC communication protocol implementation.
5
6 @var KRPC_TIMEOUT: the number of seconds after which requests timeout
7 @var UDP_PACKET_LIMIT: the maximum number of bytes that can be sent in a
8     UDP packet without fragmentation
9
10 @var KRPC_ERROR: the code for a generic error
11 @var KRPC_ERROR_SERVER_ERROR: the code for a server error
12 @var KRPC_ERROR_MALFORMED_PACKET: the code for a malformed packet error
13 @var KRPC_ERROR_METHOD_UNKNOWN: the code for a method unknown error
14 @var KRPC_ERROR_MALFORMED_REQUEST: the code for a malformed request error
15 @var KRPC_ERROR_INVALID_TOKEN: the code for an invalid token error
16 @var KRPC_ERROR_RESPONSE_TOO_LONG: the code for a response too long error
17
18 @var KRPC_ERROR_INTERNAL: the code for an internal error
19 @var KRPC_ERROR_RECEIVED_UNKNOWN: the code for an unknown message type error
20 @var KRPC_ERROR_TIMEOUT: the code for a timeout error
21 @var KRPC_ERROR_PROTOCOL_STOPPED: the code for a stopped protocol error
22
23 @var TID: the identifier for the transaction ID
24 @var REQ: the identifier for a request packet
25 @var RSP: the identifier for a response packet
26 @var TYP: the identifier for the type of packet
27 @var ARG: the identifier for the argument to the request
28 @var ERR: the identifier for an error packet
29
30 @group Remote node error codes: KRPC_ERROR, KRPC_ERROR_SERVER_ERROR,
31     KRPC_ERROR_MALFORMED_PACKET, KRPC_ERROR_METHOD_UNKNOWN,
32     KRPC_ERROR_MALFORMED_REQUEST, KRPC_ERROR_INVALID_TOKEN,
33     KRPC_ERROR_RESPONSE_TOO_LONG
34 @group Local node error codes: KRPC_ERROR_INTERNAL, KRPC_ERROR_RECEIVED_UNKNOWN,
35     KRPC_ERROR_TIMEOUT, KRPC_ERROR_PROTOCOL_STOPPED
36 @group Command identifiers: TID, REQ, RSP, TYP, ARG, ERR
37
38 """
39
40 from bencode import bencode, bdecode
41 from time import asctime
42 from math import ceil
43
44 from twisted.internet.defer import Deferred
45 from twisted.internet import protocol, reactor
46 from twisted.python import log
47 from twisted.trial import unittest
48
49 from khash import newID
50
51 KRPC_TIMEOUT = 20
52 UDP_PACKET_LIMIT = 1472
53
54 # Remote node errors
55 KRPC_ERROR = 200
56 KRPC_ERROR_SERVER_ERROR = 201
57 KRPC_ERROR_MALFORMED_PACKET = 202
58 KRPC_ERROR_METHOD_UNKNOWN = 203
59 KRPC_ERROR_MALFORMED_REQUEST = 204
60 KRPC_ERROR_INVALID_TOKEN = 205
61 KRPC_ERROR_RESPONSE_TOO_LONG = 206
62
63 # Local errors
64 KRPC_ERROR_INTERNAL = 100
65 KRPC_ERROR_RECEIVED_UNKNOWN = 101
66 KRPC_ERROR_TIMEOUT = 102
67 KRPC_ERROR_PROTOCOL_STOPPED = 103
68
69 # commands
70 TID = 't'
71 REQ = 'q'
72 RSP = 'r'
73 TYP = 'y'
74 ARG = 'a'
75 ERR = 'e'
76
77 class KrpcError(Exception):
78     """An error occurred in the KRPC protocol."""
79     pass
80
81 def verifyMessage(msg):
82     """Check received message for corruption and errors.
83     
84     @type msg: C{dictionary}
85     @param msg: the dictionary of information received on the connection
86     @raise KrpcError: if the message is corrupt
87     """
88     
89     if type(msg) != dict:
90         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "not a dictionary")
91     if TYP not in msg:
92         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no message type")
93     if msg[TYP] == REQ:
94         if REQ not in msg:
95             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type not specified")
96         if type(msg[REQ]) != str:
97             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "request type is not a string")
98         if ARG not in msg:
99             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no arguments for request")
100         if type(msg[ARG]) != dict:
101             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "arguments for request are not in a dictionary")
102     elif msg[TYP] == RSP:
103         if RSP not in msg:
104             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response not specified")
105         if type(msg[RSP]) != dict:
106             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "response is not a dictionary")
107     elif msg[TYP] == ERR:
108         if ERR not in msg:
109             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error not specified")
110         if type(msg[ERR]) != list:
111             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a list")
112         if len(msg[ERR]) != 2:
113             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error is not a 2-element list")
114         if type(msg[ERR][0]) not in (int, long):
115             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error number is not a number")
116         if type(msg[ERR][1]) != str:
117             raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "error string is not a string")
118 #    else:
119 #        raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "unknown message type")
120     if TID not in msg:
121         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "no transaction ID specified")
122     if type(msg[TID]) != str:
123         raise KrpcError, (KRPC_ERROR_MALFORMED_PACKET, "transaction id is not a string")
124
125 class hostbroker(protocol.DatagramProtocol):
126     """The factory for the KRPC protocol.
127     
128     @type server: L{khashmir.Khashmir}
129     @ivar server: the main Khashmir program
130     @type config: C{dictionary}
131     @ivar config: the configuration parameters for the DHT
132     @type connections: C{dictionary}
133     @ivar connections: all the connections that have ever been made to the
134         protocol, keys are IP address and port pairs, values are L{KRPC}
135         protocols for the addresses
136     @ivar protocol: the protocol to use to handle incoming connections
137         (added externally)
138     @type addr: (C{string}, C{int})
139     @ivar addr: the IP address and port of this node
140     """
141     
142     def __init__(self, server, config):
143         """Initialize the factory.
144         
145         @type server: L{khashmir.Khashmir}
146         @param server: the main DHT program
147         @type config: C{dictionary}
148         @param config: the configuration parameters for the DHT
149         """
150         self.server = server
151         self.config = config
152         # this should be changed to storage that drops old entries
153         self.connections = {}
154         
155     def datagramReceived(self, datagram, addr):
156         """Optionally create a new protocol object, and handle the new datagram.
157         
158         @type datagram: C{string}
159         @param datagram: the data received from the transport.
160         @type addr: (C{string}, C{int})
161         @param addr: source IP address and port of datagram.
162         """
163         c = self.connectionForAddr(addr)
164         c.datagramReceived(datagram, addr)
165         #if c.idle():
166         #    del self.connections[addr]
167
168     def connectionForAddr(self, addr):
169         """Get a protocol object for the source.
170         
171         @type addr: (C{string}, C{int})
172         @param addr: source IP address and port of datagram.
173         """
174         # Don't connect to ourself
175         if addr == self.addr:
176             raise KrcpError
177         
178         # Create a new protocol object if necessary
179         if not self.connections.has_key(addr):
180             conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
181             self.connections[addr] = conn
182         else:
183             conn = self.connections[addr]
184         return conn
185
186     def makeConnection(self, transport):
187         """Make a connection to a transport and save our address."""
188         protocol.DatagramProtocol.makeConnection(self, transport)
189         tup = transport.getHost()
190         self.addr = (tup.host, tup.port)
191         
192     def stopProtocol(self):
193         """Stop all the open connections."""
194         for conn in self.connections.values():
195             conn.stop()
196         protocol.DatagramProtocol.stopProtocol(self)
197
198 class KRPC:
199     """The KRPC protocol implementation.
200     
201     @ivar transport: the transport to use for the protocol
202     @type factory: L{khashmir.Khashmir}
203     @ivar factory: the main Khashmir program
204     @type addr: (C{string}, C{int})
205     @ivar addr: the IP address and port of the source node
206     @type noisy: C{boolean}
207     @ivar noisy: whether to log additional details of the protocol
208     @type tids: C{dictionary}
209     @ivar tids: the transaction IDs outstanding for requests, keys are the
210         transaction ID of the request, values are the deferreds to call with
211         the results
212     @type stopped: C{boolean}
213     @ivar stopped: whether the protocol has been stopped
214     """
215     
216     def __init__(self, addr, server, transport, spew = False):
217         """Initialize the protocol.
218         
219         @type addr: (C{string}, C{int})
220         @param addr: the IP address and port of the source node
221         @type server: L{khashmir.Khashmir}
222         @param server: the main Khashmir program
223         @param transport: the transport to use for the protocol
224         @type spew: C{boolean}
225         @param spew: whether to log additional details of the protocol
226             (optional, defaults to False)
227         """
228         self.transport = transport
229         self.factory = server
230         self.addr = addr
231         self.noisy = spew
232         self.tids = {}
233         self.stopped = False
234
235     def datagramReceived(self, data, addr):
236         """Process the new datagram.
237         
238         @type data: C{string}
239         @param data: the data received from the transport.
240         @type addr: (C{string}, C{int})
241         @param addr: source IP address and port of datagram.
242         """
243         if self.stopped:
244             if self.noisy:
245                 log.msg("stopped, dropping message from %r: %s" % (addr, data))
246
247         # Bdecode the message
248         try:
249             msg = bdecode(data)
250         except Exception, e:
251             if self.noisy:
252                 log.msg("krpc bdecode error: ")
253                 log.err(e)
254             return
255
256         # Make sure the remote node isn't trying anything funny
257         try:
258             verifyMessage(msg)
259         except Exception, e:
260             log.msg("krpc message verification error: ")
261             log.err(e)
262             return
263
264         if self.noisy:
265             log.msg("%d received from %r: %s" % (self.factory.port, addr, msg))
266
267         # Process it based on its type
268         if msg[TYP]  == REQ:
269             ilen = len(data)
270             
271             # Requests are handled by the factory
272             f = getattr(self.factory ,"krpc_" + msg[REQ], None)
273             msg[ARG]['_krpc_sender'] =  self.addr
274             if f and callable(f):
275                 try:
276                     ret = f(*(), **msg[ARG])
277                 except KrpcError, e:
278                     log.msg('Got a Krpc error while running: krpc_%s' % msg[REQ])
279                     log.err(e)
280                     olen = self._sendResponse(addr, msg[TID], ERR, [e[0], e[1]])
281                 except TypeError, e:
282                     log.msg('Got a malformed request for: krpc_%s' % msg[REQ])
283                     log.err(e)
284                     olen = self._sendResponse(addr, msg[TID], ERR,
285                                               [KRPC_ERROR_MALFORMED_REQUEST, str(e)])
286                 except Exception, e:
287                     log.msg('Got an unknown error while running: krpc_%s' % msg[REQ])
288                     log.err(e)
289                     olen = self._sendResponse(addr, msg[TID], ERR,
290                                               [KRPC_ERROR_SERVER_ERROR, str(e)])
291                 else:
292                     olen = self._sendResponse(addr, msg[TID], RSP, ret)
293             else:
294                 # Request for unknown method
295                 log.msg("ERROR: don't know about method %s" % msg[REQ])
296                 olen = self._sendResponse(addr, msg[TID], ERR,
297                                           [KRPC_ERROR_METHOD_UNKNOWN, "unknown method "+str(msg[REQ])])
298             if self.noisy:
299                 log.msg("%s >>> %s - %s %s %s" % (addr, self.factory.node.port,
300                                                   ilen, msg[REQ], olen))
301         elif msg[TYP] == RSP:
302             # Responses get processed by their TID's deferred
303             if self.tids.has_key(msg[TID]):
304                 df = self.tids[msg[TID]]
305                 #       callback
306                 del(self.tids[msg[TID]])
307                 df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
308             else:
309                 # no tid, this transaction timed out already...
310                 if self.noisy:
311                     log.msg('timeout: %r' % msg[RSP]['id'])
312         elif msg[TYP] == ERR:
313             # Errors get processed by their TID's deferred's errback
314             if self.tids.has_key(msg[TID]):
315                 df = self.tids[msg[TID]]
316                 del(self.tids[msg[TID]])
317                 # callback
318                 df.errback(KrpcError(*msg[ERR]))
319             else:
320                 # day late and dollar short, just log it
321                 log.msg("Got an error for an unknown request: %r" % (msg[ERR], ))
322                 pass
323         else:
324             # Received an unknown message type
325             if self.noisy:
326                 log.msg("unknown message type: %r" % msg)
327             if msg[TID] in self.tids:
328                 df = self.tids[msg[TID]]
329                 del(self.tids[msg[TID]])
330                 # callback
331                 df.errback(KrpcError(KRPC_ERROR_RECEIVED_UNKNOWN,
332                                      "Received an unknown message type: %r" % msg[TYP]))
333                 
334     def _sendResponse(self, addr, tid, msgType, response):
335         """Helper function for sending responses to nodes.
336         
337         @type addr: (C{string}, C{int})
338         @param addr: source IP address and port of datagram.
339         @param tid: the transaction ID of the request
340         @param msgType: the type of message to respond with
341         @param response: the arguments for the response
342         """
343         if not response:
344             response = {}
345         
346         try:
347             # Create the response message
348             msg = {TID : tid, TYP : msgType, msgType : response}
349     
350             if self.noisy:
351                 log.msg("%d responding to %r: %s" % (self.factory.port, addr, msg))
352     
353             out = bencode(msg)
354             
355             # Make sure its not too long
356             if len(out) > UDP_PACKET_LIMIT:
357                 # Can we remove some values to shorten it?
358                 if 'values' in response:
359                     # Save the original list of values
360                     orig_values = response['values']
361                     len_orig_values = len(bencode(orig_values))
362                     
363                     # Caclulate the maximum value length possible
364                     max_len_values = len_orig_values - (len(out) - UDP_PACKET_LIMIT)
365                     assert max_len_values > 0
366                     
367                     # Start with a calculation of how many values should be included
368                     # (assumes all values are the same length)
369                     per_value = (float(len_orig_values) - 2.0) / float(len(orig_values))
370                     num_values = len(orig_values) - int(ceil(float(len(out) - UDP_PACKET_LIMIT) / per_value))
371     
372                     # Do a linear search for the actual maximum number possible
373                     bencoded_values = len(bencode(orig_values[:num_values]))
374                     while bencoded_values < max_len_values and num_values + 1 < len(orig_values):
375                         bencoded_values += len(bencode(orig_values[num_values]))
376                         num_values += 1
377                     while bencoded_values > max_len_values and num_values > 0:
378                         num_values -= 1
379                         bencoded_values -= len(bencode(orig_values[num_values]))
380                     assert num_values > 0
381     
382                     # Encode the result
383                     response['values'] = orig_values[:num_values]
384                     out = bencode(msg)
385                     assert len(out) < UDP_PACKET_LIMIT
386                     log.msg('Shortened a long packet from %d to %d values, new packet length: %d' % 
387                             (len(orig_values), num_values, len(out)))
388                 else:
389                     # Too long a response, send an error
390                     log.msg('Could not send response, too long: %d bytes' % len(out))
391                     msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_RESPONSE_TOO_LONG, "response was %d bytes" % len(out)]}
392                     out = bencode(msg)
393
394         except Exception, e:
395             # Unknown error, send an error message
396             msg = {TID : tid, TYP : ERR, ERR : [KRPC_ERROR_SERVER_ERROR, "unknown error sending response: %s" % str(e)]}
397             out = bencode(msg)
398                     
399         self.transport.write(out, addr)
400         return len(out)
401     
402     def sendRequest(self, method, args):
403         """Send a request to the remote node.
404         
405         @type method: C{string}
406         @param method: the methiod name to call on the remote node
407         @param args: the arguments to send to the remote node's method
408         """
409         if self.stopped:
410             raise KrpcError, (KRPC_ERROR_PROTOCOL_STOPPED, "cannot send, connection has been stopped")
411
412         # Create the request message
413         msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
414         if self.noisy:
415             log.msg("%d sending to %r: %s" % (self.factory.port, self.addr, msg))
416         data = bencode(msg)
417         
418         # Create the deferred and save it with the TID
419         d = Deferred()
420         self.tids[msg[TID]] = d
421
422         # Schedule a later timeout call
423         def timeOut(tids = self.tids, id = msg[TID], method = method, addr = self.addr):
424             """Call the deferred's errback if a timeout occurs."""
425             if tids.has_key(id):
426                 df = tids[id]
427                 del(tids[id])
428                 df.errback(KrpcError(KRPC_ERROR_TIMEOUT, "timeout waiting for '%s' from %r" % (method, addr)))
429         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
430         
431         # Cancel the timeout call if a response is received
432         def dropTimeOut(dict, later_call = later):
433             """Cancel the timeout call when a response is received."""
434             if later_call.active():
435                 later_call.cancel()
436             return dict
437         d.addBoth(dropTimeOut)
438         
439         self.transport.write(data, self.addr)
440         return d
441     
442     def stop(self):
443         """Timeout all pending requests."""
444         for df in self.tids.values():
445             df.errback(KrpcError(KRPC_ERROR_PROTOCOL_STOPPED, 'connection has been stopped while waiting for response'))
446         self.tids = {}
447         self.stopped = True
448
449 #{ For testing the KRPC protocol
450 def connectionForAddr(host, port):
451     return host
452     
453 class Receiver(protocol.Factory):
454     protocol = KRPC
455     def __init__(self):
456         self.buf = []
457     def krpc_store(self, msg, _krpc_sender):
458         self.buf += [msg]
459         return {}
460     def krpc_echo(self, msg, _krpc_sender):
461         return {'msg': msg}
462     def krpc_values(self, length, num, _krpc_sender):
463         return {'values': ['1'*length]*num}
464
465 def make(port):
466     af = Receiver()
467     a = hostbroker(af, {'SPEW': False})
468     a.protocol = KRPC
469     p = reactor.listenUDP(port, a)
470     return af, a, p
471     
472 class KRPCTests(unittest.TestCase):
473     timeout = 2
474     
475     def setUp(self):
476         self.af, self.a, self.ap = make(1180)
477         self.bf, self.b, self.bp = make(1181)
478
479     def tearDown(self):
480         self.ap.stopListening()
481         self.bp.stopListening()
482
483     def bufEquals(self, result, value):
484         self.failUnlessEqual(self.bf.buf, value)
485
486     def testSimpleMessage(self):
487         d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
488         d.addCallback(self.bufEquals, ["This is a test."])
489         return d
490
491     def testMessageBlast(self):
492         for i in range(100):
493             d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
494         d.addCallback(self.bufEquals, ["This is a test."] * 100)
495         return d
496
497     def testEcho(self):
498         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
499         df.addCallback(self.gotMsg, "This is a test.")
500         return df
501
502     def gotMsg(self, dict, should_be):
503         _krpc_sender = dict['_krpc_sender']
504         msg = dict['rsp']
505         self.failUnlessEqual(msg['msg'], should_be)
506
507     def testManyEcho(self):
508         for i in xrange(100):
509             df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
510             df.addCallback(self.gotMsg, "This is a test.")
511         return df
512
513     def testMultiEcho(self):
514         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
515         df.addCallback(self.gotMsg, "This is a test.")
516
517         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
518         df.addCallback(self.gotMsg, "This is another test.")
519
520         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
521         df.addCallback(self.gotMsg, "This is yet another test.")
522         
523         return df
524
525     def testEchoReset(self):
526         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
527         df.addCallback(self.gotMsg, "This is a test.")
528
529         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
530         df.addCallback(self.gotMsg, "This is another test.")
531         df.addCallback(self.echoReset)
532         return df
533     
534     def echoReset(self, dict):
535         del(self.a.connections[('127.0.0.1', 1181)])
536         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
537         df.addCallback(self.gotMsg, "This is yet another test.")
538         return df
539
540     def testUnknownMeth(self):
541         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
542         df.addBoth(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
543         return df
544
545     def testMalformedRequest(self):
546         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test.", 'foo': 'bar'})
547         df.addBoth(self.gotErr, KRPC_ERROR_MALFORMED_REQUEST)
548         return df
549
550     def gotErr(self, err, should_be):
551         self.failUnlessEqual(err.value[0], should_be)
552         
553     def testLongPackets(self):
554         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('values', {'length' : 1, 'num': 2000})
555         df.addCallback(self.gotLongRsp)
556         return df
557
558     def gotLongRsp(self, dict):
559         # Not quite accurate, but good enough
560         self.failUnless(len(bencode(dict))-10 < UDP_PACKET_LIMIT)
561