]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
17b046e7e1d3d710ec0113a3ce153c7f79419c45
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
3 # see LICENSE.txt for license information
4
5 from random import uniform as rand
6 from struct import pack, unpack
7 from time import time
8 import unittest
9 from bisect import insort_left
10
11 from twisted.internet import protocol
12 from twisted.internet import abstract
13 from twisted.internet import reactor
14 from twisted.internet import app
15 from twisted.internet import interfaces
16
17 # flags
18 FLAG_AIRHOOK = 128
19 FLAG_OBSERVED = 16
20 FLAG_SESSION = 8
21 FLAG_MISSED = 4
22 FLAG_NEXT = 2
23 FLAG_INTERVAL = 1
24
25 MAX_PACKET_SIZE = 1450
26
27 pending = 0
28 sent = 1
29 confirmed = 2
30
31
32
33 class Airhook(protocol.DatagramProtocol):       
34     def __init__(self):
35         self.noisy = None
36         # this should be changed to storage that drops old entries
37         self.connections = {}
38         
39     def datagramReceived(self, datagram, addr):
40         #print `addr`, `datagram`
41         #if addr != self.addr:
42         self.connectionForAddr(addr).datagramReceived(datagram)
43
44     def connectionForAddr(self, addr):
45         if addr == self.addr:
46             raise Exception
47         if not self.connections.has_key(addr):
48             conn = self.connection()
49             conn.protocol = self.factory.buildProtocol(addr)
50             conn.protocol.makeConnection(conn)
51             conn.makeConnection(self.transport)
52             conn.addr = addr
53             self.connections[addr] = conn
54         else:
55             conn = self.connections[addr]
56         return conn
57     def makeConnection(self, transport):
58         protocol.DatagramProtocol.makeConnection(self, transport)
59         tup = transport.getHost()
60         self.addr = (tup[1], tup[2])
61         
62 class AirhookPacket:
63     def __init__(self, msg):
64         self.datagram = msg
65         self.oseq =  ord(msg[1])
66         self.seq = unpack("!H", msg[2:4])[0]
67         self.flags = ord(msg[0])
68         self.session = None
69         self.observed = None
70         self.next = None
71         self.missed = []
72         self.msgs = []
73         skip = 4
74         if self.flags & FLAG_OBSERVED:
75             self.observed = unpack("!L", msg[skip:skip+4])[0]
76             skip += 4
77         if self.flags & FLAG_SESSION:
78             self.session =  unpack("!L", msg[skip:skip+4])[0]
79             skip += 4
80         if self.flags & FLAG_NEXT:
81             self.next =  ord(msg[skip])
82             skip += 1
83         if self.flags & FLAG_MISSED:
84             num = ord(msg[skip]) + 1
85             skip += 1
86             for i in range(num):
87                 self.missed.append( ord(msg[skip+i]))
88             skip += num
89         if self.flags & FLAG_NEXT:
90             while len(msg) - skip > 0:
91                 n = ord(msg[skip]) + 1
92                 skip+=1
93                 self.msgs.append( msg[skip:skip+n])
94                 skip += n
95
96 class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
97     def __init__(self):        
98         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
99         self.observed = None  # their session id
100         self.sessionID = long(rand(0, 2**32))  # our session id
101         
102         self.lastTransmit = 0  # time we last sent a packet with messages
103         self.lastReceived = 0 # time we last received a packet with messages
104         self.lastTransmitSeq = -1 # last sequence we sent a packet
105         self.state = pending # one of pending, sent, confirmed
106         
107         self.sendSession = None  # send session/observed fields until obSeq > sendSession
108         self.response = 0 # if we know we have a response now (like resending missed packets)
109         self.noisy = 0
110         self.resetConnection()
111     
112     def resetConnection(self):
113         self.weMissed = []
114         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
115         self.inMsg = 0   # next incoming message number
116         self.outMsgNums = [0] * 256 # outgoing message numbers i = outNum % 256
117         self.next = 0  # next outgoing message number
118         self.scheduled = 0 # a sendNext is scheduled, don't schedule another
119         self.omsgq = [] # list of messages to go out
120         self.imsgq = [] # list of messages coming in
121         self.obSeq = 0   # highest sequence confirmed by remote
122         self.inSeq = 0   # last received sequence
123
124     def datagramReceived(self, datagram):
125         if not datagram:
126             return
127         if self.noisy:
128             print `datagram`
129         p = AirhookPacket(datagram)
130         
131             
132         # check for state change
133         if self.state == pending:
134             if p.observed != None and p.session != None:
135                 if p.observed == self.sessionID:
136                     self.observed = p.session
137                     self.state = confirmed
138                     self.response = 1
139                 else:
140                     self.observed = p.session
141                     self.response = 1
142             elif p.session != None:
143                 self.observed = p.session
144                 self.response = 1
145             else:
146                 self.response = 1
147         elif self.state == sent:
148             if p.observed != None and p.session != None:
149                 if p.observed == self.sessionID:
150                     self.observed = p.session
151                     self.sendSession = self.outSeq
152                     self.state = confirmed
153             if p.session != None:
154                 if not self.observed:
155                     self.observed = p.session
156                 elif self.observed != p.session:
157                     self.state = pending
158                     self.observed = p.session
159                     self.resetConnection()
160                     self.response = 1
161                     if hasattr(self.protocol, "resetConnection") and callable(self.protocol.resetConnection):
162                         self.protocol.resetConnection()
163                     self.inSeq = p.seq
164                     self.schedule()
165                     return
166             elif p.session == None and p.observed == None:
167                 self.response = 1
168                 self.schedule()
169     
170         elif self.state == confirmed:
171             if p.session != None or p.observed != None :
172                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
173                     self.state = pending
174                     self.observed = p.session
175                     self.resetConnection()
176                     self.inSeq = p.seq
177                     if hasattr(self.protocol, "resetConnection") and callable(self.protocol.resetConnection):
178                         self.protocol.resetConnection()
179                     self.schedule()
180                     return
181         # check to make sure sequence number isn't out of order
182         if (p.seq - self.inSeq) % 2**16 >= 256:
183             return
184     
185         if self.state == confirmed:     
186             msgs = []           
187             missed = []
188             
189             # see if they need us to resend anything
190             for i in p.missed:
191                 if self.outMsgs[i] != None:
192                     self.omsgq.append(self.outMsgs[i])
193                     self.outMsgs[i] = None
194                     
195             # see if we missed any messages
196             if p.next != None:
197                 missed_count = (p.next - self.inMsg) % 256
198                 if missed_count:
199                     self.lastReceived = time()
200                     for i in range(missed_count):
201                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
202                     self.weMissed += missed
203                     self.response = 1
204                 # record highest message number seen
205                 self.inMsg = (p.next + len(p.msgs)) % 256
206             
207             # append messages, update sequence
208             self.imsgq += p.msgs
209             
210         if self.state == confirmed:
211             # unpack the observed sequence
212             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
213             if ((self.outSeq - tseq)) % 2**16 > 255:
214                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
215             self.obSeq = tseq
216
217         self.inSeq = p.seq
218
219         self.lastReceived = time()
220         self.dataCameIn()
221         
222         self.schedule()
223         
224     def sendNext(self):
225         flags = 0
226         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
227         ids = ""
228         missed = ""
229         msgs = ""
230         
231         # session / observed logic
232         if self.state == pending:
233             if self.observed != None:
234                 flags = flags | FLAG_OBSERVED
235                 ids +=  pack("!L", self.observed)
236             flags = flags | FLAG_SESSION
237             ids +=  pack("!L", self.sessionID)
238             self.state = sent
239         elif self.state == sent:
240             if self.observed != None:
241                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
242                 ids +=  pack("!LL", self.observed, self.sessionID)
243             else:
244                 flags = flags | FLAG_SESSION
245                 ids +=  pack("!L", self.sessionID)
246
247         else:
248             if self.state == sent or self.sendSession != None:
249                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
250                     self.sendSession = None
251                 else:
252                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
253                     ids +=  pack("!LL", self.observed, self.sessionID)
254         
255         # missed header
256         if self.obSeq >= 0:
257             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
258
259         if len(self.weMissed) > 0:
260             flags = flags | FLAG_MISSED
261             missed += chr(len(self.weMissed) - 1)
262             for i in self.weMissed:
263                 missed += chr(i[1])
264                 
265         # append any outgoing messages
266         if self.state == confirmed and self.omsgq:
267             first = self.next
268             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
269             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
270                 msg = self.omsgq.pop()
271                 msgs += chr(len(msg) - 1) + msg
272                 self.outMsgs[self.next] = msg
273                 self.next = (self.next + 1) % 256
274                 outstanding+=1
275         # update outgoing message stat
276         if msgs:
277             flags = flags | FLAG_NEXT
278             ids += chr(first)
279             self.lastTransmitSeq = self.outSeq
280             #self.outMsgNums[self.outSeq % 256] = first
281         #else:
282         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
283         
284         # do we need a NEXT flag despite not having sent any messages?
285         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
286             flags = flags | FLAG_NEXT
287             ids += chr(self.next)
288         
289         # update stats and send packet
290         packet = chr(flags) + header + ids + missed + msgs
291         self.outSeq = (self.outSeq + 1) % 2**16
292         self.lastTransmit = time()
293         self.transport.write(packet, self.addr)
294         
295         self.scheduled = 0
296         self.schedule()
297         
298     def timeToSend(self):
299         if self.state == pending:
300             return (1, 0)
301         # any outstanding messages and are we not too far ahead of our counterparty?
302         elif len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
303             return (1, 0)
304         # do we explicitly need to send a response?
305         elif self.response:
306             self.response = 0
307             return (1, 0)
308         # have we not sent anything in a while?
309         elif time() - self.lastTransmit > 1.0:
310             return (1, 1)
311             
312         # nothing to send
313         return (0, 0)
314
315     def schedule(self):
316         tts, t = self.timeToSend()
317         if tts and not self.scheduled:
318             self.scheduled = 1
319             reactor.callLater(t, self.sendNext)
320         
321     def write(self, data):
322         # micropackets can only be 255 bytes or less
323         if len(data) <= 255:
324             self.omsgq.insert(0, data)
325         self.schedule()
326         
327     def dataCameIn(self):
328         """
329         called when we get a packet bearing messages
330         """
331         for msg in self.imsgq:
332             self.protocol.dataReceived(msg)
333         self.imsgq = []
334
335 class ustr(str):
336     """
337         this subclass of string encapsulates each ordered message, caches it's sequence number,
338         and has comparison functions to sort by sequence number
339     """
340     def getseq(self):
341         if not hasattr(self, 'seq'):
342             self.seq = unpack("!H", self[0:2])[0]
343         return self.seq
344     def __lt__(self, other):
345         return (self.getseq() - other.getseq()) % 2**16 > 255
346     def __le__(self, other):
347         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
348     def __eq__(self, other):
349         return self.getseq() == other.getseq()
350     def __ne__(self, other):
351         return self.getseq() != other.getseq()
352     def __gt__(self, other):
353         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
354     def __ge__(self, other):
355         return (self.getseq() - other.getseq()) % 2**16 < 256
356         
357 class StreamConnection(AirhookConnection):
358     """
359         this implements a simple protocol for a stream over airhook
360         this is done for convenience, instead of making it a twisted.internet.protocol....
361         the first two octets of each message are interpreted as a 16-bit sequence number
362         253 bytes are used for payload
363         
364     """
365     def __init__(self):
366         AirhookConnection.__init__(self)
367         self.resetStream()
368         
369     def resetStream(self):
370         self.oseq = 0
371         self.iseq = 0
372         self.q = []
373
374     def resetConnection(self):
375         AirhookConnection.resetConnection(self)
376         self.resetStream()
377
378     def loseConnection(self):
379         pass
380     
381     def dataCameIn(self):
382         # put 'em together
383         for msg in self.imsgq:
384             insort_left(self.q, ustr(msg))
385         self.imsgq = []
386         data = ""
387         while self.q and self.iseq == self.q[0].getseq():
388             data += self.q[0][2:]
389             self.q = self.q[1:]
390             self.iseq = (self.iseq + 1) % 2**16
391         if data != '':
392             self.protocol.dataReceived(data)
393         
394     def write(self, data):
395         # chop it up and queue it up
396         while data:
397             p = pack("!H", self.oseq) + data[:253]
398             self.omsgq.insert(0, p)
399             data = data[253:]
400             self.oseq = (self.oseq + 1) % 2**16
401
402         self.schedule()
403         
404     def writeSequence(self, sequence):
405         for data in sequence:
406             self.write(data)
407
408
409 def listenAirhook(port, factory):
410     ah = Airhook()
411     ah.connection = AirhookConnection
412     ah.factory = factory
413     reactor.listenUDP(port, ah)
414     return ah
415
416 def listenAirhookStream(port, factory):
417     ah = Airhook()
418     ah.connection = StreamConnection
419     ah.factory = factory
420     reactor.listenUDP(port, ah)
421     return ah
422
423