from time import time
from StringIO import StringIO
import unittest
+from bisect import insort_left
from twisted.internet import protocol
from twisted.internet import reactor
sent = 1
confirmed = 2
+class Delegate:
+ def setDelegate(self, delegate):
+ self.delegate = delegate
+ def getDelegate(self):
+ return self.delegate
+ def msgDelegate(self, method, args=(), kwargs={}):
+ if hasattr(self, 'delegate') and hasattr(self.delegate, method) and callable(getattr(self.delegate, method)):
+ apply(getattr(self.delegate, method) , args, kwargs)
+
class Airhook(protocol.DatagramProtocol):
+ def __init__(self, connection_class):
+ self.connection_class = connection_class
def startProtocol(self):
self.connections = {}
def connectionForAddr(self, addr):
if not self.connections.has_key(addr):
- conn = AirhookConnection(self.transport, addr)
+ conn = connection_class(self.transport, addr, self.delegate)
self.connections[addr] = conn
return self.connections[addr]
+
class AirhookPacket:
def __init__(self, msg):
self.datagram = msg
skip+=1
self.msgs.append( msg[skip:skip+n])
skip += n
-
-
-class AirhookConnection:
- def __init__(self, transport, addr):
+class AirhookConnection(Delegate):
+ def __init__(self, transport, addr, delegate):
+ self.delegate = delegate
self.addr = addr
type, self.host, self.port = addr
self.transport = transport
self.sendSession = None # send session/observed fields until obSeq > sendSession
self.resetMessages()
-
+
def resetMessages(self):
self.weMissed = []
self.inMsg = 0 # next incoming message number
if response:
reactor.callLater(0, self.sendNext)
self.lastReceived = time()
-
-
+ self.dataCameIn()
+
def sendNext(self):
flags = 0
header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
else:
reactor.callLater(1, self.sendNext)
+
+ def dataCameIn(self):
+ self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
+ if hasattr(self, 'delegate') and self.delegate != None:
+ self.imsgq = []
+
+class ustr(str):
+ def getseq(self):
+ if not hasattr(self, 'seq'):
+ self.seq = unpack("!H", self[0:2])[0]
+ return self.seq
+ def __lt__(self, other):
+ return self.getseq() < other.getseq()
+ def __le__(self, other):
+ return self.getseq() <= other.getseq()
+ def __eq__(self, other):
+ return self.getseq() != other.getseq()
+ def __ne__(self, other):
+ return self.getseq() <= other.getseq()
+ def __gt__(self, other):
+ return self.getseq() > other.getseq()
+ def __ge__(self, other):
+ return self.getseq() >= other.getseq()
+
+class OrderedConnection(AirhookConnection):
+ def __init__(self, transport, addr, delegate):
+ AirhookConnection.__init__(self, transport, addr, delegate)
+ self.oseq = 0
+ self.iseq = 0
+ self.q = []
+
+ def dataCameIn(self):
+ # put 'em together
+ for msg in self.imsgq:
+ insort_left(self.q, ustr(msg))
+ self.imsgq = []
+ data = ""
+ while self.q and self.iseq == self.q[0].getseq():
+ data += self.q[0][2:]
+ self.iseq = (self.iseq + 1) % 2**16
+ self.q = self.q[1:]
+ if data:
+ self.msgDelegate('dataCameIn', (self.host, self.port, data))
+
+ def sendSomeData(self, data):
+ # chop it up and queue it up
+ while data:
+ p = "%s%s" % (pack("!H", self.oseq), data[:253])
+ self.omsgq.append(p)
+ data = data[253:]
+ self.oseq = (self.oseq + 1) % 2**16
+
+ if self.omsgq:
+ self.sendNext()
class SimpleTest(unittest.TestCase):
def setUp(self):
self.noisy = 0
- self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
- self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+ self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
def testReallySimple(self):
# connect to eachother and send a few packets, observe sequence incrementing
self.noisy = 0
class BasicTests(unittest.TestCase):
def setUp(self):
- self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040))
- self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040))
+ self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
+ self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
self.noisy = 0
def testSimple(self):
a = self.a
self.assertEqual(len(b.imsgq), num)
- def testTwoWayBlast(self, num = 2**15, prob=0.5):
+ def testTwoWayBlast(self, num = 2**9, prob=0.5):
a = self.a
b = self.b
import sha
self.assertEqual(a.outMsgNums[a.obSeq], 0)
self.assertEqual(a.next, 254)
self.assertEqual(a.outMsgNums[19], 254)
+
+class OrderedTests(unittest.TestCase):
+ def setUp(self):
+ self.noisy = 0
+ class queuer:
+ def __init__(self):
+ self.msg = ""
+ def dataCameIn(self, host, port, data):
+ self.msg+= data
+ self.A = queuer()
+ self.B = queuer()
+ self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A)
+ self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B)
+
+ def testOrderedSimple(self, num = 2**17, prob=1.0):
+ f = open('/dev/urandom', 'r')
+ a = self.a
+ b = self.b
+ A = self.A
+ B = self.B
+
+ MSGA = f.read(num)
+ MSGB = f.read(num)
+ self.a.sendSomeData(MSGA)
+ self.b.sendSomeData(MSGB)
+
+ while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+ if rand(0,1) < prob:
+ msga = swap(a, '>', self.noisy)
+ b.datagramReceived(msga)
+ else:
+ msga = swap(a, '>', 0)
+ if rand(0,1) < prob:
+ msgb = swap(b, '<', self.noisy)
+ a.datagramReceived(msgb)
+ else:
+ msgb = swap(b, '<', 0)
+ self.assertEqual(self.A.msg, MSGB)
+ self.assertEqual(self.B.msg, MSGA)
+
+ def testOrderedLossy(self, num = 2**17, prob=0.5):
+ self.testOrderedSimple(num, prob)