Set a new peer's ranking values so they don't get an unfair advantage.
[quix0rs-apt-p2p.git] / apt_p2p / HTTPDownloader.py
index 55a818f1a135624306494c077db9ced72fca7b95..8ecac193f6cba22c71ac1747d5df89d327a4ea4f 100644 (file)
@@ -95,13 +95,18 @@ class LoggingHTTPClientProtocol(HTTPClientProtocol):
         return chanRequest.responseDefer
 
     def setReadPersistent(self, persist):
+        oldPersist = self.readPersistent
         self.readPersistent = persist
         if not persist:
             # Tell all requests but first to abort.
             lostRequests = self.inRequests[1:]
             del self.inRequests[1:]
             for request in lostRequests:
-                request.connectionLost(PipelineError('The pipelined connection was lost'))
+                request.connectionLost(PipelineError('Pipelined connection was closed.'))
+        elif (oldPersist is PERSIST_NO_PIPELINE and
+              persist is PERSIST_PIPELINE and
+              self.outRequest is None):
+            self.manager.clientPipelining(self)
 
     def connectionLost(self, reason):
         self.readPersistent = False
@@ -115,7 +120,7 @@ class LoggingHTTPClientProtocol(HTTPClientProtocol):
         del self.inRequests[1:]
         for request in lostRequests:
             if request is not None:
-                request.connectionLost(reason)
+                request.connectionLost(PipelineError('Pipelined connection was closed.'))
                 
 class Peer(ClientFactory):
     """A manager for all HTTP requests to a single peer.
@@ -158,7 +163,7 @@ class Peer(ClientFactory):
         log.msg('Connecting to (%s, %d)' % (self.host, self.port))
         self.connecting = True
         d = protocol.ClientCreator(reactor, LoggingHTTPClientProtocol, self,
-                                   stats = self.stats, mirror = self.mirror).connectTCP(self.host, self.port)
+                                   stats = self.stats, mirror = self.mirror).connectTCP(self.host, self.port, timeout = 10)
         d.addCallbacks(self.connected, self.connectionError)
 
     def connected(self, proto):
@@ -353,8 +358,9 @@ class Peer(ClientFactory):
     def _processLastResponse(self):
         """Save the download time of the last request for speed calculations."""
         if self._lastResponse is not None:
-            now = datetime.now()
-            self._downloadSpeeds.append((now, now - self._lastResponse[0], self._lastResponse[1]))
+            if self._lastResponse[1] is not None:
+                now = datetime.now()
+                self._downloadSpeeds.append((now, now - self._lastResponse[0], self._lastResponse[1]))
             self._lastResponse = None
             
     def downloadSpeed(self):
@@ -371,7 +377,7 @@ class Peer(ClientFactory):
 
         # If there are none, then you get 0
         if not self._downloadSpeeds:
-            return 0.0
+            return 150000.0
         
         for download in self._downloadSpeeds:
             total_time += download[1].days*86400.0 + download[1].seconds + download[1].microseconds/1000000.0
@@ -394,7 +400,7 @@ class Peer(ClientFactory):
 
         # If there are none, give it the benefit of the doubt
         if not self._responseTimes:
-            return 0.0
+            return 0.1
 
         for response in self._responseTimes:
             total_response += response[1].days*86400.0 + response[1].seconds + response[1].microseconds/1000000.0
@@ -570,6 +576,45 @@ class TestClientManager(unittest.TestCase):
         d.addCallback(self.gotResp, 1, 100)
         return d
         
+    def test_timeout(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import TimeoutError
+        host = 'steveholt.hopto.org'
+        self.client = Peer(host, 80)
+        self.timeout = 60
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, TimeoutError)
+        d.addCallback(lambda a: self.flushLoggedErrors(TimeoutError))
+        return d
+        
+    def test_dnserror(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import DNSLookupError
+        host = 'hureyfnvbfha.debian.net'
+        self.client = Peer(host, 80)
+        self.timeout = 5
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, DNSLookupError)
+        d.addCallback(lambda a: self.flushLoggedErrors(DNSLookupError))
+        return d
+        
+    def test_noroute(self):
+        """Tests a connection timeout."""
+        from twisted.internet.error import NoRouteError
+        host = '1.2.3.4'
+        self.client = Peer(host, 80)
+        self.timeout = 5
+        
+        d = self.client.get('/rfc/rfc0013.txt')
+        d.addCallback(self.gotResp, 1, 1070)
+        d = self.failUnlessFailure(d, NoRouteError)
+        d.addCallback(lambda a: self.flushLoggedErrors(NoRouteError))
+        return d
+        
     def tearDown(self):
         for p in self.pending_calls:
             if p.active():