]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
twistified airhook, added reactor based tests
authorburris <burris>
Mon, 13 Jan 2003 07:08:01 +0000 (07:08 +0000)
committerburris <burris>
Mon, 13 Jan 2003 07:08:01 +0000 (07:08 +0000)
airhook.py
test_airhook.py

index b72f9ea23afd67ad3f4abd89743bbdccedcceeee..96009d4a751b2fe42ab4b9feaae36a6d3a671384 100644 (file)
@@ -4,12 +4,14 @@
 from random import uniform as rand
 from struct import pack, unpack
 from time import time
-from StringIO import StringIO
 import unittest
 from bisect import insort_left
 
 from twisted.internet import protocol
+from twisted.internet import abstract
 from twisted.internet import reactor
+from twisted.internet import app
+from twisted.internet import interfaces
 
 # flags
 FLAG_AIRHOOK = 128
@@ -19,333 +21,360 @@ FLAG_MISSED = 4
 FLAG_NEXT = 2
 FLAG_INTERVAL = 1
 
-MAX_PACKET_SIZE = 1496
+MAX_PACKET_SIZE = 1450
 
 pending = 0
 sent = 1
 confirmed = 2
 
-class Delegate:
-       def setDelegate(self, delegate):
-               self.delegate = delegate
-       def getDelegate(self):
-               return self.delegate
-       def msgDelegate(self, method, args=(), kwargs={}):
-               if hasattr(self, 'delegate') and hasattr(self.delegate, method) and callable(getattr(self.delegate, method)):
-                       apply(getattr(self.delegate, method) , args, kwargs)
 
-class Airhook(protocol.DatagramProtocol):
 
-       def __init__(self, connection_class):
-               self.connection_class = connection_class
-       def startProtocol(self):
-               self.connections = {}
-                               
-       def datagramReceived(self, datagram, addr):
-               flag = datagram[0]
-               if not flag & FLAG_AIRHOOK:  # first bit always must be 0
-                       conn = self.connectionForAddr(addr)
-                       conn.datagramReceieved(datagram)
+class Airhook(protocol.DatagramProtocol):       
+    def __init__(self):
+        self.noisy = None
+        # this should be changed to storage that drops old entries
+        self.connections = {}
+        
+    def datagramReceived(self, datagram, addr):
+        self.connectionForAddr(addr).datagramReceived(datagram)
 
-       def connectionForAddr(self, addr):
-               if not self.connections.has_key(addr):
-                       conn = connection_class(self.transport, addr, self.delegate)
-                       self.connections[addr] = conn
-               return self.connections[addr]
-
-               
+    def connectionForAddr(self, addr):
+        if not self.connections.has_key(addr):
+            conn = self.connection()
+            conn.protocol = self.factory.buildProtocol(addr)
+            conn.protocol.makeConnection(conn)
+            conn.makeConnection(self.transport)
+            conn.addr = addr
+            self.connections[addr] = conn
+        else:
+            conn = self.connections[addr]
+        return conn
+    
 class AirhookPacket:
-       def __init__(self, msg):
-               self.datagram = msg
-               self.oseq =  ord(msg[1])
-               self.seq = unpack("!H", msg[2:4])[0]
-               self.flags = ord(msg[0])
-               self.session = None
-               self.observed = None
-               self.next = None
-               self.missed = []
-               self.msgs = []
-               skip = 4
-               if self.flags & FLAG_OBSERVED:
-                       self.observed = unpack("!L", msg[skip:skip+4])[0]
-                       skip += 4
-               if self.flags & FLAG_SESSION:
-                       self.session =  unpack("!L", msg[skip:skip+4])[0]
-                       skip += 4
-               if self.flags & FLAG_NEXT:
-                       self.next =  ord(msg[skip])
-                       skip += 1
-               if self.flags & FLAG_MISSED:
-                       num = ord(msg[skip]) + 1
-                       skip += 1
-                       for i in range(num):
-                               self.missed.append( ord(msg[skip+i]))
-                       skip += num
-               if self.flags & FLAG_NEXT:
-                       while len(msg) - skip > 0:
-                               n = ord(msg[skip]) + 1
-                               skip+=1
-                               self.msgs.append( msg[skip:skip+n])
-                               skip += n
-
-class AirhookConnection(Delegate):
-       def __init__(self, transport, addr, delegate):
-               self.delegate = delegate
-               self.addr = addr
-               type, self.host, self.port = addr
-               self.transport = transport
-               
-               self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
-               self.obSeq = 0   # highest sequence confirmed by remote
-               self.inSeq = 0   # last received sequence
-               self.observed = None  # their session id
-               self.sessionID = long(rand(0, 2**32))  # our session id
-               
-               self.lastTransmit = -1  # time we last sent a packet with messages
-               self.lastReceieved = 0 # time we last received a packet with messages
-               self.lastTransmitSeq = -1 # last sequence we sent a packet
-               self.state = pending
-               
-               self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
-               self.omsgq = [] # list of messages to go out
-               self.imsgq = [] # list of messages coming in
-               self.sendSession = None  # send session/observed fields until obSeq > sendSession
+    def __init__(self, msg):
+        self.datagram = msg
+        self.oseq =  ord(msg[1])
+        self.seq = unpack("!H", msg[2:4])[0]
+        self.flags = ord(msg[0])
+        self.session = None
+        self.observed = None
+        self.next = None
+        self.missed = []
+        self.msgs = []
+        skip = 4
+        if self.flags & FLAG_OBSERVED:
+            self.observed = unpack("!L", msg[skip:skip+4])[0]
+            skip += 4
+        if self.flags & FLAG_SESSION:
+            self.session =  unpack("!L", msg[skip:skip+4])[0]
+            skip += 4
+        if self.flags & FLAG_NEXT:
+            self.next =  ord(msg[skip])
+            skip += 1
+        if self.flags & FLAG_MISSED:
+            num = ord(msg[skip]) + 1
+            skip += 1
+            for i in range(num):
+                self.missed.append( ord(msg[skip+i]))
+            skip += num
+        if self.flags & FLAG_NEXT:
+            while len(msg) - skip > 0:
+                n = ord(msg[skip]) + 1
+                skip+=1
+                self.msgs.append( msg[skip:skip+n])
+                skip += n
 
-               self.resetMessages()
-       
-       def resetMessages(self):
-               self.weMissed = []
-               self.inMsg = 0   # next incoming message number
-               self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
-               self.next = 0  # next outgoing message number
+class AirhookConnection(protocol.ConnectedDatagramProtocol, interfaces.IUDPConnectedTransport):
+    def __init__(self):        
+        self.outSeq = 0  # highest sequence we have sent, can't be 255 more than obSeq
+        self.obSeq = 0   # highest sequence confirmed by remote
+        self.inSeq = 0   # last received sequence
+        self.observed = None  # their session id
+        self.sessionID = long(rand(0, 2**32))  # our session id
+        
+        self.lastTransmit = -1  # time we last sent a packet with messages
+        self.lastReceieved = 0 # time we last received a packet with messages
+        self.lastTransmitSeq = -1 # last sequence we sent a packet
+        self.state = pending
+        
+        self.outMsgs = [None] * 256  # outgoing messages  (seq sent, message), index = message number
+        self.omsgq = [] # list of messages to go out
+        self.imsgq = [] # list of messages coming in
+        self.sendSession = None  # send session/observed fields until obSeq > sendSession
+        self.response = 0 # if we know we have a response now (like resending missed packets)
+        self.noisy = 0
+        self.resetMessages()
+    
+    def resetMessages(self):
+        self.weMissed = []
+        self.inMsg = 0   # next incoming message number
+        self.outMsgNums = [None] * 256 # outgoing message numbers i = outNum % 256
+        self.next = 0  # next outgoing message number
 
