]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
d3ae57a40e8a17f88d8ed17036c52ffe26cf93fc
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 import unittest
8 from bisect import insort_left
9
10 from twisted.internet import protocol
11 from twisted.internet import abstract
12 from twisted.internet import reactor
13 from twisted.internet import app
14 from twisted.internet import interfaces
15
16 # flags
17 FLAG_AIRHOOK = 128
18 FLAG_OBSERVED = 16
19 FLAG_SESSION = 8
20 FLAG_MISSED = 4
21 FLAG_NEXT = 2
22 FLAG_INTERVAL = 1
23
24 MAX_PACKET_SIZE = 1450
25
26 pending = 0
27 sent = 1
28 confirmed = 2
29
30
31
32 class Airhook(protocol.DatagramProtocol):       
33     def __init__(self):
34         self.noisy = None
35         # this should be changed to storage that drops old entries
36         self.connections = {}
37         
38     def datagramReceived(self, datagram, addr):
39         #print `addr`, `datagram`
40         #if addr != self.addr:
41         self.connectionForAddr(addr).datagramReceived(datagram)
42
43     def connectionForAddr(self, addr):
44         if not self.connections.has_key(addr):
45             conn = self.connection()
46             conn.protocol = self.factory.buildProtocol(addr)
47             conn.protocol.makeConnection(conn)
48             conn.makeConnection(self.transport)
49             conn.addr = addr
50             self.connections[addr] = conn
51         else:
52             conn = self.connections[addr]
53         return conn
54 #    def makeConnection(self, transport):
55 #        protocol.DatagramProtocol.makeConnection(self, transport)
56 #        tup = transport.getHost()
57 #        self.addr = (tup[1], tup[2])
58         
59 class AirhookPacket:
60     def __init__(self, msg):
61         self.datagram = msg
62         self.oseq =  ord(msg[1])
63         self.seq = unpack("!H", msg[2:4])[0]
64         self.flags = ord(msg[0])
65         self.session = None
66         self.observed = None
67         self.next = None
68         self.missed = []
69         self.msgs = []
70         skip = 4
71         if self.flags & FLAG_OBSERVED:
72             self.observed = unpack("!L", msg[skip:skip+4])[0]
73             skip += 4
74         if self.flags & FLAG_SESSION:
75             self.session =  unpack("!L", msg[skip:skip+4])[0]
76             skip += 4
77         if self.flags & FLAG_NEXT:
78             self.next =  ord(msg[skip])
79             skip += 1
80         if self.flags & FLAG_MISSED:
81             num = ord(msg[skip]) + 1
82             skip += 1
83             for i in range(num):
84                 self.missed.append( ord(msg[skip+i]))
85             skip += num
86         if self.flags & FLAG_NEXT:
87             while len(msg) - skip > 0:
88                 n = ord(msg[skip]) + 1
89                 skip+=1
90                 self.msgs.append( msg[skip:skip+n])
91                 skip += n
92
93 class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
94     def __init__(self):        
95         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
96         self.obSeq = 0   # highest sequence confirmed by remote
97         self.inSeq = 0   # last received sequence
98         self.observed = None  # their session id
99         self.sessionID = long(rand(0, 2**32))  # our session id
100         
101         self.lastTransmit = 0  # time we last sent a packet with messages
102         self.lastReceieved = 0 # time we last received a packet with messages
103         self.lastTransmitSeq = -1 # last sequence we sent a packet
104         self.state = pending # one of pending, sent, confirmed
105         
106         self.omsgq = [] # list of messages to go out
107         self.imsgq = [] # list of messages coming in
108         self.sendSession = None  # send session/observed fields until obSeq > sendSession
109         self.response = 0 # if we know we have a response now (like resending missed packets)
110         self.noisy = 0
111         self.resetConnection()
112     
113     def resetConnection(self):
114         self.weMissed = []
115         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
116         self.inMsg = 0   # next incoming message number
117         self.outMsgNums = [0] * 256 # outgoing message numbers i = outNum % 256
118         self.next = 0  # next outgoing message number
119         self.scheduled = 0 # a sendNext is scheduled, don't schedule another
120
121     def datagramReceived(self, datagram):
122         if not datagram:
123             return
124         if self.noisy:
125             print `datagram`
126         p = AirhookPacket(datagram)
127         
128             
129         # check for state change
130         if self.state == pending:
131             if p.observed != None and p.session != None:
132                 if p.observed == self.sessionID:
133                     self.observed = p.session
134                     self.state = confirmed
135                 else:
136                     # bogus packet!
137                     return
138             elif p.session != None:
139                 self.observed = p.session
140                 self.response = 1
141         elif self.state == sent:
142             if p.observed != None and p.session != None:
143                 if p.observed == self.sessionID:
144                     self.observed = p.session
145                     self.sendSession = self.outSeq
146                     self.state = confirmed
147             if p.session != None:
148                 if not self.observed:
149                     self.observed = p.session
150                 elif self.observed != p.session:
151                     self.state = pending
152                     self.resetConnection()
153                     self.inSeq = p.seq
154         elif self.state == confirmed:
155             if p.session != None or p.observed != None :
156                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
157                     self.state = pending
158                     self.observed = p.session
159                     self.resetConnection()
160                     self.inSeq = p.seq
161                     if hasattr(self.protocol, "resetConnection") and callable(self.protocol.resetConnection):
162                         self.protocol.resetConnection()
163
164         # check to make sure sequence number isn't out of order
165         if (p.seq - self.inSeq) % 2**16 >= 256:
166             return
167     
168         if self.state == confirmed:     
169             msgs = []           
170             missed = []
171             
172             # see if they need us to resend anything
173             for i in p.missed:
174                 if self.outMsgs[i] != None:
175                     self.omsgq.append(self.outMsgs[i])
176                     self.outMsgs[i] = None
177                     
178             # see if we missed any messages
179             if p.next != None:
180                 missed_count = (p.next - self.inMsg) % 256
181                 if missed_count:
182                     self.lastReceived = time()
183                     for i in range(missed_count):
184                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
185                     self.weMissed += missed
186                     self.response = 1
187                 # record highest message number seen
188                 self.inMsg = (p.next + len(p.msgs)) % 256
189             
190             # append messages, update sequence
191             self.imsgq += p.msgs
192             
193         if self.state == confirmed:
194             # unpack the observed sequence
195             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
196             if ((self.outSeq - tseq)) % 2**16 > 255:
197                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
198             self.obSeq = tseq
199
200         self.inSeq = p.seq
201
202         self.lastReceived = time()
203         self.dataCameIn()
204         
205         self.schedule()
206         
207     def sendNext(self):
208         flags = 0
209         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
210         ids = ""
211         missed = ""
212         msgs = ""
213         
214         # session / observed logic
215         if self.state == pending:
216             if self.observed != None:
217                 flags = flags | FLAG_OBSERVED
218                 ids +=  pack("!L", self.observed)
219             flags = flags | FLAG_SESSION
220             ids +=  pack("!L", self.sessionID)
221             self.state = sent
222         elif self.state == sent:
223             if self.observed != None:
224                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
225                 ids +=  pack("!LL", self.observed, self.sessionID)
226             else:
227                 flags = flags | FLAG_SESSION
228                 ids +=  pack("!L", self.sessionID)
229
230         else:
231             if self.state == sent or self.sendSession != None:
232                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
233                     self.sendSession = None
234                 else:
235                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
236                     ids +=  pack("!LL", self.observed, self.sessionID)
237         
238         # missed header
239         if self.obSeq >= 0:
240             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
241
242         if len(self.weMissed) > 0:
243             flags = flags | FLAG_MISSED
244             missed += chr(len(self.weMissed) - 1)
245             for i in self.weMissed:
246                 missed += chr(i[1])
247                 
248         # append any outgoing messages
249         if self.state == confirmed and self.omsgq:
250             first = self.next
251             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
252             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
253                 msg = self.omsgq.pop()
254                 msgs += chr(len(msg) - 1) + msg
255                 self.outMsgs[self.next] = msg
256                 self.next = (self.next + 1) % 256
257                 outstanding+=1
258         # update outgoing message stat
259         if msgs:
260             flags = flags | FLAG_NEXT
261             ids += chr(first)
262             self.lastTransmitSeq = self.outSeq
263             #self.outMsgNums[self.outSeq % 256] = first
264         #else:
265         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
266         
267         # do we need a NEXT flag despite not having sent any messages?
268         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
269             flags = flags | FLAG_NEXT
270             ids += chr(self.next)
271         
272         # update stats and send packet
273         packet = chr(flags) + header + ids + missed + msgs
274         self.outSeq = (self.outSeq + 1) % 2**16
275         self.lastTransmit = time()
276         self.transport.write(packet, self.addr)
277         
278         self.scheduled = 0
279         self.schedule()
280         
281     def timeToSend(self):
282         if self.state == pending:
283             return (1, 0)
284         # any outstanding messages and are we not too far ahead of our counterparty?
285         elif len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
286             return (1, 0)
287         # do we explicitly need to send a response?
288         elif self.response:
289             self.response = 0
290             return (1, 0)
291         # have we not sent anything in a while?
292         elif time() - self.lastTransmit > 1.0:
293             return (1, 1)
294             
295         # nothing to send
296         return (0, 0)
297
298     def schedule(self):
299         tts, t = self.timeToSend()
300         if tts and not self.scheduled:
301             self.scheduled = 1
302             reactor.callLater(t, self.sendNext)
303         
304     def write(self, data):
305         # micropackets can only be 255 bytes or less
306         if len(data) <= 255:
307             self.omsgq.insert(0, data)
308         self.schedule()
309         
310     def dataCameIn(self):
311         """
312         called when we get a packet bearing messages
313         """
314         for msg in self.imsgq:
315             self.protocol.dataReceived(msg)
316         self.imsgq = []
317
318 class ustr(str):
319     """
320         this subclass of string encapsulates each ordered message, caches it's sequence number,
321         and has comparison functions to sort by sequence number
322     """
323     def getseq(self):
324         if not hasattr(self, 'seq'):
325             self.seq = unpack("!H", self[0:2])[0]
326         return self.seq
327     def __lt__(self, other):
328         return (self.getseq() - other.getseq()) % 2**16 > 255
329     def __le__(self, other):
330         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
331     def __eq__(self, other):
332         return self.getseq() == other.getseq()
333     def __ne__(self, other):
334         return self.getseq() != other.getseq()
335     def __gt__(self, other):
336         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
337     def __ge__(self, other):
338         return (self.getseq() - other.getseq()) % 2**16 < 256
339         
340 class StreamConnection(AirhookConnection):
341     """
342         this implements a simple protocol for a stream over airhook
343         this is done for convenience, instead of making it a twisted.internet.protocol....
344         the first two octets of each message are interpreted as a 16-bit sequence number
345         253 bytes are used for payload
346         
347     """
348     def __init__(self):
349         AirhookConnection.__init__(self)
350         self.resetStream()
351         
352     def resetStream(self):
353         self.oseq = 0
354         self.iseq = 0
355         self.q = []
356
357     def resetConnection(self):
358         AirhookConnection.resetConnection(self)
359         self.resetStream()
360         
361     def dataCameIn(self):
362         # put 'em together
363         for msg in self.imsgq:
364             insort_left(self.q, ustr(msg))
365         self.imsgq = []
366         data = ""
367         while self.q and self.iseq == self.q[0].getseq():
368             data += self.q[0][2:]
369             self.q = self.q[1:]
370             self.iseq = (self.iseq + 1) % 2**16
371         if data != '':
372             self.protocol.dataReceived(data)
373         
374     def write(self, data):
375         # chop it up and queue it up
376         while data:
377             p = pack("!H", self.oseq) + data[:253]
378             self.omsgq.insert(0, p)
379             data = data[253:]
380             self.oseq = (self.oseq + 1) % 2**16
381
382         self.schedule()
383         
384     def writeSequence(self, sequence):
385         for data in sequence:
386             self.write(data)
387
388
389 def listenAirhook(port, factory):
390     ah = Airhook()
391     ah.connection = AirhookConnection
392     ah.factory = factory
393     reactor.listenUDP(port, ah)
394     return ah
395
396 def listenAirhookStream(port, factory):
397     ah = Airhook()
398     ah.connection = StreamConnection
399     ah.factory = factory
400     reactor.listenUDP(port, ah)
401     return ah
402
403