Upgrade the security in khashmir by using longer TIDs.
[quix0rs-apt-p2p.git] / apt_dht_Khashmir / krpc.py
1 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
2 # see LICENSE.txt for license information
3
4 from bencode import bencode, bdecode
5 from time import asctime
6 import sys
7 from traceback import format_exception
8
9 from twisted.internet.defer import Deferred
10 from twisted.internet import protocol, reactor
11 from twisted.trial import unittest
12
13 from khash import newID
14
15 KRPC_TIMEOUT = 20
16
17 KRPC_ERROR = 1
18 KRPC_ERROR_METHOD_UNKNOWN = 2
19 KRPC_ERROR_RECEIVED_UNKNOWN = 3
20 KRPC_ERROR_TIMEOUT = 4
21
22 # commands
23 TID = 't'
24 REQ = 'q'
25 RSP = 'r'
26 TYP = 'y'
27 ARG = 'a'
28 ERR = 'e'
29
30 class ProtocolError(Exception):
31     pass
32
33 class hostbroker(protocol.DatagramProtocol):       
34     def __init__(self, server, config):
35         self.server = server
36         self.config = config
37         # this should be changed to storage that drops old entries
38         self.connections = {}
39         
40     def datagramReceived(self, datagram, addr):
41         #print `addr`, `datagram`
42         #if addr != self.addr:
43         c = self.connectionForAddr(addr)
44         c.datagramReceived(datagram, addr)
45         #if c.idle():
46         #    del self.connections[addr]
47
48     def connectionForAddr(self, addr):
49         if addr == self.addr:
50             raise Exception
51         if not self.connections.has_key(addr):
52             conn = self.protocol(addr, self.server, self.transport, self.config['SPEW'])
53             self.connections[addr] = conn
54         else:
55             conn = self.connections[addr]
56         return conn
57
58     def makeConnection(self, transport):
59         protocol.DatagramProtocol.makeConnection(self, transport)
60         tup = transport.getHost()
61         self.addr = (tup.host, tup.port)
62         
63     def stopProtocol(self):
64         for conn in self.connections.values():
65             conn.stop()
66         protocol.DatagramProtocol.stopProtocol(self)
67
68 ## connection
69 class KRPC:
70     def __init__(self, addr, server, transport, spew = False):
71         self.transport = transport
72         self.factory = server
73         self.addr = addr
74         self.noisy = spew
75         self.tids = {}
76         self.stopped = False
77
78     def datagramReceived(self, str, addr):
79         if self.stopped:
80             if self.noisy:
81                 print "stopped, dropping message from", addr, str
82         # bdecode
83         try:
84             msg = bdecode(str)
85         except Exception, e:
86             if self.noisy:
87                 print "response decode error: " + `e`
88         else:
89             if self.noisy:
90                 print self.factory.port, "received from", addr, self.addr, ":", msg
91             # look at msg type
92             if msg[TYP]  == REQ:
93                 ilen = len(str)
94                 # if request
95                 #       tell factory to handle
96                 f = getattr(self.factory ,"krpc_" + msg[REQ], None)
97                 msg[ARG]['_krpc_sender'] =  self.addr
98                 if f and callable(f):
99                     try:
100                         ret = f(*(), **msg[ARG])
101                     except Exception, e:
102                         olen = self._sendResponse(addr, msg[TID], ERR, `format_exception(type(e), e, sys.exc_info()[2])`)
103                     else:
104                         olen = self._sendResponse(addr, msg[TID], RSP, ret)
105                 else:
106                     if self.noisy:
107                         print "don't know about method %s" % msg[REQ]
108                     # unknown method
109                     olen = self._sendResponse(addr, msg[TID], ERR, KRPC_ERROR_METHOD_UNKNOWN)
110                 if self.noisy:
111                     print "%s %s >>> %s - %s %s %s" % (asctime(), addr, self.factory.node.port, 
112                                                     ilen, msg[REQ], olen)
113             elif msg[TYP] == RSP:
114                 # if response
115                 #       lookup tid
116                 if self.tids.has_key(msg[TID]):
117                     df = self.tids[msg[TID]]
118                     #   callback
119                     del(self.tids[msg[TID]])
120                     df.callback({'rsp' : msg[RSP], '_krpc_sender': addr})
121                 else:
122                     print 'timeout ' + `msg[RSP]['id']`
123                     # no tid, this transaction timed out already...
124             elif msg[TYP] == ERR:
125                 # if error
126                 #       lookup tid
127                 if self.tids.has_key(msg[TID]):
128                     df = self.tids[msg[TID]]
129                     #   callback
130                     df.errback(msg[ERR])
131                     del(self.tids[msg[TID]])
132                 else:
133                     # day late and dollar short
134                     pass
135             else:
136                 print "unknown message type " + `msg`
137                 # unknown message type
138                 df = self.tids[msg[TID]]
139                 #       callback
140                 df.errback(KRPC_ERROR_RECEIVED_UNKNOWN)
141                 del(self.tids[msg[TID]])
142                 
143     def _sendResponse(self, addr, tid, msgType, response):
144         if not response:
145             response = {}
146             
147         msg = {TID : tid, TYP : msgType, msgType : response}
148
149         if self.noisy:
150             print self.factory.port, "responding to", addr, ":", msg
151
152         out = bencode(msg)
153         self.transport.write(out, addr)
154         return len(out)
155     
156     def sendRequest(self, method, args):
157         if self.stopped:
158             raise ProtocolError, "connection has been stopped"
159         # make message
160         # send it
161         msg = {TID : newID(), TYP : REQ,  REQ : method, ARG : args}
162         if self.noisy:
163             print self.factory.port, "sending to", self.addr, ":", msg
164         str = bencode(msg)
165         d = Deferred()
166         self.tids[msg[TID]] = d
167         def timeOut(tids = self.tids, id = msg[TID], msg = msg):
168             if tids.has_key(id):
169                 df = tids[id]
170                 del(tids[id])
171                 print ">>>>>> KRPC_ERROR_TIMEOUT"
172                 df.errback(ProtocolError('timeout waiting for %r' % msg))
173         later = reactor.callLater(KRPC_TIMEOUT, timeOut)
174         def dropTimeOut(dict, later_call = later):
175             if later_call.active():
176                 later_call.cancel()
177             return dict
178         d.addBoth(dropTimeOut)
179         self.transport.write(str, self.addr)
180         return d
181     
182     def stop(self):
183         """Timeout all pending requests."""
184         for df in self.tids.values():
185             df.errback(ProtocolError('connection has been closed'))
186         self.tids = {}
187         self.stopped = True
188  
189 def connectionForAddr(host, port):
190     return host
191     
192 class Receiver(protocol.Factory):
193     protocol = KRPC
194     def __init__(self):
195         self.buf = []
196     def krpc_store(self, msg, _krpc_sender):
197         self.buf += [msg]
198     def krpc_echo(self, msg, _krpc_sender):
199         return msg
200
201 def make(port):
202     af = Receiver()
203     a = hostbroker(af, {'SPEW': False})
204     a.protocol = KRPC
205     p = reactor.listenUDP(port, a)
206     return af, a, p
207     
208 class KRPCTests(unittest.TestCase):
209     def setUp(self):
210         self.af, self.a, self.ap = make(1180)
211         self.bf, self.b, self.bp = make(1181)
212
213     def tearDown(self):
214         self.ap.stopListening()
215         self.bp.stopListening()
216
217     def bufEquals(self, result, value):
218         self.failUnlessEqual(self.bf.buf, value)
219
220     def testSimpleMessage(self):
221         d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
222         d.addCallback(self.bufEquals, ["This is a test."])
223         return d
224
225     def testMessageBlast(self):
226         for i in range(100):
227             d = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('store', {'msg' : "This is a test."})
228         d.addCallback(self.bufEquals, ["This is a test."] * 100)
229         return d
230
231     def testEcho(self):
232         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
233         df.addCallback(self.gotMsg, "This is a test.")
234         return df
235
236     def gotMsg(self, dict, should_be):
237         _krpc_sender = dict['_krpc_sender']
238         msg = dict['rsp']
239         self.failUnlessEqual(msg, should_be)
240
241     def testManyEcho(self):
242         for i in xrange(100):
243             df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
244             df.addCallback(self.gotMsg, "This is a test.")
245         return df
246
247     def testMultiEcho(self):
248         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
249         df.addCallback(self.gotMsg, "This is a test.")
250
251         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
252         df.addCallback(self.gotMsg, "This is another test.")
253
254         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
255         df.addCallback(self.gotMsg, "This is yet another test.")
256         
257         return df
258
259     def testEchoReset(self):
260         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is a test."})
261         df.addCallback(self.gotMsg, "This is a test.")
262
263         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is another test."})
264         df.addCallback(self.gotMsg, "This is another test.")
265         df.addCallback(self.echoReset)
266         return df
267     
268     def echoReset(self, dict):
269         del(self.a.connections[('127.0.0.1', 1181)])
270         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('echo', {'msg' : "This is yet another test."})
271         df.addCallback(self.gotMsg, "This is yet another test.")
272         return df
273
274     def testUnknownMeth(self):
275         df = self.a.connectionForAddr(('127.0.0.1', 1181)).sendRequest('blahblah', {'msg' : "This is a test."})
276         df.addErrback(self.gotErr, KRPC_ERROR_METHOD_UNKNOWN)
277         return df
278
279     def gotErr(self, err, should_be):
280         self.failUnlessEqual(err.value, should_be)