FLAG_NEXT = 2
FLAG_INTERVAL = 1
-MAX_PACKET_SIZE = 1480
+MAX_PACKET_SIZE = 1496
pending = 0
sent = 1
self.transport = transport
self.outSeq = 0 # highest sequence we have sent, can't be 255 more than obSeq
- self.obSeq = -1 # highest sequence confirmed by remote
+ self.obSeq = 0 # highest sequence confirmed by remote
self.inSeq = 0 # last received sequence
self.observed = None # their session id
self.sessionID = long(rand(0, 2**32)) # our session id
self.state = pending
self.outMsgs = [None] * 256 # outgoing messages (seq sent, message), index = message number
- self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
- self.next = -1 # next outgoing message number
self.omsgq = [] # list of messages to go out
self.imsgq = [] # list of messages coming in
self.sendSession = None # send session/observed fields until obSeq > sendSession
def resetMessages(self):
self.weMissed = []
self.inMsg = 0 # next incoming message number
-
+ self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
+ self.next = 0 # next outgoing message number
+
def datagramReceived(self, datagram):
if not datagram:
return
response = 0 # if we know we have a response now (like resending missed packets)
p = AirhookPacket(datagram)
+ # check to make sure sequence number isn't out of order
+ if (p.seq - self.inSeq) % 2**16 >= 256:
+ return
+
# check for state change
if self.state == pending:
if p.observed != None and p.session != None:
if self.state != pending:
msgs = []
missed = []
-
- # check to make sure sequence number isn't out of wack
- assert (p.seq - self.inSeq) % 2**16 < 256
# see if they need us to resend anything
for i in p.missed:
response = 1
if self.outMsgs[i] != None:
- self.omsgq.insert(0, self.outMsgs[i])
+ self.omsgq.append(self.outMsgs[i])
self.outMsgs[i] = None
- # see if we need them to send anything
+ # see if we missed any messages
if p.next != None:
- if p.next == 0 and self.inMsg == -1:
- missed = 255
missed_count = (p.next - self.inMsg) % 256
if missed_count:
self.lastReceived = time()
missed += [(self.outSeq, (self.inMsg + i) % 256)]
response = 1
self.weMissed += missed
+ # record highest message number seen
self.inMsg = (p.next + len(p.msgs)) % 256
-
+
+ # append messages, update sequence
self.imsgq += p.msgs
- self.inSeq = p.seq
if self.state == confirmed:
# unpack the observed sequence
tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
self.obSeq = tseq
+ self.inSeq = p.seq
+
if response:
reactor.callLater(0, self.sendNext)
self.lastReceived = time()
missed = ""
msgs = ""
+ # session / observed logic
if self.state == pending:
flags = flags | FLAG_SESSION
ids += pack("!L", self.sessionID)
flags = flags | FLAG_SESSION | FLAG_OBSERVED
ids += pack("!LL", self.observed, self.sessionID)
+ # missed header
if self.obSeq >= 0:
self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
for i in self.weMissed:
missed += chr(i[1])
+ # append any outgoing messages
if self.state == confirmed and self.omsgq:
- first = (self.next + 1) % 256
- while len(self.omsgq) and (len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE) :
- if self.obSeq == -1:
- highest = 0
- else:
- highest = self.outMsgNums[self.obSeq % 256]
- if self.next != -1 and (self.next + 1) % 256 == (highest - 1) % 256:
- break
- else:
- self.next = (self.next + 1) % 256
- msg = self.omsgq.pop()
- msgs += chr(len(msg) - 1) + msg
- self.outMsgs[self.next] = msg
+ first = self.next
+ outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
+ while len(self.omsgq) and outstanding < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
+ msg = self.omsgq.pop()
+ msgs += chr(len(msg) - 1) + msg
+ self.outMsgs[self.next] = msg
+ self.next = (self.next + 1) % 256
+ outstanding+=1
+ # update outgoing message stat
if msgs:
flags = flags | FLAG_NEXT
ids += chr(first)
self.lastTransmitSeq = self.outSeq
- self.outMsgNums[self.outSeq % 256] = first
- else:
- if self.next == -1:
- self.outMsgNums[self.outSeq % 256] = 0
- else:
- self.outMsgNums[self.outSeq % 256] = self.next
-
- if (self.obSeq - self.lastTransmitSeq) % 2**16 > 256 and self.outMsgNums[self.obSeq % 256] != self.next and not flags & FLAG_NEXT:
+ #self.outMsgNums[self.outSeq % 256] = first
+ #else:
+ self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
+
+ # do we need a NEXT flag despite not having sent any messages?
+ if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
flags = flags | FLAG_NEXT
- ids += chr((self.next + 1) % 256)
+ ids += chr(self.next)
+
+ # update stats and send packet
packet = chr(flags) + header + ids + missed + msgs
self.outSeq = (self.outSeq + 1) % 2**16
self.lastTransmit = time()
self.transport.write(packet)
+ # call later
if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
reactor.callLater(0, self.sendNext)
else:
def dataCameIn(self):
+ """
+ called when we get a packet bearing messages
+ delegate must do something with the messages or they will get dropped
+ """
self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
if hasattr(self, 'delegate') and self.delegate != None:
self.imsgq = []
class ustr(str):
+ """
+ this subclass of string encapsulates each ordered message, caches it's sequence number,
+ and has comparison functions to sort by sequence number
+ """
def getseq(self):
if not hasattr(self, 'seq'):
self.seq = unpack("!H", self[0:2])[0]
return self.seq
def __lt__(self, other):
- return self.getseq() < other.getseq()
+ return (self.getseq() - other.getseq()) % 2**16 > 255
def __le__(self, other):
- return self.getseq() <= other.getseq()
+ return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
def __eq__(self, other):
- return self.getseq() != other.getseq()
+ return self.getseq() == other.getseq()
def __ne__(self, other):
- return self.getseq() <= other.getseq()
+ return self.getseq() != other.getseq()
def __gt__(self, other):
- return self.getseq() > other.getseq()
+ return (self.getseq() - other.getseq()) % 2**16 < 256 and not self.__eq__(other)
def __ge__(self, other):
- return self.getseq() >= other.getseq()
-
+ return (self.getseq() - other.getseq()) % 2**16 < 256
+
class OrderedConnection(AirhookConnection):
+ """
+ this implements a simple protocol for ordered messages over airhook
+ the first two octets of each message are interpreted as a 16-bit sequence number
+ 253 bytes are used for payload
+ """
def __init__(self, transport, addr, delegate):
AirhookConnection.__init__(self, transport, addr, delegate)
self.oseq = 0
data = ""
while self.q and self.iseq == self.q[0].getseq():
data += self.q[0][2:]
- self.iseq = (self.iseq + 1) % 2**16
self.q = self.q[1:]
- if data:
+ self.iseq = (self.iseq + 1) % 2**16
+ if data != '':
self.msgDelegate('dataCameIn', (self.host, self.port, data))
def sendSomeData(self, data):
# chop it up and queue it up
while data:
- p = "%s%s" % (pack("!H", self.oseq), data[:253])
- self.omsgq.append(p)
+ p = pack("!H", self.oseq) + data[:253]
+ self.omsgq.insert(0, p)
data = data[253:]
self.oseq = (self.oseq + 1) % 2**16
print 6*dir + " " + pscope(msg)
return msg
+
+class UstrTests(unittest.TestCase):
+ def u(self, seq):
+ return ustr("%s%s" % (pack("!H", seq), 'foobar'))
+
+ def testLT(self):
+ self.failUnless(self.u(0) < self.u(1))
+ self.failUnless(self.u(1) < self.u(2))
+ self.failUnless(self.u(2**16 - 1) < self.u(0))
+ self.failUnless(self.u(2**16 - 1) < self.u(1))
+
+ self.failIf(self.u(1) < self.u(0))
+ self.failIf(self.u(2) < self.u(1))
+ self.failIf(self.u(0) < self.u(2**16 - 1))
+ self.failIf(self.u(1) < self.u(2**16 - 1))
+
+ def testLTE(self):
+ self.failUnless(self.u(0) <= self.u(1))
+ self.failUnless(self.u(1) <= self.u(2))
+ self.failUnless(self.u(2) <= self.u(2))
+ self.failUnless(self.u(2**16 - 1) <= self.u(0))
+ self.failUnless(self.u(2**16 - 1) <= self.u(1))
+ self.failUnless(self.u(2**16 - 1) <= self.u(2**16))
+
+ self.failIf(self.u(1) <= self.u(0))
+ self.failIf(self.u(2) <= self.u(1))
+ self.failIf(self.u(0) <= self.u(2**16 - 1))
+ self.failIf(self.u(1) <= self.u(2**16 - 1))
+
+ def testGT(self):
+ self.failUnless(self.u(1) > self.u(0))
+ self.failUnless(self.u(2) > self.u(1))
+ self.failUnless(self.u(0) > self.u(2**16 - 1))
+ self.failUnless(self.u(1) > self.u(2**16 - 1))
+
+ self.failIf(self.u(0) > self.u(1))
+ self.failIf(self.u(1) > self.u(2))
+ self.failIf(self.u(2**16 - 1) > self.u(0))
+ self.failIf(self.u(2**16 - 1) > self.u(1))
+
+ def testGTE(self):
+ self.failUnless(self.u(1) >= self.u(0))
+ self.failUnless(self.u(2) >= self.u(1))
+ self.failUnless(self.u(2) >= self.u(2))
+ self.failUnless(self.u(0) >= self.u(0))
+ self.failUnless(self.u(1) >= self.u(1))
+ self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1))
+
+ self.failIf(self.u(0) >= self.u(1))
+ self.failIf(self.u(1) >= self.u(2))
+ self.failIf(self.u(2**16 - 1) >= self.u(0))
+ self.failIf(self.u(2**16 - 1) >= self.u(1))
+
+ def testEQ(self):
+ self.failUnless(self.u(0) == self.u(0))
+ self.failUnless(self.u(1) == self.u(1))
+ self.failUnless(self.u(2**16 - 1) == self.u(2**16-1))
+
+ self.failIf(self.u(0) == self.u(1))
+ self.failIf(self.u(1) == self.u(0))
+ self.failIf(self.u(2**16 - 1) == self.u(0))
+
+ def testNEQ(self):
+ self.failUnless(self.u(1) != self.u(0))
+ self.failUnless(self.u(2) != self.u(1))
+ self.failIf(self.u(2) != self.u(2))
+ self.failIf(self.u(0) != self.u(0))
+ self.failIf(self.u(1) != self.u(1))
+ self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1))
+
+
class SimpleTest(unittest.TestCase):
def setUp(self):
self.noisy = 0
self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
def testReallySimple(self):
# connect to eachother and send a few packets, observe sequence incrementing
- self.noisy = 0
a = self.a
b = self.b
self.assertEqual(a.state, pending)
self.assertEqual(b.state, pending)
self.assertEqual(a.outSeq, 0)
self.assertEqual(b.outSeq, 0)
- self.assertEqual(a.obSeq, -1)
- self.assertEqual(b.obSeq, -1)
+ self.assertEqual(a.obSeq, 0)
+ self.assertEqual(b.obSeq, 0)
msg = swap(a, '>', self.noisy)
self.assertEqual(a.state, sent)
self.assertEqual(a.outSeq, 1)
- self.assertEqual(a.obSeq, -1)
+ self.assertEqual(a.obSeq, 0)
b.datagramReceived(msg)
self.assertEqual(b.state, sent)
self.assertEqual(b.inSeq, 0)
- self.assertEqual(b.obSeq, -1)
+ self.assertEqual(b.obSeq, 0)
msg = swap(b, '<', self.noisy)
self.assertEqual(b.outSeq, 1)
class BasicTests(unittest.TestCase):
def setUp(self):
+ self.noisy = 0
self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
- self.noisy = 0
def testSimple(self):
a = self.a
b = self.b
l.sort();a.imsgq.sort()
self.assertEqual(a.imsgq, l)
- def testOneWayBlast(self, num = 2**8):
+ def testOneWayBlast(self, num = 2**12):
a = self.a
b = self.b
import sha
for i in xrange(num):
a.omsgq.append(sha.sha(`i`).digest())
- msg = swap(a, '>', self.noisy)
- while a.state != confirmed or ord(msg[0]) & FLAG_NEXT:
- b.datagramReceived(msg)
- msg = swap(b, '<', self.noisy)
+ msga = swap(a, '>', self.noisy)
+ while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+ b.datagramReceived(msga)
+ msgb = swap(b, '<', self.noisy)
- a.datagramReceived(msg)
- msg = swap(a, '>', self.noisy)
+ a.datagramReceived(msgb)
+ msga = swap(a, '>', self.noisy)
self.assertEqual(len(b.imsgq), num)
- def testTwoWayBlast(self, num = 2**9, prob=0.5):
+ def testTwoWayBlast(self, num = 2**12, prob=0.5):
a = self.a
b = self.b
import sha
for i in range(5000):
a.omsgq.append(sha.sha('a' + 'i').digest())
- for i in range(31):
+ for i in range(5000 / 255):
msg = swap(a, noisy=self.noisy)
self.assertEqual(a.obSeq, 0)
- self.assertEqual(a.outMsgNums[a.obSeq], 0)
- self.assertEqual(a.next, 254)
- self.assertEqual(a.outMsgNums[19], 254)
+ self.assertEqual(a.next, 255)
+ self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
class OrderedTests(unittest.TestCase):
def setUp(self):
self.a = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.A)
self.b = OrderedConnection(StringIO(), (None, 'localhost', 4040), self.B)
- def testOrderedSimple(self, num = 2**17, prob=1.0):
+ def testOrderedSimple(self, num = 2**18, prob=1.0):
f = open('/dev/urandom', 'r')
a = self.a
b = self.b
a.datagramReceived(msgb)
else:
msgb = swap(b, '<', 0)
+ self.assertEqual(len(self.A.msg), len(MSGB))
+ self.assertEqual(len(self.B.msg), len(MSGA))
self.assertEqual(self.A.msg, MSGB)
self.assertEqual(self.B.msg, MSGA)
-
- def testOrderedLossy(self, num = 2**17, prob=0.5):
+
+ def testOrderedLossy(self, num = 2**18, prob=0.5):
self.testOrderedSimple(num, prob)