-       def datagramReceived(self, datagram):
-               if not datagram:
-                       return
-               response = 0 # if we know we have a response now (like resending missed packets)
-               p = AirhookPacket(datagram)
-               
-               # check to make sure sequence number isn't out of order
-               if (p.seq - self.inSeq) % 2**16 >= 256:
-                       return
-                       
-               # check for state change
-               if self.state == pending:
-                       if p.observed != None and p.session != None:
-                               if p.observed == self.sessionID:
-                                       self.observed = p.session
-                                       self.state = confirmed
-                               else:
-                                       # bogus packet!
-                                       return
-                       elif p.session != None:
-                               self.observed = p.session
-                               self.state = sent
-                               response = 1
-               elif self.state == sent:
-                       if p.observed != None and p.session != None:
-                               if p.observed == self.sessionID:
-                                       self.observed = p.session
-                                       self.sendSession = self.outSeq
-                                       self.state = confirmed
-                       if p.session != None:
-                               if not self.observed:
-                                       self.observed = p.session
-                               elif self.observed != p.session:
-                                       self.state = pending
-                                       self.resetMessages()
-                                       self.inSeq = p.seq
-                       response = 1
-               elif self.state == confirmed:
-                       if p.session != None or p.observed != None :
-                               if p.session != self.observed or p.observed != self.sessionID:
-                                       self.state = pending
-                                       if seq == 0:
-                                               self.resetMessages()
-                                               self.inSeq = p.seq
-       
-               if self.state != pending:       
-                       msgs = []               
-                       missed = []
-                       
-                       # see if they need us to resend anything
-                       for i in p.missed:
-                               response = 1
-                               if self.outMsgs[i] != None:
-                                       self.omsgq.append(self.outMsgs[i])
-                                       self.outMsgs[i] = None
-                                       
-                       # see if we missed any messages
-                       if p.next != None:
-                               missed_count = (p.next - self.inMsg) % 256
-                               if missed_count:
-                                       self.lastReceived = time()
-                                       for i in range(missed_count):
-                                               missed += [(self.outSeq, (self.inMsg + i) % 256)]
-                                       response = 1
-                                       self.weMissed += missed
-                               # record highest message number seen
-                               self.inMsg = (p.next + len(p.msgs)) % 256
-                       
-                       # append messages, update sequence
-                       self.imsgq += p.msgs
-                       
-               if self.state == confirmed:
-                       # unpack the observed sequence
-                       tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
-                       if ((self.outSeq - tseq)) % 2**16 > 255:
-                               tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
-                       self.obSeq = tseq
+    def datagramReceived(self, datagram):
+        if not datagram:
+            return
+        p = AirhookPacket(datagram)
+        
+        # check to make sure sequence number isn't out of order
+        if (p.seq - self.inSeq) % 2**16 >= 256:
+            return
+            
+        # check for state change
+        if self.state == pending:
+            if p.observed != None and p.session != None:
+                if p.observed == self.sessionID:
+                    self.observed = p.session
+                    self.state = confirmed
+                else:
+                    # bogus packet!
+                    return
+            elif p.session != None:
+                self.observed = p.session
+                self.state = sent
+                self.response = 1
+        elif self.state == sent:
+            if p.observed != None and p.session != None:
+                if p.observed == self.sessionID:
+                    self.observed = p.session
+                    self.sendSession = self.outSeq
+                    self.state = confirmed
+            if p.session != None:
+                if not self.observed:
+                    self.observed = p.session
+                elif self.observed != p.session:
+                    self.state = pending
+                    self.resetMessages()
+                    self.inSeq = p.seq
+            self.response = 1
+        elif self.state == confirmed:
+            if p.session != None or p.observed != None :
+                if p.session != self.observed or p.observed != self.sessionID:
+                    self.state = pending
+                    if seq == 0:
+                        self.resetMessages()
+                        self.inSeq = p.seq
+    
+        if self.state != pending:      
+            msgs = []          
+            missed = []
+            
+            # see if they need us to resend anything
+            for i in p.missed:
+                if self.outMsgs[i] != None:
+                    self.omsgq.append(self.outMsgs[i])
+                    self.outMsgs[i] = None
+                    
+            # see if we missed any messages
+            if p.next != None:
+                missed_count = (p.next - self.inMsg) % 256
+                if missed_count:
+                    self.lastReceived = time()
+                    for i in range(missed_count):
+                        missed += [(self.outSeq, (self.inMsg + i) % 256)]
+                    self.weMissed += missed
+                    self.response = 1
+                # record highest message number seen
+                self.inMsg = (p.next + len(p.msgs)) % 256
+            
+            # append messages, update sequence
+            self.imsgq += p.msgs
+            
+        if self.state == confirmed:
+            # unpack the observed sequence
+            tseq = unpack('!H', pack('!H', self.outSeq)[0] +  chr(p.oseq))[0]
+            if ((self.outSeq - tseq)) % 2**16 > 255:
+                tseq = unpack('!H', chr(ord(pack('!H', self.outSeq)[0]) - 1) + chr(p.oseq))[0]
+            self.obSeq = tseq
 
-               self.inSeq = p.seq
+        self.inSeq = p.seq
 
-               if response:
-                       reactor.callLater(0, self.sendNext)
-               self.lastReceived = time()
-               self.dataCameIn()
-               
-       def sendNext(self):
-               flags = 0
-               header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
-               ids = ""
-               missed = ""
-               msgs = ""
-               
-               # session / observed logic
-               if self.state == pending:
-                       flags = flags | FLAG_SESSION
-                       ids +=  pack("!L", self.sessionID)
-                       self.state = sent
-               elif self.state == sent:
-                       if self.observed != None:
-                               flags = flags | FLAG_SESSION | FLAG_OBSERVED
-                               ids +=  pack("!LL", self.observed, self.sessionID)
-                       else:
-                               flags = flags | FLAG_SESSION
-                               ids +=  pack("!L", self.sessionID)
+        self.lastReceived = time()
+        self.dataCameIn()
+        
+        self.schedule()
+        
+    def sendNext(self):
+        flags = 0
+        header = chr(self.inSeq & 255) + pack("!H", self.outSeq)
+        ids = ""
+        missed = ""
+        msgs = ""
+        
+        # session / observed logic
+        if self.state == pending:
+            flags = flags | FLAG_SESSION
+            ids +=  pack("!L", self.sessionID)
+            self.state = sent
+        elif self.state == sent:
+            if self.observed != None:
+                flags = flags | FLAG_SESSION | FLAG_OBSERVED
+                ids +=  pack("!LL", self.observed, self.sessionID)
+            else:
+                flags = flags | FLAG_SESSION
+                ids +=  pack("!L", self.sessionID)
 
-               else:
-                       if self.state == sent or self.sendSession:
-                               if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
-                                       self.sendSession = None
-                               else:
-                                       flags = flags | FLAG_SESSION | FLAG_OBSERVED
-                                       ids +=  pack("!LL", self.observed, self.sessionID)
-               
-               # missed header
-               if self.obSeq >= 0:
-                       self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
+        else:
+            if self.state == sent or self.sendSession:
+                if self.state == confirmed and (self.obSeq - self.sendSession) % 2**16 < 256:
+                    self.sendSession = None
+                else:
+                    flags = flags | FLAG_SESSION | FLAG_OBSERVED
+                    ids +=  pack("!LL", self.observed, self.sessionID)
+        
+        # missed header
+        if self.obSeq >= 0:
+            self.weMissed = filter(lambda a: a[0] > self.obSeq, self.weMissed)
 
