diff --git a/netlib/tcp.py b/netlib/tcp.py index 7010eef0c..c6e0075e5 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,5 @@ from __future__ import (absolute_import, print_function, division) +import os import select import socket import sys @@ -26,6 +27,37 @@ class NetLibTimeout(NetLibError): pass class NetLibSSLError(NetLibError): pass +class SSLKeyLogger(object): + def __init__(self, filename): + self.filename = filename + self.f = None + self.lock = threading.Lock() + + __name__ = "SSLKeyLogger" # required for functools.wraps, which pyOpenSSL uses. + + def __call__(self, connection, where, ret): + if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: + with self.lock: + if not self.f: + self.f = open(self.filename, "ab") + self.f.write("\r\n") + client_random = connection.client_random().encode("hex") + masterkey = connection.master_key().encode("hex") + self.f.write("CLIENT_RANDOM {} {}\r\n".format(client_random, masterkey)) + self.f.flush() + + def close(self): + with self.lock: + if self.f: + self.f.close() + +_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE") +if _logfile: + log_ssl_key = SSLKeyLogger(_logfile) +else: + log_ssl_key = False + + class _FileLike: BLOCKSIZE = 1024 * 32 def __init__(self, o): @@ -314,6 +346,8 @@ class TCPClient(_Connection): if sni: self.sni = sni self.connection.set_tlsext_host_name(sni) + if log_ssl_key: + context.set_info_callback(log_ssl_key) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -418,6 +452,8 @@ class BaseHandler(_Connection): # Return true to prevent cert verification error return True ctx.set_verify(SSL.VERIFY_PEER, ver) + if log_ssl_key: + ctx.set_info_callback(log_ssl_key) return ctx def convert_to_ssl(self, cert, key, **sslctx_kwargs):