From 884d2b86801508decc4a4577a0c7f9a54b20e61f Mon Sep 17 00:00:00 2001 From: burris Date: Sun, 22 Dec 2002 08:40:59 +0000 Subject: [PATCH] airhook reliable datagram protocol --- airhook.py | 266 +++++++++++++++++++++++++++++ test_airhook.py | 433 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 699 insertions(+) create mode 100644 airhook.py create mode 100644 test_airhook.py diff --git a/airhook.py b/airhook.py new file mode 100644 index 0000000..ec847bd --- /dev/null +++ b/airhook.py @@ -0,0 +1,266 @@ +## Airhook Protocol http://airhook.org/protocol.html +## Copyright 2002, Andrew Loewenstern, All Rights Reserved + +from random import uniform as rand +from struct import pack, unpack +from time import time +from StringIO import StringIO +import unittest + +from twisted.internet import protocol +from twisted.internet import reactor + +# flags +FLAG_AIRHOOK = 128 +FLAG_OBSERVED = 16 +FLAG_SESSION = 8 +FLAG_MISSED = 4 +FLAG_NEXT = 2 +FLAG_INTERVAL = 1 + +MAX_PACKET_SIZE = 1480 + +pending = 0 +sent = 1 +confirmed = 2 + +class Airhook(protocol.DatagramProtocol): + + def startProtocol(self): + self.connections = {} + + def datagramReceived(self, datagram, addr): + flag = datagram[0] + if not flag & FLAG_AIRHOOK: # first bit always must be 0 + conn = self.connectionForAddr(addr) + conn.datagramReceieved(datagram) + + def connectionForAddr(self, addr): + if not self.connections.has_key(addr): + conn = AirhookConnection(self.transport, addr) + self.connections[addr] = conn + return self.connections[addr] + +class AirhookPacket: + def __init__(self, msg): + self.datagram = msg + self.oseq = ord(msg[1]) + self.seq = unpack("!H", msg[2:4])[0] + self.flags = ord(msg[0]) + self.session = None + self.observed = None + self.next = None + self.missed = [] + self.msgs = [] + skip = 4 + if self.flags & FLAG_OBSERVED: + self.observed = unpack("!L", msg[skip:skip+4])[0] + skip += 4 + if self.flags & FLAG_SESSION: + self.session = unpack("!L", msg[skip:skip+4])[0] + skip += 4 + if self.flags & FLAG_NEXT: + self.next = ord(msg[skip]) + skip += 1 + if self.flags & FLAG_MISSED: + num = ord(msg[skip]) + 1 + skip += 1 + for i in range(num): + self.missed.append( ord(msg[skip+i])) + skip += num + if self.flags & FLAG_NEXT: + while len(msg) - skip > 0: + n = ord(msg[skip]) + 1 + skip+=1 + self.msgs.append( msg[skip:skip+n]) + skip += n + + + +class AirhookConnection: + def __init__(self, transport, addr): + self.addr = addr + type, self.host, self.port = addr + self.transport = transport + + self.outSeq = 0 # highest sequence we have sent, can't be 255 more than obSeq + self.obSeq = -1 # highest sequence confirmed by remote + self.inSeq = 0 # last received sequence + self.observed = None # their session id + self.sessionID = long(rand(0, 2**32)) # our session id + + self.lastTransmit = -1 # time we last sent a packet with messages + self.lastReceieved = 0 # time we last received a packet with messages + self.lastTransmitSeq = -1 # last sequence we sent a packet + self.state = pending + + self.outMsgs = [None] * 256 # outgoing messages (seq sent, message), index = message number + self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256 + self.next = -1 # next outgoing message number + self.omsgq = [] # list of messages to go out + self.imsgq = [] # list of messages coming in + self.sendSession = None # send session/observed fields until obSeq > sendSession + + self.resetMessages() + + def resetMessages(self): + self.weMissed = [] + self.inMsg = 0 # next incoming message number + + def datagramReceived(self, datagram): + if not datagram: + return + response = 0 # if we know we have a response now (like resending missed packets) + p = AirhookPacket(datagram) + + # check for state change + if self.state == pending: + if p.observed != None and p.session != None: + if p.observed == self.sessionID: + self.observed = p.session + self.state = confirmed + else: + # bogus packet! + return + elif p.session != None: + self.observed = p.session + self.state = sent + response = 1 + elif self.state == sent: + if p.observed != None and p.session != None: + if p.observed == self.sessionID: + self.observed = p.session + self.sendSession = self.outSeq + self.state = confirmed + if p.session != None: + if not self.observed: + self.observed = p.session + elif self.observed != p.session: + self.state = pending + self.resetMessages() + self.inSeq = p.seq + response = 1 + elif self.state == confirmed: + if p.session != None or p.observed != None : + if p.session != self.observed or p.observed != self.sessionID: + self.state = pending + if seq == 0: + self.resetMessages() + self.inSeq = p.seq + + if self.state != pending: + msgs = [] + missed = [] + + # check to make sure sequence number isn't out of wack + assert (p.seq - self.inSeq) % 2**16 < 256 + + # see if they need us to resend anything + for i in p.missed: + response = 1 + if self.outMsgs[i] != None: + self.omsgq.insert(0, self.outMsgs[i]) + self.outMsgs[i] = None + + # see if we need them to send anything + if p.next != None: + if p.next == 0 and self.inMsg == -1: + missed = 255 + missed_count = (p.next - self.inMsg) % 256 + if missed_count: + self.lastReceived = time() + for i in range(missed_count): + missed += [(self.outSeq, (self.inMsg + i) % 256)] + response = 1 + self.weMissed += missed + self.inMsg = (p.next + len(p.msgs)) % 256 + + self.imsgq += p.msgs + self.inSeq = p.seq + + if self.state == confirmed: + # unpack the observed sequence + tseq = unpack('!H', pack('!H', self.outSeq)[0] + chr(p.oseq))[0] + if ((self.outSeq - tseq)) % 2**16 > 255: + tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0] + self.obSeq = tseq + + if response: + reactor.callLater(0, self.sendNext) + self.lastReceived = time() + + + def sendNext(self): + flags = 0 + header = chr(self.inSeq & 255) + pack("!H", self.outSeq) + ids = "" + missed = "" + msgs = "" + + if self.state == pending: + flags = flags | FLAG_SESSION + ids += pack("!L", self.sessionID) + self.state = sent + elif self.state == sent: + if self.observed != None: + flags = flags | FLAG_SESSION | FLAG_OBSERVED + ids += pack("!LL", self.observed, self.sessionID) + else: + flags = flags | FLAG_SESSION + ids += pack("!L", self.sessionID) + + else: + if self.state == sent or self.sendSession: + if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256: + self.sendSession = None + else: + flags = flags | FLAG_SESSION | FLAG_OBSERVED + ids += pack("!LL", self.observed, self.sessionID) + + if self.obSeq >= 0: + self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed) + + if self.weMissed: + flags = flags | FLAG_MISSED + missed += chr(len(self.weMissed) - 1) + for i in self.weMissed: + missed += chr(i[1]) + + if self.state == confirmed and self.omsgq: + first = (self.next + 1) % 256 + while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) : + if self.obSeq == -1: + highest = 0 + else: + highest = self.outMsgNums[self.obSeq % 256] + if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256: + break + else: + self.next = (self.next + 1) % 256 + msg = self.omsgq.pop() + msgs += chr(len(msg) - 1) + msg + self.outMsgs[self.next] = msg + if msgs: + flags = flags | FLAG_NEXT + ids += chr(first) + self.lastTransmitSeq = self.outSeq + self.outMsgNums[self.outSeq % 256] = first + else: + if self.next == -1: + self.outMsgNums[self.outSeq % 256] = 0 + else: + self.outMsgNums[self.outSeq % 256] = self.next + + if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and not flags & FLAG_NEXT: + flags = flags | FLAG_NEXT + ids += chr((self.next + 1) % 256) + packet = chr(flags) + header + ids + missed + msgs + self.outSeq = (self.outSeq + 1) % 2**16 + self.lastTransmit = time() + self.transport.write(packet) + + if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256: + reactor.callLater(0, self.sendNext) + else: + reactor.callLater(1, self.sendNext) + diff --git a/test_airhook.py b/test_airhook.py new file mode 100644 index 0000000..b00681e --- /dev/null +++ b/test_airhook.py @@ -0,0 +1,433 @@ +import unittest +from airhook import * +from random import uniform as rand + +if __name__ =="__main__": + tests = unittest.defaultTestLoader.loadTestsFromNames(['test_airhook']) + result = unittest.TextTestRunner().run(tests) + + + +def test_createStartPacket(): + flags = 0 | FLAG_AIRHOOK | FLAG_SESSION + packet = chr(flags) + "\xff" + "\x00\x00" + pack("!L", long(rand(0, 2**32))) + return packet + +def test_createReply(session, observed, obseq, seq): + flags = 0 | FLAG_AIRHOOK | FLAG_SESSION | FLAG_OBSERVED + packet = chr(flags) + pack("!H", seq)[1] + pack("!H", obseq + 1) + pack("!L", session) + pack("!L", observed) + return packet + + +def pscope(msg, noisy=0): + # packet scope + str = "" + p = AirhookPacket(msg) + str += "oseq: %s seq: %s " % (p.oseq, p.seq) + if noisy: + str += "packet: %s \n" % (`p.datagram`) + flags = p.flags + str += "flags: " + if flags & FLAG_SESSION: + str += "FLAG_SESSION " + if flags & FLAG_OBSERVED: + str += "FLAG_OBSERVED " + if flags & FLAG_MISSED: + str += "FLAG_MISSED " + if flags & FLAG_NEXT: + str += "FLAG_NEXT " + str += "\n" + + if p.observed != None: + str += "OBSERVED: %s\n" % p.observed + if p.session != None: + str += "SESSION: %s\n" % p.session + if p.next != None: + str += "NEXT: %s\n" % p.next + if p.missed: + if noisy: + str += "MISSED: " + `p.missed` + else: + str += "MISSED: " + `len(p.missed)` + str += "\n" + if p.msgs: + if noisy: + str += "MSGS: " + `p.msgs` + "\n" + else: + str += "MSGS: <%s> " % len(p.msgs) + str += "\n" + return str + +# testing function +def swap(a, dir="", noisy=0): + msg = "" + while not msg: + a.transport.seek(0) + msg= a.transport.read() + a.transport = StringIO() + if not msg: + a.sendNext() + if noisy: + print 6*dir + " " + pscope(msg) + return msg + +class SimpleTest(unittest.TestCase): + def setUp(self): + self.noisy = 0 + self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040)) + self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040)) + def testReallySimple(self): + # connect to eachother and send a few packets, observe sequence incrementing + self.noisy = 0 + a = self.a + b = self.b + self.assertEqual(a.state, pending) + self.assertEqual(b.state, pending) + self.assertEqual(a.outSeq, 0) + self.assertEqual(b.outSeq, 0) + self.assertEqual(a.obSeq, -1) + self.assertEqual(b.obSeq, -1) + + msg = swap(a, '>', self.noisy) + self.assertEqual(a.state, sent) + self.assertEqual(a.outSeq, 1) + self.assertEqual(a.obSeq, -1) + + b.datagramReceived(msg) + self.assertEqual(b.state, sent) + self.assertEqual(b.inSeq, 0) + self.assertEqual(b.obSeq, -1) + msg = swap(b, '<', self.noisy) + self.assertEqual(b.outSeq, 1) + + a.datagramReceived(msg) + self.assertEqual(a.state, confirmed) + self.assertEqual(a.obSeq, 0) + self.assertEqual(a.inSeq, 0) + msg = swap(a, '>', self.noisy) + self.assertEqual(a.outSeq, 2) + + b.datagramReceived(msg) + self.assertEqual(b.state, confirmed) + self.assertEqual(b.obSeq, 0) + self.assertEqual(b.inSeq, 1) + msg = swap(b, '<', self.noisy) + self.assertEqual(b.outSeq, 2) + + a.datagramReceived(msg) + self.assertEqual(a.outSeq, 2) + self.assertEqual(a.inSeq, 1) + self.assertEqual(a.obSeq, 1) + +class BasicTests(unittest.TestCase): + def setUp(self): + self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040)) + self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040)) + self.noisy = 0 + def testSimple(self): + a = self.a + b = self.b + + TESTMSG = "Howdy, Y'All!" + a.omsgq.append(TESTMSG) + a.sendNext() + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + msg = swap(b, '<', self.noisy) + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + + self.assertEqual(b.inMsg, 1) + self.assertEqual(len(b.imsgq), 1) + self.assertEqual(b.imsgq[0], TESTMSG) + + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + + def testLostFirst(self): + a = self.a + b = self.b + + TESTMSG = "Howdy, Y'All!" + TESTMSG2 = "Yee Haw" + + a.omsgq.append(TESTMSG) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + msg = swap(b, '<', self.noisy) + self.assertEqual(b.state, sent) + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + del(msg) # dropping first message + + a.omsgq.append(TESTMSG2) + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + self.assertEqual(b.state, confirmed) + self.assertEqual(len(b.imsgq), 1) + self.assertEqual(b.imsgq[0], TESTMSG2) + self.assertEqual(b.weMissed, [(1, 0)]) + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + self.assertEqual(len(b.imsgq), 2) + b.imsgq.sort() + l = [TESTMSG2, TESTMSG] + l.sort() + self.assertEqual(b.imsgq,l) + + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + + msg = swap(b, '<', self.noisy) + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + + msg = swap(b, '<', self.noisy) + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + self.assertEqual(len(b.imsgq), 2) + b.imsgq.sort() + l = [TESTMSG2, TESTMSG] + l.sort() + self.assertEqual(b.imsgq,l) + + def testLostSecond(self): + a = self.a + b = self.b + + TESTMSG = "Howdy, Y'All!" + TESTMSG2 = "Yee Haw" + + a.omsgq.append(TESTMSG) + msg = swap(a, '>', self.noisy) + b.datagramReceived(msg) + msg = swap(b, '<', self.noisy) + self.assertEqual(b.state, sent) + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + a.omsgq.append(TESTMSG2) + msg2 = swap(a, '>', self.noisy) + del(msg2) # dropping second message + + assert(a.outMsgs[1] != None) + + b.datagramReceived(msg) + self.assertEqual(b.state, confirmed) + self.assertEqual(len(b.imsgq), 1) + self.assertEqual(b.imsgq[0], TESTMSG) + self.assertEqual(b.inMsg, 1) + self.assertEqual(b.weMissed, []) + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + assert(a.outMsgs[1] != None) + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + self.assertEqual(b.state, confirmed) + self.assertEqual(len(b.imsgq), 1) + self.assertEqual(b.imsgq[0], TESTMSG) + self.assertEqual(b.weMissed, [(2, 1)]) + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + self.assertEqual(len(b.imsgq), 2) + b.imsgq.sort() + l = [TESTMSG2, TESTMSG] + l.sort() + self.assertEqual(b.imsgq,l) + + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + + + msg = swap(a, '>', self.noisy) + + self.assertEqual(len(b.imsgq), 2) + b.imsgq.sort() + l = [TESTMSG2, TESTMSG] + l.sort() + self.assertEqual(b.imsgq,l) + + def testDoubleDouble(self): + a = self.a + b = self.b + + TESTMSGA = "Howdy, Y'All!" + TESTMSGB = "Yee Haw" + TESTMSGC = "FOO BAR" + TESTMSGD = "WING WANG" + + a.omsgq.append(TESTMSGA) + a.omsgq.append(TESTMSGB) + + b.omsgq.append(TESTMSGC) + b.omsgq.append(TESTMSGD) + + + msg = swap(a, '>', self.noisy) + + + b.datagramReceived(msg) + self.assertEqual(b.state, sent) + + msg = swap(b, '<', self.noisy) + a.datagramReceived(msg) + + msg = swap(a, '>', self.noisy) + + b.datagramReceived(msg) + self.assertEqual(len(b.imsgq), 2) + l = [TESTMSGA, TESTMSGB] + l.sort();b.imsgq.sort() + self.assertEqual(b.imsgq, l) + self.assertEqual(b.inMsg, 2) + + msg = swap(b, '<', self.noisy) + a.datagramReceived(msg) + + self.assertEqual(len(a.imsgq), 2) + l = [TESTMSGC, TESTMSGD] + l.sort();a.imsgq.sort() + self.assertEqual(a.imsgq, l) + self.assertEqual(a.inMsg, 2) + + def testDoubleDoubleProb(self, prob=0.25): + a = self.a + b = self.b + TESTMSGA = "Howdy, Y'All!" + TESTMSGB = "Yee Haw" + TESTMSGC = "FOO BAR" + TESTMSGD = "WING WANG" + + a.omsgq.append(TESTMSGA) + a.omsgq.append(TESTMSGB) + + b.omsgq.append(TESTMSGC) + b.omsgq.append(TESTMSGD) + + while a.state != confirmed or b.state != confirmed or ord(msga[0]) & FLAG_NEXT or ord(msgb[0]) & FLAG_NEXT : + msga = swap(a, '>', self.noisy) + + if rand(0,1) < prob: + b.datagramReceived(msga) + + msgb = swap(b, '<', self.noisy) + + if rand(0,1) < prob: + a.datagramReceived(msgb) + + self.assertEqual(a.state, confirmed) + self.assertEqual(b.state, confirmed) + self.assertEqual(len(b.imsgq), 2) + l = [TESTMSGA, TESTMSGB] + l.sort();b.imsgq.sort() + self.assertEqual(b.imsgq, l) + + self.assertEqual(len(a.imsgq), 2) + l = [TESTMSGC, TESTMSGD] + l.sort();a.imsgq.sort() + self.assertEqual(a.imsgq, l) + + def testOneWayBlast(self, num = 2**8): + a = self.a + b = self.b + import sha + + + for i in xrange(num): + a.omsgq.append(sha.sha(`i`).digest()) + msg = swap(a, '>', self.noisy) + while a.state != confirmed or ord(msg[0]) & FLAG_NEXT: + b.datagramReceived(msg) + msg = swap(b, '<', self.noisy) + + a.datagramReceived(msg) + msg = swap(a, '>', self.noisy) + + self.assertEqual(len(b.imsgq), num) + + def testTwoWayBlast(self, num = 2**15, prob=0.5): + a = self.a + b = self.b + import sha + + + for i in xrange(num): + a.omsgq.append(sha.sha('a' + `i`).digest()) + b.omsgq.append(sha.sha('b' + `i`).digest()) + + while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED): + if rand(0,1) < prob: + msga = swap(a, '>', self.noisy) + b.datagramReceived(msga) + else: + msga = swap(a, '>', 0) + if rand(0,1) < prob: + msgb = swap(b, '<', self.noisy) + a.datagramReceived(msgb) + else: + msgb = swap(b, '<', 0) + + + + self.assertEqual(len(a.imsgq), num) + self.assertEqual(len(b.imsgq), num) + + def testLimitMessageNumbers(self): + a = self.a + b = self.b + import sha + + msg = swap(a, noisy=self.noisy) + b.datagramReceived(msg) + + msg = swap(b, noisy=self.noisy) + a.datagramReceived(msg) + + + for i in range(5000): + a.omsgq.append(sha.sha('a' + 'i').digest()) + + for i in range(31): + msg = swap(a, noisy=self.noisy) + self.assertEqual(a.obSeq, 0) + self.assertEqual(a.outMsgNums[a.obSeq], 0) + self.assertEqual(a.next, 254) + self.assertEqual(a.outMsgNums[19], 254) -- 2.39.5