-               if self.weMissed:
-                       flags = flags | FLAG_MISSED
-                       missed += chr(len(self.weMissed) - 1)
-                       for i in self.weMissed:
-                               missed += chr(i[1])
-                               
-               # append any outgoing messages
-               if self.state == confirmed and self.omsgq:
-                       first = self.next
-                       outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
-                       while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
-                               msg = self.omsgq.pop()
-                               msgs += chr(len(msg) - 1) + msg
-                               self.outMsgs[self.next] = msg
-                               self.next = (self.next + 1) % 256
-                               outstanding+=1
-               # update outgoing message stat
-               if msgs:
-                       flags = flags | FLAG_NEXT
-                       ids += chr(first)
-                       self.lastTransmitSeq = self.outSeq
-                       #self.outMsgNums[self.outSeq % 256] = first
-               #else:
-               self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
-               
-               # do we need a NEXT flag despite not having sent any messages?
-               if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
-                               flags = flags | FLAG_NEXT
-                               ids += chr(self.next)
-               
-               # update stats and send packet
-               packet = chr(flags) + header + ids + missed + msgs
-               self.outSeq = (self.outSeq + 1) % 2**16
-               self.lastTransmit = time()
-               self.transport.write(packet)
-               
-               # call later
-               if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
-                       reactor.callLater(0, self.sendNext)
-               else:
-                       reactor.callLater(1, self.sendNext)
+        if self.weMissed:
+            flags = flags | FLAG_MISSED
+            missed += chr(len(self.weMissed) - 1)
+            for i in self.weMissed:
+                missed += chr(i[1])
+                
+        # append any outgoing messages
+        if self.state == confirmed and self.omsgq:
+            first = self.next
+            outstanding = (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256
+            while len(self.omsgq) and outstanding  < 255 and len(self.omsgq[-1]) + len(msgs) + len(missed) + len(ids) + len(header) + 1 <= MAX_PACKET_SIZE:
+                msg = self.omsgq.pop()
+                msgs += chr(len(msg) - 1) + msg
+                self.outMsgs[self.next] = msg
+                self.next = (self.next + 1) % 256
+                outstanding+=1
+        # update outgoing message stat
+        if msgs:
+            flags = flags | FLAG_NEXT
+            ids += chr(first)
+            self.lastTransmitSeq = self.outSeq
+            #self.outMsgNums[self.outSeq % 256] = first
+        #else:
+        self.outMsgNums[self.outSeq % 256] = (self.next - 1) % 256
+        
+        # do we need a NEXT flag despite not having sent any messages?
+        if not flags & FLAG_NEXT and (256 + (((self.next - 1) % 256) - self.outMsgNums[self.obSeq % 256])) % 256 > 0:
+            flags = flags | FLAG_NEXT
+            ids += chr(self.next)
+        
+        # update stats and send packet
+        packet = chr(flags) + header + ids + missed + msgs
+        self.outSeq = (self.outSeq + 1) % 2**16
+        self.lastTransmit = time()
+        self.transport.write(packet, self.addr)
+        
+        self.schedule()
+        
+    def timeToSend(self):
+        # any outstanding messages and are we not too far ahead of our counterparty?
+        if self.omsgq and (self.next + 1) % 256 != self.outMsgNums[self.obSeq % 256] and (self.outSeq - self.obSeq) % 2**16 < 256:
+            return 1
+        # do we explicitly need to send a response?
+        elif self.response:
+            self.response = 0
+            return 1
+        # have we not sent anything in a while?
+        elif time() - self.lastTransmit > 1.0:
+            return 1
+        
+        # nothing to send
+        return 0
 
-
-       def dataCameIn(self):
-               """
-               called when we get a packet bearing messages
-               delegate must do something with the messages or they will get dropped 
-               """
-               self.msgDelegate('dataCameIn', (self.host, self.port, self.imsgq))
-               if hasattr(self, 'delegate') and self.delegate != None:
-                       self.imsgq = []
+    def schedule(self):
+        if self.timeToSend():
+            reactor.callLater(0, self.sendNext)
+        else:
+            reactor.callLater(1, self.sendNext)
+        
+    def write(self, data):
+        # micropackets can only be 255 bytes or less
+        if len(data) <= 255:
+            self.omsgq.insert(0, data)
+        self.schedule()
+        
+    def dataCameIn(self):
+        """
+        called when we get a packet bearing messages
+        """
+        for msg in self.imsgq:
+            self.protocol.dataReceived(msg)
+        self.imsgq = []
 
 class ustr(str):
-       """
-               this subclass of string encapsulates each ordered message, caches it's sequence number,
-               and has comparison functions to sort by sequence number
-       """
-       def getseq(self):
-               if not hasattr(self, 'seq'):
-                       self.seq = unpack("!H", self[0:2])[0]
-               return self.seq
-       def __lt__(self, other):
-               return (self.getseq() - other.getseq()) % 2**16 > 255
-       def __le__(self, other):
-               return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
-       def __eq__(self, other):
-               return self.getseq() == other.getseq()
-       def __ne__(self, other):
-               return self.getseq() != other.getseq()
-       def __gt__(self, other):
-               return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
-       def __ge__(self, other):
-               return (self.getseq() - other.getseq()) % 2**16 < 256
-               
+    """
+        this subclass of string encapsulates each ordered message, caches it's sequence number,
+        and has comparison functions to sort by sequence number
+    """
+    def getseq(self):
+        if not hasattr(self, 'seq'):
+            self.seq = unpack("!H", self[0:2])[0]
+        return self.seq
+    def __lt__(self, other):
+        return (self.getseq() - other.getseq()) % 2**16 > 255
+    def __le__(self, other):
+        return (self.getseq() - other.getseq()) % 2**16 > 255 or self.__eq__(other)
+    def __eq__(self, other):
+        return self.getseq() == other.getseq()
+    def __ne__(self, other):
+        return self.getseq() != other.getseq()
+    def __gt__(self, other):
+        return (self.getseq() - other.getseq()) % 2**16 < 256  and not self.__eq__(other)
+    def __ge__(self, other):
+        return (self.getseq() - other.getseq()) % 2**16 < 256
+        
 class StreamConnection(AirhookConnection):
-       """
-               this implements a simple protocol for a stream over airhook
-               the first two octets of each message are interpreted as a 16-bit sequence number
-               253 bytes are used for payload
-               
-               delegate should implement method:
-                 def dataCameIn(self, host, port, data):
-                       [...]
-       """
-       def __init__(self, transport, addr, delegate):
-               AirhookConnection.__init__(self, transport, addr, delegate)
-               self.oseq = 0
-               self.iseq = 0
-               self.q = []
+    """
+        this implements a simple protocol for a stream over airhook
+        this is done for convenience, instead of making it a twisted.internet.protocol....
+        the first two octets of each message are interpreted as a 16-bit sequence number
+        253 bytes are used for payload
+        
+    """
+    def __init__(self):
+        AirhookConnection.__init__(self)
+        self.oseq = 0
+        self.iseq = 0
+        self.q = []
+
+    def dataCameIn(self):
+        # put 'em together
+        for msg in self.imsgq:
+            insort_left(self.q, ustr(msg))
+        self.imsgq = []
+        data = ""
+        while self.q and self.iseq == self.q[0].getseq():
+            data += self.q[0][2:]
+            self.q = self.q[1:]
+            self.iseq = (self.iseq + 1) % 2**16
+        if data != '':
+            self.protocol.dataReceived(data)
+        
+    def write(self, data):
+        # chop it up and queue it up
+        while data:
+            p = pack("!H", self.oseq) + data[:253]
+            self.omsgq.insert(0, p)
+            data = data[253:]
+            self.oseq = (self.oseq + 1) % 2**16
+
+        self.schedule()
+        
+    def writeSequence(self, sequence):
+        for data in sequence:
+            self.write(data)
+
+
+def listenAirhook(port, factory):
+    ah = Airhook()
+    ah.connection = AirhookConnection
+    ah.factory = factory
+    reactor.listenUDP(port, ah)
+    return ah
 
-       def dataCameIn(self):
-               # put 'em together
-               for msg in self.imsgq:
-                       insort_left(self.q, ustr(msg))
-               self.imsgq = []
-               data = ""
-               while self.q and self.iseq == self.q[0].getseq():
-                       data += self.q[0][2:]
-                       self.q = self.q[1:]
-                       self.iseq = (self.iseq + 1) % 2**16
-               if data != '':
-                       self.msgDelegate('dataCameIn', (self.host, self.port, data))
-               
-       def sendSomeData(self, data):
-               # chop it up and queue it up
-               while data:
-                       p = pack("!H", self.oseq) + data[:253]
-                       self.omsgq.insert(0, p)
-                       data = data[253:]
-                       self.oseq = (self.oseq + 1) % 2**16
+def listenAirhookStream(port, factory):
+    ah = Airhook()
+    ah.connection = StreamConnection
+    ah.factory = factory
+    reactor.listenUDP(port, ah)
+    return ah
 
-               if self.omsgq:
-                       self.sendNext()
+    
index 014e13bd05ae311c060ffcc3da3315dc7b364985..aabaebc0befb924d406f3ca12b636c0050ea739e 100644 (file)
 import unittest
 from airhook import *
 from random import uniform as rand
-
-if __name__ =="__main__":
-       tests = unittest.defaultTestLoader.loadTestsFromNames(['test_airhook'])
-       result = unittest.TextTestRunner().run(tests)
-
+from cStringIO import StringIO
 
 
+if __name__ =="__main__":
+    tests = unittest.defaultTestLoader.loadTestsFromNames(['test_airhook'])
+    result = unittest.TextTestRunner().run(tests)
+
+class Echo(protocol.Protocol):
+    def dataReceived(self, data):
+        self.transport.write(data)
+        
+class Noisy(protocol.Protocol):
+    def dataReceived(self, data):
+        print `data`
+
+class Receiver(protocol.Protocol):
+    def __init__(self):
+        self.q = []
+    def dataReceived(self, data):
+        self.q.append(data)
+
+class StreamReceiver(protocol.Protocol):
+    def __init__(self):
+        self.buf = ""
+    def dataReceived(self, data):
+        self.buf += data
+        
+class EchoFactory(protocol.Factory):
+    def buildProtocol(self, addr):
+        return Echo()
+class NoisyFactory(protocol.Factory):
+    def buildProtocol(self, addr):
+        return Noisy()
+class ReceiverFactory(protocol.Factory):
+    def buildProtocol(self, addr):
+        return Receiver()
+class StreamReceiverFactory(protocol.Factory):
+    def buildProtocol(self, addr):
+        return StreamReceiver()
+        
+def makeEcho(port):
+    return listenAirhookStream(port, EchoFactory())
+def makeNoisy(port):
+    return listenAirhookStream(port, NoisyFactory())
+def makeReceiver(port):
+    return listenAirhookStream(port, ReceiverFactory())
+def makeStreamReceiver(port):
+    return listenAirhookStream(port, StreamReceiverFactory())
+
+class DummyTransport:
+    def __init__(self):
+        self.s = StringIO()
+    def write(self, data, addr):
+        self.s.write(data)
+    def seek(self, num):
+        return self.s.seek(num)
+    def read(self):
+        return self.s.read()
+        
 def test_createStartPacket():
-       flags = 0 | FLAG_AIRHOOK | FLAG_SESSION 
-       packet = chr(flags) + "\xff" + "\x00\x00" + pack("!L", long(rand(0, 2**32)))
-       return packet
+    flags = 0 | FLAG_AIRHOOK | FLAG_SESSION 
+    packet = chr(flags) + "\xff" + "\x00\x00" + pack("!L", long(rand(0, 2**32)))
+    return packet
 
 def test_createReply(session, observed, obseq, seq):
-       flags = 0 | FLAG_AIRHOOK | FLAG_SESSION | FLAG_OBSERVED
-       packet = chr(flags) + pack("!H", seq)[1] + pack("!H", obseq + 1) + pack("!L", session) + pack("!L", observed)
-       return packet
-
+    flags = 0 | FLAG_AIRHOOK | FLAG_SESSION | FLAG_OBSERVED
+    packet = chr(flags) + pack("!H", seq)[1] + pack("!H", obseq + 1) + pack("!L", session) + pack("!L", observed)
+    return packet
 
 def pscope(msg, noisy=0):
-       # packet scope
-       str = ""
-       p = AirhookPacket(msg)
-       str += "oseq: %s  seq: %s " %  (p.oseq, p.seq)
-       if noisy:
-               str += "packet: %s  \n" % (`p.datagram`)
-       flags = p.flags
-       str += "flags: "
-       if flags & FLAG_SESSION:
-               str += "FLAG_SESSION "
-       if flags & FLAG_OBSERVED:
-               str += "FLAG_OBSERVED "
-       if flags & FLAG_MISSED:
-               str += "FLAG_MISSED "
-       if flags & FLAG_NEXT:
-               str += "FLAG_NEXT "
-       str += "\n"
-       
-       if p.observed != None:
-               str += "OBSERVED: %s\n" % p.observed
-       if p.session != None:
-               str += "SESSION: %s\n" % p.session
-       if p.next != None:
-               str += "NEXT: %s\n" % p.next
-       if p.missed:
-               if noisy:
-                       str += "MISSED: " + `p.missed`
-               else:
-                       str += "MISSED: " + `len(p.missed)`
-               str += "\n"
-       if p.msgs:
-               if noisy:
-                       str += "MSGS: " + `p.msgs` + "\n"
-               else:
-                       str += "MSGS: <%s> " % len(p.msgs)
-               str += "\n"
-       return str
-                       
+    # packet scope
+    str = ""
+    p = AirhookPacket(msg)
+    str += "oseq: %s  seq: %s " %  (p.oseq, p.seq)
+    if noisy:
+        str += "packet: %s  \n" % (`p.datagram`)
+    flags = p.flags
+    str += "flags: "
+    if flags & FLAG_SESSION:
+        str += "FLAG_SESSION "
+    if flags & FLAG_OBSERVED:
+        str += "FLAG_OBSERVED "
+    if flags & FLAG_MISSED:
+        str += "FLAG_MISSED "
+    if flags & FLAG_NEXT:
+        str += "FLAG_NEXT "
+    str += "\n"
+    
+    if p.observed != None:
+        str += "OBSERVED: %s\n" % p.observed
+    if p.session != None:
+        str += "SESSION: %s\n" % p.session
+    if p.next != None:
+        str += "NEXT: %s\n" % p.next
+    if p.missed:
+        if noisy:
+            str += "MISSED: " + `p.missed`
+        else:
+            str += "MISSED: " + `len(p.missed)`
+        str += "\n"
+    if p.msgs:
+        if noisy:
+            str += "MSGS: " + `p.msgs` + "\n"
+        else:
+            str += "MSGS: <%s> " % len(p.msgs)
+        str += "\n"
+    return str
+            
 # testing function
 def swap(a, dir="", noisy=0):
-       msg = ""
-       while not msg:
-               a.transport.seek(0)
-               msg= a.transport.read()
-               a.transport = StringIO()
-               if not msg:
-                       a.sendNext()
-       if noisy:
-                               print 6*dir + " " + pscope(msg)
-       return msg
-       
+    msg = ""
+    while not msg:
+        a.transport.seek(0)
+        msg= a.transport.read()
+        a.transport = DummyTransport()
+        if not msg:
+            a.sendNext()
+    if noisy:
+                print 6*dir + " " + pscope(msg)
+    return msg
+    
+def runTillEmpty(a, b, prob=1.0, noisy=0):
+    msga = ''
+    msgb = ''
+    while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
+        if rand(0,1) < prob:
+            msga = swap(a, '>', noisy)
+            b.datagramReceived(msga)
+        else:
+            msga = swap(a, '>', 0)
+        if rand(0,1) < prob:
+            msgb = swap(b, '<', noisy)
+            a.datagramReceived(msgb)
+        else:
+            msgb = swap(b, '<', 0)
 
 class UstrTests(unittest.TestCase):
-       def u(self, seq):
-               return ustr("%s%s" % (pack("!H", seq), 'foobar'))
-               
-       def testLT(self):
-               self.failUnless(self.u(0) < self.u(1))
-               self.failUnless(self.u(1) < self.u(2))
-               self.failUnless(self.u(2**16 - 1) < self.u(0))
-               self.failUnless(self.u(2**16 - 1) < self.u(1))
-               
-               self.failIf(self.u(1) < self.u(0))
-               self.failIf(self.u(2) < self.u(1))
-               self.failIf(self.u(0) < self.u(2**16 - 1))
-               self.failIf(self.u(1) < self.u(2**16 - 1))
-               
-       def testLTE(self):
-               self.failUnless(self.u(0) <= self.u(1))
-               self.failUnless(self.u(1) <= self.u(2))
-               self.failUnless(self.u(2) <= self.u(2))
-               self.failUnless(self.u(2**16 - 1) <= self.u(0))
-               self.failUnless(self.u(2**16 - 1) <= self.u(1))
-               self.failUnless(self.u(2**16 - 1) <= self.u(2**16))
-
-               self.failIf(self.u(1) <= self.u(0))
-               self.failIf(self.u(2) <= self.u(1))
-               self.failIf(self.u(0) <= self.u(2**16 - 1))
-               self.failIf(self.u(1) <= self.u(2**16 - 1))
-               
-       def testGT(self):
-               self.failUnless(self.u(1) > self.u(0))
-               self.failUnless(self.u(2) > self.u(1))
-               self.failUnless(self.u(0) > self.u(2**16 - 1))
-               self.failUnless(self.u(1) > self.u(2**16 - 1))
-
-               self.failIf(self.u(0) > self.u(1))
-               self.failIf(self.u(1) > self.u(2))
-               self.failIf(self.u(2**16 - 1) > self.u(0))
-               self.failIf(self.u(2**16 - 1) > self.u(1))
-
-       def testGTE(self):
-               self.failUnless(self.u(1) >= self.u(0))
-               self.failUnless(self.u(2) >= self.u(1))
-               self.failUnless(self.u(2) >= self.u(2))
-               self.failUnless(self.u(0) >= self.u(0))
-               self.failUnless(self.u(1) >= self.u(1))
-               self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1))
-
-               self.failIf(self.u(0) >= self.u(1))
-               self.failIf(self.u(1) >= self.u(2))
-               self.failIf(self.u(2**16 - 1) >= self.u(0))
-               self.failIf(self.u(2**16 - 1) >= self.u(1))
-               
-       def testEQ(self):
-               self.failUnless(self.u(0) == self.u(0))
-               self.failUnless(self.u(1) == self.u(1))
-               self.failUnless(self.u(2**16 - 1) == self.u(2**16-1))
-       
-               self.failIf(self.u(0) == self.u(1))
-               self.failIf(self.u(1) == self.u(0))
-               self.failIf(self.u(2**16 - 1) == self.u(0))
-
-       def testNEQ(self):
-               self.failUnless(self.u(1) != self.u(0))
-               self.failUnless(self.u(2) != self.u(1))
-               self.failIf(self.u(2) != self.u(2))
-               self.failIf(self.u(0) != self.u(0))
-               self.failIf(self.u(1) != self.u(1))
-               self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1))
+    def u(self, seq):
+        return ustr("%s%s" % (pack("!H", seq), 'foobar'))
+        
+    def testLT(self):
+        self.failUnless(self.u(0) < self.u(1))
+        self.failUnless(self.u(1) < self.u(2))
+        self.failUnless(self.u(2**16 - 1) < self.u(0))
+        self.failUnless(self.u(2**16 - 1) < self.u(1))
+        
+        self.failIf(self.u(1) < self.u(0))
+        self.failIf(self.u(2) < self.u(1))
+        self.failIf(self.u(0) < self.u(2**16 - 1))
+        self.failIf(self.u(1) < self.u(2**16 - 1))
+        
+    def testLTE(self):
+        self.failUnless(self.u(0) <= self.u(1))
+        self.failUnless(self.u(1) <= self.u(2))
+        self.failUnless(self.u(2) <= self.u(2))
+        self.failUnless(self.u(2**16 - 1) <= self.u(0))
+        self.failUnless(self.u(2**16 - 1) <= self.u(1))
+        self.failUnless(self.u(2**16 - 1) <= self.u(2**16))
+
+        self.failIf(self.u(1) <= self.u(0))
+        self.failIf(self.u(2) <= self.u(1))
+        self.failIf(self.u(0) <= self.u(2**16 - 1))
+        self.failIf(self.u(1) <= self.u(2**16 - 1))
+        
+    def testGT(self):
+        self.failUnless(self.u(1) > self.u(0))
+        self.failUnless(self.u(2) > self.u(1))
+        self.failUnless(self.u(0) > self.u(2**16 - 1))
+        self.failUnless(self.u(1) > self.u(2**16 - 1))
+
+        self.failIf(self.u(0) > self.u(1))
+        self.failIf(self.u(1) > self.u(2))
+        self.failIf(self.u(2**16 - 1) > self.u(0))
+        self.failIf(self.u(2**16 - 1) > self.u(1))
+
+    def testGTE(self):
+        self.failUnless(self.u(1) >= self.u(0))
+        self.failUnless(self.u(2) >= self.u(1))
+        self.failUnless(self.u(2) >= self.u(2))
+        self.failUnless(self.u(0) >= self.u(0))
+        self.failUnless(self.u(1) >= self.u(1))
+        self.failUnless(self.u(2**16 - 1) >= self.u(2**16 - 1))
+
+        self.failIf(self.u(0) >= self.u(1))
+        self.failIf(self.u(1) >= self.u(2))
+        self.failIf(self.u(2**16 - 1) >= self.u(0))
+        self.failIf(self.u(2**16 - 1) >= self.u(1))
+        
+    def testEQ(self):
+        self.failUnless(self.u(0) == self.u(0))
+        self.failUnless(self.u(1) == self.u(1))
+        self.failUnless(self.u(2**16 - 1) == self.u(2**16-1))
+    
+        self.failIf(self.u(0) == self.u(1))
+        self.failIf(self.u(1) == self.u(0))
+        self.failIf(self.u(2**16 - 1) == self.u(0))
+
+    def testNEQ(self):
+        self.failUnless(self.u(1) != self.u(0))
+        self.failUnless(self.u(2) != self.u(1))
+        self.failIf(self.u(2) != self.u(2))
+        self.failIf(self.u(0) != self.u(0))
+        self.failIf(self.u(1) != self.u(1))
+        self.failIf(self.u(2**16 - 1) != self.u(2**16 - 1))
 
 
 class SimpleTest(unittest.TestCase):
