]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
bug fixes, more tests, looking solid now
authorburris <burris>
Mon, 23 Dec 2002 02:59:11 +0000 (02:59 +0000)
committerburris <burris>
Mon, 23 Dec 2002 02:59:11 +0000 (02:59 +0000)
airhook.py
test_airhook.py

index 4053dfe5906493a09b012950b5f40e6c8add07f9..8d9443c52651c07cd3f24a1bd52b5c22cdb2e972 100644 (file)
@@ -19,7 +19,7 @@ FLAG_MISSED = 4
 FLAG_NEXT = 2
 FLAG_INTERVAL = 1
 
-MAX_PACKET_SIZE = 1480
+MAX_PACKET_SIZE = 1496
 
 pending = 0
 sent = 1
@@ -96,7 +96,7 @@ class AirhookConnection(Delegate):
                self.transport = transport
                
                self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
-               self.obSeq = -1   # highest sequence confirmed by remote
+               self.obSeq = 0   # highest sequence confirmed by remote
                self.inSeq = 0   # last received sequence
                self.observed = None  # their session id
                self.sessionID = long(rand(0, 2**32))  # our session id
@@ -107,8 +107,6 @@ class AirhookConnection(Delegate):
                self.state = pending
                
                self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
-               self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
-               self.next = -1  # next outgoing message number
                self.omsgq = [] # list of messages to go out
                self.imsgq = [] # list of messages coming in
                self.sendSession = None  # send session/observed fields until obSeq > sendSession
@@ -118,13 +116,19 @@ class AirhookConnection(Delegate):
        def resetMessages(self):
                self.weMissed = []
                self.inMsg = 0   # next incoming message number
-       
+               self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
+               self.next = 0  # next outgoing message number
+
        def datagramReceived(self, datagram):
                if not datagram:
                        return
                response = 0 # if we know we have a response now (like resending missed packets)
                p = AirhookPacket(datagram)
                
+               # check to make sure sequence number isn't out of order
+               if (p.seq - self.inSeq) % 2**16 >= 256:
+                       return
+                       
                # check for state change
                if self.state == pending:
                        if p.observed != None and p.session != None:
@@ -163,21 +167,16 @@ class AirhookConnection(Delegate):
                if self.state != pending:       
                        msgs = []               
                        missed = []
-
-                       # check to make sure sequence number isn't out of wack
-                       assert (p.seq - self.inSeq) % 2**16 < 256
                        
                        # see if they need us to resend anything
                        for i in p.missed:
                                response = 1
                                if self.outMsgs[i] != None:
-                                       self.omsgq.insert(0, self.outMsgs[i])
+                                       self.omsgq.append(self.outMsgs[i])
                                        self.outMsgs[i] = None
                                        
-                       # see if we need them to send anything
+                       # see if we missed any messages
                        if p.next != None:
-                               if p.next == 0 and self.inMsg == -1:
-                                       missed = 255
                                missed_count = (p.next - self.inMsg) % 256
                                if missed_count:
                                        self.lastReceived = time()
@@ -185,10 +184,11 @@ class AirhookConnection(Delegate):
                                                missed += [(self.outSeq, (self.inMsg + i) % 256)]
                                        response = 1
                                        self.weMissed += missed
+                               # record highest message number seen
                                self.inMsg = (p.next + len(p.msgs)) % 256
-                               
+                       
+                       # append messages, update sequence
                        self.imsgq += p.msgs
-                       self.inSeq = p.seq
                        
                if self.state == confirmed:
                        # unpack the observed sequence
@@ -197,6 +197,8 @@ class AirhookConnection(Delegate):
                                tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
                        self.obSeq = tseq
 
+               self.inSeq = p.seq
+
                if response:
                        reactor.callLater(0, self.sendNext)
                self.lastReceived = time()
@@ -209,6 +211,7 @@ class AirhookConnection(Delegate):
                missed = ""
                msgs = ""
                
+               # session / observed logic
                if self.state == pending:
                        flags = flags | FLAG_SESSION
                        ids +=  pack("!L", self.sessionID)
@@ -229,6 +232,7 @@ class AirhookConnection(Delegate):
                                        flags = flags | FLAG_SESSION | FLAG_OBSERVED
                                        ids +=  pack("!LL", self.observed, self.sessionID)
                
+               # missed header
                if self.obSeq >= 0:
                        self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
 
