]> git.mxchange.org Git - quix0rs-apt-p2p.git/blob - airhook.py
bug fixes, more tests, looking solid now
[quix0rs-apt-p2p.git] / airhook.py
1 ##  Airhook Protocol http://airhook.org/protocol.html
2 ##  Copyright 2002, Andrew Loewenstern, All Rights Reserved
3
4 from random import uniform as rand
5 from struct import pack, unpack
6 from time import time
7 from StringIO import StringIO
8 import unittest
9 from bisect import insort_left
10
11 from twisted.internet import protocol
12 from twisted.internet import reactor
13
14 # flags
15 FLAG_AIRHOOK = 128
16 FLAG_OBSERVED = 16
17 FLAG_SESSION = 8
18 FLAG_MISSED = 4
19 FLAG_NEXT = 2
20 FLAG_INTERVAL = 1
21
22 MAX_PACKET_SIZE = 1496
23
24 pending = 0
25 sent = 1
26 confirmed = 2
27
28 class Delegate:
29         def setDelegate(self, delegate):
30                 self.delegate = delegate
31         def getDelegate(self):
32                 return self.delegate
33         def msgDelegate(self, method, args=(), kwargs={}):
34                 if hasattr(self, 'delegate') and hasattr(self.delegate, method) and callable(getattr(self.delegate, method)):
35                         apply(getattr(self.delegate, method) , args, kwargs)
36
37 class Airhook(protocol.DatagramProtocol):
38
39         def __init__(self, connection_class):
40                 self.connection_class = connection_class
41         def startProtocol(self):
42                 self.connections = {}
43                                 
44         def datagramReceived(self, datagram, addr):
45                 flag = datagram[0]
46                 if not flag & FLAG_AIRHOOK:  # first bit always must be 0
47                         conn = self.connectionForAddr(addr)
48                         conn.datagramReceieved(datagram)
49
50         def connectionForAddr(self, addr):
51                 if not self.connections.has_key(addr):
52                         conn = connection_class(self.transport, addr, self.delegate)
53                         self.connections[addr] = conn
54                 return self.connections[addr]
55
56                 
57 class AirhookPacket:
58         def __init__(self, msg):
59                 self.datagram = msg
60                 self.oseq =  ord(msg[1])
61                 self.seq = unpack("!H", msg[2:4])[0]
62                 self.flags = ord(msg[0])
63                 self.session = None
64                 self.observed = None
65                 self.next = None
66                 self.missed = []
67                 self.msgs = []
68                 skip = 4
69                 if self.flags & FLAG_OBSERVED:
70                         self.observed = unpack("!L", msg[skip:skip+4])[0]
71                         skip += 4
72                 if self.flags & FLAG_SESSION:
73                         self.session =  unpack("!L", msg[skip:skip+4])[0]
74                         skip += 4
75                 if self.flags & FLAG_NEXT:
76                         self.next =  ord(msg[skip])
77                         skip += 1
78                 if self.flags & FLAG_MISSED:
79                         num = ord(msg[skip]) + 1
80                         skip += 1
81                         for i in range(num):
82                                 self.missed.append( ord(msg[skip+i]))
83                         skip += num
84                 if self.flags & FLAG_NEXT:
85                         while len(msg) - skip > 0:
86                                 n = ord(msg[skip]) + 1
87                                 skip+=1
88                                 self.msgs.append( msg[skip:skip+n])
89                                 skip += n
90
91 class AirhookConnection(Delegate):
92         def __init__(self, transport, addr, delegate):
93                 self.delegate = delegate
94                 self.addr = addr
95                 type, self.host, self.port = addr
96                 self.transport = transport
97                 
98                 self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
99                 self.obSeq = 0   # highest sequence confirmed by remote
100                 self.inSeq = 0   # last received sequence
101                 self.observed = None  # their session id
102                 self.sessionID = long(rand(0, 2**32))  # our session id
103                 
104                 self.lastTransmit = -1  # time we last sent a packet with messages
105                 self.lastReceieved = 0 # time we last received a packet with messages
106                 self.lastTransmitSeq = -1 # last sequence we sent a packet
107                 self.state = pending
108                 
109                 self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
110                 self.omsgq = [] # list of messages to go out
111                 self.imsgq = [] # list of messages coming in
112                 self.sendSession = None  # send session/observed fields until obSeq > sendSession
113
114                 self.resetMessages()
115         
116         def resetMessages(self):
117                 self.weMissed = []
118                 self.inMsg = 0   # next incoming message number
119                 self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
120                 self.next = 0  # next outgoing message number
121
122         def datagramReceived(self, datagram):
123                 if not datagram:
124                         return
125                 response = 0 # if we know we have a response now (like resending missed packets)
126                 p = AirhookPacket(datagram)
127                 
128                 # check to make sure sequence number isn't out of order
129                 if (p.seq - self.inSeq) % 2**16 >= 256:
130                         return
131                         
132                 # check for state change
133                 if self.state == pending:
134                         if p.observed != None and p.session != None:
135                                 if p.observed == self.sessionID:
136                                         self.observed = p.session
137                                         self.state = confirmed
138                                 else:
139                                         # bogus packet!
140                                         return
141                         elif p.session != None:
142                                 self.observed = p.session
143                                 self.state = sent
144                                 response = 1
145                 elif self.state == sent:
146                         if p.observed != None and p.session != None:
147                                 if p.observed == self.sessionID:
148                                         self.observed = p.session
149                                         self.sendSession = self.outSeq
150                                         self.state = confirmed
151                         if p.session != None:
152                                 if not self.observed:
153                                         self.observed = p.session
154                                 elif self.observed != p.session:
155                                         self.state = pending
156                                         self.resetMessages()
157                                         self.inSeq = p.seq
158                         response = 1
159                 elif self.state == confirmed:
160                         if p.session != None or p.observed != None :
161                                 if p.session != self.observed or p.observed != self.sessionID:
162                                         self.state = pending
163                                         if seq == 0:
164                                                 self.resetMessages()
165                                                 self.inSeq = p.seq
166         
167                 if self.state != pending:       
168                         msgs = []               
169                         missed = []
170                         
171                         # see if they need us to resend anything
172                         for i in p.missed:
173                                 response = 1
174                                 if self.outMsgs[i] != None:
175                                         self.omsgq.append(self.outMsgs[i])
176                                         self.outMsgs[i] = None
177                                         
178                         # see if we missed any messages
179                         if p.next != None:
180                                 missed_count = (p.next - self.inMsg) % 256
181                                 if missed_count:
182                                         self.lastReceived = time()
183                                         for i in range(missed_count):
184                                                 missed += [(self.outSeq, (self.inMsg + i) % 256)]
185                                         response = 1
186                                         self.weMissed += missed
187                                 # record highest message number seen
188                                 self.inMsg = (p.next + len(p.msgs)) % 256
189                         
190                         # append messages, update sequence
191                         self.imsgq += p.msgs
192                         
193                 if self.state == confirmed:
194                         # unpack the observed sequence
195                         tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
196                         if ((self.outSeq - tseq)) % 2**16 > 255:
197                                 tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
198                         self.obSeq = tseq
199
200                 self.inSeq = p.seq
201
202                 if response:
203                         reactor.callLater(0, self.sendNext)
204                 self.lastReceived = time()
205                 self.dataCameIn()
206                 
207         def sendNext(self):
208                 flags = 0
209                 header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
210                 ids = ""
211                 missed = ""
212                 msgs = ""
213                 
214                 # session / observed logic
215                 if self.state == pending:
216                         flags = flags | FLAG_SESSION
217                         ids +=  pack("!L", self.sessionID)
218                         self.state = sent
219                 elif self.state == sent:
220                         if self.observed != None:
221                                 flags = flags | FLAG_SESSION | FLAG_OBSERVED
222                                 ids +=  pack("!LL", self.observed, self.sessionID)
223                         else:
224                                 flags = flags | FLAG_SESSION
225                                 ids +=  pack("!L", self.sessionID)
226
227                 else:
228                         if self.state == sent or self.sendSession:
229                                 if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
230                                         self.sendSession = None
231                                 else:
232                                         flags = flags | FLAG_SESSION | FLAG_OBSERVED
233                                         ids +=  pack("!LL", self.observed, self.sessionID)
234                 
235                 # missed header
236                 if self.obSeq >= 0:
237                         self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
238
239                 if self.weMissed:
240                         flags = flags | FLAG_MISSED
241                         missed += chr(len(self.weMissed) - 1)
242                         for i in self.weMissed:
243                                 missed += chr(i[1])
244                                 
245                 # append any outgoing messages
246                 if self.state == confirmed and self.omsgq:
247                         first = self.next
248                         outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
249                         while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
250                                 msg = self.omsgq.pop()
251                                 msgs += chr(len(msg) - 1) + msg
252                                 self.outMsgs[self.next] = msg
253                                 self.next = (self.next + 1) % 256
254                                 outstanding+=1
255                 # update outgoing message stat
256                 if msgs:
257                         flags = flags | FLAG_NEXT
258                         ids += chr(first)
259                         self.lastTransmitSeq = self.outSeq
260                         #self.outMsgNums[self.outSeq % 256] = first
261                 #else:
262                 self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
263                 
264                 # do we need a NEXT flag despite not having sent any messages?
265                 if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
266                                 flags = flags | FLAG_NEXT
267                                 ids += chr(self.next)
268                 
269                 # update stats and send packet
270                 packet = chr(flags) + header + ids + missed + msgs
271                 self.outSeq = (self.outSeq + 1) % 2**16
272                 self.lastTransmit = time()
273                 self.transport.write(packet)
274                 
275                 # call later
276                 if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
277                         reactor.callLater(0, self.sendNext)
278                 else:
279                         reactor.callLater(1, self.sendNext)
280
281
282         def dataCameIn(self):
283                 """
284                 called when we get a packet bearing messages
285                 delegate must do something with the messages or they will get dropped 
286                 """
287                 self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
288                 if hasattr(self, 'delegate') and self.delegate != None:
289                         self.imsgq = []
290
291 class ustr(str):
292         """
293                 this subclass of string encapsulates each ordered message, caches it's sequence number,
294                 and has comparison functions to sort by sequence number
295         """
296         def getseq(self):
297                 if not hasattr(self, 'seq'):
298                         self.seq = unpack("!H", self[0:2])[0]
299                 return self.seq
300         def __lt__(self, other):
301                 return (self.getseq() - other.getseq()) % 2**16 > 255
302         def __le__(self, other):
303                 return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
304         def __eq__(self, other):
305                 return self.getseq() == other.getseq()
306         def __ne__(self, other):
307                 return self.getseq() != other.getseq()
308         def __gt__(self, other):
309                 return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
310         def __ge__(self, other):
311                 return (self.getseq() - other.getseq()) % 2**16 < 256
312                 
313 class OrderedConnection(AirhookConnection):
314         """
315                 this implements a simple protocol for ordered messages over airhook
316                 the first two octets of each message are interpreted as a 16-bit sequence number
317                 253 bytes are used for payload
318         """
319         def __init__(self, transport, addr, delegate):
320                 AirhookConnection.__init__(self, transport, addr, delegate)
321                 self.oseq = 0
322                 self.iseq = 0
323                 self.q = []
324
325         def dataCameIn(self):
326                 # put 'em together
327                 for msg in self.imsgq:
328                         insort_left(self.q, ustr(msg))
329                 self.imsgq = []
330                 data = ""
331                 while self.q and self.iseq == self.q[0].getseq():
332                         data += self.q[0][2:]
333                         self.q = self.q[1:]
334                         self.iseq = (self.iseq + 1) % 2**16
335                 if data != '':
336                         self.msgDelegate('dataCameIn', (self.host, self.port, data))
337                 
338         def sendSomeData(self, data):
339                 # chop it up and queue it up
340                 while data:
341                         p = pack("!H", self.oseq) + data[:253]
342                         self.omsgq.insert(0, p)
343                         data = data[253:]
344                         self.oseq = (self.oseq + 1) % 2**16
345
346                 if self.omsgq:
347                         self.sendNext()