-       def setUp(self):
-               self.noisy = 0
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
-       def testReallySimple(self):
-               # connect to eachother and send a few packets, observe sequence incrementing
-               a = self.a
-               b = self.b
-               self.assertEqual(a.state, pending)
-               self.assertEqual(b.state, pending)
-               self.assertEqual(a.outSeq, 0)
-               self.assertEqual(b.outSeq, 0)
-               self.assertEqual(a.obSeq, 0)
-               self.assertEqual(b.obSeq, 0)
-
-               msg = swap(a, '>', self.noisy)          
-               self.assertEqual(a.state, sent)
-               self.assertEqual(a.outSeq, 1)
-               self.assertEqual(a.obSeq, 0)
-
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, sent)
-               self.assertEqual(b.inSeq, 0)
-               self.assertEqual(b.obSeq, 0)
-               msg = swap(b, '<', self.noisy)          
-               self.assertEqual(b.outSeq, 1)
-
-               a.datagramReceived(msg)
-               self.assertEqual(a.state, confirmed)
-               self.assertEqual(a.obSeq, 0)
-               self.assertEqual(a.inSeq, 0)
-               msg = swap(a, '>', self.noisy)          
-               self.assertEqual(a.outSeq, 2)
-
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, confirmed)
-               self.assertEqual(b.obSeq, 0)
-               self.assertEqual(b.inSeq, 1)
-               msg = swap(b, '<', self.noisy)          
-               self.assertEqual(b.outSeq, 2)
-
-               a.datagramReceived(msg)
-               self.assertEqual(a.outSeq, 2)
-               self.assertEqual(a.inSeq, 1)
-               self.assertEqual(a.obSeq, 1)
+    def setUp(self):
+        self.noisy = 0
+        self.a = AirhookConnection()
+        self.a.makeConnection(DummyTransport())
+        self.a.addr = ('127.0.0.1', 4444)
+        self.b = AirhookConnection()
+        self.b.makeConnection(DummyTransport())
+        self.b.addr = ('127.0.0.1', 4444)
+
+    def testReallySimple(self):
+        # connect to eachother and send a few packets, observe sequence incrementing
+        a = self.a
+        b = self.b
+        self.assertEqual(a.state, pending)
+        self.assertEqual(b.state, pending)
+        self.assertEqual(a.outSeq, 0)
+        self.assertEqual(b.outSeq, 0)
+        self.assertEqual(a.obSeq, 0)
+        self.assertEqual(b.obSeq, 0)
+
+        msg = swap(a, '>', self.noisy)         
+        self.assertEqual(a.state, sent)
+        self.assertEqual(a.outSeq, 1)
+        self.assertEqual(a.obSeq, 0)
+
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, sent)
+        self.assertEqual(b.inSeq, 0)
+        self.assertEqual(b.obSeq, 0)
+        msg = swap(b, '<', self.noisy)         
+        self.assertEqual(b.outSeq, 1)
+
+        a.datagramReceived(msg)
+        self.assertEqual(a.state, confirmed)
+        self.assertEqual(a.obSeq, 0)
+        self.assertEqual(a.inSeq, 0)
+        msg = swap(a, '>', self.noisy)         
+        self.assertEqual(a.outSeq, 2)
+
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, confirmed)
+        self.assertEqual(b.obSeq, 0)
+        self.assertEqual(b.inSeq, 1)
+        msg = swap(b, '<', self.noisy)         
+        self.assertEqual(b.outSeq, 2)
+
+        a.datagramReceived(msg)
+        self.assertEqual(a.outSeq, 2)
+        self.assertEqual(a.inSeq, 1)
+        self.assertEqual(a.obSeq, 1)
 
 class BasicTests(unittest.TestCase):