@@ -238,39 +242,37 @@ class AirhookConnection(Delegate):
                        for i in self.weMissed:
                                missed += chr(i[1])
                                
+               # append any outgoing messages
                if self.state == confirmed and self.omsgq:
-                       first = (self.next + 1) % 256
-                       while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) :
-                               if self.obSeq == -1:
-                                       highest = 0
-                               else:
-                                       highest = self.outMsgNums[self.obSeq % 256]
-                               if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256:
-                                       break
-                               else:
-                                       self.next = (self.next + 1) % 256
-                                       msg = self.omsgq.pop()
-                                       msgs += chr(len(msg) - 1) + msg
-                                       self.outMsgs[self.next] = msg
+                       first = self.next
+                       outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
+                       while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
+                               msg = self.omsgq.pop()
+                               msgs += chr(len(msg) - 1) + msg
+                               self.outMsgs[self.next] = msg
+                               self.next = (self.next + 1) % 256
+                               outstanding+=1
+               # update outgoing message stat
                if msgs:
                        flags = flags | FLAG_NEXT
                        ids += chr(first)
                        self.lastTransmitSeq = self.outSeq
-                       self.outMsgNums[self.outSeq % 256] = first
-               else:
-                       if self.next == -1:
-                               self.outMsgNums[self.outSeq % 256] = 0
-                       else:
-                               self.outMsgNums[self.outSeq % 256] = self.next
-                       
-               if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and  not flags & FLAG_NEXT:
+                       #self.outMsgNums[self.outSeq % 256] = first
+               #else:
+               self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
+               
+               # do we need a NEXT flag despite not having sent any messages?
+               if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
                                flags = flags | FLAG_NEXT
-                               ids += chr((self.next + 1) % 256)
+                               ids += chr(self.next)
+               
+               # update stats and send packet
                packet = chr(flags) + header + ids + missed + msgs
                self.outSeq = (self.outSeq + 1) % 2**16
                self.lastTransmit = time()
                self.transport.write(packet)
                
+               # call later
                if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
                        reactor.callLater(0, self.sendNext)
                else:
@@ -278,29 +280,42 @@ class AirhookConnection(Delegate):
 
 
        def dataCameIn(self):
+               """
+               called when we get a packet bearing messages
+               delegate must do something with the messages or they will get dropped 
+               """
                self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
                if hasattr(self, 'delegate') and self.delegate != None:
                        self.imsgq = []
 
 class ustr(str):
+       """
+               this subclass of string encapsulates each ordered message, caches it's sequence number,
+               and has comparison functions to sort by sequence number
+       """
        def getseq(self):
                if not hasattr(self, 'seq'):
                        self.seq = unpack("!H", self[0:2])[0]
                return self.seq
        def __lt__(self, other):
-               return self.getseq() < other.getseq()
+               return (self.getseq() - other.getseq()) % 2**16 > 255
        def __le__(self, other):
-               return self.getseq() <= other.getseq()
+               return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
        def __eq__(self, other):
-               return self.getseq() != other.getseq()
+               return self.getseq() == other.getseq()
        def __ne__(self, other):
-               return self.getseq() <= other.getseq()
+               return self.getseq() != other.getseq()
        def __gt__(self, other):
-               return self.getseq() > other.getseq()
+               return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
        def __ge__(self, other):
-               return self.getseq() >= other.getseq()
-
+               return (self.getseq() - other.getseq()) % 2**16 < 256
+               
 class OrderedConnection(AirhookConnection):
+       """
+               this implements a simple protocol for ordered messages over airhook
+               the first two octets of each message are interpreted as a 16-bit sequence number
+               253 bytes are used for payload
+       """
        def __init__(self, transport, addr, delegate):
                AirhookConnection.__init__(self, transport, addr, delegate)
                self.oseq = 0
@@ -315,16 +330,16 @@ class OrderedConnection(AirhookConnection):
                data = ""
                while self.q and self.iseq == self.q[0].getseq():
                        data += self.q[0][2:]
-                       self.iseq = (self.iseq + 1) % 2**16
                        self.q = self.q[1:]
-               if data:
+                       self.iseq = (self.iseq + 1) % 2**16
+               if data != '':
                        self.msgDelegate('dataCameIn', (self.host, self.port, data))
                
        def sendSomeData(self, data):
                # chop it up and queue it up
                while data:
