96009d4a751b2fe42ab4b9feaae36a6d3a671384
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 import unittest
8 from bisect import insort_left
9
10 from twisted.internet import protocol
11 from twisted.internet import abstract
12 from twisted.internet import reactor
13 from twisted.internet import app
14 from twisted.internet import interfaces
15
16 # flags
17 FLAG_AIRHOOK = 128
18 FLAG_OBSERVED = 16
19 FLAG_SESSION = 8
20 FLAG_MISSED = 4
21 FLAG_NEXT = 2
22 FLAG_INTERVAL = 1
23
24 MAX_PACKET_SIZE = 1450
25
26 pending = 0
27 sent = 1
28 confirmed = 2
29
30
31
32 class Airhook(protocol.DatagramProtocol):       
33     def __init__(self):
34         self.noisy = None
35         # this should be changed to storage that drops old entries
36         self.connections = {}
37         
38     def datagramReceived(self, datagram, addr):
39         self.connectionForAddr(addr).datagramReceived(datagram)
40
41     def connectionForAddr(self, addr):
42         if not self.connections.has_key(addr):
43             conn = self.connection()
44             conn.protocol = self.factory.buildProtocol(addr)
45             conn.protocol.makeConnection(conn)
46             conn.makeConnection(self.transport)
47             conn.addr = addr
48             self.connections[addr] = conn
49         else:
50             conn = self.connections[addr]
51         return conn
52     
53 class AirhookPacket:
54     def __init__(self, msg):
55         self.datagram = msg
56         self.oseq =  ord(msg[1])
57         self.seq = unpack("!H", msg[2:4])[0]
58         self.flags = ord(msg[0])
59         self.session = None
60         self.observed = None
61         self.next = None
62         self.missed = []
63         self.msgs = []
64         skip = 4
65         if self.flags & FLAG_OBSERVED:
66             self.observed = unpack("!L", msg[skip:skip+4])[0]
67             skip += 4
68         if self.flags & FLAG_SESSION:
69             self.session =  unpack("!L", msg[skip:skip+4])[0]
70             skip += 4
71         if self.flags & FLAG_NEXT:
72             self.next =  ord(msg[skip])
73             skip += 1
74         if self.flags & FLAG_MISSED:
75             num = ord(msg[skip]) + 1
76             skip += 1
77             for i in range(num):
78                 self.missed.append( ord(msg[skip+i]))
79             skip += num
80         if self.flags & FLAG_NEXT:
81             while len(msg) - skip > 0:
82                 n = ord(msg[skip]) + 1
83                 skip+=1
84                 self.msgs.append( msg[skip:skip+n])
85                 skip += n
86
87 class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
88     def __init__(self):        
89         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
90         self.obSeq = 0   # highest sequence confirmed by remote
91         self.inSeq = 0   # last received sequence
92         self.observed = None  # their session id
93         self.sessionID = long(rand(0, 2**32))  # our session id
94         
95         self.lastTransmit = -1  # time we last sent a packet with messages
96         self.lastReceieved = 0 # time we last received a packet with messages
97         self.lastTransmitSeq = -1 # last sequence we sent a packet
98         self.state = pending
99         
100         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
101         self.omsgq = [] # list of messages to go out
102         self.imsgq = [] # list of messages coming in
103         self.sendSession = None  # send session/observed fields until obSeq > sendSession
104         self.response = 0 # if we know we have a response now (like resending missed packets)
105         self.noisy = 0
106         self.resetMessages()
107     
108     def resetMessages(self):
109         self.weMissed = []
110         self.inMsg = 0   # next incoming message number
111         self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
112         self.next = 0  # next outgoing message number
113
114     def datagramReceived(self, datagram):
115         if not datagram:
116             return
117         p = AirhookPacket(datagram)
118         
119         # check to make sure sequence number isn't out of order
120         if (p.seq - self.inSeq) % 2**16 >= 256:
121             return
122             
123         # check for state change
124         if self.state == pending:
125             if p.observed != None and p.session != None:
126                 if p.observed == self.sessionID:
127                     self.observed = p.session
128                     self.state = confirmed
129                 else:
130                     # bogus packet!
131                     return
132             elif p.session != None:
133                 self.observed = p.session
134                 self.state = sent
135                 self.response = 1
136         elif self.state == sent:
137             if p.observed != None and p.session != None:
138                 if p.observed == self.sessionID:
139                     self.observed = p.session
140                     self.sendSession = self.outSeq
141                     self.state = confirmed
142             if p.session != None:
143                 if not self.observed:
144                     self.observed = p.session
145                 elif self.observed != p.session:
146                     self.state = pending
147                     self.resetMessages()
148                     self.inSeq = p.seq
149             self.response = 1
150         elif self.state == confirmed:
151             if p.session != None or p.observed != None :
152                 if p.session != self.observed or p.observed != self.sessionID:
153                     self.state = pending
154                     if seq == 0:
155                         self.resetMessages()
156                         self.inSeq = p.seq
157     
158         if self.state != pending:       
159             msgs = []           
160             missed = []
161             
162             # see if they need us to resend anything
163             for i in p.missed:
164                 if self.outMsgs[i] != None:
165                     self.omsgq.append(self.outMsgs[i])
166                     self.outMsgs[i] = None
167                     
168             # see if we missed any messages
169             if p.next != None:
170                 missed_count = (p.next - self.inMsg) % 256
171                 if missed_count:
172                     self.lastReceived = time()
173                     for i in range(missed_count):
174                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
175                     self.weMissed += missed
176                     self.response = 1
177                 # record highest message number seen
178                 self.inMsg = (p.next + len(p.msgs)) % 256
179             
180             # append messages, update sequence
181             self.imsgq += p.msgs
182             
183         if self.state == confirmed:
184             # unpack the observed sequence
185             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
186             if ((self.outSeq - tseq)) % 2**16 > 255:
187                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
188             self.obSeq = tseq
189
190         self.inSeq = p.seq
191
192         self.lastReceived = time()
193         self.dataCameIn()
194         
195         self.schedule()
196         
197     def sendNext(self):
198         flags = 0
199         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
200         ids = ""
201         missed = ""
202         msgs = ""
203         
204         # session / observed logic
205         if self.state == pending:
206             flags = flags | FLAG_SESSION
207             ids +=  pack("!L", self.sessionID)
208             self.state = sent
209         elif self.state == sent:
210             if self.observed != None:
211                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
212                 ids +=  pack("!LL", self.observed, self.sessionID)
213             else:
214                 flags = flags | FLAG_SESSION
215                 ids +=  pack("!L", self.sessionID)
216
217         else:
218             if self.state == sent or self.sendSession:
219                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
220                     self.sendSession = None
221                 else:
222                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
223                     ids +=  pack("!LL", self.observed, self.sessionID)
224         
225         # missed header
226         if self.obSeq >= 0:
227             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
228
229         if self.weMissed:
230             flags = flags | FLAG_MISSED
231             missed += chr(len(self.weMissed) - 1)
232             for i in self.weMissed:
233                 missed += chr(i[1])
234                 
235         # append any outgoing messages
236         if self.state == confirmed and self.omsgq:
237             first = self.next
238             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
239             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
240                 msg = self.omsgq.pop()
241                 msgs += chr(len(msg) - 1) + msg
242                 self.outMsgs[self.next] = msg
243                 self.next = (self.next + 1) % 256
244                 outstanding+=1
245         # update outgoing message stat
246         if msgs:
247             flags = flags | FLAG_NEXT
248             ids += chr(first)
249             self.lastTransmitSeq = self.outSeq
250             #self.outMsgNums[self.outSeq % 256] = first
251         #else:
252         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
253         
254         # do we need a NEXT flag despite not having sent any messages?
255         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
256             flags = flags | FLAG_NEXT
257             ids += chr(self.next)
258         
259         # update stats and send packet
260         packet = chr(flags) + header + ids + missed + msgs
261         self.outSeq = (self.outSeq + 1) % 2**16
262         self.lastTransmit = time()
263         self.transport.write(packet, self.addr)
264         
265         self.schedule()
266         
267     def timeToSend(self):
268         # any outstanding messages and are we not too far ahead of our counterparty?
269         if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
270             return 1
271         # do we explicitly need to send a response?
272         elif self.response:
273             self.response = 0
274             return 1
275         # have we not sent anything in a while?
276         elif time() - self.lastTransmit > 1.0:
277             return 1
278         
279         # nothing to send
280         return 0
281
282     def schedule(self):
283         if self.timeToSend():
284             reactor.callLater(0, self.sendNext)
285         else:
286             reactor.callLater(1, self.sendNext)
287         
288     def write(self, data):
289         # micropackets can only be 255 bytes or less
290         if len(data) <= 255:
291             self.omsgq.insert(0, data)
292         self.schedule()
293         
294     def dataCameIn(self):
295         """
296         called when we get a packet bearing messages
297         """
298         for msg in self.imsgq:
299             self.protocol.dataReceived(msg)
300         self.imsgq = []
301
302 class ustr(str):
303     """
304         this subclass of string encapsulates each ordered message, caches it's sequence number,
305         and has comparison functions to sort by sequence number
306     """
307     def getseq(self):
308         if not hasattr(self, 'seq'):
309             self.seq = unpack("!H", self[0:2])[0]
310         return self.seq
311     def __lt__(self, other):
312         return (self.getseq() - other.getseq()) % 2**16 > 255
313     def __le__(self, other):
314         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
315     def __eq__(self, other):
316         return self.getseq() == other.getseq()
317     def __ne__(self, other):
318         return self.getseq() != other.getseq()
319     def __gt__(self, other):
320         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
321     def __ge__(self, other):
322         return (self.getseq() - other.getseq()) % 2**16 < 256
323         
324 class StreamConnection(AirhookConnection):
325     """
326         this implements a simple protocol for a stream over airhook
327         this is done for convenience, instead of making it a twisted.internet.protocol....
328         the first two octets of each message are interpreted as a 16-bit sequence number
329         253 bytes are used for payload
330         
331     """
332     def __init__(self):
333         AirhookConnection.__init__(self)
334         self.oseq = 0
335         self.iseq = 0
336         self.q = []
337
338     def dataCameIn(self):
339         # put 'em together
340         for msg in self.imsgq:
341             insort_left(self.q, ustr(msg))
342         self.imsgq = []
343         data = ""
344         while self.q and self.iseq == self.q[0].getseq():
345             data += self.q[0][2:]
346             self.q = self.q[1:]
347             self.iseq = (self.iseq + 1) % 2**16
348         if data != '':
349             self.protocol.dataReceived(data)
350         
351     def write(self, data):
352         # chop it up and queue it up
353         while data:
354             p = pack("!H", self.oseq) + data[:253]
355             self.omsgq.insert(0, p)
356             data = data[253:]
357             self.oseq = (self.oseq + 1) % 2**16
358
359         self.schedule()
360         
361     def writeSequence(self, sequence):
362         for data in sequence:
363             self.write(data)
364
365
366 def listenAirhook(port, factory):
367     ah = Airhook()
368     ah.connection = AirhookConnection
369     ah.factory = factory
370     reactor.listenUDP(port, ah)
371     return ah
372
373 def listenAirhookStream(port, factory):
374     ah = Airhook()
375     ah.connection = StreamConnection
376     ah.factory = factory
377     reactor.listenUDP(port, ah)
378     return ah
379
380