]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
397cf53e41c4eb1a8c2b4d04b14ae5e7dd5de133
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ## Copyright 2002-2003 Andrew Loewenstern, All Rights Reserved
3 # see LICENSE.txt for license information
4
5 from random import uniform as rand
6 from struct import pack, unpack
7 from time import time
8 import unittest
9 from bisect import insort_left
10
11 from twisted.internet import protocol
12 from twisted.internet import abstract
13 from twisted.internet import reactor
14 from twisted.internet import app
15 from twisted.internet import interfaces
16
17 # flags
18 FLAG_AIRHOOK = 128
19 FLAG_OBSERVED = 16
20 FLAG_SESSION = 8
21 FLAG_MISSED = 4
22 FLAG_NEXT = 2
23 FLAG_INTERVAL = 1
24
25 MAX_PACKET_SIZE = 1450
26
27 pending = 0
28 sent = 1
29 confirmed = 2
30
31
32
33 class Airhook:       
34     def __init__(self):
35         self.noisy = None
36         # this should be changed to storage that drops old entries
37         self.connections = {}
38         
39     def datagramReceived(self, datagram, addr):
40         #print `addr`, `datagram`
41         #if addr != self.addr:
42         self.connectionForAddr(addr).datagramReceived(datagram)
43
44     def connectionForAddr(self, addr):
45         if addr == self.addr:
46             raise Exception
47         if not self.connections.has_key(addr):
48             conn = self.connection()
49             conn.addr = addr
50             conn.makeConnection(self.transport)
51             conn.protocol = self.factory.buildProtocol(addr)
52             conn.protocol.makeConnection(conn)
53             self.connections[addr] = conn
54         else:
55             conn = self.connections[addr]
56         return conn
57     def makeConnection(self, transport):
58         self.transport = transport
59         addr = transport.getHost()
60         self.addr = (addr.host, addr.port)
61
62     def doStart(self):
63         pass
64     def doStop(self):
65         pass
66     
67 class AirhookPacket:
68     def __init__(self, msg):
69         self.datagram = msg
70         self.oseq =  ord(msg[1])
71         self.seq = unpack("!H", msg[2:4])[0]
72         self.flags = ord(msg[0])
73         self.session = None
74         self.observed = None
75         self.next = None
76         self.missed = []
77         self.msgs = []
78         skip = 4
79         if self.flags & FLAG_OBSERVED:
80             self.observed = unpack("!L", msg[skip:skip+4])[0]
81             skip += 4
82         if self.flags & FLAG_SESSION:
83             self.session =  unpack("!L", msg[skip:skip+4])[0]
84             skip += 4
85         if self.flags & FLAG_NEXT:
86             self.next =  ord(msg[skip])
87             skip += 1
88         if self.flags & FLAG_MISSED:
89             num = ord(msg[skip]) + 1
90             skip += 1
91             for i in range(num):
92                 self.missed.append( ord(msg[skip+i]))
93             skip += num
94         if self.flags & FLAG_NEXT:
95             while len(msg) - skip > 0:
96                 n = ord(msg[skip]) + 1
97                 skip+=1
98                 self.msgs.append( msg[skip:skip+n])
99                 skip += n
100
101 class AirhookConnection:
102     def __init__(self):        
103         self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
104         self.observed = None  # their session id
105         self.sessionID = long(rand(0, 2**32))  # our session id
106         
107         self.lastTransmit = 0  # time we last sent a packet with messages
108         self.lastReceived = 0 # time we last received a packet with messages
109         self.lastTransmitSeq = -1 # last sequence we sent a packet
110         self.state = pending # one of pending, sent, confirmed
111         
112         self.sendSession = None  # send session/observed fields until obSeq > sendSession
113         self.response = 0 # if we know we have a response now (like resending missed packets)
114         self.noisy = 0
115         self.resetConnection()
116
117     def makeConnection(self, transport):
118         self.transport = transport
119         
120     def resetConnection(self):
121         self.weMissed = []
122         self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
123         self.inMsg = 0   # next incoming message number
124         self.outMsgNums = [0] * 256 # outgoing message numbers i = outNum % 256
125         self.next = 0  # next outgoing message number
126         self.scheduled = 0 # a sendNext is scheduled, don't schedule another
127         self.omsgq = [] # list of messages to go out
128         self.imsgq = [] # list of messages coming in
129         self.obSeq = 0   # highest sequence confirmed by remote
130         self.inSeq = 0   # last received sequence
131
132     def datagramReceived(self, datagram):
133         if not datagram:
134             return
135         if self.noisy:
136             print `datagram`
137         p = AirhookPacket(datagram)
138         
139             
140         # check for state change
141         if self.state == pending:
142             if p.observed != None and p.session != None:
143                 if p.observed == self.sessionID:
144                     self.observed = p.session
145                     self.state = confirmed
146                     self.response = 1
147                 else:
148                     self.observed = p.session
149                     self.response = 1
150             elif p.session != None:
151                 self.observed = p.session
152                 self.response = 1
153             else:
154                 self.response = 1
155         elif self.state == sent:
156             if p.observed != None and p.session != None:
157                 if p.observed == self.sessionID:
158                     self.observed = p.session
159                     self.sendSession = self.outSeq
160                     self.state = confirmed
161             if p.session != None:
162                 if not self.observed:
163                     self.observed = p.session
164                 elif self.observed != p.session:
165                     self.state = pending
166                     self.observed = p.session
167                     self.resetConnection()
168                     self.response = 1
169                     if hasattr(self.protocol, "resetConnection") and callable(self.protocol.resetConnection):
170                         self.protocol.resetConnection()
171                     self.inSeq = p.seq
172                     self.schedule()
173                     return
174             elif p.session == None and p.observed == None:
175                 self.response = 1
176                 self.schedule()
177     
178         elif self.state == confirmed:
179             if p.session != None or p.observed != None :
180                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
181                     self.state = pending
182                     self.observed = p.session
183                     self.resetConnection()
184                     self.inSeq = p.seq
185                     if hasattr(self.protocol, "resetConnection") and callable(self.protocol.resetConnection):
186                         self.protocol.resetConnection()
187                     self.schedule()
188                     return
189         # check to make sure sequence number isn't out of order
190         if (p.seq - self.inSeq) % 2**16 >= 256:
191             return
192     
193         if self.state == confirmed:     
194             msgs = []           
195             missed = []
196             
197             # see if they need us to resend anything
198             for i in p.missed:
199                 if self.outMsgs[i] != None:
200                     self.omsgq.append(self.outMsgs[i])
201                     self.outMsgs[i] = None
202                     
203             # see if we missed any messages
204             if p.next != None:
205                 missed_count = (p.next - self.inMsg) % 256
206                 if missed_count:
207                     self.lastReceived = time()
208                     for i in range(missed_count):
209                         missed += [(self.outSeq, (self.inMsg + i) % 256)]
210                     self.weMissed += missed
211                     self.response = 1
212                 # record highest message number seen
213                 self.inMsg = (p.next + len(p.msgs)) % 256
214             
215             # append messages, update sequence
216             self.imsgq += p.msgs
217             
218         if self.state == confirmed:
219             # unpack the observed sequence
220             tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
221             if ((self.outSeq - tseq)) % 2**16 > 255:
222                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
223             self.obSeq = tseq
224
225         self.inSeq = p.seq
226
227         self.lastReceived = time()
228         self.dataCameIn()
229         
230         self.schedule()
231         
232     def sendNext(self):
233         flags = 0
234         header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
235         ids = ""
236         missed = ""
237         msgs = ""
238         
239         # session / observed logic
240         if self.state == pending:
241             if self.observed != None:
242                 flags = flags | FLAG_OBSERVED
243                 ids +=  pack("!L", self.observed)
244             flags = flags | FLAG_SESSION
245             ids +=  pack("!L", self.sessionID)
246             self.state = sent
247         elif self.state == sent:
248             if self.observed != None:
249                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
250                 ids +=  pack("!LL", self.observed, self.sessionID)
251             else:
252                 flags = flags | FLAG_SESSION
253                 ids +=  pack("!L", self.sessionID)
254
255         else:
256             if self.state == sent or self.sendSession != None:
257                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
258                     self.sendSession = None
259                 else:
260                     flags = flags | FLAG_SESSION | FLAG_OBSERVED
261                     ids +=  pack("!LL", self.observed, self.sessionID)
262         
263         # missed header
264         if self.obSeq >= 0:
265             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
266
267         if len(self.weMissed) > 0:
268             flags = flags | FLAG_MISSED
269             missed += chr(len(self.weMissed) - 1)
270             for i in self.weMissed:
271                 missed += chr(i[1])
272                 
273         # append any outgoing messages
274         if self.state == confirmed and self.omsgq:
275             first = self.next
276             outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
277             while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
278                 msg = self.omsgq.pop()
279                 msgs += chr(len(msg) - 1) + msg
280                 self.outMsgs[self.next] = msg
281                 self.next = (self.next + 1) % 256
282                 outstanding+=1
283         # update outgoing message stat
284         if msgs:
285             flags = flags | FLAG_NEXT
286             ids += chr(first)
287             self.lastTransmitSeq = self.outSeq
288             #self.outMsgNums[self.outSeq % 256] = first
289         #else:
290         self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
291         
292         # do we need a NEXT flag despite not having sent any messages?
293         if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
294             flags = flags | FLAG_NEXT
295             ids += chr(self.next)
296         
297         # update stats and send packet
298         packet = chr(flags) + header + ids + missed + msgs
299         self.outSeq = (self.outSeq + 1) % 2**16
300         self.lastTransmit = time()
301         self.transport.write(packet, self.addr)
302         
303         self.scheduled = 0
304         self.schedule()
305         
306     def timeToSend(self):
307         if self.state == pending:
308             return (1, 0)
309
310         outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
311         # any outstanding messages and are we not too far ahead of our counterparty?
312         if len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
313         #if len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and not outstanding:
314             return (1, 0)
315
316         # do we explicitly need to send a response?
317         if self.response:
318             self.response = 0
319             return (1, 0)
320         # have we not sent anything in a while?
321         #elif time() - self.lastTransmit > 1.0:
322         #return (1, 1)
323             
324         # nothing to send
325         return (0, 0)
326
327     def schedule(self):
328         tts, t = self.timeToSend()
329         if tts and not self.scheduled:
330             self.scheduled = 1
331             reactor.callLater(t, self.sendNext)
332         
333     def write(self, data):
334         # micropackets can only be 255 bytes or less
335         if len(data) <= 255:
336             self.omsgq.insert(0, data)
337         self.schedule()
338         
339     def dataCameIn(self):
340         """
341         called when we get a packet bearing messages
342         """
343         for msg in self.imsgq:
344             self.protocol.dataReceived(msg)
345         self.imsgq = []
346
347 class ustr(str):
348     """
349         this subclass of string encapsulates each ordered message, caches it's sequence number,
350         and has comparison functions to sort by sequence number
351     """
352     def getseq(self):
353         if not hasattr(self, 'seq'):
354             self.seq = unpack("!H", self[0:2])[0]
355         return self.seq
356     def __lt__(self, other):
357         return (self.getseq() - other.getseq()) % 2**16 > 255
358     def __le__(self, other):
359         return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
360     def __eq__(self, other):
361         return self.getseq() == other.getseq()
362     def __ne__(self, other):
363         return self.getseq() != other.getseq()
364     def __gt__(self, other):
365         return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
366     def __ge__(self, other):
367         return (self.getseq() - other.getseq()) % 2**16 < 256
368         
369 class StreamConnection(AirhookConnection):
370     """
371         this implements a simple protocol for a stream over airhook
372         this is done for convenience, instead of making it a twisted.internet.protocol....
373         the first two octets of each message are interpreted as a 16-bit sequence number
374         253 bytes are used for payload
375         
376     """
377     def __init__(self):
378         AirhookConnection.__init__(self)
379         self.resetStream()
380         
381     def resetStream(self):
382         self.oseq = 0
383         self.iseq = 0
384         self.q = []
385
386     def resetConnection(self):
387         AirhookConnection.resetConnection(self)
388         self.resetStream()
389
390     def loseConnection(self):
391         pass
392     
393     def dataCameIn(self):
394         # put 'em together
395         for msg in self.imsgq:
396             insort_left(self.q, ustr(msg))
397         self.imsgq = []
398         data = ""
399         while self.q and self.iseq == self.q[0].getseq():
400             data += self.q[0][2:]
401             self.q = self.q[1:]
402             self.iseq = (self.iseq + 1) % 2**16
403         if data != '':
404             self.protocol.dataReceived(data)
405         
406     def write(self, data):
407         # chop it up and queue it up
408         while data:
409             p = pack("!H", self.oseq) + data[:253]
410             self.omsgq.insert(0, p)
411             data = data[253:]
412             self.oseq = (self.oseq + 1) % 2**16
413         self.schedule()
414         
415     def writeSequence(self, sequence):
416         for data in sequence:
417             self.write(data)
418
419
420 def listenAirhook(port, factory):
421     ah = Airhook()
422     ah.connection = AirhookConnection
423     ah.factory = factory
424     reactor.listenUDP(port, ah)
425     return ah
426
427 def listenAirhookStream(port, factory):
428     ah = Airhook()
429     ah.connection = StreamConnection
430     ah.factory = factory
431     reactor.listenUDP(port, ah)
432     return ah
433
434