-                       p = "%s%s" % (pack("!H", self.oseq), data[:253])
-                       self.omsgq.append(p)
+                       p = pack("!H", self.oseq) + data[:253]
+                       self.omsgq.insert(0, p)
                        data = data[253:]
                        self.oseq = (self.oseq + 1) % 2**16
 
index 669bee5a202adfb103b9a0e3ed751abf176c2cce..2fb68a22c0c2ef4c4e0f8905575df34cc3be43b8 100644 (file)
@@ -71,6 +71,77 @@ def swap(a, dir="", noisy=0):
                                print 6*dir + " " + pscope(msg)
        return msg
        
+
+class UstrTests(unittest.TestCase):
+       def u(self, seq):
+               return ustr("%s%s" % (pack("!H", seq), 'foobar'))
+               
+       def testLT(self):
+               self.failUnless(self.u(0) < self.u(1))
+               self.failUnless(self.u(1) < self.u(2))
+               self.failUnless(self.u(2**16 - 1) < self.u(0))
+               self.failUnless(self.u(2**16 - 1) < self.u(1))
+               
+               self.failIf(self.u(1) < self.u(0))
+               self.failIf(self.u(2) < self.u(1))
+               self.failIf(self.u(0) < self.u(2**16 - 1))
+               self.failIf(self.u(1) < self.u(2**16 - 1))
+               
+       def testLTE(self):
+               self.failUnless(self.u(0) <= self.u(1))
+               self.failUnless(self.u(1) <= self.u(2))
+               self.failUnless(self.u(2) <= self.u(2))
+               self.failUnless(self.u(2**16 - 1) <= self.u(0))
+               self.failUnless(self.u(2**16 - 1) <= self.u(1))
+               self.failUnless(self.u(2**16 - 1) <= self.u(2**16))
+
+               self.failIf(self.u(1) <= self.u(0))
+               self.failIf(self.u(2) <= self.u(1))
+               self.failIf(self.u(0) <= self.u(2**16 - 1))
+               self.failIf(self.u(1) <= self.u(2**16 - 1))
+               
+       def testGT(self):
+               self.failUnless(self.u(1) > self.u(0))
+               self.failUnless(self.u(2) > self.u(1))
+               self.failUnless(self.u(0) > self.u(2**16 - 1))
+               self.failUnless(self.u(1) > self.u(2**16 - 1))
+
+               self.failIf(self.u(0) > self.u(1))
+               self.failIf(self.u(1) > self.u(2))
+               self.failIf(self.u(2**16 - 1) > self.u(0))
+               self.failIf(self.u(2**16 - 1) > self.u(1))
+
+       def testGTE(self):
+               self.failUnless(self.u(1) >= self.u(0))
+               self.failUnless(self.u(2) >= self.u(1))
+               self.failUnless(self.u(2) >= self.u(2))
+               self.failUnless(self.u(0) >= self.u(0))
+               self.failUnless(self.u(1) >= self.u(1))
+               self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1))
+
+               self.failIf(self.u(0) >= self.u(1))
+               self.failIf(self.u(1) >= self.u(2))
+               self.failIf(self.u(2**16 - 1) >= self.u(0))
+               self.failIf(self.u(2**16 - 1) >= self.u(1))
+               
+       def testEQ(self):
+               self.failUnless(self.u(0) == self.u(0))
+               self.failUnless(self.u(1) == self.u(1))
+               self.failUnless(self.u(2**16 - 1) == self.u(2**16-1))
+       
+               self.failIf(self.u(0) == self.u(1))
+               self.failIf(self.u(1) == self.u(0))
+               self.failIf(self.u(2**16 - 1) == self.u(0))
+
+       def testNEQ(self):
+               self.failUnless(self.u(1) != self.u(0))
+               self.failUnless(self.u(2) != self.u(1))
+               self.failIf(self.u(2) != self.u(2))
+               self.failIf(self.u(0) != self.u(0))
+               self.failIf(self.u(1) != self.u(1))
+               self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1))
+
+
 class SimpleTest(unittest.TestCase):
        def setUp(self):
                self.noisy = 0
@@ -78,25 +149,24 @@ class SimpleTest(unittest.TestCase):
                self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
        def testReallySimple(self):
                # connect to eachother and send a few packets, observe sequence incrementing
