]> git.mxchange.org Git - quix0rs-apt-p2p.git/commitdiff
Read the response in the unit tests to prevent RST packets.
authorCameron Dale <camrdale@gmail.com>
Wed, 12 Dec 2007 21:31:57 +0000 (13:31 -0800)
committerCameron Dale <camrdale@gmail.com>
Wed, 12 Dec 2007 21:31:57 +0000 (13:31 -0800)
HTTPDownloader.py

index 290c7f9a8a56f0a77e3bd27025d52912fc8a6da5..a7c50442308f6b4223cdb14bdbeb11426614f8f3 100644 (file)
@@ -5,6 +5,7 @@ from twisted.web2.client.interfaces import IHTTPClientManager
 from twisted.web2.client.http import ProtocolError, ClientRequest, HTTPClientProtocol
 from twisted.trial import unittest
 from zope.interface import implements
+from twisted.web2 import stream as stream_mod, http, http_headers, responsecode
 
 class HTTPClientManager(ClientFactory):
     """A manager for all HTTP requests to a single site.
@@ -129,8 +130,13 @@ class TestClientManager(unittest.TestCase):
     
     def gotResp(self, resp, num, expect):
         self.failUnless(resp.code >= 200 and resp.code < 300, "Got a non-200 response: %r" % resp.code)
-        self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
-        resp.stream.close()
+        if expect is not None:
+            self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
+        def print_(n):
+            pass
+        def printdone(n):
+            pass
+        stream_mod.readStream(resp.stream, print_).addCallback(printdone)
     
     def test_download(self):
         host = 'www.camrdale.org'
@@ -208,7 +214,11 @@ class TestDownloader(unittest.TestCase):
         self.failUnless(resp.code >= 200 and resp.code < 300, "Got a non-200 response: %r" % resp.code)
         if expect is not None:
             self.failUnless(resp.stream.length == expect, "Length was incorrect, got %r, expected %r" % (resp.stream.length, expect))
-        resp.stream.close()
+        def print_(n):
+            pass
+        def printdone(n):
+            pass
+        stream_mod.readStream(resp.stream, print_).addCallback(printdone)
     
     def test_download(self):
         self.manager = HTTPDownloader()