]> git.mxchange.org Git - quix0rs-apt-p2p.git/blobdiff - test_airhook.py
bug fixes, more tests, looking solid now
[quix0rs-apt-p2p.git] / test_airhook.py
index b00681ecce8a78c0afb29098b1ba45db9b0232e0..2fb68a22c0c2ef4c4e0f8905575df34cc3be43b8 100644 (file)
@@ -71,32 +71,102 @@ def swap(a, dir="", noisy=0):
                                print 6*dir + " " + pscope(msg)
        return msg
        
+
+class UstrTests(unittest.TestCase):
+       def u(self, seq):
+               return ustr("%s%s" % (pack("!H", seq), 'foobar'))
+               
+       def testLT(self):
+               self.failUnless(self.u(0) < self.u(1))
+               self.failUnless(self.u(1) < self.u(2))
+               self.failUnless(self.u(2**16 - 1) < self.u(0))
+               self.failUnless(self.u(2**16 - 1) < self.u(1))
+               
+               self.failIf(self.u(1) < self.u(0))
+               self.failIf(self.u(2) < self.u(1))
+               self.failIf(self.u(0) < self.u(2**16 - 1))
+               self.failIf(self.u(1) < self.u(2**16 - 1))
+               
+       def testLTE(self):
+               self.failUnless(self.u(0) <= self.u(1))
+               self.failUnless(self.u(1) <= self.u(2))
+               self.failUnless(self.u(2) <= self.u(2))
+               self.failUnless(self.u(2**16 - 1) <= self.u(0))
+               self.failUnless(self.u(2**16 - 1) <= self.u(1))
+               self.failUnless(self.u(2**16 - 1) <= self.u(2**16))
+
+               self.failIf(self.u(1) <= self.u(0))
+               self.failIf(self.u(2) <= self.u(1))
+               self.failIf(self.u(0) <= self.u(2**16 - 1))
+               self.failIf(self.u(1) <= self.u(2**16 - 1))
+               
+       def testGT(self):
+               self.failUnless(self.u(1) > self.u(0))
+               self.failUnless(self.u(2) > self.u(1))
+               self.failUnless(self.u(0) > self.u(2**16 - 1))
+               self.failUnless(self.u(1) > self.u(2**16 - 1))
+
+               self.failIf(self.u(0) > self.u(1))
+               self.failIf(self.u(1) > self.u(2))
+               self.failIf(self.u(2**16 - 1) > self.u(0))
+               self.failIf(self.u(2**16 - 1) > self.u(1))
+
+       def testGTE(self):
+               self.failUnless(self.u(1) >= self.u(0))
+               self.failUnless(self.u(2) >= self.u(1))
+               self.failUnless(self.u(2) >= self.u(2))
+               self.failUnless(self.u(0) >= self.u(0))
+               self.failUnless(self.u(1) >= self.u(1))
+               self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1))
+
+               self.failIf(self.u(0) >= self.u(1))
+               self.failIf(self.u(1) >= self.u(2))
+               self.failIf(self.u(2**16 - 1) >= self.u(0))
+               self.failIf(self.u(2**16 - 1) >= self.u(1))
+               
+       def testEQ(self):
+               self.failUnless(self.u(0) == self.u(0))
+               self.failUnless(self.u(1) == self.u(1))
+               self.failUnless(self.u(2**16 - 1) == self.u(2**16-1))
+       
+               self.failIf(self.u(0) == self.u(1))
+               self.failIf(self.u(1) == self.u(0))
+               self.failIf(self.u(2**16 - 1) == self.u(0))
+
+       def testNEQ(self):
+               self.failUnless(self.u(1) != self.u(0))
+               self.failUnless(self.u(2) != self.u(1))
+               self.failIf(self.u(2) != self.u(2))
+               self.failIf(self.u(0) != self.u(0))
+               self.failIf(self.u(1) != self.u(1))
+               self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1))
+
+
 class SimpleTest(unittest.TestCase):
        def setUp(self):
                self.noisy = 0
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
        def testReallySimple(self):
                # connect to eachother and send a few packets, observe sequence incrementing
-               self.noisy = 0
                a = self.a
                b = self.b
                self.assertEqual(a.state, pending)
                self.assertEqual(b.state, pending)
                self.assertEqual(a.outSeq, 0)
                self.assertEqual(b.outSeq, 0)
