From e06b4f0f3d5c3df708399d179cce6a2e98c0366b Mon Sep 17 00:00:00 2001 From: burris Date: Mon, 23 Dec 2002 02:59:11 +0000 Subject: [PATCH] bug fixes, more tests, looking solid now --- airhook.py | 109 ++++++++++++++++++++++++++-------------------- test_airhook.py | 113 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 154 insertions(+), 68 deletions(-) diff --git a/airhook.py b/airhook.py index 4053dfe..8d9443c 100644 --- a/airhook.py +++ b/airhook.py @@ -19,7 +19,7 @@ FLAG_MISSED = 4 FLAG_NEXT = 2 FLAG_INTERVAL = 1 -MAX_PACKET_SIZE = 1480 +MAX_PACKET_SIZE = 1496 pending = 0 sent = 1 @@ -96,7 +96,7 @@ class AirhookConnection(Delegate): self.transport = transport self.outSeq = 0 # highest sequence we have sent, can't be 255 more than obSeq - self.obSeq = -1 # highest sequence confirmed by remote + self.obSeq = 0 # highest sequence confirmed by remote self.inSeq = 0 # last received sequence self.observed = None # their session id self.sessionID = long(rand(0, 2**32)) # our session id @@ -107,8 +107,6 @@ class AirhookConnection(Delegate): self.state = pending self.outMsgs = [None] * 256 # outgoing messages (seq sent, message), index = message number - self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256 - self.next = -1 # next outgoing message number self.omsgq = [] # list of messages to go out self.imsgq = [] # list of messages coming in self.sendSession = None # send session/observed fields until obSeq > sendSession @@ -118,13 +116,19 @@ class AirhookConnection(Delegate): def resetMessages(self): self.weMissed = [] self.inMsg = 0 # next incoming message number - + self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256 + self.next = 0 # next outgoing message number + def datagramReceived(self, datagram): if not datagram: return response = 0 # if we know we have a response now (like resending missed packets) p = AirhookPacket(datagram) + # check to make sure sequence number isn't out of order + if (p.seq - self.inSeq) % 2**16 >= 256: + return + # check for state change if self.state == pending: if p.observed != None and p.session != None: @@ -163,21 +167,16 @@ class AirhookConnection(Delegate): if self.state != pending: msgs = [] missed = [] - - # check to make sure sequence number isn't out of wack - assert (p.seq - self.inSeq) % 2**16 < 256 # see if they need us to resend anything for i in p.missed: response = 1 if self.outMsgs[i] != None: - self.omsgq.insert(0, self.outMsgs[i]) + self.omsgq.append(self.outMsgs[i]) self.outMsgs[i] = None - # see if we need them to send anything + # see if we missed any messages if p.next != None: - if p.next == 0 and self.inMsg == -1: - missed = 255 missed_count = (p.next - self.inMsg) % 256 if missed_count: self.lastReceived = time() @@ -185,10 +184,11 @@ class AirhookConnection(Delegate): missed += [(self.outSeq, (self.inMsg + i) % 256)] response = 1 self.weMissed += missed + # record highest message number seen self.inMsg = (p.next + len(p.msgs)) % 256 - + + # append messages, update sequence self.imsgq += p.msgs - self.inSeq = p.seq if self.state == confirmed: # unpack the observed sequence @@ -197,6 +197,8 @@ class AirhookConnection(Delegate): tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0] self.obSeq = tseq + self.inSeq = p.seq + if response: reactor.callLater(0, self.sendNext) self.lastReceived = time() @@ -209,6 +211,7 @@ class AirhookConnection(Delegate): missed = "" msgs = "" + # session / observed logic if self.state == pending: flags = flags | FLAG_SESSION ids += pack("!L", self.sessionID) @@ -229,6 +232,7 @@ class AirhookConnection(Delegate): flags = flags | FLAG_SESSION | FLAG_OBSERVED ids += pack("!LL", self.observed, self.sessionID) + # missed header if self.obSeq >= 0: self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed) @@ -238,39 +242,37 @@ class AirhookConnection(Delegate): for i in self.weMissed: missed += chr(i[1]) + # append any outgoing messages if self.state == confirmed and self.omsgq: - first = (self.next + 1) % 256 - while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) : - if self.obSeq == -1: - highest = 0 - else: - highest = self.outMsgNums[self.obSeq % 256] - if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256: - break - else: - self.next = (self.next + 1) % 256 - msg = self.omsgq.pop() - msgs += chr(len(msg) - 1) + msg - self.outMsgs[self.next] = msg + first = self.next + outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 + while len(self.omsgq) and outstanding < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE: + msg = self.omsgq.pop() + msgs += chr(len(msg) - 1) + msg + self.outMsgs[self.next] = msg + self.next = (self.next + 1) % 256 + outstanding+=1 + # update outgoing message stat if msgs: flags = flags | FLAG_NEXT ids += chr(first) self.lastTransmitSeq = self.outSeq - self.outMsgNums[self.outSeq % 256] = first - else: - if self.next == -1: - self.outMsgNums[self.outSeq % 256] = 0 - else: - self.outMsgNums[self.outSeq % 256] = self.next - - if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and not flags & FLAG_NEXT: + #self.outMsgNums[self.outSeq % 256] = first + #else: + self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256 + + # do we need a NEXT flag despite not having sent any messages? + if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0: flags = flags | FLAG_NEXT - ids += chr((self.next + 1) % 256) + ids += chr(self.next) + + # update stats and send packet packet = chr(flags) + header + ids + missed + msgs self.outSeq = (self.outSeq + 1) % 2**16 self.lastTransmit = time() self.transport.write(packet) + # call later if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256: reactor.callLater(0, self.sendNext) else: @@ -278,29 +280,42 @@ class AirhookConnection(Delegate): def dataCameIn(self): + """ + called when we get a packet bearing messages + delegate must do something with the messages or they will get dropped + """ self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq)) if hasattr(self, 'delegate') and self.delegate != None: self.imsgq = [] class ustr(str): + """ + this subclass of string encapsulates each ordered message, caches it's sequence number, + and has comparison functions to sort by sequence number + """ def getseq(self): if not hasattr(self, 'seq'): self.seq = unpack("!H", self[0:2])[0] return self.seq def __lt__(self, other): - return self.getseq() < other.getseq() + return (self.getseq() - other.getseq()) % 2**16 > 255 def __le__(self, other): - return self.getseq() <= other.getseq() + return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other) def __eq__(self, other): - return self.getseq() != other.getseq() + return self.getseq() == other.getseq() def __ne__(self, other): - return self.getseq() <= other.getseq() + return self.getseq() != other.getseq() def __gt__(self, other): - return self.getseq() > other.getseq() + return (self.getseq() - other.getseq()) % 2**16 < 256 and not self.__eq__(other) def __ge__(self, other): - return self.getseq() >= other.getseq() - + return (self.getseq() - other.getseq()) % 2**16 < 256 + class OrderedConnection(AirhookConnection): + """ + this implements a simple protocol for ordered messages over airhook + the first two octets of each message are interpreted as a 16-bit sequence number + 253 bytes are used for payload + """ def __init__(self, transport, addr, delegate): AirhookConnection.__init__(self, transport, addr, delegate) self.oseq = 0 @@ -315,16 +330,16 @@ class OrderedConnection(AirhookConnection): data = "" while self.q and self.iseq == self.q[0].getseq(): data += self.q[0][2:] - self.iseq = (self.iseq + 1) % 2**16 self.q = self.q[1:] - if data: + self.iseq = (self.iseq + 1) % 2**16 + if data != '': self.msgDelegate('dataCameIn', (self.host, self.port, data)) def sendSomeData(self, data): # chop it up and queue it up while data: - p = "%s%s" % (pack("!H", self.oseq), data[:253]) - self.omsgq.append(p) + p = pack("!H", self.oseq) + data[:253] + self.omsgq.insert(0, p) data = data[253:] self.oseq = (self.oseq + 1) % 2**16 diff --git a/test_airhook.py b/test_airhook.py index 669bee5..2fb68a2 100644 --- a/test_airhook.py +++ b/test_airhook.py @@ -71,6 +71,77 @@ def swap(a, dir="", noisy=0): print 6*dir + " " + pscope(msg) return msg + +class UstrTests(unittest.TestCase): + def u(self, seq): + return ustr("%s%s" % (pack("!H", seq), 'foobar')) + + def testLT(self): + self.failUnless(self.u(0) < self.u(1)) + self.failUnless(self.u(1) < self.u(2)) + self.failUnless(self.u(2**16 - 1) < self.u(0)) + self.failUnless(self.u(2**16 - 1) < self.u(1)) + + self.failIf(self.u(1) < self.u(0)) + self.failIf(self.u(2) < self.u(1)) + self.failIf(self.u(0) < self.u(2**16 - 1)) + self.failIf(self.u(1) < self.u(2**16 - 1)) + + def testLTE(self): + self.failUnless(self.u(0) <= self.u(1)) + self.failUnless(self.u(1) <= self.u(2)) + self.failUnless(self.u(2) <= self.u(2)) + self.failUnless(self.u(2**16 - 1) <= self.u(0)) + self.failUnless(self.u(2**16 - 1) <= self.u(1)) + self.failUnless(self.u(2**16 - 1) <= self.u(2**16)) + + self.failIf(self.u(1) <= self.u(0)) + self.failIf(self.u(2) <= self.u(1)) + self.failIf(self.u(0) <= self.u(2**16 - 1)) + self.failIf(self.u(1) <= self.u(2**16 - 1)) + + def testGT(self): + self.failUnless(self.u(1) > self.u(0)) + self.failUnless(self.u(2) > self.u(1)) + self.failUnless(self.u(0) > self.u(2**16 - 1)) + self.failUnless(self.u(1) > self.u(2**16 - 1)) + + self.failIf(self.u(0) > self.u(1)) + self.failIf(self.u(1) > self.u(2)) + self.failIf(self.u(2**16 - 1) > self.u(0)) + self.failIf(self.u(2**16 - 1) > self.u(1)) + + def testGTE(self): + self.failUnless(self.u(1) >= self.u(0)) + self.failUnless(self.u(2) >= self.u(1)) + self.failUnless(self.u(2) >= self.u(2)) + self.failUnless(self.u(0) >= self.u(0)) + self.failUnless(self.u(1) >= self.u(1)) + self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1)) + + self.failIf(self.u(0) >= self.u(1)) + self.failIf(self.u(1) >= self.u(2)) + self.failIf(self.u(2**16 - 1) >= self.u(0)) + self.failIf(self.u(2**16 - 1) >= self.u(1)) + + def testEQ(self): + self.failUnless(self.u(0) == self.u(0)) + self.failUnless(self.u(1) == self.u(1)) + self.failUnless(self.u(2**16 - 1) == self.u(2**16-1)) + + self.failIf(self.u(0) == self.u(1)) + self.failIf(self.u(1) == self.u(0)) + self.failIf(self.u(2**16 - 1) == self.u(0)) + + def testNEQ(self): + self.failUnless(self.u(1) != self.u(0)) + self.failUnless(self.u(2) != self.u(1)) + self.failIf(self.u(2) != self.u(2)) + self.failIf(self.u(0) != self.u(0)) + self.failIf(self.u(1) != self.u(1)) + self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1)) + + class SimpleTest(unittest.TestCase): def setUp(self): self.noisy = 0 @@ -78,25 +149,24 @@ class SimpleTest(unittest.TestCase): self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None) def testReallySimple(self): # connect to eachother and send a few packets, observe sequence incrementing - self.noisy = 0 a = self.a b = self.b self.assertEqual(a.state, pending) self.assertEqual(b.state, pending) self.assertEqual(a.outSeq, 0) self.assertEqual(b.outSeq, 0) - self.assertEqual(a.obSeq, -1) - self.assertEqual(b.obSeq, -1) + self.assertEqual(a.obSeq, 0) + self.assertEqual(b.obSeq, 0) msg = swap(a, '>', self.noisy) self.assertEqual(a.state, sent) self.assertEqual(a.outSeq, 1) - self.assertEqual(a.obSeq, -1) + self.assertEqual(a.obSeq, 0) b.datagramReceived(msg) self.assertEqual(b.state, sent) self.assertEqual(b.inSeq, 0) - self.assertEqual(b.obSeq, -1) + self.assertEqual(b.obSeq, 0) msg = swap(b, '<', self.noisy) self.assertEqual(b.outSeq, 1) @@ -121,9 +191,9 @@ class SimpleTest(unittest.TestCase): class BasicTests(unittest.TestCase): def setUp(self): + self.noisy = 0 self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None) self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None) - self.noisy = 0 def testSimple(self): a = self.a b = self.b @@ -365,7 +435,7 @@ class BasicTests(unittest.TestCase): l.sort();a.imsgq.sort() self.assertEqual(a.imsgq, l) - def testOneWayBlast(self, num = 2**8): + def testOneWayBlast(self, num = 2**12): a = self.a b = self.b import sha @@ -373,17 +443,17 @@ class BasicTests(unittest.TestCase): for i in xrange(num): a.omsgq.append(sha.sha(`i`).digest()) - msg = swap(a, '>', self.noisy) - while a.state != confirmed or ord(msg[0]) & FLAG_NEXT: - b.datagramReceived(msg) - msg = swap(b, '<', self.noisy) + msga = swap(a, '>', self.noisy) + while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED): + b.datagramReceived(msga) + msgb = swap(b, '<', self.noisy) - a.datagramReceived(msg) - msg = swap(a, '>', self.noisy) + a.datagramReceived(msgb) + msga = swap(a, '>', self.noisy) self.assertEqual(len(b.imsgq), num) - def testTwoWayBlast(self, num = 2**9, prob=0.5): + def testTwoWayBlast(self, num = 2**12, prob=0.5): a = self.a b = self.b import sha @@ -425,12 +495,11 @@ class BasicTests(unittest.TestCase): for i in range(5000): a.omsgq.append(sha.sha('a' + 'i').digest()) - for i in range(31): + for i in range(5000 / 255): msg = swap(a, noisy=self.noisy) self.assertEqual(a.obSeq, 0) - self.assertEqual(a.outMsgNums[a.obSeq], 0) - self.assertEqual(a.next, 254) - self.assertEqual(a.outMsgNums[19], 254) + self.assertEqual(a.next, 255) + self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254) class OrderedTests(unittest.TestCase): def setUp(self): @@ -445,7 +514,7 @@ class OrderedTests(unittest.TestCase): self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A) self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B) - def testOrderedSimple(self, num = 2**17, prob=1.0): + def testOrderedSimple(self, num = 2**18, prob=1.0): f = open('/dev/urandom', 'r') a = self.a b = self.b @@ -468,8 +537,10 @@ class OrderedTests(unittest.TestCase): a.datagramReceived(msgb) else: msgb = swap(b, '<', 0) + self.assertEqual(len(self.A.msg), len(MSGB)) + self.assertEqual(len(self.B.msg), len(MSGA)) self.assertEqual(self.A.msg, MSGB) self.assertEqual(self.B.msg, MSGA) - - def testOrderedLossy(self, num = 2**17, prob=0.5): + + def testOrderedLossy(self, num = 2**18, prob=0.5): self.testOrderedSimple(num, prob) -- 2.39.5