-       def setUp(self):
-               self.noisy = 0
-               self.a = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
-               self.b = AirhookConnection(StringIO(), (None, 'localhost', 4040), None)
-       def testSimple(self):
-               a = self.a
-               b = self.b
-               
-               TESTMSG = "Howdy, Y'All!"
-               a.omsgq.append(TESTMSG)
-               a.sendNext()
-               msg = swap(a, '>', self.noisy)
-               
-               b.datagramReceived(msg)
-               msg = swap(b, '<', self.noisy)
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-               
-               self.assertEqual(b.inMsg, 1)
-               self.assertEqual(len(b.imsgq), 1)
-               self.assertEqual(b.imsgq[0], TESTMSG)
-               
-               msg = swap(b, '<', self.noisy)
-               
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-               
-       def testLostFirst(self):
-               a = self.a
-               b = self.b
-               
-               TESTMSG = "Howdy, Y'All!"
-               TESTMSG2 = "Yee Haw"
-               
-               a.omsgq.append(TESTMSG)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-               msg = swap(b, '<', self.noisy)
-               self.assertEqual(b.state, sent)
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               del(msg) # dropping first message
-               
-               a.omsgq.append(TESTMSG2)
-               msg = swap(a, '>', self.noisy)
-       
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, confirmed)
-               self.assertEqual(len(b.imsgq), 1)
-               self.assertEqual(b.imsgq[0], TESTMSG2)
-               self.assertEqual(b.weMissed, [(1, 0)])
-               msg = swap(b, '<', self.noisy)
-               
-               a.datagramReceived(msg)
-                                                               
-               msg = swap(a, '>', self.noisy)
-               
-               b.datagramReceived(msg)
-               self.assertEqual(len(b.imsgq), 2)
-               b.imsgq.sort()
-               l = [TESTMSG2, TESTMSG]
-               l.sort()
-               self.assertEqual(b.imsgq,l)
-               
-               msg = swap(b, '<', self.noisy)
-               
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-               
-               msg = swap(b, '<', self.noisy)
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-
-               msg = swap(b, '<', self.noisy)
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               self.assertEqual(len(b.imsgq), 2)
-               b.imsgq.sort()
-               l = [TESTMSG2, TESTMSG]
-               l.sort()
-               self.assertEqual(b.imsgq,l)
-
-       def testLostSecond(self):
-               a = self.a
-               b = self.b
-               
-               TESTMSG = "Howdy, Y'All!"
-               TESTMSG2 = "Yee Haw"
-               
-               a.omsgq.append(TESTMSG)
-               msg = swap(a, '>', self.noisy)
-               b.datagramReceived(msg)
-               msg = swap(b, '<', self.noisy)
-               self.assertEqual(b.state, sent)
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               a.omsgq.append(TESTMSG2)
-               msg2 = swap(a, '>', self.noisy)
-               del(msg2) # dropping second message
-
-               assert(a.outMsgs[1] != None)
-
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, confirmed)
-               self.assertEqual(len(b.imsgq), 1)
-               self.assertEqual(b.imsgq[0], TESTMSG)
-               self.assertEqual(b.inMsg, 1)
-               self.assertEqual(b.weMissed, [])
-               msg = swap(b, '<', self.noisy)
-               
-               a.datagramReceived(msg)
-               assert(a.outMsgs[1] != None)
-               msg = swap(a, '>', self.noisy)
-
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, confirmed)
-               self.assertEqual(len(b.imsgq), 1)
-               self.assertEqual(b.imsgq[0], TESTMSG)
-               self.assertEqual(b.weMissed, [(2, 1)])
-               msg = swap(b, '<', self.noisy)
-
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               b.datagramReceived(msg)
-               self.assertEqual(len(b.imsgq), 2)
-               b.imsgq.sort()
-               l = [TESTMSG2, TESTMSG]
-               l.sort()
-               self.assertEqual(b.imsgq,l)
-               
-               msg = swap(b, '<', self.noisy)
-
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               b.datagramReceived(msg)
-               
-               msg = swap(b, '<', self.noisy)
-
-               a.datagramReceived(msg)
-               msg = swap(a, '>', self.noisy)
-
-               b.datagramReceived(msg)
-
-               msg = swap(b, '<', self.noisy)
-
-               a.datagramReceived(msg)
-
-
-               msg = swap(a, '>', self.noisy)
-
-               self.assertEqual(len(b.imsgq), 2)
-               b.imsgq.sort()
-               l = [TESTMSG2, TESTMSG]
-               l.sort()
-               self.assertEqual(b.imsgq,l)
-
-       def testDoubleDouble(self):
-               a = self.a
-               b = self.b
-               
-               TESTMSGA = "Howdy, Y'All!"
-               TESTMSGB = "Yee Haw"
-               TESTMSGC = "FOO BAR"
-               TESTMSGD = "WING WANG"
-               
-               a.omsgq.append(TESTMSGA)
-               a.omsgq.append(TESTMSGB)
-
-               b.omsgq.append(TESTMSGC)
-               b.omsgq.append(TESTMSGD)
-               
-               
-               msg = swap(a, '>', self.noisy)
-                       
-
-               b.datagramReceived(msg)
-               self.assertEqual(b.state, sent)
-               
-               msg = swap(b, '<', self.noisy)
-               a.datagramReceived(msg)
-
-               msg = swap(a, '>', self.noisy)
-
-               b.datagramReceived(msg)
-               self.assertEqual(len(b.imsgq), 2)
-               l = [TESTMSGA, TESTMSGB]
-               l.sort();b.imsgq.sort()
-               self.assertEqual(b.imsgq, l)
-               self.assertEqual(b.inMsg, 2)
-
-               msg = swap(b, '<', self.noisy)
-               a.datagramReceived(msg)
-               
-               self.assertEqual(len(a.imsgq), 2)
-               l = [TESTMSGC, TESTMSGD]
-               l.sort();a.imsgq.sort()
-               self.assertEqual(a.imsgq, l)
-               self.assertEqual(a.inMsg, 2)
-
-       def testDoubleDoubleProb(self, prob=0.25):
-               a = self.a
-               b = self.b
-               TESTMSGA = "Howdy, Y'All!"
-               TESTMSGB = "Yee Haw"
-               TESTMSGC = "FOO BAR"
-               TESTMSGD = "WING WANG"
-               
-               a.omsgq.append(TESTMSGA)
-               a.omsgq.append(TESTMSGB)
-
-               b.omsgq.append(TESTMSGC)
-               b.omsgq.append(TESTMSGD)
-               
-               while a.state != confirmed or b.state != confirmed or ord(msga[0]) & FLAG_NEXT or ord(msgb[0]) & FLAG_NEXT :
-                       msga = swap(a, '>', self.noisy)
-       
-                       if rand(0,1) < prob:
-                               b.datagramReceived(msga)
-                       
-                       msgb = swap(b, '<', self.noisy)
-
-                       if rand(0,1) < prob:
-                               a.datagramReceived(msgb)
-
-               self.assertEqual(a.state, confirmed)
-               self.assertEqual(b.state, confirmed)
-               self.assertEqual(len(b.imsgq), 2)
-               l = [TESTMSGA, TESTMSGB]
-               l.sort();b.imsgq.sort()
-               self.assertEqual(b.imsgq, l)
-                               
-               self.assertEqual(len(a.imsgq), 2)
-               l = [TESTMSGC, TESTMSGD]
-               l.sort();a.imsgq.sort()
-               self.assertEqual(a.imsgq, l)
-
-       def testOneWayBlast(self, num = 2**12):
-               a = self.a
-               b = self.b
-               import sha
-               
-               
-               for i in xrange(num):
-                       a.omsgq.append(sha.sha(`i`).digest())
-               msga = swap(a, '>', self.noisy)
-               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
-                       b.datagramReceived(msga)
-                       msgb = swap(b, '<', self.noisy)
-
-                       a.datagramReceived(msgb)
-                       msga = swap(a, '>', self.noisy)
-
-               self.assertEqual(len(b.imsgq), num)
-               
-       def testTwoWayBlast(self, num = 2**12, prob=0.5):
-               a = self.a
-               b = self.b
-               import sha
-               
-               
-               for i in xrange(num):
-                       a.omsgq.append(sha.sha('a' + `i`).digest())
-                       b.omsgq.append(sha.sha('b' + `i`).digest())
-                       
-               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
-                       if rand(0,1) < prob:
-                               msga = swap(a, '>', self.noisy)
-                               b.datagramReceived(msga)
-                       else:
-                               msga = swap(a, '>', 0)
-                       if rand(0,1) < prob:
-                               msgb = swap(b, '<', self.noisy)
-                               a.datagramReceived(msgb)
-                       else:
-                               msgb = swap(b, '<', 0)
-                                       
-
-
-               self.assertEqual(len(a.imsgq), num)
-               self.assertEqual(len(b.imsgq), num)
-               
-       def testLimitMessageNumbers(self):
-               a = self.a
-               b = self.b
-               import sha
-
-               msg = swap(a, noisy=self.noisy)
-               b.datagramReceived(msg)
-
-               msg = swap(b, noisy=self.noisy)
-               a.datagramReceived(msg)
-               
-               
-               for i in range(5000):
-                       a.omsgq.append(sha.sha('a' + 'i').digest())
-               
-               for i in range(5000 / 255):
-                       msg = swap(a, noisy=self.noisy)
-                       self.assertEqual(a.obSeq, 0)
-               self.assertEqual(a.next, 255)
-               self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
+    def setUp(self):
+        self.noisy = 0
+        self.a = AirhookConnection()
+        self.a.makeConnection(DummyTransport())
+        self.a.addr = ('127.0.0.1', 4444)
+        self.b = AirhookConnection()
+        self.b.makeConnection(DummyTransport())
+        self.b.addr = ('127.0.0.1', 4444)
+        self.a.protocol = Receiver()
+        self.b.protocol = Receiver()
+
+    def testSimple(self):
+        a = self.a
+        b = self.b
+        
+        TESTMSG = "Howdy, Y'All!"
+        a.omsgq.append(TESTMSG)
+        a.sendNext()
+        msg = swap(a, '>', self.noisy)
+        
+        b.datagramReceived(msg)
+        msg = swap(b, '<', self.noisy)
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+        
+        self.assertEqual(b.inMsg, 1)
+        self.assertEqual(len(b.protocol.q), 1)
+        self.assertEqual(b.protocol.q[0], TESTMSG)
+        
+        msg = swap(b, '<', self.noisy)
+        
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+        
+    def testLostFirst(self):
+        a = self.a
+        b = self.b
+        
+        TESTMSG = "Howdy, Y'All!"
+        TESTMSG2 = "Yee Haw"
+        
+        a.omsgq.append(TESTMSG)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+        msg = swap(b, '<', self.noisy)
+        self.assertEqual(b.state, sent)
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        del(msg) # dropping first message
+        
+        a.omsgq.append(TESTMSG2)
+        msg = swap(a, '>', self.noisy)
+    
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, confirmed)
+        self.assertEqual(len(b.protocol.q), 1)
+        self.assertEqual(b.protocol.q[0], TESTMSG2)
+        self.assertEqual(b.weMissed, [(1, 0)])
+        msg = swap(b, '<', self.noisy)
+        
+        a.datagramReceived(msg)
+                                
+        msg = swap(a, '>', self.noisy)
+        
+        b.datagramReceived(msg)
+        self.assertEqual(len(b.protocol.q), 2)
+        b.protocol.q.sort()
+        l = [TESTMSG2, TESTMSG]
+        l.sort()
+        self.assertEqual(b.protocol.q,l)
+        
+        msg = swap(b, '<', self.noisy)
+        
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+        
+        msg = swap(b, '<', self.noisy)
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, '<', self.noisy)
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        self.assertEqual(len(b.protocol.q), 2)
+        b.protocol.q.sort()
+        l = [TESTMSG2, TESTMSG]
+        l.sort()
+        self.assertEqual(b.protocol.q,l)
+
+    def testLostSecond(self):
+        a = self.a
+        b = self.b
+        
+        TESTMSG = "Howdy, Y'All!"
+        TESTMSG2 = "Yee Haw"
+        
+        a.omsgq.append(TESTMSG)
+        msg = swap(a, '>', self.noisy)
+        b.datagramReceived(msg)
+        msg = swap(b, '<', self.noisy)
+        self.assertEqual(b.state, sent)
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        a.omsgq.append(TESTMSG2)
+        msg2 = swap(a, '>', self.noisy)
+        del(msg2) # dropping second message
+
+        assert(a.outMsgs[1] != None)
+
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, confirmed)
+        self.assertEqual(len(b.protocol.q), 1)
+        self.assertEqual(b.protocol.q[0], TESTMSG)
+        self.assertEqual(b.inMsg, 1)
+        self.assertEqual(b.weMissed, [])
+        msg = swap(b, '<', self.noisy)
+        
+        a.datagramReceived(msg)
+        assert(a.outMsgs[1] != None)
+        msg = swap(a, '>', self.noisy)
+
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, confirmed)
+        self.assertEqual(len(b.protocol.q), 1)
+        self.assertEqual(b.protocol.q[0], TESTMSG)
+        self.assertEqual(b.weMissed, [(2, 1)])
+        msg = swap(b, '<', self.noisy)
+
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        b.datagramReceived(msg)
+        self.assertEqual(len(b.protocol.q), 2)
+        b.protocol.q.sort()
+        l = [TESTMSG2, TESTMSG]
+        l.sort()
+        self.assertEqual(b.protocol.q,l)
+        
+        msg = swap(b, '<', self.noisy)
+
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        b.datagramReceived(msg)
+        
+        msg = swap(b, '<', self.noisy)
+
+        a.datagramReceived(msg)
+        msg = swap(a, '>', self.noisy)
+
+        b.datagramReceived(msg)
+
+        msg = swap(b, '<', self.noisy)
+
+        a.datagramReceived(msg)
+
+
+        msg = swap(a, '>', self.noisy)
+
+        self.assertEqual(len(b.protocol.q), 2)
+        b.protocol.q.sort()
+        l = [TESTMSG2, TESTMSG]
+        l.sort()
+        self.assertEqual(b.protocol.q,l)
+
+    def testDoubleDouble(self):
+        a = self.a
+        b = self.b
+        
+        TESTMSGA = "Howdy, Y'All!"
+        TESTMSGB = "Yee Haw"
+        TESTMSGC = "FOO BAR"
+        TESTMSGD = "WING WANG"
+        
+        a.omsgq.append(TESTMSGA)
+        a.omsgq.append(TESTMSGB)
+
+        b.omsgq.append(TESTMSGC)
+        b.omsgq.append(TESTMSGD)
+        
+        
+        msg = swap(a, '>', self.noisy)
+            
+
+        b.datagramReceived(msg)
+        self.assertEqual(b.state, sent)
+        
+        msg = swap(b, '<', self.noisy)
+        a.datagramReceived(msg)
+
+        msg = swap(a, '>', self.noisy)
+
+        b.datagramReceived(msg)
+        self.assertEqual(len(b.protocol.q), 2)
+        l = [TESTMSGA, TESTMSGB]
+        l.sort();b.protocol.q.sort()
+        self.assertEqual(b.protocol.q, l)
+        self.assertEqual(b.inMsg, 2)
+
+        msg = swap(b, '<', self.noisy)
+        a.datagramReceived(msg)
+        
+        self.assertEqual(len(a.protocol.q), 2)
+        l = [TESTMSGC, TESTMSGD]
+        l.sort();a.protocol.q.sort()
+        self.assertEqual(a.protocol.q, l)
+        self.assertEqual(a.inMsg, 2)
+
+    def testDoubleDoubleProb(self, prob=0.25):
+        a = self.a
+        b = self.b
+
+        TESTMSGA = "Howdy, Y'All!"
+        TESTMSGB = "Yee Haw"
+        TESTMSGC = "FOO BAR"
+        TESTMSGD = "WING WANG"
+        
+        a.omsgq.append(TESTMSGA)
+        a.omsgq.append(TESTMSGB)
+
+        b.omsgq.append(TESTMSGC)
+        b.omsgq.append(TESTMSGD)
+        
+        runTillEmpty(a, b, prob, self.noisy)
+        
+        self.assertEqual(a.state, confirmed)
+        self.assertEqual(b.state, confirmed)
+        self.assertEqual(len(b.protocol.q), 2)
+        l = [TESTMSGA, TESTMSGB]
+        l.sort();b.protocol.q.sort()
+        self.assertEqual(b.protocol.q, l)
+                
+        self.assertEqual(len(a.protocol.q), 2)
+        l = [TESTMSGC, TESTMSGD]
+        l.sort();a.protocol.q.sort()
+        self.assertEqual(a.protocol.q, l)
+
+    def testOneWayBlast(self, num = 2**12):
+        a = self.a
+        b = self.b
+        
+        import sha
+        
+        
+        for i in xrange(num):
+            a.omsgq.append(sha.sha(`i`).digest())
+        runTillEmpty(a, b, noisy=self.noisy)
+
+        self.assertEqual(len(b.protocol.q), num)
+        
+    def testTwoWayBlast(self, num = 2**12, prob=0.5):
+        a = self.a
+        b = self.b
+
+        import sha
+        
+        
+        for i in xrange(num):
+            a.omsgq.append(sha.sha('a' + `i`).digest())
+            b.omsgq.append(sha.sha('b' + `i`).digest())
+            
+        runTillEmpty(a, b, prob, self.noisy)                    
+
+
+        self.assertEqual(len(a.protocol.q), num)
+        self.assertEqual(len(b.protocol.q), num)
+        
+    def testLimitMessageNumbers(self):
+        a = self.a
+        b = self.b
+        import sha
+
+        msg = swap(a, noisy=self.noisy)
+        b.datagramReceived(msg)
+
+        msg = swap(b, noisy=self.noisy)
+        a.datagramReceived(msg)
+        
+        
+        for i in range(5000):
+            a.omsgq.append(sha.sha('a' + 'i').digest())
+        
+        for i in range(5000 / 255):
+            msg = swap(a, noisy=self.noisy)
+            self.assertEqual(a.obSeq, 0)
+        self.assertEqual(a.next, 255)
+        self.assertEqual(a.outMsgNums[(a.outSeq-1) % 256], 254)
 
 class StreamTests(unittest.TestCase):
