ripped out xmlrpc, experimented with xmlrpc but with bencode, finally
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 import unittest
8 from bisect import insort_left
9
10 from twisted.internet import protocol
11 from twisted.internet import abstract
12 from twisted.internet import reactor
13 from twisted.internet import app
14 from twisted.internet import interfaces
15
16 # flags
17 FLAG_AIRHOOK = 128
18 FLAG_OBSERVED = 16
19 FLAG_SESSION = 8
20 FLAG_MISSED = 4
21 FLAG_NEXT = 2
22 FLAG_INTERVAL = 1
23
24 MAX_PACKET_SIZE = 1450
25
26 pending = 0
27 sent = 1
28 confirmed = 2
29
30
31
32 class Airhook(protocol.DatagramProtocol):       
33     def __init__(self):
34         self.noisy = None
35         # this should be changed to storage that drops old entries
36         self.connections = {}
37         
38     def datagramReceived(self, datagram, addr):
39         #print `addr`, `datagram`
40         #if addr != self.addr:
41         self.connectionForAddr(addr).datagramReceived(datagram)
42
43     def connectionForAddr(self, addr):
44         if not self.connections.has_key(addr):
45             conn = self.connection()
46             conn.protocol = self.factory.buildProtocol(addr)
47             conn.protocol.makeConnection(conn)
48             conn.makeConnection(self.transport)
49             conn.addr = addr
50             self.connections[addr] = conn
51         else:
52             conn = self.connections[addr]
53         return conn
54 #    def makeConnection(self, transport):
55 #        protocol.DatagramProtocol.makeConnection(self, transport)
56 #        tup = transport.getHost()
57 #        self.addr = (tup[1], tup[2])
58         
59 class AirhookPacket:
60     def __init__(self, msg):
61         self.datagram = msg
62         self.oseq =  ord(msg[1])
63         self.seq = unpack("!H", msg[2:4])[0]
64         self.flags = ord(msg[0])
65         self.session = None
66         self.observed = None
67         self.next = None
68         self.missed = []
69         self.msgs = []
70         skip = 4
71         if self.flags & FLAG_OBSERVED:
72             self.observed = unpack("!L", msg[skip:skip+4])[0]
73             skip += 4
74         if self.flags & FLAG_SESSION:
75             self.session =  unpack("!L", msg[skip:skip+4])[0]
76             skip += 4
77         if self.flags & FLAG_NEXT:
78             self.next =  ord(msg[skip])
79             skip += 1
80         if self.flags & FLAG_MISSED:
81             num = ord(msg[skip]) + 1
82             skip += 1
83             for i in range(num):
84                 self.missed.append( ord(msg[skip+i]))
85             skip += num
86         if self.flags & FLAG_NEXT:
87             while len(msg) - skip > 0:
88                 n = ord(msg[skip]) + 1
89                 skip+=1
90                 self.msgs.append( msg[skip:skip+n])
91                 skip += n
92
93 class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
94     def __init__(self):        
95         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
96         self.obSeq = 0   # highest sequence confirmed by remote
97         self.inSeq = 0   # last received sequence
98         self.observed = None  # their session id
99         self.sessionID = long(rand(0, 2**32))  # our session id
100         
101         self.lastTransmit = 0  # time we last sent a packet with messages
102         self.lastReceieved = 0 # time we last received a packet with messages
103         self.lastTransmitSeq = -1 # last sequence we sent a packet
104         self.state = pending
105         
106         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
107         self.omsgq = [] # list of messages to go out
108         self.imsgq = [] # list of messages coming in
109         self.sendSession = None  # send session/observed fields until obSeq > sendSession
110         self.response = 0 # if we know we have a response now (like resending missed packets)
111         self.noisy = 0
112         self.scheduled = 0 # a sendNext is scheduled, don't schedule another
113         self.resetMessages()
114     
115     def resetMessages(self):
116         self.weMissed = []
117         self.inMsg = 0   # next incoming message number
118         self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
119         self.next = 0  # next outgoing message number
120
121     def datagramReceived(self, datagram):
122         if not datagram:
123             return
124         p = AirhookPacket(datagram)
125         
126         # check to make sure sequence number isn't out of order
127         if (p.seq - self.inSeq) % 2**16 >= 256:
128             return
129             
130         # check for state change
131         if self.state == pending:
132             if p.observed != None and p.session != None:
133                 if p.observed == self.sessionID:
134                     self.observed = p.session
135                     self.state = confirmed
136                 else:
137                     # bogus packet!
138                     return
139             elif p.session != None:
140                 self.observed = p.session
141                 self.state = sent
142                 self.response = 1
143         elif self.state == sent:
144             if p.observed != None and p.session != None:
145                 if p.observed == self.sessionID:
146                     self.observed = p.session
147                     self.sendSession = self.outSeq
148                     self.state = confirmed
149             if p.session != None:
150                 if not self.observed:
151                     self.observed = p.session
152                 elif self.observed != p.session:
153                     self.state = pending
154                     self.resetMessages()
155                     self.inSeq = p.seq
156         elif self.state == confirmed:
157             if p.session != None or p.observed != None :
158                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
159                     self.state = pending
160                     self.resetMessages()
161                     self.inSeq = p.seq
162     
163         if self.state != pending:       
164             msgs = []           
165             missed = []
166             
167             # see if they need us to resend anything
168             for i in p.missed:
169                 if self.outMsgs[i] != None:
170                     self.omsgq.append(self.outMsgs[i])
171                     self.outMsgs[i] = None
172                     
173             # see if we missed any messages
174             if p.next != None:
175                 missed_count = (p.next - self.inMsg) % 256
176                 if missed_count:
177                     self.lastReceived = time()
178                     for i in range(missed_count):
179                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
180                     self.weMissed += missed
181                     self.response = 1
182                 # record highest message number seen
183                 self.inMsg = (p.next + len(p.msgs)) % 256
184             
185             # append messages, update sequence
186             self.imsgq += p.msgs
187             
188         if self.state == confirmed:
189             # unpack the observed sequence
190             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
191             if ((self.outSeq - tseq)) % 2**16 > 255:
192                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
193             self.obSeq = tseq
194
195         self.inSeq = p.seq
196
197         self.lastReceived = time()
198         self.dataCameIn()
199         
200         self.schedule()
201         
202     def sendNext(self):
203         flags = 0
204         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
205         ids = ""
206         missed = ""
207         msgs = ""
208         
209         # session / observed logic
210         if self.state == pending:
211             flags = flags | FLAG_SESSION
212             ids +=  pack("!L", self.sessionID)
213             self.state = sent
214         elif self.state == sent:
215             if self.observed != None:
216                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
217                 ids +=  pack("!LL", self.observed, self.sessionID)
218             else:
219                 flags = flags | FLAG_SESSION
220                 ids +=  pack("!L", self.sessionID)
221
222         else:
223             if self.state == sent or self.sendSession:
224                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
225                     self.sendSession = None
226                 else:
227                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
228                     ids +=  pack("!LL", self.observed, self.sessionID)
229         
230         # missed header
231         if self.obSeq >= 0:
232             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
233
234         if self.weMissed:
235             flags = flags | FLAG_MISSED
236             missed += chr(len(self.weMissed) - 1)
237             for i in self.weMissed:
238                 missed += chr(i[1])
239                 
240         # append any outgoing messages
241         if self.state == confirmed and self.omsgq:
242             first = self.next
243             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
244             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
245                 msg = self.omsgq.pop()
246                 msgs += chr(len(msg) - 1) + msg
247                 self.outMsgs[self.next] = msg
248                 self.next = (self.next + 1) % 256
249                 outstanding+=1
250         # update outgoing message stat
251         if msgs:
252             flags = flags | FLAG_NEXT
253             ids += chr(first)
254             self.lastTransmitSeq = self.outSeq
255             #self.outMsgNums[self.outSeq % 256] = first
256         #else:
257         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
258         
259         # do we need a NEXT flag despite not having sent any messages?
260         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
261             flags = flags | FLAG_NEXT
262             ids += chr(self.next)
263         
264         # update stats and send packet
265         packet = chr(flags) + header + ids + missed + msgs
266         self.outSeq = (self.outSeq + 1) % 2**16
267         self.lastTransmit = time()
268         self.transport.write(packet, self.addr)
269         
270         self.scheduled = 0
271         self.schedule()
272         
273     def timeToSend(self):
274         # any outstanding messages and are we not too far ahead of our counterparty?
275         if len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
276             return (1, 0)
277         # do we explicitly need to send a response?
278         elif self.response:
279             self.response = 0
280             return (1, 0)
281         # have we not sent anything in a while?
282         elif time() - self.lastTransmit > 1.0:
283             return (1, 1)
284         elif self.state == pending:
285             return (1, 1)
286             
287         # nothing to send
288         return (0, 0)
289
290     def schedule(self):
291         tts, t = self.timeToSend()
292         if tts and not self.scheduled:
293             self.scheduled = 1
294             reactor.callLater(t, self.sendNext)
295         
296     def write(self, data):
297         # micropackets can only be 255 bytes or less
298         if len(data) <= 255:
299             self.omsgq.insert(0, data)
300         self.schedule()
301         
302     def dataCameIn(self):
303         """
304         called when we get a packet bearing messages
305         """
306         for msg in self.imsgq:
307             self.protocol.dataReceived(msg)
308         self.imsgq = []
309
310 class ustr(str):
311     """
312         this subclass of string encapsulates each ordered message, caches it's sequence number,
313         and has comparison functions to sort by sequence number
314     """
315     def getseq(self):
316         if not hasattr(self, 'seq'):
317             self.seq = unpack("!H", self[0:2])[0]
318         return self.seq
319     def __lt__(self, other):
320         return (self.getseq() - other.getseq()) % 2**16 > 255
321     def __le__(self, other):
322         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
323     def __eq__(self, other):
324         return self.getseq() == other.getseq()
325     def __ne__(self, other):
326         return self.getseq() != other.getseq()
327     def __gt__(self, other):
328         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
329     def __ge__(self, other):
330         return (self.getseq() - other.getseq()) % 2**16 < 256
331         
332 class StreamConnection(AirhookConnection):
333     """
334         this implements a simple protocol for a stream over airhook
335         this is done for convenience, instead of making it a twisted.internet.protocol....
336         the first two octets of each message are interpreted as a 16-bit sequence number
337         253 bytes are used for payload
338         
339     """
340     def __init__(self):
341         AirhookConnection.__init__(self)
342         self.oseq = 0
343         self.iseq = 0
344         self.q = []
345
346     def dataCameIn(self):
347         # put 'em together
348         for msg in self.imsgq:
349             insort_left(self.q, ustr(msg))
350         self.imsgq = []
351         data = ""
352         while self.q and self.iseq == self.q[0].getseq():
353             data += self.q[0][2:]
354             self.q = self.q[1:]
355             self.iseq = (self.iseq + 1) % 2**16
356         if data != '':
357             self.protocol.dataReceived(data)
358         
359     def write(self, data):
360         # chop it up and queue it up
361         while data:
362             p = pack("!H", self.oseq) + data[:253]
363             self.omsgq.insert(0, p)
364             data = data[253:]
365             self.oseq = (self.oseq + 1) % 2**16
366
367         self.schedule()
368         
369     def writeSequence(self, sequence):
370         for data in sequence:
371             self.write(data)
372
373
374 def listenAirhook(port, factory):
375     ah = Airhook()
376     ah.connection = AirhookConnection
377     ah.factory = factory
378     reactor.listenUDP(port, ah)
379     return ah
380
381 def listenAirhookStream(port, factory):
382     ah = Airhook()
383     ah.connection = StreamConnection
384     ah.factory = factory
385     reactor.listenUDP(port, ah)
386     return ah
387
388