]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
stream connection class
authorburris <burris>
Sun, 22 Dec 2002 11:02:32 +0000 (11:02 +0000)
committerburris <burris>
Sun, 22 Dec 2002 11:02:32 +0000 (11:02 +0000)
airhook.py
test_airhook.py

index ec847bd13e3003e2bb191404cf751ed8762cef82..4053dfe5906493a09b012950b5f40e6c8add07f9 100644 (file)
@@ -6,6 +6,7 @@ from struct import pack, unpack
 from time import time
 from StringIO import StringIO
 import unittest
+from bisect import insort_left
 
 from twisted.internet import protocol
 from twisted.internet import reactor
@@ -24,8 +25,19 @@ pending = 0
 sent = 1
 confirmed = 2
 
+class Delegate:
+       def setDelegate(self, delegate):
+               self.delegate = delegate
+       def getDelegate(self):
+               return self.delegate
+       def msgDelegate(self, method, args=(), kwargs={}):
+               if hasattr(self, 'delegate') and hasattr(self.delegate, method) and callable(getattr(self.delegate, method)):
+                       apply(getattr(self.delegate, method) , args, kwargs)
+
 class Airhook(protocol.DatagramProtocol):
 
+       def __init__(self, connection_class):
+               self.connection_class = connection_class
        def startProtocol(self):
                self.connections = {}
                                
@@ -37,10 +49,11 @@ class Airhook(protocol.DatagramProtocol):
 
        def connectionForAddr(self, addr):
                if not self.connections.has_key(addr):
-                       conn = AirhookConnection(self.transport, addr)
+                       conn = connection_class(self.transport, addr, self.delegate)
                        self.connections[addr] = conn
                return self.connections[addr]
 
+               
 class AirhookPacket:
        def __init__(self, msg):
                self.datagram = msg
@@ -74,11 +87,10 @@ class AirhookPacket:
                                skip+=1
                                self.msgs.append( msg[skip:skip+n])
                                skip += n
-               
-               
 
-class AirhookConnection:
-       def __init__(self, transport, addr):
+class AirhookConnection(Delegate):
+       def __init__(self, transport, addr, delegate):
+               self.delegate = delegate
                self.addr = addr
                type, self.host, self.port = addr
                self.transport = transport
@@ -102,7 +114,7 @@ class AirhookConnection:
                self.sendSession = None  # send session/observed fields until obSeq > sendSession
 
                self.resetMessages()
-               
+       
        def resetMessages(self):
                self.weMissed = []
                self.inMsg = 0   # next incoming message number
@@ -188,8 +200,8 @@ class AirhookConnection:
                if response:
                        reactor.callLater(0, self.sendNext)
                self.lastReceived = time()
-
-
+               self.dataCameIn()
+               
        def sendNext(self):
                flags = 0
                header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
@@ -264,3 +276,57 @@ class AirhookConnection:
                else:
                        reactor.callLater(1, self.sendNext)
 
+
+       def dataCameIn(self):
+               self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
+               if hasattr(self, 'delegate') and self.delegate != None:
+                       self.imsgq = []
+
+class ustr(str):
+       def getseq(self):
+               if not hasattr(self, 'seq'):
+                       self.seq = unpack("!H", self[0:2])[0]
+               return self.seq
+       def __lt__(self, other):
+               return self.getseq() < other.getseq()
+       def __le__(self, other):
+               return self.getseq() <= other.getseq()
+       def __eq__(self, other):
+               return self.getseq() != other.getseq()
+       def __ne__(self, other):
+               return self.getseq() <= other.getseq()
+       def __gt__(self, other):
+               return self.getseq() > other.getseq()
+       def __ge__(self, other):
+               return self.getseq() >= other.getseq()
+
+class OrderedConnection(AirhookConnection):
+       def __init__(self, transport, addr, delegate):
+               AirhookConnection.__init__(self, transport, addr, delegate)
+               self.oseq = 0
+               self.iseq = 0
+               self.q = []
+
+       def dataCameIn(self):
+               # put 'em together
+               for msg in self.imsgq:
+                       insort_left(self.q, ustr(msg))
+               self.imsgq = []
+               data = ""
+               while self.q and self.iseq == self.q[0].getseq():
+                       data += self.q[0][2:]
+                       self.iseq = (self.iseq + 1) % 2**16
+                       self.q = self.q[1:]
+               if data:
+                       self.msgDelegate('dataCameIn', (self.host, self.port, data))
+               
+       def sendSomeData(self, data):
+               # chop it up and queue it up
+               while data:
+                       p = "%s%s" % (pack("!H", self.oseq), data[:253])
+                       self.omsgq.append(p)
+                       data = data[253:]
+                       self.oseq = (self.oseq + 1) % 2**16
+
+               if self.omsgq:
+                       self.sendNext()
index b00681ecce8a78c0afb29098b1ba45db9b0232e0..669bee5a202adfb103b9a0e3ed751abf176c2cce 100644 (file)
@@ -74,8 +74,8 @@ def swap(a, dir="", noisy=0):
 class SimpleTest(unittest.TestCase):
        def setUp(self):
                self.noisy = 0
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
        def testReallySimple(self):
                # connect to eachother and send a few packets, observe sequence incrementing
                self.noisy = 0
@@ -121,8 +121,8 @@ class SimpleTest(unittest.TestCase):
 
 class BasicTests(unittest.TestCase):
        def setUp(self):
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
                self.noisy = 0
        def testSimple(self):
                a = self.a
@@ -383,7 +383,7 @@ class BasicTests(unittest.TestCase):
 
                self.assertEqual(len(b.imsgq), num)
                
-       def testTwoWayBlast(self, num = 2**15, prob=0.5):
+       def testTwoWayBlast(self, num = 2**9, prob=0.5):
                a = self.a
                b = self.b
                import sha
@@ -431,3 +431,45 @@ class BasicTests(unittest.TestCase):
                        self.assertEqual(a.outMsgNums[a.obSeq], 0)
                self.assertEqual(a.next, 254)
                self.assertEqual(a.outMsgNums[19], 254)
+
+class OrderedTests(unittest.TestCase):
+       def setUp(self):
+               self.noisy = 0
+               class queuer:
+                       def __init__(self):
+                               self.msg = ""
+                       def dataCameIn(self, host, port, data):
+                               self.msg+= data
+               self.A = queuer()
+               self.B = queuer()
+               self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A)
+               self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B)
+
+       def testOrderedSimple(self, num = 2**17, prob=1.0):
+               f = open('/dev/urandom', 'r')
+               a = self.a
+               b = self.b
+               A = self.A
+               B = self.B
+
+               MSGA = f.read(num)
+               MSGB = f.read(num)
+               self.a.sendSomeData(MSGA)
+               self.b.sendSomeData(MSGB)
+               
+               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+                       if rand(0,1) < prob:
+                               msga = swap(a, '>', self.noisy)
+                               b.datagramReceived(msga)
+                       else:
+                               msga = swap(a, '>', 0)
+                       if rand(0,1) < prob:
+                               msgb = swap(b, '<', self.noisy)
+                               a.datagramReceived(msgb)
+                       else:
+                               msgb = swap(b, '<', 0)
+               self.assertEqual(self.A.msg, MSGB)
+               self.assertEqual(self.B.msg, MSGA)
+               
+       def testOrderedLossy(self, num = 2**17, prob=0.5):
+               self.testOrderedSimple(num, prob)