-       def setUp(self):
-               self.noisy = 0
-               class queuer:
-                       def __init__(self):
-                               self.msg = ""
-                       def dataCameIn(self, host, port, data):
-                               self.msg+= data
-               self.A = queuer()
-               self.B = queuer()
-               self.a = StreamConnection(StringIO(), (None, 'localhost', 4040), self.A)
-               self.b = StreamConnection(StringIO(), (None, 'localhost', 4040), self.B)
-
-       def testStreamSimple(self, num = 2**18, prob=1.0):
-               f = open('/dev/urandom', 'r')
-               a = self.a
-               b = self.b
-               A = self.A
-               B = self.B
-
-               MSGA = f.read(num)
-               MSGB = f.read(num)
-               self.a.sendSomeData(MSGA)
-               self.b.sendSomeData(MSGB)
-               
-               while a.omsgq or b.omsgq or a.weMissed or b.weMissed or ord(msga[0]) & (FLAG_NEXT | FLAG_MISSED) or ord(msgb[0]) & (FLAG_NEXT | FLAG_MISSED):
-                       if rand(0,1) < prob:
-                               msga = swap(a, '>', self.noisy)
-                               b.datagramReceived(msga)
-                       else:
-                               msga = swap(a, '>', 0)
-                       if rand(0,1) < prob:
-                               msgb = swap(b, '<', self.noisy)
-                               a.datagramReceived(msgb)
-                       else:
-                               msgb = swap(b, '<', 0)
-               self.assertEqual(len(self.A.msg), len(MSGB))
-               self.assertEqual(len(self.B.msg), len(MSGA))
-               self.assertEqual(self.A.msg, MSGB)
-               self.assertEqual(self.B.msg, MSGA)
-
-       def testStreamLossy(self, num = 2**18, prob=0.5):
-               self.testStreamSimple(num, prob)
+    def setUp(self):
+        self.noisy = 0
+        self.a = StreamConnection()
+        self.a.makeConnection(DummyTransport())
+        self.a.addr = ('127.0.0.1', 4444)
+        self.b = StreamConnection()
+        self.b.makeConnection(DummyTransport())
+        self.b.addr = ('127.0.0.1', 4444)
+        self.a.protocol = StreamReceiver()
+        self.b.protocol = StreamReceiver()
+
+    def testStreamSimple(self, num = 2**12, prob=1.0):
+        f = open('/dev/urandom', 'r')
+        a = self.a
+        b = self.b
+
+        MSGA = f.read(num)
+        MSGB = f.read(num)
+        self.a.write(MSGA)
+        self.b.write(MSGB)
+        
+        runTillEmpty(a, b, prob, self.noisy)
+                
+        self.assertEqual(len(a.protocol.buf), len(MSGB))
+        self.assertEqual(len(b.protocol.buf), len(MSGA))
+        self.assertEqual(a.protocol.buf, MSGB)
+        self.assertEqual(b.protocol.buf, MSGA)
+
+    def testStreamLossy(self, num = 2**12, prob=0.5):
+        self.testStreamSimple(num, prob)
+
+class SimpleReactor(unittest.TestCase):
+    def setUp(self):
+        self.noisy = 0
+        self.a = makeReceiver(2020)
+        self.b = makeReceiver(2021)
+        self.ac = self.a.connectionForAddr(('127.0.0.1', 2021))
+        self.bc = self.b.connectionForAddr(('127.0.0.1', 2020))
+    def testSimple(self):
+        msg = "Testing 1, 2, 3"
+        self.ac.write(msg)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.bc.protocol.q, [msg])
+
+class SimpleReactorEcho(unittest.TestCase):
+    def setUp(self):
+        self.noisy = 0
+        self.a = makeReceiver(2022)
+        self.b = makeEcho(2023)
+        self.ac = self.a.connectionForAddr(('127.0.0.1', 2023))
+        self.bc = self.b.connectionForAddr(('127.0.0.1', 2022))
+    def testSimple(self):
+        msg = "Testing 1, 2, 3"
+        self.ac.write(msg)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.ac.protocol.q, [msg])
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.ac.protocol.q, [msg])
+
+
+class SimpleReactorStream(unittest.TestCase):
+    def setUp(self):
+        self.noisy = 0
+        self.a = makeStreamReceiver(2024)
+        self.b = makeStreamReceiver(2025)
+        self.ac = self.a.connectionForAddr(('127.0.0.1', 2025))
+        self.bc = self.b.connectionForAddr(('127.0.0.1', 2024))
+    def testSimple(self):
+        msg = "Testing 1, 2, 3"
+        self.ac.write(msg)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.bc.protocol.buf, msg)
+        
+class SimpleReactorStreamBig(unittest.TestCase):
+    def setUp(self):
+        self.noisy = 0
+        self.a = makeStreamReceiver(2026)
+        self.b = makeStreamReceiver(2027)
+        self.ac = self.a.connectionForAddr(('127.0.0.1', 2027))
+        self.bc = self.b.connectionForAddr(('127.0.0.1', 2026))
+    def testBig(self):
+        msg = open('/dev/urandom').read(4096)
+        self.ac.write(msg)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.bc.protocol.buf, msg)
+
+class EchoReactorStreamBig(unittest.TestCase):
+    def setUp(self):
+        self.noisy = 0
+        self.a = makeStreamReceiver(2028)
+        self.b = makeEcho(2029)
+        self.ac = self.a.connectionForAddr(('127.0.0.1', 2028))
+        self.bc = self.b.connectionForAddr(('127.0.0.1', 2029))
+    def testBig(self):
+        msg = open('/dev/urandom').read(4096)
+        self.ac.write(msg)
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        reactor.iterate()
+        self.assertEqual(self.ac.protocol.buf, msg)
+
+        
\ No newline at end of file