Set a new peer's ranking values so they don't get an unfair advantage.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPDownloader.py
index f1488e540481a51a75b445a2bf2fa48ca465ec4c..8ecac193f6cba22c71ac1747d5df89d327a4ea4f 100644 (file)
@@ -9,7 +9,8 @@ from twisted.internet.protocol import ClientFactory
 from twisted import version as twisted_version
 from twisted.python import log
 from twisted.web2.client.interfaces import IHTTPClientManager
-from twisted.web2.client.http import ProtocolError, ClientRequest, HTTPClientProtocol
+from twisted.web2.client.http import ProtocolError, ClientRequest, HTTPClientProtocol, HTTPClientChannelRequest
+from twisted.web2.channel.http import PERSIST_NO_PIPELINE, PERSIST_PIPELINE
 from twisted.web2 import stream as stream_mod, http_headers
 from twisted.web2 import version as web2_version
 from twisted.trial import unittest
@@ -17,6 +18,33 @@ from zope.interface import implements
 
 from apt_p2p_conf import version
 
+class PipelineError(Exception):
+    """An error has occurred in pipelining requests."""
+
+class FixedHTTPClientChannelRequest(HTTPClientChannelRequest):
+    """Fix the broken _error function."""
+
+    def __init__(self, channel, request, closeAfter):
+        HTTPClientChannelRequest.__init__(self, channel, request, closeAfter)
+        self.started = False
+
+    def _error(self, err):
+        """
+        Abort parsing, and depending of the status of the request, either fire
+        the C{responseDefer} if no response has been sent yet, or close the
+        stream.
+        """
+        if self.started:
+            self.abortParse()
+        if hasattr(self, 'stream') and self.stream is not None:
+            self.stream.finish(err)
+        else:
+            self.responseDefer.errback(err)
+
+    def gotInitialLine(self, initialLine):
+        self.started = True
+        HTTPClientChannelRequest.gotInitialLine(self, initialLine)
+    
 class LoggingHTTPClientProtocol(HTTPClientProtocol):
     """A modified client protocol that logs the number of bytes received."""
     
@@ -35,6 +63,65 @@ class LoggingHTTPClientProtocol(HTTPClientProtocol):
             self.stats.receivedBytes(len(data), self.mirror)
         HTTPClientProtocol.rawDataReceived(self, data)
 
