diff --git a/netlib/tcp.py b/netlib/tcp.py index acd67cad6..a8a681391 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -6,7 +6,6 @@ import sys import threading import time import traceback -import contextlib import binascii from six.moves import range @@ -582,12 +581,24 @@ class _Connection(object): return context -@contextlib.contextmanager -def _closer(client): - try: - yield - finally: - client.close() +class ConnectionCloser(object): + def __init__(self, conn): + self.conn = conn + self._canceled = False + + def pop(self): + """ + Cancel the current closer, and return a fresh one. + """ + self._canceled = True + return ConnectionCloser(self.conn) + + def __enter__(self): + return self + + def __exit__(self, *args): + if not self._canceled: + self.conn.close() class TCPClient(_Connection): @@ -717,11 +728,12 @@ class TCPClient(_Connection): except (socket.error, IOError) as err: raise exceptions.TcpException( 'Error connecting to "%s": %s' % - (self.address.host, err)) + (self.address.host, err) + ) self.connection = connection self.ip_address = Address(connection.getpeername()) self._makefile() - return _closer(self) + return ConnectionCloser(self) def settimeout(self, n): self.connection.settimeout(n) diff --git a/pathod/pathoc.py b/pathod/pathoc.py index b25639887..21fc9845e 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -291,44 +291,45 @@ class Pathoc(tcp.TCPClient): if self.use_http2 and not self.ssl: raise NotImplementedError("HTTP2 without SSL is not supported.") - ret = tcp.TCPClient.connect(self) - if connect_to: - self.http_connect(connect_to) + with tcp.TCPClient.connect(self) as closer: + if connect_to: + self.http_connect(connect_to) - self.sslinfo = None - if self.ssl: - try: - alpn_protos = [b'http/1.1'] - if self.use_http2: - alpn_protos.append(b'h2') + self.sslinfo = None + if self.ssl: + try: + alpn_protos = [b'http/1.1'] + if self.use_http2: + alpn_protos.append(b'h2') - self.convert_to_ssl( - sni=self.sni, - cert=self.clientcert, - method=self.ssl_version, - options=self.ssl_options, - cipher_list=self.ciphers, - alpn_protos=alpn_protos + self.convert_to_ssl( + sni=self.sni, + cert=self.clientcert, + method=self.ssl_version, + options=self.ssl_options, + cipher_list=self.ciphers, + alpn_protos=alpn_protos + ) + except exceptions.TlsException as v: + raise PathocError(str(v)) + + self.sslinfo = SSLInfo( + self.connection.get_peer_cert_chain(), + self.get_current_cipher(), + self.get_alpn_proto_negotiated() ) - except exceptions.TlsException as v: - raise PathocError(str(v)) + if showssl: + print(str(self.sslinfo), file=fp) - self.sslinfo = SSLInfo( - self.connection.get_peer_cert_chain(), - self.get_current_cipher(), - self.get_alpn_proto_negotiated() - ) - if showssl: - print(str(self.sslinfo), file=fp) + if self.use_http2: + self.protocol.check_alpn() + if not self.http2_skip_connection_preface: + self.protocol.perform_client_connection_preface() - if self.use_http2: - self.protocol.check_alpn() - if not self.http2_skip_connection_preface: - self.protocol.perform_client_connection_preface() + if self.timeout: + self.settimeout(self.timeout) - if self.timeout: - self.settimeout(self.timeout) - return ret + return closer.pop() def stop(self): if self.ws_framereader: diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index dc5011c0b..77d4721c6 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -83,7 +83,13 @@ class TestDaemon(PathocTestDaemon): def test_ssl_error(self): c = pathoc.Pathoc(("127.0.0.1", self.d.port), ssl=True, fp=None) - tutils.raises("ssl handshake", c.connect) + try: + with c.connect(): + pass + except Exception as e: + assert "SSL" in str(e) + else: + raise AssertionError("No exception raised.") def test_showssl(self): assert "certificate chain" not in self.tval(