-               self.assertEqual(a.obSeq, -1)
-               self.assertEqual(b.obSeq, -1)
+               self.assertEqual(a.obSeq, 0)
+               self.assertEqual(b.obSeq, 0)
 
                msg = swap(a, '>', self.noisy)          
                self.assertEqual(a.state, sent)
                self.assertEqual(a.outSeq, 1)
-               self.assertEqual(a.obSeq, -1)
+               self.assertEqual(a.obSeq, 0)
 
                b.datagramReceived(msg)
                self.assertEqual(b.state, sent)
                self.assertEqual(b.inSeq, 0)
-               self.assertEqual(b.obSeq, -1)
+               self.assertEqual(b.obSeq, 0)
                msg = swap(b, '<', self.noisy)          
                self.assertEqual(b.outSeq, 1)
 
@@ -121,9 +191,9 @@ class SimpleTest(unittest.TestCase):
 
 class BasicTests(unittest.TestCase):
        def setUp(self):
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
                self.noisy = 0
+               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
        def testSimple(self):
                a = self.a
                b = self.b
@@ -365,7 +435,7 @@ class BasicTests(unittest.TestCase):
                l.sort();a.imsgq.sort()
                self.assertEqual(a.imsgq, l)
 
-       def testOneWayBlast(self, num = 2**8):
+       def testOneWayBlast(self, num = 2**12):
                a = self.a
                b = self.b
                import sha
@@ -373,17 +443,17 @@ class BasicTests(unittest.TestCase):
                
                for i in xrange(num):
                        a.omsgq.append(sha.sha(`i`).digest())
-               msg = swap(a, '>', self.noisy)
-               while a.state != confirmed or ord(msg[0]) & FLAG_NEXT:
-                       b.datagramReceived(msg)
-                       msg = swap(b, '<', self.noisy)
+               msga = swap(a, '>', self.noisy)
+               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+                       b.datagramReceived(msga)
+                       msgb = swap(b, '<', self.noisy)
 
-                       a.datagramReceived(msg)
-                       msg = swap(a, '>', self.noisy)
+                       a.datagramReceived(msgb)
+                       msga = swap(a, '>', self.noisy)
 
                self.assertEqual(len(b.imsgq), num)
                
-       def testTwoWayBlast(self, num = 2**15, prob=0.5):
+       def testTwoWayBlast(self, num = 2**12, prob=0.5):
                a = self.a
                b = self.b
                import sha
@@ -425,9 +495,52 @@ class BasicTests(unittest.TestCase):
                for i in range(5000):
                        a.omsgq.append(sha.sha('a' + 'i').digest())
                
-               for i in range(31):
+               for i in range(5000 / 255):
                        msg = swap(a, noisy=self.noisy)
                        self.assertEqual(a.obSeq, 0)
-                       self.assertEqual(a.outMsgNums[a.obSeq], 0)
-               self.assertEqual(a.next, 254)
-               self.assertEqual(a.outMsgNums[19], 254)
+               self.assertEqual(a.next, 255)
+               self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
+
+class OrderedTests(unittest.TestCase):
+       def setUp(self):
+               self.noisy = 0
+               class queuer:
+                       def __init__(self):
+                               self.msg = ""
+                       def dataCameIn(self, host, port, data):
+                               self.msg+= data
+               self.A = queuer()
+               self.B = queuer()
+               self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A)
+               self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B)
+
+       def testOrderedSimple(self, num = 2**18, prob=1.0):
+               f = open('/dev/urandom', 'r')
+               a = self.a
+               b = self.b
+               A = self.A
+               B = self.B
+
+               MSGA = f.read(num)
+               MSGB = f.read(num)
+               self.a.sendSomeData(MSGA)
+               self.b.sendSomeData(MSGB)
+               
+               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+                       if rand(0,1) < prob:
+                               msga = swap(a, '>', self.noisy)
+                               b.datagramReceived(msga)
+                       else:
+                               msga = swap(a, '>', 0)
+                       if rand(0,1) < prob:
+                               msgb = swap(b, '<', self.noisy)
+                               a.datagramReceived(msgb)
+                       else:
+                               msgb = swap(b, '<', 0)
+               self.assertEqual(len(self.A.msg), len(MSGB))
+               self.assertEqual(len(self.B.msg), len(MSGA))
+               self.assertEqual(self.A.msg, MSGB)
+               self.assertEqual(self.B.msg, MSGA)
+
+       def testOrderedLossy(self, num = 2**18, prob=0.5):
+               self.testOrderedSimple(num, prob)