Merge pull request #2851 from mhils/always-use-tls-settings

Use TLS options also for request replay
This commit is contained in:
Maximilian Hils 2018-02-10 12:40:39 +01:00 committed by GitHub
commit cda7c8d754
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 46 additions and 27 deletions

View File

@ -253,7 +253,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
address=address, address=address,
ip_address=address, ip_address=address,
cert=None, cert=None,
sni=None, sni=address[0],
alpn_proto_negotiated=None, alpn_proto_negotiated=None,
tls_version=None, tls_version=None,
source_address=('', 0), source_address=('', 0),
@ -276,21 +276,21 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.wfile.write(message) self.wfile.write(message)
self.wfile.flush() self.wfile.flush()
def establish_tls(self, clientcerts, sni, **kwargs): def establish_tls(self, *, sni=None, client_certs=None, **kwargs):
if sni and not isinstance(sni, str): if sni and not isinstance(sni, str):
raise ValueError("sni must be str, not " + type(sni).__name__) raise ValueError("sni must be str, not " + type(sni).__name__)
clientcert = None client_cert = None
if clientcerts: if client_certs:
if os.path.isfile(clientcerts): if os.path.isfile(client_certs):
clientcert = clientcerts client_cert = client_certs
else: else:
path = os.path.join( path = os.path.join(
clientcerts, client_certs,
self.address[0].encode("idna").decode()) + ".pem" self.address[0].encode("idna").decode()) + ".pem"
if os.path.exists(path): if os.path.exists(path):
clientcert = path client_cert = path
self.convert_to_tls(cert=clientcert, sni=sni, **kwargs) self.convert_to_tls(cert=client_cert, sni=sni, **kwargs)
self.sni = sni self.sni = sni
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
self.tls_version = self.connection.get_protocol_version_name() self.tls_version = self.connection.get_protocol_version_name()

View File

@ -13,6 +13,7 @@ import certifi
from OpenSSL import SSL from OpenSSL import SSL
from kaitaistruct import KaitaiStream from kaitaistruct import KaitaiStream
import mitmproxy.options # noqa
from mitmproxy import exceptions, certs from mitmproxy import exceptions, certs
from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.contrib.kaitaistruct import tls_client_hello
from mitmproxy.net import check from mitmproxy.net import check
@ -57,6 +58,26 @@ METHOD_NAMES = {
} }
def client_arguments_from_options(options: "mitmproxy.options.Options") -> dict:
if options.ssl_insecure:
verify = SSL.VERIFY_NONE
else:
verify = SSL.VERIFY_PEER
method, tls_options = VERSION_CHOICES[options.ssl_version_server]
return {
"verify": verify,
"method": method,
"options": tls_options,
"ca_path": options.ssl_verify_upstream_trusted_cadir,
"ca_pemfile": options.ssl_verify_upstream_trusted_ca,
"client_certs": options.client_certs,
"cipher_list": options.ciphers_server,
}
class MasterSecretLogger: class MasterSecretLogger:
def __init__(self, filename): def __init__(self, filename):
self.filename = filename self.filename = filename

View File

@ -9,7 +9,7 @@ from mitmproxy import http
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import options from mitmproxy import options
from mitmproxy import connections from mitmproxy import connections
from mitmproxy.net import server_spec from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread from mitmproxy.coretypes import basethread
from mitmproxy.utils import human from mitmproxy.utils import human
@ -76,8 +76,8 @@ class RequestReplayThread(basethread.BaseThread):
if resp.status_code != 200: if resp.status_code != 200:
raise exceptions.ReplayException("Upstream server refuses CONNECT request") raise exceptions.ReplayException("Upstream server refuses CONNECT request")
server.establish_tls( server.establish_tls(
self.options.client_certs, sni=self.f.server_conn.sni,
sni=self.f.server_conn.sni **tls.client_arguments_from_options(self.options)
) )
r.first_line_format = "relative" r.first_line_format = "relative"
else: else:
@ -91,8 +91,8 @@ class RequestReplayThread(basethread.BaseThread):
server.connect() server.connect()
if r.scheme == "https": if r.scheme == "https":
server.establish_tls( server.establish_tls(
self.options.client_certs, sni=self.f.server_conn.sni,
sni=self.f.server_conn.sni **tls.client_arguments_from_options(self.options)
) )
r.first_line_format = "relative" r.first_line_format = "relative"

View File

@ -424,6 +424,9 @@ class TlsLayer(base.Layer):
# * which results in garbage because the layers don' match. # * which results in garbage because the layers don' match.
alpn = [self.client_conn.get_alpn_proto_negotiated()] alpn = [self.client_conn.get_alpn_proto_negotiated()]
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
# because it's incompatible with h2. :-)
ciphers_server = self.config.options.ciphers_server ciphers_server = self.config.options.ciphers_server
if not ciphers_server and self._client_tls: if not ciphers_server and self._client_tls:
ciphers_server = [] ciphers_server = []
@ -432,16 +435,12 @@ class TlsLayer(base.Layer):
ciphers_server.append(CIPHER_ID_NAME_MAP[id]) ciphers_server.append(CIPHER_ID_NAME_MAP[id])
ciphers_server = ':'.join(ciphers_server) ciphers_server = ':'.join(ciphers_server)
args = net_tls.client_arguments_from_options(self.config.options)
args["cipher_list"] = ciphers_server
self.server_conn.establish_tls( self.server_conn.establish_tls(
self.config.client_certs, sni=self.server_sni,
self.server_sni,
method=self.config.openssl_method_server,
options=self.config.openssl_options_server,
verify=self.config.openssl_verification_mode_server,
ca_path=self.config.options.ssl_verify_upstream_trusted_cadir,
ca_pemfile=self.config.options.ssl_verify_upstream_trusted_ca,
cipher_list=ciphers_server,
alpn_protos=alpn, alpn_protos=alpn,
**args
) )
tls_cert_err = self.server_conn.ssl_verification_error tls_cert_err = self.server_conn.ssl_verification_error
if tls_cert_err is not None: if tls_cert_err is not None:

View File

@ -155,7 +155,7 @@ class TestServerConnection:
def test_sni(self): def test_sni(self):
c = connections.ServerConnection(('', 1234)) c = connections.ServerConnection(('', 1234))
with pytest.raises(ValueError, matches='sni must be str, not '): with pytest.raises(ValueError, matches='sni must be str, not '):
c.establish_tls(None, b'foobar') c.establish_tls(sni=b'foobar')
def test_state(self): def test_state(self):
c = tflow.tserver_conn() c = tflow.tserver_conn()
@ -222,17 +222,16 @@ class TestServerConnectionTLS(tservers.ServerTestBase):
def handle(self): def handle(self):
self.finish() self.finish()
@pytest.mark.parametrize("clientcert", [ @pytest.mark.parametrize("client_certs", [
None, None,
tutils.test_data.path("mitmproxy/data/clientcert"), tutils.test_data.path("mitmproxy/data/clientcert"),
tutils.test_data.path("mitmproxy/data/clientcert/client.pem"), tutils.test_data.path("mitmproxy/data/clientcert/client.pem"),
]) ])
def test_tls(self, clientcert): def test_tls(self, client_certs):
c = connections.ServerConnection(("127.0.0.1", self.port)) c = connections.ServerConnection(("127.0.0.1", self.port))
c.connect() c.connect()
c.establish_tls(clientcert, "foo.com") c.establish_tls(client_certs=client_certs)
assert c.connected() assert c.connected()
assert c.sni == "foo.com"
assert c.tls_established assert c.tls_established
c.close() c.close()
c.finish() c.finish()