+    def submitRequest(self, request, closeAfter=True):
+        """
+        @param request: The request to send to a remote server.
+        @type request: L{ClientRequest}
+
+        @param closeAfter: If True the 'Connection: close' header will be sent,
+            otherwise 'Connection: keep-alive'
+        @type closeAfter: C{bool}
+
+        @rtype: L{twisted.internet.defer.Deferred}
+        @return: A Deferred which will be called back with the
+            L{twisted.web2.http.Response} from the server.
+        """
+
+        # Assert we're in a valid state to submit more
+        assert self.outRequest is None
+        assert ((self.readPersistent is PERSIST_NO_PIPELINE
+                 and not self.inRequests)
+                or self.readPersistent is PERSIST_PIPELINE)
+
+        self.manager.clientBusy(self)
+        if closeAfter:
+            self.readPersistent = False
+
+        self.outRequest = chanRequest = FixedHTTPClientChannelRequest(self,
+                                            request, closeAfter)
+        self.inRequests.append(chanRequest)
+
+        chanRequest.submit()
+        return chanRequest.responseDefer
+
+    def setReadPersistent(self, persist):
+        oldPersist = self.readPersistent
+        self.readPersistent = persist
+        if not persist:
+            # Tell all requests but first to abort.
+            lostRequests = self.inRequests[1:]
+            del self.inRequests[1:]
+            for request in lostRequests:
+                request.connectionLost(PipelineError('Pipelined connection was closed.'))
+        elif (oldPersist is PERSIST_NO_PIPELINE and
+              persist is PERSIST_PIPELINE and
+              self.outRequest is None):
+            self.manager.clientPipelining(self)
+
+    def connectionLost(self, reason):
+        self.readPersistent = False
+        self.setTimeout(None)
+        self.manager.clientGone(self)
+        # Cancel the current request
+        if self.inRequests and self.inRequests[0] is not None:
+            self.inRequests[0].connectionLost(reason)
+        # Tell all remaining requests to abort.
+        lostRequests = self.inRequests[1:]
+        del self.inRequests[1:]
+        for request in lostRequests:
+            if request is not None:
+                request.connectionLost(PipelineError('Pipelined connection was closed.'))
+                
 class Peer(ClientFactory):
     """A manager for all HTTP requests to a single peer.
     
@@ -51,13 +138,13 @@ class Peer(ClientFactory):
         self.port = port
         self.stats = stats
         self.mirror = False
-        self.rank = 0.5
+        self.rank = 0.01
         self.busy = False
         self.pipeline = False
         self.closed = True
         self.connecting = False
         self.request_queue = []
-        self.response_queue = []
+        self.outstanding = 0
         self.proto = None
         self.connector = None
         self._errors = 0
@@ -73,17 +160,19 @@ class Peer(ClientFactory):
     def connect(self):
         """Connect to the peer."""
         assert self.closed and not self.connecting
+        log.msg('Connecting to (%s, %d)' % (self.host, self.port))
         self.connecting = True
         d = protocol.ClientCreator(reactor, LoggingHTTPClientProtocol, self,
-                                   stats = self.stats, mirror = self.mirror).connectTCP(self.host, self.port)
+                                   stats = self.stats, mirror = self.mirror).connectTCP(self.host, self.port, timeout = 10)
         d.addCallbacks(self.connected, self.connectionError)
 
     def connected(self, proto):
         """Begin processing the queued requests."""
+        log.msg('Connected to (%s, %d)' % (self.host, self.port))
         self.closed = False
         self.connecting = False
         self.proto = proto
-        self.processQueue()
+        reactor.callLater(0, self.processQueue)
         
     def connectionError(self, err):
         """Cancel the requests."""
@@ -92,8 +181,8 @@ class Peer(ClientFactory):
 
         # Remove one request so that we don't loop indefinitely
         if self.request_queue:
-            req = self.request_queue.pop(0)
-            req.deferRequest.errback(err)
+            req, deferRequest, submissionTime = self.request_queue.pop(0)
+            deferRequest.errback(err)
             
         self._completed += 1
         self._errors += 1
@@ -113,12 +202,12 @@ class Peer(ClientFactory):
         @type request: L{twisted.web2.client.http.ClientRequest}
         @return: deferred that will fire with the completed request
         """
-        request.submissionTime = datetime.now()
-        request.deferRequest = defer.Deferred()
-        self.request_queue.append(request)
+        submissionTime = datetime.now()
+        deferRequest = defer.Deferred()
+        self.request_queue.append((request, deferRequest, submissionTime))
         self.rerank()
-        self.processQueue()
-        return request.deferRequest
+        reactor.callLater(0, self.processQueue)
+        return deferRequest
 
     def processQueue(self):
         """Check the queue to see if new requests can be sent to the peer."""
@@ -131,36 +220,55 @@ class Peer(ClientFactory):
             return
         if self.busy and not self.pipeline:
             return
-        if self.response_queue and not self.pipeline:
+        if self.outstanding and not self.pipeline:
+            return
+        if not ((self.proto.readPersistent is PERSIST_NO_PIPELINE
+                 and not self.proto.inRequests)
+                 or self.proto.readPersistent is PERSIST_PIPELINE):
+            log.msg('HTTP protocol is not ready though we were told to pipeline: %r, %r' %
+                    (self.proto.readPersistent, self.proto.inRequests))
             return
 
