]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
fixed reset connection handling
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 import unittest
8 from bisect import insort_left
9
10 from twisted.internet import protocol
11 from twisted.internet import abstract
12 from twisted.internet import reactor
13 from twisted.internet import app
14 from twisted.internet import interfaces
15
16 # flags
17 FLAG_AIRHOOK = 128
18 FLAG_OBSERVED = 16
19 FLAG_SESSION = 8
20 FLAG_MISSED = 4
21 FLAG_NEXT = 2
22 FLAG_INTERVAL = 1
23
24 MAX_PACKET_SIZE = 1450
25
26 pending = 0
27 sent = 1
28 confirmed = 2
29
30
31
32 class Airhook(protocol.DatagramProtocol):       
33     def __init__(self):
34         self.noisy = None
35         # this should be changed to storage that drops old entries
36         self.connections = {}
37         
38     def datagramReceived(self, datagram, addr):
39         #print `addr`, `datagram`
40         #if addr != self.addr:
41         self.connectionForAddr(addr).datagramReceived(datagram)
42
43     def connectionForAddr(self, addr):
44         if not self.connections.has_key(addr):
45             conn = self.connection()
46             conn.protocol = self.factory.buildProtocol(addr)
47             conn.protocol.makeConnection(conn)
48             conn.makeConnection(self.transport)
49             conn.addr = addr
50             self.connections[addr] = conn
51         else:
52             conn = self.connections[addr]
53         return conn
54 #    def makeConnection(self, transport):
55 #        protocol.DatagramProtocol.makeConnection(self, transport)
56 #        tup = transport.getHost()
57 #        self.addr = (tup[1], tup[2])
58         
59 class AirhookPacket:
60     def __init__(self, msg):
61         self.datagram = msg
62         self.oseq =  ord(msg[1])
63         self.seq = unpack("!H", msg[2:4])[0]
64         self.flags = ord(msg[0])
65         self.session = None
66         self.observed = None
67         self.next = None
68         self.missed = []
69         self.msgs = []
70         skip = 4
71         if self.flags & FLAG_OBSERVED:
72             self.observed = unpack("!L", msg[skip:skip+4])[0]
73             skip += 4
74         if self.flags & FLAG_SESSION:
75             self.session =  unpack("!L", msg[skip:skip+4])[0]
76             skip += 4
77         if self.flags & FLAG_NEXT:
78             self.next =  ord(msg[skip])
79             skip += 1
80         if self.flags & FLAG_MISSED:
81             num = ord(msg[skip]) + 1
82             skip += 1
83             for i in range(num):
84                 self.missed.append( ord(msg[skip+i]))
85             skip += num
86         if self.flags & FLAG_NEXT:
87             while len(msg) - skip > 0:
88                 n = ord(msg[skip]) + 1
89                 skip+=1
90                 self.msgs.append( msg[skip:skip+n])
91                 skip += n
92
93 class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
94     def __init__(self):        
95         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
96         self.obSeq = 0   # highest sequence confirmed by remote
97         self.inSeq = 0   # last received sequence
98         self.observed = None  # their session id
99         self.sessionID = long(rand(0, 2**32))  # our session id
100         
101         self.lastTransmit = 0  # time we last sent a packet with messages
102         self.lastReceieved = 0 # time we last received a packet with messages
103         self.lastTransmitSeq = -1 # last sequence we sent a packet
104         self.state = pending
105         
106         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
107         self.omsgq = [] # list of messages to go out
108         self.imsgq = [] # list of messages coming in
109         self.sendSession = None  # send session/observed fields until obSeq > sendSession
110         self.response = 0 # if we know we have a response now (like resending missed packets)
111         self.noisy = 0
112         self.scheduled = 0 # a sendNext is scheduled, don't schedule another
113         self.resetMessages()
114     
115     def resetMessages(self):
116         self.weMissed = []
117         self.inMsg = 0   # next incoming message number
118         self.outMsgNums = [0] * 256 # outgoing message numbers i = outNum % 256
119         self.next = 0  # next outgoing message number
120
121     def datagramReceived(self, datagram):
122         if not datagram:
123             return
124         if self.noisy:
125             print `datagram`
126         p = AirhookPacket(datagram)
127         
128             
129         # check for state change
130         if self.state == pending:
131             if p.observed != None and p.session != None:
132                 if p.observed == self.sessionID:
133                     self.observed = p.session
134                     self.state = confirmed
135                 else:
136                     # bogus packet!
137                     return
138             elif p.session != None:
139                 self.observed = p.session
140                 self.response = 1
141         elif self.state == sent:
142             if p.observed != None and p.session != None:
143                 if p.observed == self.sessionID:
144                     self.observed = p.session
145                     self.sendSession = self.outSeq
146                     self.state = confirmed
147             if p.session != None:
148                 if not self.observed:
149                     self.observed = p.session
150                 elif self.observed != p.session:
151                     self.state = pending
152                     self.resetMessages()
153                     self.inSeq = p.seq
154         elif self.state == confirmed:
155             if p.session != None or p.observed != None :
156                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
157                     self.state = pending
158                     self.observed = p.session
159                     self.resetMessages()
160                     self.inSeq = p.seq
161
162         # check to make sure sequence number isn't out of order
163         if (p.seq - self.inSeq) % 2**16 >= 256:
164             return
165     
166         if self.state == confirmed:     
167             msgs = []           
168             missed = []
169             
170             # see if they need us to resend anything
171             for i in p.missed:
172                 if self.outMsgs[i] != None:
173                     self.omsgq.append(self.outMsgs[i])
174                     self.outMsgs[i] = None
175                     
176             # see if we missed any messages
177             if p.next != None:
178                 missed_count = (p.next - self.inMsg) % 256
179                 if missed_count:
180                     self.lastReceived = time()
181                     for i in range(missed_count):
182                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
183                     self.weMissed += missed
184                     self.response = 1
185                 # record highest message number seen
186                 self.inMsg = (p.next + len(p.msgs)) % 256
187             
188             # append messages, update sequence
189             self.imsgq += p.msgs
190             
191         if self.state == confirmed:
192             # unpack the observed sequence
193             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
194             if ((self.outSeq - tseq)) % 2**16 > 255:
195                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
196             self.obSeq = tseq
197
198         self.inSeq = p.seq
199
200         self.lastReceived = time()
201         self.dataCameIn()
202         
203         self.schedule()
204         
205     def sendNext(self):
206         flags = 0
207         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
208         ids = ""
209         missed = ""
210         msgs = ""
211         
212         # session / observed logic
213         if self.state == pending:
214             if self.observed != None:
215                 flags = flags | FLAG_OBSERVED
216                 ids +=  pack("!L", self.observed)
217             flags = flags | FLAG_SESSION
218             ids +=  pack("!L", self.sessionID)
219             self.state = sent
220         elif self.state == sent:
221             if self.observed != None:
222                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
223                 ids +=  pack("!LL", self.observed, self.sessionID)
224             else:
225                 flags = flags | FLAG_SESSION
226                 ids +=  pack("!L", self.sessionID)
227
228         else:
229             if self.state == sent or self.sendSession:
230                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
231                     self.sendSession = None
232                 else:
233                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
234                     ids +=  pack("!LL", self.observed, self.sessionID)
235         
236         # missed header
237         if self.obSeq >= 0:
238             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
239
240         if len(self.weMissed) > 0:
241             flags = flags | FLAG_MISSED
242             missed += chr(len(self.weMissed) - 1)
243             for i in self.weMissed:
244                 missed += chr(i[1])
245                 
246         # append any outgoing messages
247         if self.state == confirmed and self.omsgq:
248             first = self.next
249             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
250             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
251                 msg = self.omsgq.pop()
252                 msgs += chr(len(msg) - 1) + msg
253                 self.outMsgs[self.next] = msg
254                 self.next = (self.next + 1) % 256
255                 outstanding+=1
256         # update outgoing message stat
257         if msgs:
258             flags = flags | FLAG_NEXT
259             ids += chr(first)
260             self.lastTransmitSeq = self.outSeq
261             #self.outMsgNums[self.outSeq % 256] = first
262         #else:
263         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
264         
265         # do we need a NEXT flag despite not having sent any messages?
266         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
267             flags = flags | FLAG_NEXT
268             ids += chr(self.next)
269         
270         # update stats and send packet
271         packet = chr(flags) + header + ids + missed + msgs
272         self.outSeq = (self.outSeq + 1) % 2**16
273         self.lastTransmit = time()
274         self.transport.write(packet, self.addr)
275         
276         self.scheduled = 0
277         self.schedule()
278         
279     def timeToSend(self):
280         if self.state == pending:
281             return (1, 0)
282         # any outstanding messages and are we not too far ahead of our counterparty?
283         elif len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
284             return (1, 0)
285         # do we explicitly need to send a response?
286         elif self.response:
287             self.response = 0
288             return (1, 0)
289         # have we not sent anything in a while?
290         elif time() - self.lastTransmit > 1.0:
291             return (1, 1)
292             
293         # nothing to send
294         return (0, 0)
295
296     def schedule(self):
297         tts, t = self.timeToSend()
298         if tts and not self.scheduled:
299             self.scheduled = 1
300             reactor.callLater(t, self.sendNext)
301         
302     def write(self, data):
303         # micropackets can only be 255 bytes or less
304         if len(data) <= 255:
305             self.omsgq.insert(0, data)
306         self.schedule()
307         
308     def dataCameIn(self):
309         """
310         called when we get a packet bearing messages
311         """
312         for msg in self.imsgq:
313             self.protocol.dataReceived(msg)
314         self.imsgq = []
315
316 class ustr(str):
317     """
318         this subclass of string encapsulates each ordered message, caches it's sequence number,
319         and has comparison functions to sort by sequence number
320     """
321     def getseq(self):
322         if not hasattr(self, 'seq'):
323             self.seq = unpack("!H", self[0:2])[0]
324         return self.seq
325     def __lt__(self, other):
326         return (self.getseq() - other.getseq()) % 2**16 > 255
327     def __le__(self, other):
328         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
329     def __eq__(self, other):
330         return self.getseq() == other.getseq()
331     def __ne__(self, other):
332         return self.getseq() != other.getseq()
333     def __gt__(self, other):
334         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
335     def __ge__(self, other):
336         return (self.getseq() - other.getseq()) % 2**16 < 256
337         
338 class StreamConnection(AirhookConnection):
339     """
340         this implements a simple protocol for a stream over airhook
341         this is done for convenience, instead of making it a twisted.internet.protocol....
342         the first two octets of each message are interpreted as a 16-bit sequence number
343         253 bytes are used for payload
344         
345     """
346     def __init__(self):
347         AirhookConnection.__init__(self)
348         self.oseq = 0
349         self.iseq = 0
350         self.q = []
351
352     def dataCameIn(self):
353         # put 'em together
354         for msg in self.imsgq:
355             insort_left(self.q, ustr(msg))
356         self.imsgq = []
357         data = ""
358         while self.q and self.iseq == self.q[0].getseq():
359             data += self.q[0][2:]
360             self.q = self.q[1:]
361             self.iseq = (self.iseq + 1) % 2**16
362         if data != '':
363             self.protocol.dataReceived(data)
364         
365     def write(self, data):
366         # chop it up and queue it up
367         while data:
368             p = pack("!H", self.oseq) + data[:253]
369             self.omsgq.insert(0, p)
370             data = data[253:]
371             self.oseq = (self.oseq + 1) % 2**16
372
373         self.schedule()
374         
375     def writeSequence(self, sequence):
376         for data in sequence:
377             self.write(data)
378
379
380 def listenAirhook(port, factory):
381     ah = Airhook()
382     ah.connection = AirhookConnection
383     ah.factory = factory
384     reactor.listenUDP(port, ah)
385     return ah
386
387 def listenAirhookStream(port, factory):
388     ah = Airhook()
389     ah.connection = StreamConnection
390     ah.factory = factory
391     reactor.listenUDP(port, ah)
392     return ah
393
394