-               self.noisy = 0
                a = self.a
                b = self.b
                self.assertEqual(a.state, pending)
                self.assertEqual(b.state, pending)
                self.assertEqual(a.outSeq, 0)
                self.assertEqual(b.outSeq, 0)
-               self.assertEqual(a.obSeq, -1)
-               self.assertEqual(b.obSeq, -1)
+               self.assertEqual(a.obSeq, 0)
+               self.assertEqual(b.obSeq, 0)
 
                msg = swap(a, '>', self.noisy)          
                self.assertEqual(a.state, sent)
                self.assertEqual(a.outSeq, 1)
-               self.assertEqual(a.obSeq, -1)
+               self.assertEqual(a.obSeq, 0)
 
                b.datagramReceived(msg)
                self.assertEqual(b.state, sent)
                self.assertEqual(b.inSeq, 0)
-               self.assertEqual(b.obSeq, -1)
+               self.assertEqual(b.obSeq, 0)
                msg = swap(b, '<', self.noisy)          
                self.assertEqual(b.outSeq, 1)
 
@@ -121,9 +191,9 @@ class SimpleTest(unittest.TestCase):
 
 class BasicTests(unittest.TestCase):
        def setUp(self):
+               self.noisy = 0
                self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
                self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
-               self.noisy = 0
        def testSimple(self):
                a = self.a
                b = self.b
@@ -365,7 +435,7 @@ class BasicTests(unittest.TestCase):
                l.sort();a.imsgq.sort()
                self.assertEqual(a.imsgq, l)
 
-       def testOneWayBlast(self, num = 2**8):
+       def testOneWayBlast(self, num = 2**12):
                a = self.a
                b = self.b
                import sha
@@ -373,17 +443,17 @@ class BasicTests(unittest.TestCase):
                
                for i in xrange(num):
                        a.omsgq.append(sha.sha(`i`).digest())
-               msg = swap(a, '>', self.noisy)
-               while a.state != confirmed or ord(msg[0]) & FLAG_NEXT:
-                       b.datagramReceived(msg)
-                       msg = swap(b, '<', self.noisy)
+               msga = swap(a, '>', self.noisy)
+               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+                       b.datagramReceived(msga)
+                       msgb = swap(b, '<', self.noisy)
 
-                       a.datagramReceived(msg)
-                       msg = swap(a, '>', self.noisy)
+                       a.datagramReceived(msgb)
+                       msga = swap(a, '>', self.noisy)
 
                self.assertEqual(len(b.imsgq), num)
                
-       def testTwoWayBlast(self, num = 2**9, prob=0.5):
+       def testTwoWayBlast(self, num = 2**12, prob=0.5):
                a = self.a
                b = self.b
                import sha
@@ -425,12 +495,11 @@ class BasicTests(unittest.TestCase):
                for i in range(5000):
                        a.omsgq.append(sha.sha('a' + 'i').digest())
                
-               for i in range(31):
+               for i in range(5000 / 255):
                        msg = swap(a, noisy=self.noisy)
                        self.assertEqual(a.obSeq, 0)
-                       self.assertEqual(a.outMsgNums[a.obSeq], 0)
-               self.assertEqual(a.next, 254)
-               self.assertEqual(a.outMsgNums[19], 254)
+               self.assertEqual(a.next, 255)
+               self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
 
 class OrderedTests(unittest.TestCase):
        def setUp(self):
@@ -445,7 +514,7 @@ class OrderedTests(unittest.TestCase):
                self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A)
                self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B)
 
-       def testOrderedSimple(self, num = 2**17, prob=1.0):
+       def testOrderedSimple(self, num = 2**18, prob=1.0):
                f = open('/dev/urandom', 'r')
                a = self.a
                b = self.b
@@ -468,8 +537,10 @@ class OrderedTests(unittest.TestCase):
                                a.datagramReceived(msgb)
                        else:
                                msgb = swap(b, '<', 0)
+               self.assertEqual(len(self.A.msg), len(MSGB))
+               self.assertEqual(len(self.B.msg), len(MSGA))
                self.assertEqual(self.A.msg, MSGB)
                self.assertEqual(self.B.msg, MSGA)
-               
-       def testOrderedLossy(self, num = 2**17, prob=0.5):
+
+       def testOrderedLossy(self, num = 2**18, prob=0.5):
                self.testOrderedSimple(num, prob)