-        req = self.request_queue.pop(0)
-        self.response_queue.append(req)
+        req, deferRequest, submissionTime = self.request_queue.pop(0)
+        try:
+            deferResponse = self.proto.submitRequest(req, False)
+        except:
+            # Try again later
+            log.msg('Got an error trying to submit a new HTTP request %s' % (request.uri, ))
+            log.err()
+            self.request_queue.insert(0, (request, deferRequest, submissionTime))
+            ractor.callLater(1, self.processQueue)
+            return
+            
+        self.outstanding += 1
         self.rerank()
-        req.deferResponse = self.proto.submitRequest(req, False)
-        req.deferResponse.addCallbacks(self.requestComplete, self.requestError)
+        deferResponse.addCallbacks(self.requestComplete, self.requestError,
+                                   callbackArgs = (req, deferRequest, submissionTime),
+                                   errbackArgs = (req, deferRequest))
 
-    def requestComplete(self, resp):
+    def requestComplete(self, resp, req, deferRequest, submissionTime):
         """Process a completed request."""
         self._processLastResponse()
-        req = self.response_queue.pop(0)
-        log.msg('%s of %s completed with code %d' % (req.method, req.uri, resp.code))
+        self.outstanding -= 1
+        assert self.outstanding >= 0
+        log.msg('%s of %s completed with code %d (%r)' % (req.method, req.uri, resp.code, resp.headers))
         self._completed += 1
         now = datetime.now()
-        self._responseTimes.append((now, now - req.submissionTime))
+        self._responseTimes.append((now, now - submissionTime))
         self._lastResponse = (now, resp.stream.length)
         self.rerank()
-        req.deferRequest.callback(resp)
+        deferRequest.callback(resp)
 
-    def requestError(self, error):
+    def requestError(self, error, req, deferRequest):
         """Process a request that ended with an error."""
         self._processLastResponse()
-        req = self.response_queue.pop(0)
+        self.outstanding -= 1
+        assert self.outstanding >= 0
         log.msg('Download of %s generated error %r' % (req.uri, error))
         self._completed += 1
         self._errors += 1
         self.rerank()
-        req.deferRequest.errback(error)
+        deferRequest.errback(error)
         
     def hashError(self, error):
         """Log that a hash error occurred from the peer."""
@@ -177,28 +285,26 @@ class Peer(ClientFactory):
         """Try to send a new request."""
         self._processLastResponse()
         self.busy = False
-        self.processQueue()
+        reactor.callLater(0, self.processQueue)
         self.rerank()
 
     def clientPipelining(self, proto):
         """Try to send a new request."""
         self.pipeline = True
-        self.processQueue()
+        reactor.callLater(0, self.processQueue)
 
     def clientGone(self, proto):
         """Mark sent requests as errors."""
         self._processLastResponse()
-        for req in self.response_queue:
-            req.deferRequest.errback(ProtocolError('lost connection'))
+        log.msg('Lost the connection to (%s, %d)' % (self.host, self.port))
         self.busy = False
         self.pipeline = False
         self.closed = True
         self.connecting = False
-        self.response_queue = []
         self.proto = None
         self.rerank()
         if self.request_queue:
-            self.processQueue()
+            reactor.callLater(0, self.processQueue)
             
     #{ Downloading request interface
     def setCommonHeaders(self):
@@ -247,13 +353,14 @@ class Peer(ClientFactory):
     #{ Peer information
     def isIdle(self):
         """Check whether the peer is idle or not."""
-        return not self.busy and not self.request_queue and not self.response_queue
+        return not self.busy and not self.request_queue and not self.outstanding
     
     def _processLastResponse(self):
         """Save the download time of the last request for speed calculations."""
         if self._lastResponse is not None:
-            now = datetime.now()
-            self._downloadSpeeds.append((now, now - self._lastResponse[0], self._lastResponse[1]))
+            if self._lastResponse[1] is not None:
+                now = datetime.now()
+                self._downloadSpeeds.append((now, now - self._lastResponse[0], self._lastResponse[1]))
             self._lastResponse = None
             
     def downloadSpeed(self):
