mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-02-02 08:15:22 +00:00
Merge pull request #2851 from mhils/always-use-tls-settings
Use TLS options also for request replay
This commit is contained in:
commit
cda7c8d754
@ -253,7 +253,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
|||||||
address=address,
|
address=address,
|
||||||
ip_address=address,
|
ip_address=address,
|
||||||
cert=None,
|
cert=None,
|
||||||
sni=None,
|
sni=address[0],
|
||||||
alpn_proto_negotiated=None,
|
alpn_proto_negotiated=None,
|
||||||
tls_version=None,
|
tls_version=None,
|
||||||
source_address=('', 0),
|
source_address=('', 0),
|
||||||
@ -276,21 +276,21 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
|||||||
self.wfile.write(message)
|
self.wfile.write(message)
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
def establish_tls(self, clientcerts, sni, **kwargs):
|
def establish_tls(self, *, sni=None, client_certs=None, **kwargs):
|
||||||
if sni and not isinstance(sni, str):
|
if sni and not isinstance(sni, str):
|
||||||
raise ValueError("sni must be str, not " + type(sni).__name__)
|
raise ValueError("sni must be str, not " + type(sni).__name__)
|
||||||
clientcert = None
|
client_cert = None
|
||||||
if clientcerts:
|
if client_certs:
|
||||||
if os.path.isfile(clientcerts):
|
if os.path.isfile(client_certs):
|
||||||
clientcert = clientcerts
|
client_cert = client_certs
|
||||||
else:
|
else:
|
||||||
path = os.path.join(
|
path = os.path.join(
|
||||||
clientcerts,
|
client_certs,
|
||||||
self.address[0].encode("idna").decode()) + ".pem"
|
self.address[0].encode("idna").decode()) + ".pem"
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
clientcert = path
|
client_cert = path
|
||||||
|
|
||||||
self.convert_to_tls(cert=clientcert, sni=sni, **kwargs)
|
self.convert_to_tls(cert=client_cert, sni=sni, **kwargs)
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
|
self.alpn_proto_negotiated = self.get_alpn_proto_negotiated()
|
||||||
self.tls_version = self.connection.get_protocol_version_name()
|
self.tls_version = self.connection.get_protocol_version_name()
|
||||||
|
@ -13,6 +13,7 @@ import certifi
|
|||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
from kaitaistruct import KaitaiStream
|
from kaitaistruct import KaitaiStream
|
||||||
|
|
||||||
|
import mitmproxy.options # noqa
|
||||||
from mitmproxy import exceptions, certs
|
from mitmproxy import exceptions, certs
|
||||||
from mitmproxy.contrib.kaitaistruct import tls_client_hello
|
from mitmproxy.contrib.kaitaistruct import tls_client_hello
|
||||||
from mitmproxy.net import check
|
from mitmproxy.net import check
|
||||||
@ -57,6 +58,26 @@ METHOD_NAMES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def client_arguments_from_options(options: "mitmproxy.options.Options") -> dict:
|
||||||
|
|
||||||
|
if options.ssl_insecure:
|
||||||
|
verify = SSL.VERIFY_NONE
|
||||||
|
else:
|
||||||
|
verify = SSL.VERIFY_PEER
|
||||||
|
|
||||||
|
method, tls_options = VERSION_CHOICES[options.ssl_version_server]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"verify": verify,
|
||||||
|
"method": method,
|
||||||
|
"options": tls_options,
|
||||||
|
"ca_path": options.ssl_verify_upstream_trusted_cadir,
|
||||||
|
"ca_pemfile": options.ssl_verify_upstream_trusted_ca,
|
||||||
|
"client_certs": options.client_certs,
|
||||||
|
"cipher_list": options.ciphers_server,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class MasterSecretLogger:
|
class MasterSecretLogger:
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
self.filename = filename
|
self.filename = filename
|
||||||
|
@ -9,7 +9,7 @@ from mitmproxy import http
|
|||||||
from mitmproxy import flow
|
from mitmproxy import flow
|
||||||
from mitmproxy import options
|
from mitmproxy import options
|
||||||
from mitmproxy import connections
|
from mitmproxy import connections
|
||||||
from mitmproxy.net import server_spec
|
from mitmproxy.net import server_spec, tls
|
||||||
from mitmproxy.net.http import http1
|
from mitmproxy.net.http import http1
|
||||||
from mitmproxy.coretypes import basethread
|
from mitmproxy.coretypes import basethread
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
@ -76,8 +76,8 @@ class RequestReplayThread(basethread.BaseThread):
|
|||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
|
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
|
||||||
server.establish_tls(
|
server.establish_tls(
|
||||||
self.options.client_certs,
|
sni=self.f.server_conn.sni,
|
||||||
sni=self.f.server_conn.sni
|
**tls.client_arguments_from_options(self.options)
|
||||||
)
|
)
|
||||||
r.first_line_format = "relative"
|
r.first_line_format = "relative"
|
||||||
else:
|
else:
|
||||||
@ -91,8 +91,8 @@ class RequestReplayThread(basethread.BaseThread):
|
|||||||
server.connect()
|
server.connect()
|
||||||
if r.scheme == "https":
|
if r.scheme == "https":
|
||||||
server.establish_tls(
|
server.establish_tls(
|
||||||
self.options.client_certs,
|
sni=self.f.server_conn.sni,
|
||||||
sni=self.f.server_conn.sni
|
**tls.client_arguments_from_options(self.options)
|
||||||
)
|
)
|
||||||
r.first_line_format = "relative"
|
r.first_line_format = "relative"
|
||||||
|
|
||||||
|
@ -424,6 +424,9 @@ class TlsLayer(base.Layer):
|
|||||||
# * which results in garbage because the layers don' match.
|
# * which results in garbage because the layers don' match.
|
||||||
alpn = [self.client_conn.get_alpn_proto_negotiated()]
|
alpn = [self.client_conn.get_alpn_proto_negotiated()]
|
||||||
|
|
||||||
|
# We pass through the list of ciphers send by the client, because some HTTP/2 servers
|
||||||
|
# will select a non-HTTP/2 compatible cipher from our default list and then hang up
|
||||||
|
# because it's incompatible with h2. :-)
|
||||||
ciphers_server = self.config.options.ciphers_server
|
ciphers_server = self.config.options.ciphers_server
|
||||||
if not ciphers_server and self._client_tls:
|
if not ciphers_server and self._client_tls:
|
||||||
ciphers_server = []
|
ciphers_server = []
|
||||||
@ -432,16 +435,12 @@ class TlsLayer(base.Layer):
|
|||||||
ciphers_server.append(CIPHER_ID_NAME_MAP[id])
|
ciphers_server.append(CIPHER_ID_NAME_MAP[id])
|
||||||
ciphers_server = ':'.join(ciphers_server)
|
ciphers_server = ':'.join(ciphers_server)
|
||||||
|
|
||||||
|
args = net_tls.client_arguments_from_options(self.config.options)
|
||||||
|
args["cipher_list"] = ciphers_server
|
||||||
self.server_conn.establish_tls(
|
self.server_conn.establish_tls(
|
||||||
self.config.client_certs,
|
sni=self.server_sni,
|
||||||
self.server_sni,
|
|
||||||
method=self.config.openssl_method_server,
|
|
||||||
options=self.config.openssl_options_server,
|
|
||||||
verify=self.config.openssl_verification_mode_server,
|
|
||||||
ca_path=self.config.options.ssl_verify_upstream_trusted_cadir,
|
|
||||||
ca_pemfile=self.config.options.ssl_verify_upstream_trusted_ca,
|
|
||||||
cipher_list=ciphers_server,
|
|
||||||
alpn_protos=alpn,
|
alpn_protos=alpn,
|
||||||
|
**args
|
||||||
)
|
)
|
||||||
tls_cert_err = self.server_conn.ssl_verification_error
|
tls_cert_err = self.server_conn.ssl_verification_error
|
||||||
if tls_cert_err is not None:
|
if tls_cert_err is not None:
|
||||||
|
@ -155,7 +155,7 @@ class TestServerConnection:
|
|||||||
def test_sni(self):
|
def test_sni(self):
|
||||||
c = connections.ServerConnection(('', 1234))
|
c = connections.ServerConnection(('', 1234))
|
||||||
with pytest.raises(ValueError, matches='sni must be str, not '):
|
with pytest.raises(ValueError, matches='sni must be str, not '):
|
||||||
c.establish_tls(None, b'foobar')
|
c.establish_tls(sni=b'foobar')
|
||||||
|
|
||||||
def test_state(self):
|
def test_state(self):
|
||||||
c = tflow.tserver_conn()
|
c = tflow.tserver_conn()
|
||||||
@ -222,17 +222,16 @@ class TestServerConnectionTLS(tservers.ServerTestBase):
|
|||||||
def handle(self):
|
def handle(self):
|
||||||
self.finish()
|
self.finish()
|
||||||
|
|
||||||
@pytest.mark.parametrize("clientcert", [
|
@pytest.mark.parametrize("client_certs", [
|
||||||
None,
|
None,
|
||||||
tutils.test_data.path("mitmproxy/data/clientcert"),
|
tutils.test_data.path("mitmproxy/data/clientcert"),
|
||||||
tutils.test_data.path("mitmproxy/data/clientcert/client.pem"),
|
tutils.test_data.path("mitmproxy/data/clientcert/client.pem"),
|
||||||
])
|
])
|
||||||
def test_tls(self, clientcert):
|
def test_tls(self, client_certs):
|
||||||
c = connections.ServerConnection(("127.0.0.1", self.port))
|
c = connections.ServerConnection(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.establish_tls(clientcert, "foo.com")
|
c.establish_tls(client_certs=client_certs)
|
||||||
assert c.connected()
|
assert c.connected()
|
||||||
assert c.sni == "foo.com"
|
|
||||||
assert c.tls_established
|
assert c.tls_established
|
||||||
c.close()
|
c.close()
|
||||||
c.finish()
|
c.finish()
|
||||||
|
Loading…
Reference in New Issue
Block a user