--- /dev/null
+## Airhook Protocol http://airhook.org/protocol.html
+## Copyright 2002, Andrew Loewenstern, All Rights Reserved
+
+from random import uniform as rand
+from struct import pack, unpack
+from time import time
+from StringIO import StringIO
+import unittest
+
+from twisted.internet import protocol
+from twisted.internet import reactor
+
+# flags
+FLAG_AIRHOOK = 128
+FLAG_OBSERVED = 16
+FLAG_SESSION = 8
+FLAG_MISSED = 4
+FLAG_NEXT = 2
+FLAG_INTERVAL = 1
+
+MAX_PACKET_SIZE = 1480
+
+pending = 0
+sent = 1
+confirmed = 2
+
+class Airhook(protocol.DatagramProtocol):
+
+ def startProtocol(self):
+ self.connections = {}
+
+ def datagramReceived(self, datagram, addr):
+ flag = datagram[0]
+ if not flag & FLAG_AIRHOOK: # first bit always must be 0
+ conn = self.connectionForAddr(addr)
+ conn.datagramReceieved(datagram)
+
+ def connectionForAddr(self, addr):
+ if not self.connections.has_key(addr):
+ conn = AirhookConnection(self.transport, addr)
+ self.connections[addr] = conn
+ return self.connections[addr]
+
+class AirhookPacket:
+ def __init__(self, msg):
+ self.datagram = msg
+ self.oseq = ord(msg[1])
+ self.seq = unpack("!H", msg[2:4])[0]
+ self.flags = ord(msg[0])
+ self.session = None
+ self.observed = None
+ self.next = None
+ self.missed = []
+ self.msgs = []
+ skip = 4
+ if self.flags & FLAG_OBSERVED:
+ self.observed = unpack("!L", msg[skip:skip+4])[0]
+ skip += 4
+ if self.flags & FLAG_SESSION:
+ self.session = unpack("!L", msg[skip:skip+4])[0]
+ skip += 4
+ if self.flags & FLAG_NEXT:
+ self.next = ord(msg[skip])
+ skip += 1
+ if self.flags & FLAG_MISSED:
+ num = ord(msg[skip]) + 1
+ skip += 1
+ for i in range(num):
+ self.missed.append( ord(msg[skip+i]))
+ skip += num
+ if self.flags & FLAG_NEXT:
+ while len(msg) - skip > 0:
+ n = ord(msg[skip]) + 1
+ skip+=1
+ self.msgs.append( msg[skip:skip+n])
+ skip += n
+
+
+
+class AirhookConnection:
+ def __init__(self, transport, addr):
+ self.addr = addr
+ type, self.host, self.port = addr
+ self.transport = transport
+
+ self.outSeq = 0 # highest sequence we have sent, can't be 255 more than obSeq
+ self.obSeq = -1 # highest sequence confirmed by remote
+ self.inSeq = 0 # last received sequence
+ self.observed = None # their session id
+ self.sessionID = long(rand(0, 2**32)) # our session id
+
+ self.lastTransmit = -1 # time we last sent a packet with messages
+ self.lastReceieved = 0 # time we last received a packet with messages
+ self.lastTransmitSeq = -1 # last sequence we sent a packet
+ self.state = pending
+
+ self.outMsgs = [None] * 256 # outgoing messages (seq sent, message), index = message number
+ self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
+ self.next = -1 # next outgoing message number
+ self.omsgq = [] # list of messages to go out
+ self.imsgq = [] # list of messages coming in
+ self.sendSession = None # send session/observed fields until obSeq > sendSession
+
+ self.resetMessages()
+
+ def resetMessages(self):
+ self.weMissed = []
+ self.inMsg = 0 # next incoming message number
+
+ def datagramReceived(self, datagram):
+ if not datagram:
+ return
+ response = 0 # if we know we have a response now (like resending missed packets)
+ p = AirhookPacket(datagram)
+
+ # check for state change
+ if self.state == pending:
+ if p.observed != None and p.session != None:
+ if p.observed == self.sessionID:
+ self.observed = p.session
+ self.state = confirmed
+ else:
+ # bogus packet!
+ return
+ elif p.session != None:
+ self.observed = p.session
+ self.state = sent
+ response = 1
+ elif self.state == sent:
+ if p.observed != None and p.session != None:
+ if p.observed == self.sessionID:
+ self.observed = p.session
+ self.sendSession = self.outSeq
+ self.state = confirmed
+ if p.session != None:
+ if not self.observed:
+ self.observed = p.session
+ elif self.observed != p.session:
+ self.state = pending
+ self.resetMessages()
+ self.inSeq = p.seq
+ response = 1
+ elif self.state == confirmed:
+ if p.session != None or p.observed != None :
+ if p.session != self.observed or p.observed != self.sessionID:
+ self.state = pending
+ if seq == 0:
+ self.resetMessages()
+ self.inSeq = p.seq
+
+ if self.state != pending:
+ msgs = []
+ missed = []
+
+ # check to make sure sequence number isn't out of wack
+ assert (p.seq - self.inSeq) % 2**16 < 256
+
+ # see if they need us to resend anything
+ for i in p.missed:
+ response = 1
+ if self.outMsgs[i] != None:
+ self.omsgq.insert(0, self.outMsgs[i])
+ self.outMsgs[i] = None
+
+ # see if we need them to send anything
+ if p.next != None:
+ if p.next == 0 and self.inMsg == -1:
+ missed = 255
+ missed_count = (p.next - self.inMsg) % 256
+ if missed_count:
+ self.lastReceived = time()
+ for i in range(missed_count):
+ missed += [(self.outSeq, (self.inMsg + i) % 256)]
+ response = 1
+ self.weMissed += missed
+ self.inMsg = (p.next + len(p.msgs)) % 256
+
+ self.imsgq += p.msgs
+ self.inSeq = p.seq
+
+ if self.state == confirmed:
+ # unpack the observed sequence
+ tseq = unpack('!H', pack('!H', self.outSeq)[0] + chr(p.oseq))[0]
+ if ((self.outSeq - tseq)) % 2**16 > 255:
+ tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
+ self.obSeq = tseq
+
+ if response:
+ reactor.callLater(0, self.sendNext)
+ self.lastReceived = time()
+
+
+ def sendNext(self):
+ flags = 0
+ header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
+ ids = ""
+ missed = ""
+ msgs = ""
+
+ if self.state == pending:
+ flags = flags | FLAG_SESSION
+ ids += pack("!L", self.sessionID)
+ self.state = sent
+ elif self.state == sent:
+ if self.observed != None:
+ flags = flags | FLAG_SESSION | FLAG_OBSERVED
+ ids += pack("!LL", self.observed, self.sessionID)
+ else:
+ flags = flags | FLAG_SESSION
+ ids += pack("!L", self.sessionID)
+
+ else:
+ if self.state == sent or self.sendSession:
+ if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
+ self.sendSession = None
+ else:
+ flags = flags | FLAG_SESSION | FLAG_OBSERVED
+ ids += pack("!LL", self.observed, self.sessionID)
+
+ if self.obSeq >= 0:
+ self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
+
+ if self.weMissed:
+ flags = flags | FLAG_MISSED
+ missed += chr(len(self.weMissed) - 1)
+ for i in self.weMissed:
+ missed += chr(i[1])
+
+ if self.state == confirmed and self.omsgq:
+ first = (self.next + 1) % 256
+ while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) :
+ if self.obSeq == -1:
+ highest = 0
+ else:
+ highest = self.outMsgNums[self.obSeq % 256]
+ if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256:
+ break
+ else:
+ self.next = (self.next + 1) % 256
+ msg = self.omsgq.pop()
+ msgs += chr(len(msg) - 1) + msg
+ self.outMsgs[self.next] = msg
+ if msgs:
+ flags = flags | FLAG_NEXT
+ ids += chr(first)
+ self.lastTransmitSeq = self.outSeq
+ self.outMsgNums[self.outSeq % 256] = first
+ else:
+ if self.next == -1:
+ self.outMsgNums[self.outSeq % 256] = 0
+ else:
+ self.outMsgNums[self.outSeq % 256] = self.next
+
+ if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and not flags & FLAG_NEXT:
+ flags = flags | FLAG_NEXT
+ ids += chr((self.next + 1) % 256)
+ packet = chr(flags) + header + ids + missed + msgs
+ self.outSeq = (self.outSeq + 1) % 2**16
+ self.lastTransmit = time()
+ self.transport.write(packet)
+
+ if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
+ reactor.callLater(0, self.sendNext)
+ else:
+ reactor.callLater(1, self.sendNext)
+
--- /dev/null
+import unittest
+from airhook import *
+from random import uniform as rand
+
+if __name__ =="__main__":
+ tests = unittest.defaultTestLoader.loadTestsFromNames(['test_airhook'])
+ result = unittest.TextTestRunner().run(tests)
+
+
+
+def test_createStartPacket():
+ flags = 0 | FLAG_AIRHOOK | FLAG_SESSION
+ packet = chr(flags) + "\xff" + "\x00\x00" + pack("!L", long(rand(0, 2**32)))
+ return packet
+
+def test_createReply(session, observed, obseq, seq):
+ flags = 0 | FLAG_AIRHOOK | FLAG_SESSION | FLAG_OBSERVED
+ packet = chr(flags) + pack("!H", seq)[1] + pack("!H", obseq + 1) + pack("!L", session) + pack("!L", observed)
+ return packet
+
+
+def pscope(msg, noisy=0):
+ # packet scope
+ str = ""
+ p = AirhookPacket(msg)
+ str += "oseq: %s seq: %s " % (p.oseq, p.seq)
+ if noisy:
+ str += "packet: %s \n" % (`p.datagram`)
+ flags = p.flags
+ str += "flags: "
+ if flags & FLAG_SESSION:
+ str += "FLAG_SESSION "
+ if flags & FLAG_OBSERVED:
+ str += "FLAG_OBSERVED "
+ if flags & FLAG_MISSED:
+ str += "FLAG_MISSED "
+ if flags & FLAG_NEXT:
+ str += "FLAG_NEXT "
+ str += "\n"
+
+ if p.observed != None:
+ str += "OBSERVED: %s\n" % p.observed
+ if p.session != None:
+ str += "SESSION: %s\n" % p.session
+ if p.next != None:
+ str += "NEXT: %s\n" % p.next
+ if p.missed:
+ if noisy:
+ str += "MISSED: " + `p.missed`
+ else:
+ str += "MISSED: " + `len(p.missed)`
+ str += "\n"
+ if p.msgs:
+ if noisy:
+ str += "MSGS: " + `p.msgs` + "\n"
+ else:
+ str += "MSGS: <%s> " % len(p.msgs)
+ str += "\n"
+ return str
+
+# testing function
+def swap(a, dir="", noisy=0):
+ msg = ""
+ while not msg:
+ a.transport.seek(0)
+ msg= a.transport.read()
+ a.transport = StringIO()
+ if not msg:
+ a.sendNext()
+ if noisy:
+ print 6*dir + " " + pscope(msg)
+ return msg
+
+class SimpleTest(unittest.TestCase):
+ def setUp(self):
+ self.noisy = 0
+ self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ def testReallySimple(self):
+ # connect to eachother and send a few packets, observe sequence incrementing
+ self.noisy = 0
+ a = self.a
+ b = self.b
+ self.assertEqual(a.state, pending)
+ self.assertEqual(b.state, pending)
+ self.assertEqual(a.outSeq, 0)
+ self.assertEqual(b.outSeq, 0)
+ self.assertEqual(a.obSeq, -1)
+ self.assertEqual(b.obSeq, -1)
+
+ msg = swap(a, '>', self.noisy)
+ self.assertEqual(a.state, sent)
+ self.assertEqual(a.outSeq, 1)
+ self.assertEqual(a.obSeq, -1)
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, sent)
+ self.assertEqual(b.inSeq, 0)
+ self.assertEqual(b.obSeq, -1)
+ msg = swap(b, '<', self.noisy)
+ self.assertEqual(b.outSeq, 1)
+
+ a.datagramReceived(msg)
+ self.assertEqual(a.state, confirmed)
+ self.assertEqual(a.obSeq, 0)
+ self.assertEqual(a.inSeq, 0)
+ msg = swap(a, '>', self.noisy)
+ self.assertEqual(a.outSeq, 2)
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, confirmed)
+ self.assertEqual(b.obSeq, 0)
+ self.assertEqual(b.inSeq, 1)
+ msg = swap(b, '<', self.noisy)
+ self.assertEqual(b.outSeq, 2)
+
+ a.datagramReceived(msg)
+ self.assertEqual(a.outSeq, 2)
+ self.assertEqual(a.inSeq, 1)
+ self.assertEqual(a.obSeq, 1)
+
+class BasicTests(unittest.TestCase):
+ def setUp(self):
+ self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ self.noisy = 0
+ def testSimple(self):
+ a = self.a
+ b = self.b
+
+ TESTMSG = "Howdy, Y'All!"
+ a.omsgq.append(TESTMSG)
+ a.sendNext()
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ msg = swap(b, '<', self.noisy)
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+
+ self.assertEqual(b.inMsg, 1)
+ self.assertEqual(len(b.imsgq), 1)
+ self.assertEqual(b.imsgq[0], TESTMSG)
+
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+
+ def testLostFirst(self):
+ a = self.a
+ b = self.b
+
+ TESTMSG = "Howdy, Y'All!"
+ TESTMSG2 = "Yee Haw"
+
+ a.omsgq.append(TESTMSG)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+ msg = swap(b, '<', self.noisy)
+ self.assertEqual(b.state, sent)
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ del(msg) # dropping first message
+
+ a.omsgq.append(TESTMSG2)
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, confirmed)
+ self.assertEqual(len(b.imsgq), 1)
+ self.assertEqual(b.imsgq[0], TESTMSG2)
+ self.assertEqual(b.weMissed, [(1, 0)])
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ self.assertEqual(len(b.imsgq), 2)
+ b.imsgq.sort()
+ l = [TESTMSG2, TESTMSG]
+ l.sort()
+ self.assertEqual(b.imsgq,l)
+
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+
+ msg = swap(b, '<', self.noisy)
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+
+ msg = swap(b, '<', self.noisy)
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ self.assertEqual(len(b.imsgq), 2)
+ b.imsgq.sort()
+ l = [TESTMSG2, TESTMSG]
+ l.sort()
+ self.assertEqual(b.imsgq,l)
+
+ def testLostSecond(self):
+ a = self.a
+ b = self.b
+
+ TESTMSG = "Howdy, Y'All!"
+ TESTMSG2 = "Yee Haw"
+
+ a.omsgq.append(TESTMSG)
+ msg = swap(a, '>', self.noisy)
+ b.datagramReceived(msg)
+ msg = swap(b, '<', self.noisy)
+ self.assertEqual(b.state, sent)
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ a.omsgq.append(TESTMSG2)
+ msg2 = swap(a, '>', self.noisy)
+ del(msg2) # dropping second message
+
+ assert(a.outMsgs[1] != None)
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, confirmed)
+ self.assertEqual(len(b.imsgq), 1)
+ self.assertEqual(b.imsgq[0], TESTMSG)
+ self.assertEqual(b.inMsg, 1)
+ self.assertEqual(b.weMissed, [])
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ assert(a.outMsgs[1] != None)
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, confirmed)
+ self.assertEqual(len(b.imsgq), 1)
+ self.assertEqual(b.imsgq[0], TESTMSG)
+ self.assertEqual(b.weMissed, [(2, 1)])
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ self.assertEqual(len(b.imsgq), 2)
+ b.imsgq.sort()
+ l = [TESTMSG2, TESTMSG]
+ l.sort()
+ self.assertEqual(b.imsgq,l)
+
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+
+
+ msg = swap(a, '>', self.noisy)
+
+ self.assertEqual(len(b.imsgq), 2)
+ b.imsgq.sort()
+ l = [TESTMSG2, TESTMSG]
+ l.sort()
+ self.assertEqual(b.imsgq,l)
+
+ def testDoubleDouble(self):
+ a = self.a
+ b = self.b
+
+ TESTMSGA = "Howdy, Y'All!"
+ TESTMSGB = "Yee Haw"
+ TESTMSGC = "FOO BAR"
+ TESTMSGD = "WING WANG"
+
+ a.omsgq.append(TESTMSGA)
+ a.omsgq.append(TESTMSGB)
+
+ b.omsgq.append(TESTMSGC)
+ b.omsgq.append(TESTMSGD)
+
+
+ msg = swap(a, '>', self.noisy)
+
+
+ b.datagramReceived(msg)
+ self.assertEqual(b.state, sent)
+
+ msg = swap(b, '<', self.noisy)
+ a.datagramReceived(msg)
+
+ msg = swap(a, '>', self.noisy)
+
+ b.datagramReceived(msg)
+ self.assertEqual(len(b.imsgq), 2)
+ l = [TESTMSGA, TESTMSGB]
+ l.sort();b.imsgq.sort()
+ self.assertEqual(b.imsgq, l)
+ self.assertEqual(b.inMsg, 2)
+
+ msg = swap(b, '<', self.noisy)
+ a.datagramReceived(msg)
+
+ self.assertEqual(len(a.imsgq), 2)
+ l = [TESTMSGC, TESTMSGD]
+ l.sort();a.imsgq.sort()
+ self.assertEqual(a.imsgq, l)
+ self.assertEqual(a.inMsg, 2)
+
+ def testDoubleDoubleProb(self, prob=0.25):
+ a = self.a
+ b = self.b
+ TESTMSGA = "Howdy, Y'All!"
+ TESTMSGB = "Yee Haw"
+ TESTMSGC = "FOO BAR"
+ TESTMSGD = "WING WANG"
+
+ a.omsgq.append(TESTMSGA)
+ a.omsgq.append(TESTMSGB)
+
+ b.omsgq.append(TESTMSGC)
+ b.omsgq.append(TESTMSGD)
+
+ while a.state != confirmed or b.state != confirmed or ord(msga[0]) & FLAG_NEXT or ord(msgb[0]) & FLAG_NEXT :
+ msga = swap(a, '>', self.noisy)
+
+ if rand(0,1) < prob:
+ b.datagramReceived(msga)
+
+ msgb = swap(b, '<', self.noisy)
+
+ if rand(0,1) < prob:
+ a.datagramReceived(msgb)
+
+ self.assertEqual(a.state, confirmed)
+ self.assertEqual(b.state, confirmed)
+ self.assertEqual(len(b.imsgq), 2)
+ l = [TESTMSGA, TESTMSGB]
+ l.sort();b.imsgq.sort()
+ self.assertEqual(b.imsgq, l)
+
+ self.assertEqual(len(a.imsgq), 2)
+ l = [TESTMSGC, TESTMSGD]
+ l.sort();a.imsgq.sort()
+ self.assertEqual(a.imsgq, l)
+
+ def testOneWayBlast(self, num = 2**8):
+ a = self.a
+ b = self.b
+ import sha
+
+
+ for i in xrange(num):
+ a.omsgq.append(sha.sha(`i`).digest())
+ msg = swap(a, '>', self.noisy)
+ while a.state != confirmed or ord(msg[0]) & FLAG_NEXT:
+ b.datagramReceived(msg)
+ msg = swap(b, '<', self.noisy)
+
+ a.datagramReceived(msg)
+ msg = swap(a, '>', self.noisy)
+
+ self.assertEqual(len(b.imsgq), num)
+
+ def testTwoWayBlast(self, num = 2**15, prob=0.5):
+ a = self.a
+ b = self.b
+ import sha
+
+
+ for i in xrange(num):
+ a.omsgq.append(sha.sha('a' + `i`).digest())
+ b.omsgq.append(sha.sha('b' + `i`).digest())
+
+ while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+ if rand(0,1) < prob:
+ msga = swap(a, '>', self.noisy)
+ b.datagramReceived(msga)
+ else:
+ msga = swap(a, '>', 0)
+ if rand(0,1) < prob:
+ msgb = swap(b, '<', self.noisy)
+ a.datagramReceived(msgb)
+ else:
+ msgb = swap(b, '<', 0)
+
+
+
+ self.assertEqual(len(a.imsgq), num)
+ self.assertEqual(len(b.imsgq), num)
+
+ def testLimitMessageNumbers(self):
+ a = self.a
+ b = self.b
+ import sha
+
+ msg = swap(a, noisy=self.noisy)
+ b.datagramReceived(msg)
+
+ msg = swap(b, noisy=self.noisy)
+ a.datagramReceived(msg)
+
+
+ for i in range(5000):
+ a.omsgq.append(sha.sha('a' + 'i').digest())
+
+ for i in range(31):
+ msg = swap(a, noisy=self.noisy)
+ self.assertEqual(a.obSeq, 0)
+ self.assertEqual(a.outMsgNums[a.obSeq], 0)
+ self.assertEqual(a.next, 254)
+ self.assertEqual(a.outMsgNums[19], 254)