fixed reset connection handling
authorburris <burris>
Mon, 20 Jan 2003 04:22:35 +0000 (04:22 +0000)
committerburris <burris>
Mon, 20 Jan 2003 04:22:35 +0000 (04:22 +0000)
airhook.py
test_airhook.py

index 3f1f80beb6a0dccb24a618b3dabe019fecfc658e..c100cdf12faba2265f639132d08c55e8296995d6 100644 (file)
@@ -115,17 +115,16 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
     def resetMessages(self):
         self.weMissed = []
         self.inMsg = 0   # next incoming message number
-        self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
+        self.outMsgNums = [0] * 256 # outgoing message numbers i = outNum % 256
         self.next = 0  # next outgoing message number
 
     def datagramReceived(self, datagram):
         if not datagram:
             return
+        if self.noisy:
+            print `datagram`
         p = AirhookPacket(datagram)
         
-        # check to make sure sequence number isn't out of order
-        if (p.seq - self.inSeq) % 2**16 >= 256:
-            return
             
         # check for state change
         if self.state == pending:
@@ -138,7 +137,6 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
                     return
             elif p.session != None:
                 self.observed = p.session
-                self.state = sent
                 self.response = 1
         elif self.state == sent:
             if p.observed != None and p.session != None:
@@ -157,10 +155,15 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
             if p.session != None or p.observed != None :
                 if (p.session != None and p.session != self.observed) or (p.observed != None and p.observed != self.sessionID):
                     self.state = pending
+                    self.observed = p.session
                     self.resetMessages()
                     self.inSeq = p.seq
+
+        # check to make sure sequence number isn't out of order
+        if (p.seq - self.inSeq) % 2**16 >= 256:
+            return
     
-        if self.state != pending:      
+        if self.state == confirmed:    
             msgs = []          
             missed = []
             
@@ -208,6 +211,9 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
         
         # session / observed logic
         if self.state == pending:
+            if self.observed != None:
+                flags = flags | FLAG_OBSERVED
+                ids +=  pack("!L", self.observed)
             flags = flags | FLAG_SESSION
             ids +=  pack("!L", self.sessionID)
             self.state = sent
@@ -231,7 +237,7 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
         if self.obSeq >= 0:
             self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
 
-        if self.weMissed:
+        if len(self.weMissed) > 0:
             flags = flags | FLAG_MISSED
             missed += chr(len(self.weMissed) - 1)
             for i in self.weMissed:
@@ -271,8 +277,10 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
         self.schedule()
         
     def timeToSend(self):
+        if self.state == pending:
+            return (1, 0)
         # any outstanding messages and are we not too far ahead of our counterparty?
-        if len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
+        elif len(self.omsgq) > 0 and self.state != sent and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
             return (1, 0)
         # do we explicitly need to send a response?
         elif self.response:
@@ -281,8 +289,6 @@ class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConne
         # have we not sent anything in a while?
         elif time() - self.lastTransmit > 1.0:
             return (1, 1)
-        elif self.state == pending:
-            return (1, 1)
             
         # nothing to send
         return (0, 0)
index e491d840232cf5f53c228d8fce6abdb01315039a..d6cfb13b7f7cfd173557cbc98e1c80f51b0beb3f 100644 (file)
@@ -225,10 +225,10 @@ class SimpleTest(unittest.TestCase):
         self.assertEqual(a.obSeq, 0)
 
         b.datagramReceived(msg)
-        self.assertEqual(b.state, sent)
         self.assertEqual(b.inSeq, 0)
         self.assertEqual(b.obSeq, 0)
         msg = swap(b, '<', self.noisy)         
+        self.assertEqual(b.state, sent)
         self.assertEqual(b.outSeq, 1)
 
         a.datagramReceived(msg)
@@ -443,9 +443,9 @@ class BasicTests(unittest.TestCase):
             
 
         b.datagramReceived(msg)
-        self.assertEqual(b.state, sent)
         
         msg = swap(b, '<', self.noisy)
+        self.assertEqual(b.state, sent)
         a.datagramReceived(msg)
 
         msg = swap(a, '>', self.noisy)
@@ -546,6 +546,48 @@ class BasicTests(unittest.TestCase):
         self.assertEqual(a.next, 255)
         self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
 
+    def testConnectionReset(self):
+        a = self.a
+        b = self.b
+        msg = swap(a, noisy=self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, noisy=self.noisy)
+        a.datagramReceived(msg)
+
+        a.omsgq.append("TESTING")
+        msg = swap(a, noisy=self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, noisy=self.noisy)
+        a.datagramReceived(msg)
+
+        self.assertEqual(b.protocol.q[0], "TESTING")
+        self.assertEqual(b.state, confirmed)
+        
+        self.a = AirhookConnection()
+        self.a.makeConnection(DummyTransport())
+        self.a.addr = ('127.0.0.1', 4444)
+        a = self.a
+        
+        a.omsgq.append("TESTING2")
+        msg = swap(a, noisy=self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, noisy=self.noisy)
+        a.datagramReceived(msg)
+        
+        self.assertEqual(len(b.protocol.q), 1)
+        msg = swap(a, noisy=self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, noisy=self.noisy)
+        a.datagramReceived(msg)
+
+        self.assertEqual(len(b.protocol.q), 2)
+        self.assertEqual(b.protocol.q[1], "TESTING2")
+
+        
 class StreamTests(unittest.TestCase):
     def setUp(self):
         self.noisy = 0
@@ -585,12 +627,15 @@ class SimpleReactor(unittest.TestCase):
         self.b = makeReceiver(2021)
         self.ac = self.a.connectionForAddr(('127.0.0.1', 2021))
         self.bc = self.b.connectionForAddr(('127.0.0.1', 2020))
+        self.ac.noisy = self.noisy
+        self.bc.noisy = self.noisy
     def testSimple(self):
         msg = "Testing 1, 2, 3"
         self.ac.write(msg)
         reactor.iterate()
         reactor.iterate()
         reactor.iterate()
+        self.assertEqual(self.bc.state, confirmed)
         self.assertEqual(self.bc.protocol.q, [msg])
 
 class SimpleReactorEcho(unittest.TestCase):