@@ -270,7 +377,7 @@ class Peer(ClientFactory):
 
         # If there are none, then you get 0
         if not self._downloadSpeeds:
-            return 0.0
+            return 150000.0
         
         for download in self._downloadSpeeds:
             total_time += download[1].days*86400.0 + download[1].seconds + download[1].microseconds/1000000.0
@@ -293,7 +400,7 @@ class Peer(ClientFactory):
 
         # If there are none, give it the benefit of the doubt
         if not self._responseTimes:
-            return 0.0
+            return 0.1
 
         for response in self._responseTimes:
             total_response += response[1].days*86400.0 + response[1].seconds + response[1].microseconds/1000000.0
@@ -314,12 +421,12 @@ class Peer(ClientFactory):
         rank = 1.0
         if self.closed:
             rank *= 0.9
-        rank *= exp(-(len(self.request_queue) - len(self.response_queue)))
+        rank *= exp(-(len(self.request_queue) + self.outstanding))
         speed = self.downloadSpeed()
         if speed > 0.0:
             rank *= exp(-512.0*1024 / speed)
         if self._completed:
-            rank *= exp(-float(self._errors) / self._completed)
+            rank *= exp(-10.0 * self._errors / self._completed)
         rank *= exp(-self.responseTime() / 5.0)
         self.rank = rank
         
@@ -328,16 +435,23 @@ class TestClientManager(unittest.TestCase):
     
     client = None
     pending_calls = []
+    length = []
     
     def gotResp(self, resp, num, expect):
         self.failUnless(resp.code >= 200 and resp.code < 300, "Got a non-200 response: %r" % resp.code)
         if expect is not None:
             self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
-        def print_(n):
-            pass
-        def printdone(n):
-            pass
-        stream_mod.readStream(resp.stream, print_).addCallback(printdone)
+        while len(self.length) <= num:
+            self.length.append(0)
+        self.length[num] = 0
+        def addData(data, self = self, num = num):
+            self.length[num] += len(data)
+        def checkLength(resp, self = self, num = num, length = resp.stream.length):
+            self.failUnlessEqual(self.length[num], length)
+            return resp
+        df = stream_mod.readStream(resp.stream, addData)
+        df.addCallback(checkLength)
+        return df
     
     def test_download(self):
         """Tests a normal download."""
@@ -378,7 +492,7 @@ class TestClientManager(unittest.TestCase):
         newRequest("/rfc/rfc0801.txt", 3, 40824)
         
         # This one will probably be queued
-        self.pending_calls.append(reactor.callLater(1, newRequest, '/rfc/rfc0013.txt', 4, 1070))
+        self.pending_calls.append(reactor.callLater(6, newRequest, '/rfc/rfc0013.txt', 4, 1070))
         
         # Connection should still be open, but idle
         self.pending_calls.append(reactor.callLater(10, newRequest, '/rfc/rfc0022.txt', 5, 4606))
@@ -462,6 +576,45 @@ class TestClientManager(unittest.TestCase):
         d.addCallback(self.gotResp, 1, 100)
         return d
         
+    def test_timeout(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import TimeoutError
+        host = 'steveholt.hopto.org'
+        self.client = Peer(host, 80)
+        self.timeout = 60
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, TimeoutError)
+        d.addCallback(lambda a: self.flushLoggedErrors(TimeoutError))
+        return d
+        
+    def test_dnserror(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import DNSLookupError
+        host = 'hureyfnvbfha.debian.net'
+        self.client = Peer(host, 80)
+        self.timeout = 5
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, DNSLookupError)
+        d.addCallback(lambda a: self.flushLoggedErrors(DNSLookupError))
+        return d
+        
+    def test_noroute(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import NoRouteError
+        host = '1.2.3.4'
+        self.client = Peer(host, 80)
+        self.timeout = 5
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, NoRouteError)
+        d.addCallback(lambda a: self.flushLoggedErrors(NoRouteError))
+        return d
+        
     def tearDown(self):
         for p in self.pending_calls:
             if p.active():