Make the tcp connection closer cancellable

And use this to make pathoc error handling more sophisticated
This commit is contained in:
Aldo Cortesi 2016-06-12 11:17:05 +12:00
parent 9bea616441
commit dc545ca0f6
2 changed files with 54 additions and 41 deletions

View File

@ -6,7 +6,6 @@ import sys
import threading import threading
import time import time
import traceback import traceback
import contextlib
import binascii import binascii
from six.moves import range from six.moves import range
@ -582,12 +581,24 @@ class _Connection(object):
return context return context
@contextlib.contextmanager class ConnectionCloser(object):
def _closer(client): def __init__(self, conn):
try: self.conn = conn
yield self._canceled = False
finally:
client.close() def pop(self):
"""
Cancel the current closer, and return a fresh one.
"""
self._canceled = True
return ConnectionCloser(self.conn)
def __enter__(self):
return self
def __exit__(self, *args):
if not self._canceled:
self.conn.close()
class TCPClient(_Connection): class TCPClient(_Connection):
@ -717,11 +728,12 @@ class TCPClient(_Connection):
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
raise exceptions.TcpException( raise exceptions.TcpException(
'Error connecting to "%s": %s' % 'Error connecting to "%s": %s' %
(self.address.host, err)) (self.address.host, err)
)
self.connection = connection self.connection = connection
self.ip_address = Address(connection.getpeername()) self.ip_address = Address(connection.getpeername())
self._makefile() self._makefile()
return _closer(self) return ConnectionCloser(self)
def settimeout(self, n): def settimeout(self, n):
self.connection.settimeout(n) self.connection.settimeout(n)

View File

@ -291,44 +291,45 @@ class Pathoc(tcp.TCPClient):
if self.use_http2 and not self.ssl: if self.use_http2 and not self.ssl:
raise NotImplementedError("HTTP2 without SSL is not supported.") raise NotImplementedError("HTTP2 without SSL is not supported.")
ret = tcp.TCPClient.connect(self) with tcp.TCPClient.connect(self) as closer:
if connect_to: if connect_to:
self.http_connect(connect_to) self.http_connect(connect_to)
self.sslinfo = None self.sslinfo = None
if self.ssl: if self.ssl:
try: try:
alpn_protos = [b'http/1.1'] alpn_protos = [b'http/1.1']
if self.use_http2: if self.use_http2:
alpn_protos.append(b'h2') alpn_protos.append(b'h2')
self.convert_to_ssl( self.convert_to_ssl(
sni=self.sni, sni=self.sni,
cert=self.clientcert, cert=self.clientcert,
method=self.ssl_version, method=self.ssl_version,
options=self.ssl_options, options=self.ssl_options,
cipher_list=self.ciphers, cipher_list=self.ciphers,
alpn_protos=alpn_protos alpn_protos=alpn_protos
)
except exceptions.TlsException as v:
raise PathocError(str(v))
self.sslinfo = SSLInfo(
self.connection.get_peer_cert_chain(),
self.get_current_cipher(),
self.get_alpn_proto_negotiated()
) )
except exceptions.TlsException as v: if showssl:
raise PathocError(str(v)) print(str(self.sslinfo), file=fp)
self.sslinfo = SSLInfo( if self.use_http2:
self.connection.get_peer_cert_chain(), self.protocol.check_alpn()
self.get_current_cipher(), if not self.http2_skip_connection_preface:
self.get_alpn_proto_negotiated() self.protocol.perform_client_connection_preface()
)
if showssl:
print(str(self.sslinfo), file=fp)
if self.use_http2: if self.timeout:
self.protocol.check_alpn() self.settimeout(self.timeout)
if not self.http2_skip_connection_preface:
self.protocol.perform_client_connection_preface()
if self.timeout: return closer.pop()
self.settimeout(self.timeout)
return ret
def stop(self): def stop(self):
if self.ws_framereader: if self.ws_framereader: