fix bugs, fix tests

This commit is contained in:
Maximilian Hils 2015-08-24 16:52:03 +02:00
parent 05d26545e4
commit f1f34e7713
5 changed files with 55 additions and 53 deletions

View File

@ -266,12 +266,15 @@ class HttpLayer(Layer):
self.handle_upstream_mode_connect(flow.request.copy()) self.handle_upstream_mode_connect(flow.request.copy())
return return
except (HttpErrorConnClosed, NetLibError, HttpError) as e: except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e:
self.send_to_client(make_error_response( self.send_to_client(make_error_response(
getattr(e, "code", 502), getattr(e, "code", 502),
repr(e) repr(e)
)) ))
raise ProtocolException(repr(e), e) if isinstance(e, ProtocolException):
raise e
else:
raise ProtocolException(repr(e), e)
finally: finally:
flow.live = False flow.live = False
@ -468,7 +471,7 @@ class HttpLayer(Layer):
def validate_request(self, request): def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http": if request.form_in == "absolute" and request.scheme != "http":
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme) raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = { expected_request_forms = {

View File

@ -41,7 +41,7 @@ class RootContext(object):
d = top_layer.client_conn.rfile.peek(3) d = top_layer.client_conn.rfile.peek(3)
is_ascii = ( is_ascii = (
len(d) == 3 and len(d) == 3 and
all(x in string.ascii_uppercase for x in d) all(x in string.ascii_letters for x in d) # better be safe here and don't expect uppercase...
) )
d = top_layer.client_conn.rfile.peek(len(HTTP2Protocol.CLIENT_CONNECTION_PREFACE)) d = top_layer.client_conn.rfile.peek(len(HTTP2Protocol.CLIENT_CONNECTION_PREFACE))

View File

@ -17,6 +17,7 @@ class TlsLayer(Layer):
self.client_sni = None self.client_sni = None
self._sni_from_server_change = None self._sni_from_server_change = None
self.client_alpn_protos = None self.client_alpn_protos = None
self.__server_tls_exception = None
# foo alpn protos = [netlib.http.http1.HTTP1Protocol.ALPN_PROTO_HTTP1, netlib.http.http2.HTTP2Protocol.ALPN_PROTO_H2], # TODO: read this from client_conn first # foo alpn protos = [netlib.http.http1.HTTP1Protocol.ALPN_PROTO_HTTP1, netlib.http.http2.HTTP2Protocol.ALPN_PROTO_H2], # TODO: read this from client_conn first
@ -107,49 +108,48 @@ class TlsLayer(Layer):
This callback gets called during the TLS handshake with the client. This callback gets called during the TLS handshake with the client.
The client has just sent the Sever Name Indication (SNI). The client has just sent the Sever Name Indication (SNI).
""" """
try: old_upstream_sni = self.sni_for_upstream_connection
old_upstream_sni = self.sni_for_upstream_connection
sn = connection.get_servername() sn = connection.get_servername()
if not sn: if not sn:
return return
self.client_sni = sn.decode("utf8").encode("idna")
if old_upstream_sni != self.sni_for_upstream_connection: self.client_sni = sn.decode("utf8").encode("idna")
# Perform reconnect
if self.server_conn and self._server_tls:
self.reconnect()
if self.client_sni: server_sni_changed = (old_upstream_sni != self.sni_for_upstream_connection)
# Now, change client context to reflect possibly changed certificate: server_conn_with_tls_exists = (self.server_conn and self._server_tls)
cert, key, chain_file = self._find_cert() if server_sni_changed and server_conn_with_tls_exists:
new_context = self.client_conn.create_ssl_context( try:
cert, key, self.reconnect()
method=self.config.openssl_method_client, except Exception as e:
options=self.config.openssl_options_client, self.__server_tls_exception = e
cipher_list=self.config.ciphers_client,
dhparams=self.config.certstore.dhparams, # Now, change client context to reflect possibly changed certificate:
chain_file=chain_file, cert, key, chain_file = self._find_cert()
alpn_select_callback=self.__handle_alpn_select, new_context = self.client_conn.create_ssl_context(
) cert, key,
connection.set_context(new_context) method=self.config.openssl_method_client,
# An unhandled exception in this method will core dump PyOpenSSL, so options=self.config.openssl_options_client,
# make dang sure it doesn't happen. cipher_list=self.config.ciphers_client,
except: # pragma: no cover dhparams=self.config.certstore.dhparams,
self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error") chain_file=chain_file,
alpn_select_callback=self.__handle_alpn_select,
)
connection.set_context(new_context)
def __handle_alpn_select(self, conn_, options): def __handle_alpn_select(self, conn_, options):
# TODO: change to something meaningful? # TODO: change to something meaningful?
alpn_preference = netlib.http.http1.HTTP1Protocol.ALPN_PROTO_HTTP1 # alpn_preference = netlib.http.http1.HTTP1Protocol.ALPN_PROTO_HTTP1
alpn_preference = netlib.http.http2.HTTP2Protocol.ALPN_PROTO_H2 alpn_preference = netlib.http.http2.HTTP2Protocol.ALPN_PROTO_H2
###
# TODO: Not # TODO: Don't reconnect twice?
if self.client_alpn_protos != options: upstream_alpn_changed = (self.client_alpn_protos != options)
# Perform reconnect server_conn_with_tls_exists = (self.server_conn and self._server_tls)
# TODO: Avoid double reconnect. if upstream_alpn_changed and server_conn_with_tls_exists:
if self.server_conn and self._server_tls: try:
self.reconnect() self.reconnect()
except Exception as e:
self.__server_tls_exception = e
self.client_alpn_protos = options self.client_alpn_protos = options
@ -177,6 +177,11 @@ class TlsLayer(Layer):
print("alpn: %s" % self.client_alpn_protos) print("alpn: %s" % self.client_alpn_protos)
raise ProtocolException(repr(e), e) raise ProtocolException(repr(e), e)
# Do not raise server tls exceptions immediately.
# We want to try to finish the client handshake so that other layers can send error messages over it.
if self.__server_tls_exception:
raise self.__server_tls_exception
def _establish_tls_with_server(self): def _establish_tls_with_server(self):
self.log("Establish TLS with server", "debug") self.log("Establish TLS with server", "debug")
try: try:

View File

@ -36,7 +36,7 @@ class TestServerConnection:
sc.send(protocol.assemble(f.request)) sc.send(protocol.assemble(f.request))
protocol = http.http1.HTTP1Protocol(rfile=sc.rfile) protocol = http.http1.HTTP1Protocol(rfile=sc.rfile)
assert protocol.read_response(f.request.method, 1000) assert protocol.read_response(f.request, 1000)
assert self.d.last_log() assert self.d.last_log()
sc.finish() sc.finish()

View File

@ -319,17 +319,6 @@ class TestHTTPAuth(tservers.HTTPProxTest):
assert ret.status_code == 202 assert ret.status_code == 202
class TestHTTPConnectSSLError(tservers.HTTPProxTest):
certfile = True
def test_go(self):
self.config.ssl_ports.append(self.proxy.port)
p = self.pathoc_raw()
dst = ("localhost", self.proxy.port)
p.connect(connect_to=dst)
tutils.raises("502 - Bad Gateway", p.http_connect, dst)
class TestHTTPS(tservers.HTTPProxTest, CommonMixin, TcpMixin): class TestHTTPS(tservers.HTTPProxTest, CommonMixin, TcpMixin):
ssl = True ssl = True
ssloptions = pathod.SSLOptions(request_client_cert=True) ssloptions = pathod.SSLOptions(request_client_cert=True)
@ -390,26 +379,31 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxTest):
("untrusted-cert", tutils.test_data.path("data/untrusted-server.crt")) ("untrusted-cert", tutils.test_data.path("data/untrusted-server.crt"))
]) ])
def _request(self):
p = self.pathoc()
# We need to make an actual request because the upstream connection is lazy-loaded.
return p.request("get:/p/242")
def test_default_verification_w_bad_cert(self): def test_default_verification_w_bad_cert(self):
"""Should use no verification.""" """Should use no verification."""
self.config.openssl_trusted_ca_server = tutils.test_data.path( self.config.openssl_trusted_ca_server = tutils.test_data.path(
"data/trusted-cadir/trusted-ca.pem") "data/trusted-cadir/trusted-ca.pem")
self.pathoc() assert self._request().status_code == 242
def test_no_verification_w_bad_cert(self): def test_no_verification_w_bad_cert(self):
self.config.openssl_verification_mode_server = SSL.VERIFY_NONE self.config.openssl_verification_mode_server = SSL.VERIFY_NONE
self.config.openssl_trusted_ca_server = tutils.test_data.path( self.config.openssl_trusted_ca_server = tutils.test_data.path(
"data/trusted-cadir/trusted-ca.pem") "data/trusted-cadir/trusted-ca.pem")
self.pathoc() assert self._request().status_code == 242
def test_verification_w_bad_cert(self): def test_verification_w_bad_cert(self):
self.config.openssl_verification_mode_server = SSL.VERIFY_PEER self.config.openssl_verification_mode_server = SSL.VERIFY_PEER
self.config.openssl_trusted_ca_server = tutils.test_data.path( self.config.openssl_trusted_ca_server = tutils.test_data.path(
"data/trusted-cadir/trusted-ca.pem") "data/trusted-cadir/trusted-ca.pem")
tutils.raises("SSL handshake error", self.pathoc) assert self._request().status_code == 502
class TestHTTPSNoCommonName(tservers.HTTPProxTest): class TestHTTPSNoCommonName(tservers.HTTPProxTest):