]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
4053dfe5906493a09b012950b5f40e6c8add07f9
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 from StringIO import StringIO
8 import unittest
9 from bisect import insort_left
10
11 from twisted.internet import protocol
12 from twisted.internet import reactor
13
14 # flags
15 FLAG_AIRHOOK = 128
16 FLAG_OBSERVED = 16
17 FLAG_SESSION = 8
18 FLAG_MISSED = 4
19 FLAG_NEXT = 2
20 FLAG_INTERVAL = 1
21
22 MAX_PACKET_SIZE = 1480
23
24 pending = 0
25 sent = 1
26 confirmed = 2
27
28 class Delegate:
29         def setDelegate(self, delegate):
30                 self.delegate = delegate
31         def getDelegate(self):
32                 return self.delegate
33         def msgDelegate(self, method, args=(), kwargs={}):
34                 if hasattr(self, 'delegate') and hasattr(self.delegate, method) and callable(getattr(self.delegate, method)):
35                         apply(getattr(self.delegate, method) , args, kwargs)
36
37 class Airhook(protocol.DatagramProtocol):
38
39         def __init__(self, connection_class):
40                 self.connection_class = connection_class
41         def startProtocol(self):
42                 self.connections = {}
43                                 
44         def datagramReceived(self, datagram, addr):
45                 flag = datagram[0]
46                 if not flag & FLAG_AIRHOOK:  # first bit always must be 0
47                         conn = self.connectionForAddr(addr)
48                         conn.datagramReceieved(datagram)
49
50         def connectionForAddr(self, addr):
51                 if not self.connections.has_key(addr):
52                         conn = connection_class(self.transport, addr, self.delegate)
53                         self.connections[addr] = conn
54                 return self.connections[addr]
55
56                 
57 class AirhookPacket:
58         def __init__(self, msg):
59                 self.datagram = msg
60                 self.oseq =  ord(msg[1])
61                 self.seq = unpack("!H", msg[2:4])[0]
62                 self.flags = ord(msg[0])
63                 self.session = None
64                 self.observed = None
65                 self.next = None
66                 self.missed = []
67                 self.msgs = []
68                 skip = 4
69                 if self.flags & FLAG_OBSERVED:
70                         self.observed = unpack("!L", msg[skip:skip+4])[0]
71                         skip += 4
72                 if self.flags & FLAG_SESSION:
73                         self.session =  unpack("!L", msg[skip:skip+4])[0]
74                         skip += 4
75                 if self.flags & FLAG_NEXT:
76                         self.next =  ord(msg[skip])
77                         skip += 1
78                 if self.flags & FLAG_MISSED:
79                         num = ord(msg[skip]) + 1
80                         skip += 1
81                         for i in range(num):
82                                 self.missed.append( ord(msg[skip+i]))
83                         skip += num
84                 if self.flags & FLAG_NEXT:
85                         while len(msg) - skip > 0:
86                                 n = ord(msg[skip]) + 1
87                                 skip+=1
88                                 self.msgs.append( msg[skip:skip+n])
89                                 skip += n
90
91 class AirhookConnection(Delegate):
92         def __init__(self, transport, addr, delegate):
93                 self.delegate = delegate
94                 self.addr = addr
95                 type, self.host, self.port = addr
96                 self.transport = transport
97                 
98                 self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
99                 self.obSeq = -1   # highest sequence confirmed by remote
100                 self.inSeq = 0   # last received sequence
101                 self.observed = None  # their session id
102                 self.sessionID = long(rand(0, 2**32))  # our session id
103                 
104                 self.lastTransmit = -1  # time we last sent a packet with messages
105                 self.lastReceieved = 0 # time we last received a packet with messages
106                 self.lastTransmitSeq = -1 # last sequence we sent a packet
107                 self.state = pending
108                 
109                 self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
110                 self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
111                 self.next = -1  # next outgoing message number
112                 self.omsgq = [] # list of messages to go out
113                 self.imsgq = [] # list of messages coming in
114                 self.sendSession = None  # send session/observed fields until obSeq > sendSession
115
116                 self.resetMessages()
117         
118         def resetMessages(self):
119                 self.weMissed = []
120                 self.inMsg = 0   # next incoming message number
121         
122         def datagramReceived(self, datagram):
123                 if not datagram:
124                         return
125                 response = 0 # if we know we have a response now (like resending missed packets)
126                 p = AirhookPacket(datagram)
127                 
128                 # check for state change
129                 if self.state == pending:
130                         if p.observed != None and p.session != None:
131                                 if p.observed == self.sessionID:
132                                         self.observed = p.session
133                                         self.state = confirmed
134                                 else:
135                                         # bogus packet!
136                                         return
137                         elif p.session != None:
138                                 self.observed = p.session
139                                 self.state = sent
140                                 response = 1
141                 elif self.state == sent:
142                         if p.observed != None and p.session != None:
143                                 if p.observed == self.sessionID:
144                                         self.observed = p.session
145                                         self.sendSession = self.outSeq
146                                         self.state = confirmed
147                         if p.session != None:
148                                 if not self.observed:
149                                         self.observed = p.session
150                                 elif self.observed != p.session:
151                                         self.state = pending
152                                         self.resetMessages()
153                                         self.inSeq = p.seq
154                         response = 1
155                 elif self.state == confirmed:
156                         if p.session != None or p.observed != None :
157                                 if p.session != self.observed or p.observed != self.sessionID:
158                                         self.state = pending
159                                         if seq == 0:
160                                                 self.resetMessages()
161                                                 self.inSeq = p.seq
162         
163                 if self.state != pending:       
164                         msgs = []               
165                         missed = []
166
167                         # check to make sure sequence number isn't out of wack
168                         assert (p.seq - self.inSeq) % 2**16 < 256
169                         
170                         # see if they need us to resend anything
171                         for i in p.missed:
172                                 response = 1
173                                 if self.outMsgs[i] != None:
174                                         self.omsgq.insert(0, self.outMsgs[i])
175                                         self.outMsgs[i] = None
176                                         
177                         # see if we need them to send anything
178                         if p.next != None:
179                                 if p.next == 0 and self.inMsg == -1:
180                                         missed = 255
181                                 missed_count = (p.next - self.inMsg) % 256
182                                 if missed_count:
183                                         self.lastReceived = time()
184                                         for i in range(missed_count):
185                                                 missed += [(self.outSeq, (self.inMsg + i) % 256)]
186                                         response = 1
187                                         self.weMissed += missed
188                                 self.inMsg = (p.next + len(p.msgs)) % 256
189                                 
190                         self.imsgq += p.msgs
191                         self.inSeq = p.seq
192                         
193                 if self.state == confirmed:
194                         # unpack the observed sequence
195                         tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
196                         if ((self.outSeq - tseq)) % 2**16 > 255:
197                                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
198                         self.obSeq = tseq
199
200                 if response:
201                         reactor.callLater(0, self.sendNext)
202                 self.lastReceived = time()
203                 self.dataCameIn()
204                 
205         def sendNext(self):
206                 flags = 0
207                 header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
208                 ids = ""
209                 missed = ""
210                 msgs = ""
211                 
212                 if self.state == pending:
213                         flags = flags | FLAG_SESSION
214                         ids +=  pack("!L", self.sessionID)
215                         self.state = sent
216                 elif self.state == sent:
217                         if self.observed != None:
218                                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
219                                 ids +=  pack("!LL", self.observed, self.sessionID)
220                         else:
221                                 flags = flags | FLAG_SESSION
222                                 ids +=  pack("!L", self.sessionID)
223
224                 else:
225                         if self.state == sent or self.sendSession:
226                                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
227                                         self.sendSession = None
228                                 else:
229                                         flags = flags | FLAG_SESSION | FLAG_OBSERVED
230                                         ids +=  pack("!LL", self.observed, self.sessionID)
231                 
232                 if self.obSeq >= 0:
233                         self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
234
235                 if self.weMissed:
236                         flags = flags | FLAG_MISSED
237                         missed += chr(len(self.weMissed) - 1)
238                         for i in self.weMissed:
239                                 missed += chr(i[1])
240                                 
241                 if self.state == confirmed and self.omsgq:
242                         first = (self.next + 1) % 256
243                         while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) :
244                                 if self.obSeq == -1:
245                                         highest = 0
246                                 else:
247                                         highest = self.outMsgNums[self.obSeq % 256]
248                                 if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256:
249                                         break
250                                 else:
251                                         self.next = (self.next + 1) % 256
252                                         msg = self.omsgq.pop()
253                                         msgs += chr(len(msg) - 1) + msg
254                                         self.outMsgs[self.next] = msg
255                 if msgs:
256                         flags = flags | FLAG_NEXT
257                         ids += chr(first)
258                         self.lastTransmitSeq = self.outSeq
259                         self.outMsgNums[self.outSeq % 256] = first
260                 else:
261                         if self.next == -1:
262                                 self.outMsgNums[self.outSeq % 256] = 0
263                         else:
264                                 self.outMsgNums[self.outSeq % 256] = self.next
265                         
266                 if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and  not flags & FLAG_NEXT:
267                                 flags = flags | FLAG_NEXT
268                                 ids += chr((self.next + 1) % 256)
269                 packet = chr(flags) + header + ids + missed + msgs
270                 self.outSeq = (self.outSeq + 1) % 2**16
271                 self.lastTransmit = time()
272                 self.transport.write(packet)
273                 
274                 if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
275                         reactor.callLater(0, self.sendNext)
276                 else:
277                         reactor.callLater(1, self.sendNext)
278
279
280         def dataCameIn(self):
281                 self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
282                 if hasattr(self, 'delegate') and self.delegate != None:
283                         self.imsgq = []
284
285 class ustr(str):
286         def getseq(self):
287                 if not hasattr(self, 'seq'):
288                         self.seq = unpack("!H", self[0:2])[0]
289                 return self.seq
290         def __lt__(self, other):
291                 return self.getseq() < other.getseq()
292         def __le__(self, other):
293                 return self.getseq() <= other.getseq()
294         def __eq__(self, other):
295                 return self.getseq() != other.getseq()
296         def __ne__(self, other):
297                 return self.getseq() <= other.getseq()
298         def __gt__(self, other):
299                 return self.getseq() > other.getseq()
300         def __ge__(self, other):
301                 return self.getseq() >= other.getseq()
302
303 class OrderedConnection(AirhookConnection):
304         def __init__(self, transport, addr, delegate):
305                 AirhookConnection.__init__(self, transport, addr, delegate)
306                 self.oseq = 0
307                 self.iseq = 0
308                 self.q = []
309
310         def dataCameIn(self):
311                 # put 'em together
312                 for msg in self.imsgq:
313                         insort_left(self.q, ustr(msg))
314                 self.imsgq = []
315                 data = ""
316                 while self.q and self.iseq == self.q[0].getseq():
317                         data += self.q[0][2:]
318                         self.iseq = (self.iseq + 1) % 2**16
319                         self.q = self.q[1:]
320                 if data:
321                         self.msgDelegate('dataCameIn', (self.host, self.port, data))
322                 
323         def sendSomeData(self, data):
324                 # chop it up and queue it up
325                 while data:
326                         p = "%s%s" % (pack("!H", self.oseq), data[:253])
327                         self.omsgq.append(p)
328                         data = data[253:]
329                         self.oseq = (self.oseq + 1) % 2**16
330
331                 if self.omsgq:
332                         self.sendNext()