mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
remove old mitmproxy.net.tcp code
this is not needed anymore with sans-io
This commit is contained in:
parent
cdb0cf6c0a
commit
b05c13daa6
@ -1,683 +0,0 @@
|
||||
import os
|
||||
import errno
|
||||
import select
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from typing import Optional # noqa
|
||||
|
||||
from mitmproxy.net import tls
|
||||
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.coretypes import basethread
|
||||
|
||||
socket_fileobject = socket.SocketIO
|
||||
|
||||
# workaround for https://bugs.python.org/issue29515
|
||||
# Python 3.8 for Windows is missing a constant, fixed in 3.9
|
||||
IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41)
|
||||
|
||||
|
||||
class _FileLike:
|
||||
BLOCKSIZE = 1024 * 32
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
self._log = None
|
||||
self.first_byte_timestamp = None
|
||||
|
||||
def set_descriptor(self, o):
|
||||
self.o = o
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.o, attr)
|
||||
|
||||
def start_log(self):
|
||||
"""
|
||||
Starts or resets the log.
|
||||
|
||||
This will store all bytes read or written.
|
||||
"""
|
||||
self._log = []
|
||||
|
||||
def stop_log(self):
|
||||
"""
|
||||
Stops the log.
|
||||
"""
|
||||
self._log = None
|
||||
|
||||
def is_logging(self):
|
||||
return self._log is not None
|
||||
|
||||
def get_log(self):
|
||||
"""
|
||||
Returns the log as a string.
|
||||
"""
|
||||
if not self.is_logging():
|
||||
raise ValueError("Not logging!")
|
||||
return b"".join(self._log)
|
||||
|
||||
def add_log(self, v):
|
||||
if self.is_logging():
|
||||
self._log.append(v)
|
||||
|
||||
def reset_timestamps(self):
|
||||
self.first_byte_timestamp = None
|
||||
|
||||
|
||||
class Writer(_FileLike):
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
May raise exceptions.TcpDisconnect
|
||||
"""
|
||||
if hasattr(self.o, "flush"):
|
||||
try:
|
||||
self.o.flush()
|
||||
except OSError as v:
|
||||
raise exceptions.TcpDisconnect(str(v))
|
||||
|
||||
def write(self, v):
|
||||
"""
|
||||
May raise exceptions.TcpDisconnect
|
||||
"""
|
||||
if v:
|
||||
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
||||
try:
|
||||
if hasattr(self.o, "sendall"):
|
||||
self.add_log(v)
|
||||
return self.o.sendall(v)
|
||||
else:
|
||||
r = self.o.write(v)
|
||||
self.add_log(v[:r])
|
||||
return r
|
||||
except (SSL.Error, OSError) as e:
|
||||
raise exceptions.TcpDisconnect(str(e))
|
||||
|
||||
|
||||
class Reader(_FileLike):
|
||||
|
||||
def read(self, length):
|
||||
"""
|
||||
If length is -1, we read until connection closes.
|
||||
"""
|
||||
result = b''
|
||||
start = time.time()
|
||||
while length == -1 or length > 0:
|
||||
if length == -1 or length > self.BLOCKSIZE:
|
||||
rlen = self.BLOCKSIZE
|
||||
else:
|
||||
rlen = length
|
||||
try:
|
||||
data = self.o.read(rlen)
|
||||
except SSL.ZeroReturnError:
|
||||
# TLS connection was shut down cleanly
|
||||
break
|
||||
except (SSL.WantWriteError, SSL.WantReadError):
|
||||
# From the OpenSSL docs:
|
||||
# If the underlying BIO is non-blocking, SSL_read() will also return when the
|
||||
# underlying BIO could not satisfy the needs of SSL_read() to continue the
|
||||
# operation. In this case a call to SSL_get_error with the return value of
|
||||
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
|
||||
# 300 is OpenSSL default timeout
|
||||
timeout = self.o.gettimeout() or 300
|
||||
if (time.time() - start) < timeout:
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
else:
|
||||
raise exceptions.TcpTimeout()
|
||||
except socket.timeout:
|
||||
raise exceptions.TcpTimeout()
|
||||
except OSError as e:
|
||||
raise exceptions.TcpDisconnect(str(e))
|
||||
except SSL.SysCallError as e:
|
||||
if e.args == (-1, 'Unexpected EOF'):
|
||||
break
|
||||
raise exceptions.TlsException(str(e))
|
||||
except SSL.Error as e:
|
||||
raise exceptions.TlsException(str(e))
|
||||
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
||||
if not data:
|
||||
break
|
||||
result += data
|
||||
if length != -1:
|
||||
length -= len(data)
|
||||
self.add_log(result)
|
||||
return result
|
||||
|
||||
def readline(self, size=None):
|
||||
result = b''
|
||||
bytes_read = 0
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
else:
|
||||
result += ch
|
||||
if ch == b'\n':
|
||||
break
|
||||
return result
|
||||
|
||||
def safe_read(self, length):
|
||||
"""
|
||||
Like .read, but is guaranteed to either return length bytes, or
|
||||
raise an exception.
|
||||
"""
|
||||
result = self.read(length)
|
||||
if length != -1 and len(result) != length:
|
||||
if not result:
|
||||
raise exceptions.TcpDisconnect()
|
||||
else:
|
||||
raise exceptions.TcpReadIncomplete(
|
||||
"Expected {} bytes, got {}".format(length, len(result))
|
||||
)
|
||||
return result
|
||||
|
||||
def peek(self, length):
|
||||
"""
|
||||
Tries to peek into the underlying file object.
|
||||
|
||||
Returns:
|
||||
Up to the next N bytes if peeking is successful.
|
||||
|
||||
Raises:
|
||||
exceptions.TcpException if there was an error with the socket
|
||||
TlsException if there was an error with pyOpenSSL.
|
||||
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
|
||||
"""
|
||||
if isinstance(self.o, socket_fileobject):
|
||||
try:
|
||||
return self.o._sock.recv(length, socket.MSG_PEEK)
|
||||
except OSError as e:
|
||||
raise exceptions.TcpException(repr(e))
|
||||
elif isinstance(self.o, SSL.Connection):
|
||||
try:
|
||||
return self.o.recv(length, socket.MSG_PEEK)
|
||||
except SSL.Error as e:
|
||||
raise exceptions.TlsException(str(e))
|
||||
else:
|
||||
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
|
||||
|
||||
|
||||
def ssl_read_select(rlist, timeout):
|
||||
"""
|
||||
This is a wrapper around select.select() which also works for SSL.Connections
|
||||
by taking ssl_connection.pending() into account.
|
||||
|
||||
Caveats:
|
||||
If .pending() > 0 for any of the connections in rlist, we avoid the select syscall
|
||||
and **will not include any other connections which may or may not be ready**.
|
||||
|
||||
Args:
|
||||
rlist: wait until ready for reading
|
||||
|
||||
Returns:
|
||||
subset of rlist which is ready for reading.
|
||||
"""
|
||||
return [
|
||||
conn for conn in rlist
|
||||
if isinstance(conn, SSL.Connection) and conn.pending() > 0
|
||||
] or select.select(rlist, (), (), timeout)[0]
|
||||
|
||||
|
||||
def close_socket(sock):
|
||||
"""
|
||||
Does a hard close of a socket, without emitting a RST.
|
||||
"""
|
||||
try:
|
||||
# We already indicate that we close our end.
|
||||
# may raise "Transport endpoint is not connected" on Linux
|
||||
sock.shutdown(socket.SHUT_WR)
|
||||
|
||||
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
|
||||
# readable data could lead to an immediate RST being sent (which is the
|
||||
# case on Windows).
|
||||
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
|
||||
#
|
||||
# This in turn results in the following issue: If we send an error page
|
||||
# to the client and then close the socket, the RST may be received by
|
||||
# the client before the error page and the users sees a connection
|
||||
# error rather than the error page. Thus, we try to empty the read
|
||||
# buffer on Windows first. (see
|
||||
# https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
|
||||
#
|
||||
|
||||
if os.name == "nt": # pragma: no cover
|
||||
# We cannot rely on the shutdown()-followed-by-read()-eof technique
|
||||
# proposed by the page above: Some remote machines just don't send
|
||||
# a TCP FIN, which would leave us in the unfortunate situation that
|
||||
# recv() would block infinitely. As a workaround, we set a timeout
|
||||
# here even if we are in blocking mode.
|
||||
sock.settimeout(sock.gettimeout() or 20)
|
||||
|
||||
# limit at a megabyte so that we don't read infinitely
|
||||
for _ in range(1024 ** 3 // 4096):
|
||||
# may raise a timeout/disconnect exception.
|
||||
if not sock.recv(4096):
|
||||
break
|
||||
|
||||
# Now we can close the other half as well.
|
||||
sock.shutdown(socket.SHUT_RD)
|
||||
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
sock.close()
|
||||
|
||||
|
||||
class _Connection:
|
||||
|
||||
rbufsize = -1
|
||||
wbufsize = -1
|
||||
|
||||
def _makefile(self):
|
||||
"""
|
||||
Set up .rfile and .wfile attributes from .connection
|
||||
"""
|
||||
# Ideally, we would use the Buffered IO in Python 3 by default.
|
||||
# Unfortunately, the implementation of .peek() is broken for n>1 bytes,
|
||||
# as it may just return what's left in the buffer and not all the bytes we want.
|
||||
# As a workaround, we just use unbuffered sockets directly.
|
||||
# https://mail.python.org/pipermail/python-dev/2009-June/089986.html
|
||||
self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
|
||||
self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
|
||||
|
||||
def __init__(self, connection):
|
||||
if connection:
|
||||
self.connection = connection
|
||||
self.ip_address = connection.getpeername()
|
||||
self._makefile()
|
||||
else:
|
||||
self.connection = None
|
||||
self.ip_address = None
|
||||
self.rfile = None
|
||||
self.wfile = None
|
||||
|
||||
self.tls_established = False
|
||||
self.finished = False
|
||||
|
||||
def get_current_cipher(self):
|
||||
if not self.tls_established:
|
||||
return None
|
||||
|
||||
name = self.connection.get_cipher_name()
|
||||
bits = self.connection.get_cipher_bits()
|
||||
version = self.connection.get_cipher_version()
|
||||
return name, bits, version
|
||||
|
||||
def finish(self):
|
||||
self.finished = True
|
||||
# If we have an SSL connection, wfile.close == connection.close
|
||||
# (We call _FileLike.set_descriptor(conn))
|
||||
# Closing the socket is not our task, therefore we don't call close
|
||||
# then.
|
||||
if not isinstance(self.connection, SSL.Connection):
|
||||
if not getattr(self.wfile, "closed", False):
|
||||
try:
|
||||
self.wfile.flush()
|
||||
self.wfile.close()
|
||||
except exceptions.TcpDisconnect:
|
||||
pass
|
||||
|
||||
self.rfile.close()
|
||||
else:
|
||||
try:
|
||||
self.connection.shutdown()
|
||||
except SSL.Error:
|
||||
pass
|
||||
|
||||
|
||||
class ConnectionCloser:
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self._canceled = False
|
||||
|
||||
def pop(self):
|
||||
"""
|
||||
Cancel the current closer, and return a fresh one.
|
||||
"""
|
||||
self._canceled = True
|
||||
return ConnectionCloser(self.conn)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
if not self._canceled:
|
||||
self.conn.close()
|
||||
|
||||
|
||||
class TCPClient(_Connection):
|
||||
|
||||
def __init__(self, address, source_address=None, spoof_source_address=None):
|
||||
super().__init__(None)
|
||||
self.address = address
|
||||
self.source_address = source_address
|
||||
self.cert = None
|
||||
self.server_certs = []
|
||||
self.sni = None
|
||||
self.spoof_source_address = spoof_source_address
|
||||
|
||||
@property
|
||||
def ssl_verification_error(self) -> Optional[exceptions.InvalidCertificateException]:
|
||||
return getattr(self.connection, "cert_error", None)
|
||||
|
||||
def close(self):
|
||||
# Make sure to close the real socket, not the SSL proxy.
|
||||
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
|
||||
# it tries to renegotiate...
|
||||
if self.connection:
|
||||
if isinstance(self.connection, SSL.Connection):
|
||||
close_socket(self.connection._socket)
|
||||
else:
|
||||
close_socket(self.connection)
|
||||
|
||||
def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs):
|
||||
context = tls.create_client_context(
|
||||
alpn_protos=alpn_protos,
|
||||
sni=sni,
|
||||
**sslctx_kwargs
|
||||
)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
if sni:
|
||||
self.sni = sni
|
||||
self.connection.set_tlsext_host_name(sni.encode("idna"))
|
||||
self.connection.set_connect_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
except SSL.Error as v:
|
||||
if self.ssl_verification_error:
|
||||
raise self.ssl_verification_error
|
||||
else:
|
||||
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
|
||||
|
||||
self.cert = certs.Cert(self.connection.get_peer_certificate())
|
||||
|
||||
# Keep all server certificates in a list
|
||||
for i in self.connection.get_peer_cert_chain():
|
||||
self.server_certs.append(certs.Cert(i))
|
||||
|
||||
self.tls_established = True
|
||||
self.rfile.set_descriptor(self.connection)
|
||||
self.wfile.set_descriptor(self.connection)
|
||||
|
||||
def makesocket(self, family, type, proto):
|
||||
# some parties (cuckoo sandbox) need to hook this
|
||||
return socket.socket(family, type, proto)
|
||||
|
||||
def create_connection(self, timeout=None):
|
||||
# Based on the official socket.create_connection implementation of Python 3.6.
|
||||
# https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py
|
||||
|
||||
err = None
|
||||
for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM):
|
||||
af, socktype, proto, canonname, sa = res
|
||||
sock = None
|
||||
try:
|
||||
sock = self.makesocket(af, socktype, proto)
|
||||
if timeout:
|
||||
sock.settimeout(timeout)
|
||||
if self.source_address:
|
||||
sock.bind(self.source_address)
|
||||
if self.spoof_source_address:
|
||||
try:
|
||||
if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT):
|
||||
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover
|
||||
except Exception as e:
|
||||
# socket.IP_TRANSPARENT might not be available on every OS and Python version
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
raise exceptions.TcpException(
|
||||
"Failed to spoof the source address: " + str(e)
|
||||
)
|
||||
sock.connect(sa)
|
||||
return sock
|
||||
|
||||
except OSError as _:
|
||||
err = _
|
||||
if sock is not None:
|
||||
sock.close()
|
||||
|
||||
if err is not None:
|
||||
raise err
|
||||
else:
|
||||
raise OSError("getaddrinfo returns an empty list") # pragma: no cover
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
connection = self.create_connection()
|
||||
except OSError as err:
|
||||
raise exceptions.TcpException(
|
||||
'Error connecting to "%s": %s' %
|
||||
(self.address[0], err)
|
||||
)
|
||||
self.connection = connection
|
||||
self.source_address = connection.getsockname()
|
||||
self.ip_address = connection.getpeername()
|
||||
self._makefile()
|
||||
return ConnectionCloser(self)
|
||||
|
||||
def settimeout(self, n):
|
||||
self.connection.settimeout(n)
|
||||
|
||||
def gettimeout(self):
|
||||
return self.connection.gettimeout()
|
||||
|
||||
def get_alpn_proto_negotiated(self):
|
||||
if self.tls_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else:
|
||||
return b""
|
||||
|
||||
|
||||
class BaseHandler(_Connection):
|
||||
|
||||
"""
|
||||
The instantiator is expected to call the handle() and finish() methods.
|
||||
"""
|
||||
|
||||
def __init__(self, connection, address, server):
|
||||
super().__init__(connection)
|
||||
self.address = address
|
||||
self.server = server
|
||||
self.clientcert = None
|
||||
|
||||
def convert_to_tls(self, cert, key, **sslctx_kwargs):
|
||||
"""
|
||||
Convert connection to SSL.
|
||||
For a list of parameters, see tls.create_server_context(...)
|
||||
"""
|
||||
|
||||
context = tls.create_server_context(
|
||||
cert=cert,
|
||||
key=key,
|
||||
**sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
try:
|
||||
self.connection.do_handshake()
|
||||
except SSL.Error as v:
|
||||
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
|
||||
self.tls_established = True
|
||||
cert = self.connection.get_peer_certificate()
|
||||
if cert:
|
||||
self.clientcert = certs.Cert(cert)
|
||||
self.rfile.set_descriptor(self.connection)
|
||||
self.wfile.set_descriptor(self.connection)
|
||||
|
||||
def handle(self): # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def settimeout(self, n):
|
||||
self.connection.settimeout(n)
|
||||
|
||||
def get_alpn_proto_negotiated(self):
|
||||
if self.tls_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else:
|
||||
return b""
|
||||
|
||||
|
||||
class Counter:
|
||||
def __init__(self):
|
||||
self._count = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
with self._lock:
|
||||
return self._count
|
||||
|
||||
def __enter__(self):
|
||||
with self._lock:
|
||||
self._count += 1
|
||||
|
||||
def __exit__(self, *args):
|
||||
with self._lock:
|
||||
self._count -= 1
|
||||
|
||||
|
||||
class TCPServer:
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
self.__is_shut_down = threading.Event()
|
||||
self.__is_shut_down.set()
|
||||
self.__shutdown_request = False
|
||||
|
||||
if self.address[0] == 'localhost':
|
||||
raise OSError("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.")
|
||||
|
||||
self.socket = None
|
||||
|
||||
try:
|
||||
# First try to bind an IPv6 socket, attempting to enable IPv4 support if the OS supports it.
|
||||
# This allows us to accept connections for ::1 and 127.0.0.1 on the same socket.
|
||||
# Only works if self.address == ""
|
||||
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
|
||||
self.socket.bind(self.address)
|
||||
except OSError:
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
if not self.socket:
|
||||
try:
|
||||
# Binding to an IPv6 + IPv4 socket failed, lets fall back to IPv4 only.
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self.socket.bind(self.address)
|
||||
except OSError:
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
if not self.socket:
|
||||
# Binding to an IPv4 only socket failed, lets fall back to IPv6 only.
|
||||
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self.socket.bind(self.address)
|
||||
|
||||
self.address = self.socket.getsockname()
|
||||
self.socket.listen()
|
||||
self.handler_counter = Counter()
|
||||
|
||||
def connection_thread(self, connection, client_address):
|
||||
with self.handler_counter:
|
||||
try:
|
||||
self.handle_client_connection(connection, client_address)
|
||||
except OSError as e: # pragma: no cover
|
||||
# This catches situations where the underlying connection is
|
||||
# closed beneath us. Syscalls on the connection object at this
|
||||
# point returns EINVAL. If this happens, we close the socket and
|
||||
# move on.
|
||||
if not e.errno == errno.EINVAL:
|
||||
raise
|
||||
except:
|
||||
self.handle_error(connection, client_address)
|
||||
finally:
|
||||
close_socket(connection)
|
||||
|
||||
def serve_forever(self, poll_interval=0.1):
|
||||
self.__is_shut_down.clear()
|
||||
try:
|
||||
while not self.__shutdown_request:
|
||||
r, w_, e_ = select.select([self.socket], [], [], poll_interval)
|
||||
if self.socket in r:
|
||||
connection, client_address = self.socket.accept()
|
||||
t = basethread.BaseThread(
|
||||
"TCPConnectionHandler ({}: {}:{} -> {}:{})".format(
|
||||
self.__class__.__name__,
|
||||
client_address[0],
|
||||
client_address[1],
|
||||
self.address[0],
|
||||
self.address[1],
|
||||
),
|
||||
target=self.connection_thread,
|
||||
args=(connection, client_address),
|
||||
)
|
||||
t.setDaemon(1)
|
||||
try:
|
||||
t.start()
|
||||
except threading.ThreadError:
|
||||
self.handle_error(connection, client_address)
|
||||
connection.close()
|
||||
finally:
|
||||
self.__shutdown_request = False
|
||||
self.__is_shut_down.set()
|
||||
|
||||
def shutdown(self):
|
||||
self.__shutdown_request = True
|
||||
self.__is_shut_down.wait()
|
||||
self.socket.close()
|
||||
self.handle_shutdown()
|
||||
|
||||
def handle_error(self, connection_, client_address, fp=sys.stderr):
|
||||
"""
|
||||
Called when handle_client_connection raises an exception.
|
||||
"""
|
||||
# If a thread has persisted after interpreter exit, the module might be
|
||||
# none.
|
||||
if traceback:
|
||||
exc = str(traceback.format_exc())
|
||||
print('-' * 40, file=fp)
|
||||
print(
|
||||
"Error in processing of request from %s" % repr(client_address), file=fp)
|
||||
print(exc, file=fp)
|
||||
print('-' * 40, file=fp)
|
||||
|
||||
def handle_client_connection(self, conn, client_address): # pragma: no cover
|
||||
"""
|
||||
Called after client connection.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_shutdown(self):
|
||||
"""
|
||||
Called after server shutdown.
|
||||
"""
|
||||
|
||||
def wait_for_silence(self, timeout=5):
|
||||
start = time.time()
|
||||
while 1:
|
||||
if time.time() - start >= timeout:
|
||||
raise exceptions.Timeout(
|
||||
"%s service threads still alive" %
|
||||
self.handler_counter.count
|
||||
)
|
||||
if self.handler_counter.count == 0:
|
||||
return
|
@ -4,126 +4,6 @@ Spec: https://tools.ietf.org/html/rfc6455
|
||||
"""
|
||||
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import struct
|
||||
|
||||
from wsproto.utilities import ACCEPT_GUID
|
||||
from wsproto.handshake import WEBSOCKET_VERSION
|
||||
from wsproto.frame_protocol import RsvBits, Header, Frame, XorMaskerSimple, XorMaskerNull
|
||||
|
||||
from mitmproxy.net import http
|
||||
from mitmproxy.utils import bits, strutils
|
||||
|
||||
|
||||
def read_frame(rfile, parse=True):
|
||||
"""
|
||||
Reads a full WebSocket frame from a file-like object.
|
||||
|
||||
Returns a parsed frame header, parsed frame, and the consumed bytes.
|
||||
"""
|
||||
|
||||
consumed_bytes = b''
|
||||
|
||||
def consume(len):
|
||||
nonlocal consumed_bytes
|
||||
d = rfile.safe_read(len)
|
||||
consumed_bytes += d
|
||||
return d
|
||||
|
||||
first_byte, second_byte = consume(2)
|
||||
fin = bits.getbit(first_byte, 7)
|
||||
rsv1 = bits.getbit(first_byte, 6)
|
||||
rsv2 = bits.getbit(first_byte, 5)
|
||||
rsv3 = bits.getbit(first_byte, 4)
|
||||
opcode = first_byte & 0xF
|
||||
mask_bit = bits.getbit(second_byte, 7)
|
||||
length_code = second_byte & 0x7F
|
||||
|
||||
# payload_len > 125 indicates you need to read more bytes
|
||||
# to get the actual payload length
|
||||
if length_code <= 125:
|
||||
payload_len = length_code
|
||||
elif length_code == 126:
|
||||
payload_len, = struct.unpack("!H", consume(2))
|
||||
else: # length_code == 127:
|
||||
payload_len, = struct.unpack("!Q", consume(8))
|
||||
|
||||
# masking key only present if mask bit set
|
||||
if mask_bit == 1:
|
||||
masking_key = consume(4)
|
||||
masker = XorMaskerSimple(masking_key)
|
||||
else:
|
||||
masking_key = None
|
||||
masker = XorMaskerNull()
|
||||
|
||||
masked_payload = consume(payload_len)
|
||||
|
||||
if parse:
|
||||
header = Header(
|
||||
fin=fin,
|
||||
rsv=RsvBits(rsv1, rsv2, rsv3),
|
||||
opcode=opcode,
|
||||
payload_len=payload_len,
|
||||
masking_key=masking_key,
|
||||
)
|
||||
frame = Frame(
|
||||
opcode=opcode,
|
||||
payload=masker.process(masked_payload),
|
||||
frame_finished=fin,
|
||||
message_finished=fin
|
||||
)
|
||||
else:
|
||||
header = None
|
||||
frame = None
|
||||
|
||||
return header, frame, consumed_bytes
|
||||
|
||||
|
||||
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
|
||||
"""
|
||||
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||
specified, it is generated, and can be found in sec-websocket-key in
|
||||
the returned header set.
|
||||
|
||||
Returns an instance of http.Headers
|
||||
"""
|
||||
if version is None:
|
||||
version = WEBSOCKET_VERSION
|
||||
if key is None:
|
||||
key = base64.b64encode(os.urandom(16)).decode('ascii')
|
||||
h = http.Headers(
|
||||
connection="upgrade",
|
||||
upgrade="websocket",
|
||||
sec_websocket_version=version,
|
||||
sec_websocket_key=key,
|
||||
)
|
||||
if protocol is not None:
|
||||
h['sec-websocket-protocol'] = protocol
|
||||
if extensions is not None:
|
||||
h['sec-websocket-extensions'] = extensions
|
||||
return h
|
||||
|
||||
|
||||
def server_handshake_headers(client_key, protocol=None, extensions=None):
|
||||
"""
|
||||
The server response is a valid HTTP 101 response.
|
||||
|
||||
Returns an instance of http.Headers
|
||||
"""
|
||||
h = http.Headers(
|
||||
connection="upgrade",
|
||||
upgrade="websocket",
|
||||
sec_websocket_accept=create_server_nonce(client_key),
|
||||
)
|
||||
if protocol is not None:
|
||||
h['sec-websocket-protocol'] = protocol
|
||||
if extensions is not None:
|
||||
h['sec-websocket-extensions'] = extensions
|
||||
return h
|
||||
|
||||
|
||||
def check_handshake(headers):
|
||||
return (
|
||||
"upgrade" in headers.get("connection", "").lower() and
|
||||
@ -132,14 +12,6 @@ def check_handshake(headers):
|
||||
)
|
||||
|
||||
|
||||
def create_server_nonce(client_nonce):
|
||||
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + ACCEPT_GUID).digest())
|
||||
|
||||
|
||||
def check_client_version(headers):
|
||||
return headers.get("sec-websocket-version", "") == WEBSOCKET_VERSION
|
||||
|
||||
|
||||
def get_extensions(headers):
|
||||
return headers.get("sec-websocket-extensions", None)
|
||||
|
||||
|
@ -1,17 +1,6 @@
|
||||
from io import BytesIO
|
||||
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.net import http
|
||||
|
||||
|
||||
def treader(bytes):
|
||||
"""
|
||||
Construct a tcp.Read object from bytes.
|
||||
"""
|
||||
fp = BytesIO(bytes)
|
||||
return tcp.Reader(fp)
|
||||
|
||||
|
||||
def treq(**kwargs) -> http.Request:
|
||||
"""
|
||||
Returns:
|
||||
|
@ -1,11 +1,21 @@
|
||||
import ipaddress
|
||||
from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from mitmproxy.net import socks
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
|
||||
# this is a temporary placeholder here, we remove the file-based API when we transition socks proxying to sans-io.
|
||||
class tutils: # noqa
|
||||
@staticmethod
|
||||
def treader(data: bytes):
|
||||
io = BytesIO(data)
|
||||
io.safe_read = io.read
|
||||
return io
|
||||
|
||||
|
||||
def test_client_greeting():
|
||||
raw = tutils.treader(b"\x05\x02\x00\xBE\xEF")
|
||||
out = BytesIO()
|
||||
|
@ -1,813 +0,0 @@
|
||||
from io import BytesIO
|
||||
import re
|
||||
import queue
|
||||
import time
|
||||
import socket
|
||||
import random
|
||||
import threading
|
||||
import pytest
|
||||
from unittest import mock
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.utils import data
|
||||
from ...conftest import skip_no_ipv6
|
||||
|
||||
from . import tservers
|
||||
|
||||
|
||||
cdata = data.Data(__name__)
|
||||
|
||||
|
||||
class EchoHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
v = self.rfile.readline()
|
||||
self.wfile.write(v)
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class ClientCipherListHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(f"{self.connection.get_cipher_list()}\n".encode())
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class HangHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
# Hang as long as the client connection is alive
|
||||
while True:
|
||||
try:
|
||||
self.connection.setblocking(0)
|
||||
ret = self.connection.recv(1)
|
||||
# Client connection is dead...
|
||||
if ret == "" or ret == b"":
|
||||
return
|
||||
except OSError:
|
||||
pass
|
||||
except SSL.WantReadError:
|
||||
pass
|
||||
except Exception:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
class ALPNHandler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
alp = self.get_alpn_proto_negotiated()
|
||||
if alp:
|
||||
self.wfile.write(alp)
|
||||
else:
|
||||
self.wfile.write(b"NONE")
|
||||
self.wfile.flush()
|
||||
|
||||
|
||||
class TestServer(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
def test_echo(self):
|
||||
testval = b"echo!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_thread_start_error(self):
|
||||
with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m:
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
assert not c.rfile.read(1)
|
||||
assert m.called
|
||||
assert "nonewthread" in self.q.get_nowait()
|
||||
self.test_echo()
|
||||
|
||||
|
||||
class TestServerBind(tservers.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
# We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1.
|
||||
# Those still appear as "127.0.0.1" in the table, so we need to strip the prefix.
|
||||
peername = self.connection.getpeername()
|
||||
address = re.sub(r"^::ffff:(?=\d+.\d+.\d+.\d+$)", "", peername[0])
|
||||
port = peername[1]
|
||||
|
||||
self.wfile.write(str((address, port)).encode())
|
||||
self.wfile.flush()
|
||||
|
||||
def test_bind(self):
|
||||
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
|
||||
for i in range(20):
|
||||
random_port = random.randrange(1024, 65535)
|
||||
try:
|
||||
c = tcp.TCPClient(
|
||||
("127.0.0.1", self.port), source_address=(
|
||||
"127.0.0.1", random_port))
|
||||
with c.connect():
|
||||
assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
|
||||
return
|
||||
except exceptions.TcpException: # port probably already in use
|
||||
pass
|
||||
|
||||
|
||||
@skip_no_ipv6
|
||||
class TestServerIPv6(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
addr = ("::1", 0)
|
||||
|
||||
def test_echo(self):
|
||||
testval = b"echo!\n"
|
||||
c = tcp.TCPClient(("::1", self.port))
|
||||
with c.connect():
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestEcho(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
def test_echo(self):
|
||||
testval = b"echo!\n"
|
||||
c = tcp.TCPClient(("localhost", self.port))
|
||||
with c.connect():
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class HardDisconnectHandler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class TestFinishFail(tservers.ServerTestBase):
|
||||
|
||||
"""
|
||||
This tests a difficult-to-trigger exception in the .finish() method of
|
||||
the handler.
|
||||
"""
|
||||
handler = EchoHandler
|
||||
|
||||
def test_disconnect_in_finish(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.wfile.write(b"foo\n")
|
||||
c.wfile.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
|
||||
c.finish()
|
||||
|
||||
|
||||
class TestServerSSL(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
cipher_list="AES256-SHA",
|
||||
chain_file=cdata.path("data/server.crt")
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL)
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_get_current_cipher(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
assert not c.get_current_cipher()
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
ret = c.get_current_cipher()
|
||||
assert ret
|
||||
assert "AES" in ret[0]
|
||||
|
||||
|
||||
class TestSSLv3Only(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
request_client_cert=False,
|
||||
v3_only=True
|
||||
)
|
||||
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
|
||||
|
||||
class TestInvalidTrustFile(tservers.ServerTestBase):
|
||||
def test_invalid_trust_file_should_fail(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
c.convert_to_tls(
|
||||
sni="example.mitmproxy.org",
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_pemfile=cdata.path("data/verificationcerts/generate.py")
|
||||
)
|
||||
|
||||
|
||||
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
ssl = dict(
|
||||
cert=cdata.path("data/verificationcerts/self-signed.crt"),
|
||||
key=cdata.path("data/verificationcerts/self-signed.key")
|
||||
)
|
||||
|
||||
def test_mode_default_should_pass(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls()
|
||||
|
||||
# Verification errors should be saved even if connection isn't aborted
|
||||
# aborted
|
||||
assert c.ssl_verification_error
|
||||
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_mode_none_should_pass(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(verify=SSL.VERIFY_NONE)
|
||||
|
||||
# Verification errors should be saved even if connection isn't aborted
|
||||
assert c.ssl_verification_error
|
||||
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_mode_strict_should_fail(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.InvalidCertificateException):
|
||||
c.convert_to_tls(
|
||||
sni="example.mitmproxy.org",
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
|
||||
assert c.ssl_verification_error
|
||||
|
||||
# Unknown issuing certificate authority for first certificate
|
||||
assert "errno: 18" in str(c.ssl_verification_error)
|
||||
assert "depth: 0" in str(c.ssl_verification_error)
|
||||
|
||||
|
||||
class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
ssl = dict(
|
||||
cert=cdata.path("data/verificationcerts/trusted-leaf.crt"),
|
||||
key=cdata.path("data/verificationcerts/trusted-leaf.key")
|
||||
)
|
||||
|
||||
def test_should_fail_without_sni(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
c.convert_to_tls(
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
|
||||
def test_mode_none_should_pass_without_sni(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(
|
||||
verify=SSL.VERIFY_NONE,
|
||||
ca_path=cdata.path("data/verificationcerts/")
|
||||
)
|
||||
|
||||
assert "Cannot validate hostname, SNI missing." in str(c.ssl_verification_error)
|
||||
|
||||
def test_should_fail(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.InvalidCertificateException):
|
||||
c.convert_to_tls(
|
||||
sni="mitmproxy.org",
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
assert c.ssl_verification_error
|
||||
|
||||
|
||||
class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
ssl = dict(
|
||||
cert=cdata.path("data/verificationcerts/trusted-leaf.crt"),
|
||||
key=cdata.path("data/verificationcerts/trusted-leaf.key")
|
||||
)
|
||||
|
||||
def test_mode_strict_w_pemfile_should_pass(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(
|
||||
sni="example.mitmproxy.org",
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
|
||||
)
|
||||
|
||||
assert c.ssl_verification_error is None
|
||||
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
def test_mode_strict_w_confdir_should_pass(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(
|
||||
sni="example.mitmproxy.org",
|
||||
verify=SSL.VERIFY_PEER,
|
||||
ca_path=cdata.path("data/verificationcerts/")
|
||||
)
|
||||
|
||||
assert c.ssl_verification_error is None
|
||||
|
||||
testval = b"echo!\n"
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class TestSSLClientCert(tservers.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(b"%d\n" % self.clientcert.serial)
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = dict(
|
||||
request_client_cert=True,
|
||||
v3_only=False
|
||||
)
|
||||
|
||||
def test_clientcert(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(
|
||||
cert=cdata.path("data/clientcert/client.pem"))
|
||||
assert c.rfile.readline().strip() == b"1"
|
||||
|
||||
def test_clientcert_err(self, tdata):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
c.convert_to_tls(cert=cdata.path("data/clientcert/make"))
|
||||
|
||||
|
||||
class TestSNI(tservers.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle_sni(self, connection):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(self.sni)
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
assert c.sni == "foo.com"
|
||||
assert c.rfile.readline() == b"foo.com"
|
||||
|
||||
def test_idn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(sni="mitmproxyäöüß.example.com")
|
||||
assert c.tls_established
|
||||
assert "doesn't match" not in str(c.ssl_verification_error)
|
||||
|
||||
|
||||
class TestServerCipherList(tservers.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cipher_list='AES256-GCM-SHA384'
|
||||
)
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_echo(self):
|
||||
# Not working for OpenSSL 1.1.1, see
|
||||
# https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
expected = b"['TLS_AES_256_GCM_SHA384']"
|
||||
assert c.rfile.readline() == expected
|
||||
|
||||
|
||||
class TestServerCurrentCipher(tservers.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
sni = None
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(str(self.get_current_cipher()).encode())
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = dict(
|
||||
cipher_list='AES256-GCM-SHA384'
|
||||
)
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_echo(self):
|
||||
# Not working for OpenSSL 1.1.1, see
|
||||
# https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
assert b'AES256-GCM-SHA384' in c.rfile.readline()
|
||||
|
||||
|
||||
class TestServerCipherListError(tservers.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cipher_list=b'bogus'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(Exception, match="handshake error"):
|
||||
c.convert_to_tls(sni="foo.com")
|
||||
|
||||
|
||||
class TestClientCipherListError(tservers.ServerTestBase):
|
||||
handler = ClientCipherListHandler
|
||||
ssl = dict(
|
||||
cipher_list='RC4-SHA'
|
||||
)
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
with pytest.raises(Exception, match="cipher specification"):
|
||||
c.convert_to_tls(sni="foo.com", cipher_list="bogus")
|
||||
|
||||
|
||||
class TestSSLDisconnect(tservers.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.finish()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls()
|
||||
# Exercise SSL.ZeroReturnError
|
||||
c.rfile.read(10)
|
||||
c.close()
|
||||
with pytest.raises(exceptions.TcpDisconnect):
|
||||
c.wfile.write(b"foo")
|
||||
with pytest.raises(queue.Empty):
|
||||
self.q.get_nowait()
|
||||
|
||||
|
||||
class TestSSLHardDisconnect(tservers.ServerTestBase):
|
||||
handler = HardDisconnectHandler
|
||||
ssl = True
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls()
|
||||
# Exercise SSL.SysCallError
|
||||
c.rfile.read(10)
|
||||
c.close()
|
||||
with pytest.raises(exceptions.TcpDisconnect):
|
||||
c.wfile.write(b"foo")
|
||||
|
||||
|
||||
class TestDisconnect(tservers.ServerTestBase):
|
||||
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.rfile.read(10)
|
||||
c.wfile.write(b"foo")
|
||||
c.close()
|
||||
c.close()
|
||||
|
||||
|
||||
class TestServerTimeOut(tservers.ServerTestBase):
|
||||
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.timeout = False
|
||||
self.settimeout(0.01)
|
||||
try:
|
||||
self.rfile.read(10)
|
||||
except exceptions.TcpTimeout:
|
||||
self.timeout = True
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
time.sleep(0.3)
|
||||
assert self.last_handler.timeout
|
||||
|
||||
|
||||
class TestTimeOut(tservers.ServerTestBase):
|
||||
handler = HangHandler
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.settimeout(0.1)
|
||||
assert c.gettimeout() == 0.1
|
||||
with pytest.raises(exceptions.TcpTimeout):
|
||||
c.rfile.read(10)
|
||||
|
||||
|
||||
class TestALPNClient(tservers.ServerTestBase):
|
||||
handler = ALPNHandler
|
||||
ssl = dict(
|
||||
alpn_select=b"bar"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize('alpn_protos, expected_negotiated, expected_response', [
|
||||
([b"foo", b"bar", b"fasel"], b'bar', b'bar'),
|
||||
([], b'', b'NONE'),
|
||||
(None, b'', b'NONE'),
|
||||
])
|
||||
def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(alpn_protos=alpn_protos)
|
||||
assert c.get_alpn_proto_negotiated() == expected_negotiated
|
||||
assert c.rfile.readline().strip() == expected_response
|
||||
|
||||
|
||||
class TestNoSSLNoALPNClient(tservers.ServerTestBase):
|
||||
handler = ALPNHandler
|
||||
|
||||
def test_no_ssl_no_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
assert c.get_alpn_proto_negotiated() == b""
|
||||
assert c.rfile.readline().strip() == b"NONE"
|
||||
|
||||
|
||||
class TestSSLTimeOut(tservers.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = True
|
||||
|
||||
def test_timeout_client(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls()
|
||||
c.settimeout(0.1)
|
||||
with pytest.raises(exceptions.TcpTimeout):
|
||||
c.rfile.read(10)
|
||||
|
||||
|
||||
class TestDHParams(tservers.ServerTestBase):
|
||||
handler = HangHandler
|
||||
ssl = dict(
|
||||
dhparams=certs.CertStore.load_dhparam(
|
||||
cdata.path("data/dhparam.pem"),
|
||||
),
|
||||
cipher_list="DHE-RSA-AES256-SHA"
|
||||
)
|
||||
|
||||
def test_dhparams(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.connect():
|
||||
c.convert_to_tls(method=SSL.TLSv1_2_METHOD)
|
||||
ret = c.get_current_cipher()
|
||||
assert ret[0] == "DHE-RSA-AES256-SHA"
|
||||
|
||||
|
||||
class TestTCPClient(tservers.ServerTestBase):
|
||||
|
||||
def test_conerr(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
with pytest.raises(exceptions.TcpException, match="Error connecting"):
|
||||
c.connect()
|
||||
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with c.create_connection(timeout=20) as conn:
|
||||
assert conn.gettimeout() == 20
|
||||
|
||||
def test_spoof_address(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port), spoof_source_address=("127.0.0.1", 0))
|
||||
with pytest.raises(exceptions.TcpException, match="Failed to spoof"):
|
||||
c.connect()
|
||||
|
||||
|
||||
class TestTCPServer:
|
||||
|
||||
def test_binderr(self):
|
||||
with pytest.raises(socket.error, match="prohibited"):
|
||||
tcp.TCPServer(("localhost", 8080))
|
||||
|
||||
def test_wait_for_silence(self):
|
||||
s = tcp.TCPServer(("127.0.0.1", 0))
|
||||
with s.handler_counter:
|
||||
with pytest.raises(exceptions.Timeout):
|
||||
s.wait_for_silence()
|
||||
s.shutdown()
|
||||
|
||||
|
||||
class TestFileLike:
|
||||
|
||||
def test_blocksize(self):
|
||||
s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz")
|
||||
s = tcp.Reader(s)
|
||||
s.BLOCKSIZE = 2
|
||||
assert s.read(1) == b"1"
|
||||
assert s.read(2) == b"23"
|
||||
assert s.read(3) == b"456"
|
||||
assert s.read(4) == b"7890"
|
||||
d = s.read(-1)
|
||||
assert d.startswith(b"abc") and d.endswith(b"xyz")
|
||||
|
||||
def test_wrap(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s.flush()
|
||||
s = tcp.Reader(s)
|
||||
assert s.readline() == b"foobar\n"
|
||||
assert s.readline() == b"foobar"
|
||||
# Test __getattr__
|
||||
assert s.isatty
|
||||
|
||||
def test_limit(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
assert s.readline(3) == b"foo"
|
||||
|
||||
def test_limitless(self):
|
||||
s = BytesIO(b"f" * (50 * 1024))
|
||||
s = tcp.Reader(s)
|
||||
ret = s.read(-1)
|
||||
assert len(ret) == 50 * 1024
|
||||
|
||||
def test_readlog(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
assert not s.is_logging()
|
||||
s.start_log()
|
||||
assert s.is_logging()
|
||||
s.readline()
|
||||
assert s.get_log() == b"foobar\n"
|
||||
s.read(1)
|
||||
assert s.get_log() == b"foobar\nf"
|
||||
s.start_log()
|
||||
assert s.get_log() == b""
|
||||
s.read(1)
|
||||
assert s.get_log() == b"o"
|
||||
s.stop_log()
|
||||
with pytest.raises(ValueError):
|
||||
s.get_log()
|
||||
|
||||
def test_writelog(self):
|
||||
s = BytesIO()
|
||||
s = tcp.Writer(s)
|
||||
s.start_log()
|
||||
assert s.is_logging()
|
||||
s.write(b"x")
|
||||
assert s.get_log() == b"x"
|
||||
s.write(b"x")
|
||||
assert s.get_log() == b"xx"
|
||||
|
||||
def test_writer_flush_error(self):
|
||||
s = BytesIO()
|
||||
s = tcp.Writer(s)
|
||||
o = mock.MagicMock()
|
||||
o.flush = mock.MagicMock(side_effect=socket.error)
|
||||
s.o = o
|
||||
with pytest.raises(exceptions.TcpDisconnect):
|
||||
s.flush()
|
||||
|
||||
def test_reader_read_error(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
o = mock.MagicMock()
|
||||
o.read = mock.MagicMock(side_effect=socket.error)
|
||||
s.o = o
|
||||
with pytest.raises(exceptions.TcpDisconnect):
|
||||
s.read(10)
|
||||
|
||||
def test_reset_timestamps(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
s.first_byte_timestamp = 500
|
||||
s.reset_timestamps()
|
||||
assert not s.first_byte_timestamp
|
||||
|
||||
def test_first_byte_timestamp_updated_on_read(self):
|
||||
s = BytesIO(b"foobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
s.read(1)
|
||||
assert s.first_byte_timestamp
|
||||
expected = s.first_byte_timestamp
|
||||
s.read(5)
|
||||
assert s.first_byte_timestamp == expected
|
||||
|
||||
def test_first_byte_timestamp_updated_on_readline(self):
|
||||
s = BytesIO(b"foobar\nfoobar\nfoobar")
|
||||
s = tcp.Reader(s)
|
||||
s.readline()
|
||||
assert s.first_byte_timestamp
|
||||
expected = s.first_byte_timestamp
|
||||
s.readline()
|
||||
assert s.first_byte_timestamp == expected
|
||||
|
||||
def test_read_ssl_error(self):
|
||||
s = mock.MagicMock()
|
||||
s.read = mock.MagicMock(side_effect=SSL.Error())
|
||||
s = tcp.Reader(s)
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
s.read(1)
|
||||
|
||||
def test_read_syscall_ssl_error(self):
|
||||
s = mock.MagicMock()
|
||||
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
|
||||
s = tcp.Reader(s)
|
||||
with pytest.raises(exceptions.TlsException):
|
||||
s.read(1)
|
||||
|
||||
def test_reader_readline_disconnect(self):
|
||||
o = mock.MagicMock()
|
||||
o.read = mock.MagicMock(side_effect=socket.error)
|
||||
s = tcp.Reader(o)
|
||||
with pytest.raises(exceptions.TcpDisconnect):
|
||||
s.readline(10)
|
||||
|
||||
def test_reader_incomplete_error(self):
|
||||
s = BytesIO(b"foobar")
|
||||
s = tcp.Reader(s)
|
||||
with pytest.raises(exceptions.TcpReadIncomplete):
|
||||
s.safe_read(10)
|
||||
|
||||
|
||||
class TestPeek(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
|
||||
def _connect(self, c):
|
||||
return c.connect()
|
||||
|
||||
def test_peek(self):
|
||||
testval = b"peek!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
with self._connect(c):
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
|
||||
assert c.rfile.peek(4) == b"peek"
|
||||
assert c.rfile.peek(6) == b"peek!\n"
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
c.close()
|
||||
with pytest.raises(exceptions.NetlibException):
|
||||
c.rfile.peek(1)
|
||||
|
||||
|
||||
class TestPeekSSL(TestPeek):
|
||||
ssl = True
|
||||
|
||||
def _connect(self, c):
|
||||
with c.connect() as conn:
|
||||
c.convert_to_tls()
|
||||
return conn.pop()
|
@ -4,9 +4,6 @@ import pytest
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.net import tls
|
||||
from mitmproxy.net.tcp import TCPClient
|
||||
from test.mitmproxy.net.test_tcp import EchoHandler
|
||||
from . import tservers
|
||||
|
||||
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
|
||||
"03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637"
|
||||
@ -19,7 +16,7 @@ FULL_CLIENT_HELLO_NO_EXTENSIONS = (
|
||||
CLIENT_HELLO_NO_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
class TestMasterSecretLogger(tservers.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
@ -52,6 +49,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase):
|
||||
tls.MasterSecretLogger.create_logfun("test"),
|
||||
tls.MasterSecretLogger)
|
||||
assert not tls.MasterSecretLogger.create_logfun(False)
|
||||
"""
|
||||
|
||||
|
||||
class TestTLSInvalid:
|
||||
|
@ -1,77 +1,4 @@
|
||||
import pytest
|
||||
from io import BytesIO
|
||||
from unittest import mock
|
||||
|
||||
from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame
|
||||
|
||||
from mitmproxy.net import http, websocket
|
||||
|
||||
|
||||
@pytest.mark.parametrize("input,masking_key,payload_length", [
|
||||
(b'\x01\rserver-foobar', None, 13),
|
||||
(b'\x01\x8dasdf\x12\x16\x16\x10\x04\x01I\x00\x0e\x1c\x06\x07\x13', b'asdf', 13),
|
||||
(b'\x01~\x04\x00server-foobar', None, 1024),
|
||||
(b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072),
|
||||
])
|
||||
def test_read_frame(input, masking_key, payload_length):
|
||||
bio = BytesIO(input)
|
||||
bio.safe_read = bio.read
|
||||
|
||||
header, frame, consumed_bytes = websocket.read_frame(bio)
|
||||
assert header == \
|
||||
Header(
|
||||
fin=False,
|
||||
rsv=RsvBits(rsv1=False, rsv2=False, rsv3=False),
|
||||
opcode=Opcode.TEXT,
|
||||
payload_len=payload_length,
|
||||
masking_key=masking_key,
|
||||
)
|
||||
assert frame == \
|
||||
Frame(
|
||||
opcode=Opcode.TEXT,
|
||||
payload=b'server-foobar',
|
||||
frame_finished=False,
|
||||
message_finished=False,
|
||||
)
|
||||
assert consumed_bytes == input
|
||||
|
||||
bio = BytesIO(input)
|
||||
bio.safe_read = bio.read
|
||||
header, frame, consumed_bytes = websocket.read_frame(bio, False)
|
||||
assert header is None
|
||||
assert frame is None
|
||||
assert consumed_bytes == input
|
||||
|
||||
|
||||
@mock.patch('os.urandom', return_value=b'pumpkinspumpkins')
|
||||
def test_client_handshake_headers(_):
|
||||
assert websocket.client_handshake_headers() == \
|
||||
http.Headers([
|
||||
(b'connection', b'upgrade'),
|
||||
(b'upgrade', b'websocket'),
|
||||
(b'sec-websocket-version', b'13'),
|
||||
(b'sec-websocket-key', b'cHVtcGtpbnNwdW1wa2lucw=='),
|
||||
])
|
||||
assert websocket.client_handshake_headers(b"13", b"foobar", b"foo", b"bar") == \
|
||||
http.Headers([
|
||||
(b'connection', b'upgrade'),
|
||||
(b'upgrade', b'websocket'),
|
||||
(b'sec-websocket-version', b'13'),
|
||||
(b'sec-websocket-key', b'foobar'),
|
||||
(b'sec-websocket-protocol', b'foo'),
|
||||
(b'sec-websocket-extensions', b'bar')
|
||||
])
|
||||
|
||||
|
||||
def test_server_handshake_headers():
|
||||
assert websocket.server_handshake_headers("foobar", "foo", "bar") == \
|
||||
http.Headers([
|
||||
(b'connection', b'upgrade'),
|
||||
(b'upgrade', b'websocket'),
|
||||
(b'sec-websocket-accept', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
|
||||
(b'sec-websocket-protocol', b'foo'),
|
||||
(b'sec-websocket-extensions', b'bar'),
|
||||
])
|
||||
from mitmproxy.net import websocket
|
||||
|
||||
|
||||
def test_check_handshake():
|
||||
@ -92,16 +19,6 @@ def test_check_handshake():
|
||||
})
|
||||
|
||||
|
||||
def test_create_server_nonce():
|
||||
assert websocket.create_server_nonce(b"foobar") == b"AzhRPA4TNwR6I/riJheN0TfR7+I="
|
||||
|
||||
|
||||
def test_check_client_version():
|
||||
assert not websocket.check_client_version({})
|
||||
assert not websocket.check_client_version({"sec-websocket-version": b"42"})
|
||||
assert websocket.check_client_version({"sec-websocket-version": b"13"})
|
||||
|
||||
|
||||
def test_get_extensions():
|
||||
assert websocket.get_extensions({}) is None
|
||||
assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo"
|
||||
|
@ -1,113 +0,0 @@
|
||||
import threading
|
||||
import queue
|
||||
import io
|
||||
import OpenSSL
|
||||
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.utils import data
|
||||
|
||||
cdata = data.Data(__name__)
|
||||
|
||||
|
||||
class _ServerThread(threading.Thread):
|
||||
|
||||
def __init__(self, server):
|
||||
self.server = server
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
self.server.serve_forever()
|
||||
|
||||
|
||||
class _TServer(tcp.TCPServer):
|
||||
|
||||
def __init__(self, ssl, q, handler_klass, addr, **kwargs):
|
||||
"""
|
||||
ssl: A dictionary of SSL parameters:
|
||||
|
||||
cert, key, request_client_cert, cipher_list,
|
||||
dhparams, v3_only
|
||||
"""
|
||||
tcp.TCPServer.__init__(self, addr)
|
||||
|
||||
if ssl is True:
|
||||
self.ssl = dict()
|
||||
elif isinstance(ssl, dict):
|
||||
self.ssl = ssl
|
||||
else:
|
||||
self.ssl = None
|
||||
|
||||
self.q = q
|
||||
self.handler_klass = handler_klass
|
||||
if self.handler_klass is not None:
|
||||
self.handler_klass.kwargs = kwargs
|
||||
self.last_handler = None
|
||||
|
||||
def handle_client_connection(self, request, client_address):
|
||||
h = self.handler_klass(request, client_address, self)
|
||||
self.last_handler = h
|
||||
if self.ssl is not None:
|
||||
cert = self.ssl.get(
|
||||
"cert",
|
||||
cdata.path("data/server.crt"))
|
||||
raw_key = self.ssl.get(
|
||||
"key",
|
||||
cdata.path("data/server.key"))
|
||||
with open(raw_key) as f:
|
||||
raw_key = f.read()
|
||||
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key)
|
||||
if self.ssl.get("v3_only", False):
|
||||
method = OpenSSL.SSL.SSLv3_METHOD
|
||||
options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1
|
||||
else:
|
||||
method = OpenSSL.SSL.SSLv23_METHOD
|
||||
options = None
|
||||
h.convert_to_tls(
|
||||
cert,
|
||||
key,
|
||||
method=method,
|
||||
options=options,
|
||||
handle_sni=getattr(h, "handle_sni", None),
|
||||
request_client_cert=self.ssl.get("request_client_cert", None),
|
||||
cipher_list=self.ssl.get("cipher_list", None),
|
||||
dhparams=self.ssl.get("dhparams", None),
|
||||
chain_file=self.ssl.get("chain_file", None),
|
||||
alpn_select=self.ssl.get("alpn_select", None)
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
||||
def handle_error(self, connection, client_address, fp=None):
|
||||
s = io.StringIO()
|
||||
tcp.TCPServer.handle_error(self, connection, client_address, s)
|
||||
self.q.put(s.getvalue())
|
||||
|
||||
|
||||
class ServerTestBase:
|
||||
ssl = None
|
||||
handler = None
|
||||
addr = ("127.0.0.1", 0)
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls, **kwargs):
|
||||
cls.q = queue.Queue()
|
||||
s = cls.makeserver(**kwargs)
|
||||
cls.port = s.address[1]
|
||||
cls.server = _ServerThread(s)
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def makeserver(cls, **kwargs):
|
||||
ssl = kwargs.pop('ssl', cls.ssl)
|
||||
return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
cls.server.server.shutdown()
|
||||
|
||||
def teardown(self):
|
||||
self.server.server.wait_for_silence()
|
||||
|
||||
@property
|
||||
def last_handler(self):
|
||||
return self.server.server.last_handler
|
@ -1,8 +1,8 @@
|
||||
from unittest import mock
|
||||
|
||||
from mitmproxy import controller
|
||||
from mitmproxy import io
|
||||
from mitmproxy import eventsequence
|
||||
from mitmproxy import io
|
||||
from mitmproxy.test import tflow
|
||||
from mitmproxy.test import tutils
|
||||
|
||||
@ -30,182 +30,3 @@ class MasterTest:
|
||||
fw = io.FlowWriter(f)
|
||||
t = tflow.tflow(resp=True)
|
||||
fw.add(t)
|
||||
|
||||
|
||||
# class TestMaster(taddons.RecordingMaster):
|
||||
|
||||
# def __init__(self, opts):
|
||||
# super().__init__(opts)
|
||||
# config = ProxyConfig(opts)
|
||||
# self.server = ProxyServer(config)
|
||||
|
||||
# def clear_addons(self, addons):
|
||||
# self.addons.clear()
|
||||
# self.state = TestState()
|
||||
# self.addons.add(self.state)
|
||||
# self.addons.add(*addons)
|
||||
# self.addons.trigger("configure", self.options.keys())
|
||||
# self.addons.trigger("running")
|
||||
|
||||
# def reset(self, addons):
|
||||
# self.clear_addons(addons)
|
||||
# self.clear()
|
||||
|
||||
|
||||
# class ProxyThread(threading.Thread):
|
||||
|
||||
# def __init__(self, masterclass, options):
|
||||
# threading.Thread.__init__(self)
|
||||
# self.masterclass = masterclass
|
||||
# self.options = options
|
||||
# self.tmaster = None
|
||||
# self.event_loop = None
|
||||
# controller.should_exit = False
|
||||
|
||||
# @property
|
||||
# def port(self):
|
||||
# return self.tmaster.server.address[1]
|
||||
|
||||
# @property
|
||||
# def tlog(self):
|
||||
# return self.tmaster.logs
|
||||
|
||||
# def shutdown(self):
|
||||
# self.tmaster.shutdown()
|
||||
|
||||
# def run(self):
|
||||
# self.event_loop = asyncio.new_event_loop()
|
||||
# asyncio.set_event_loop(self.event_loop)
|
||||
# self.tmaster = self.masterclass(self.options)
|
||||
# self.tmaster.addons.add(core.Core())
|
||||
# self.name = "ProxyThread (%s)" % human.format_address(self.tmaster.server.address)
|
||||
# self.tmaster.run()
|
||||
|
||||
# def set_addons(self, *addons):
|
||||
# self.tmaster.reset(addons)
|
||||
|
||||
# def start(self):
|
||||
# super().start()
|
||||
# while True:
|
||||
# if self.tmaster:
|
||||
# break
|
||||
# time.sleep(0.01)
|
||||
|
||||
|
||||
# class ProxyTestBase:
|
||||
# # Test Configuration
|
||||
# ssl = None
|
||||
# ssloptions = False
|
||||
# masterclass = TestMaster
|
||||
|
||||
# add_upstream_certs_to_client_chain = False
|
||||
|
||||
# @classmethod
|
||||
# def setup_class(cls):
|
||||
# cls.server = pathod.test.Daemon(
|
||||
# ssl=cls.ssl,
|
||||
# ssloptions=cls.ssloptions)
|
||||
# cls.server2 = pathod.test.Daemon(
|
||||
# ssl=cls.ssl,
|
||||
# ssloptions=cls.ssloptions)
|
||||
|
||||
# cls.options = cls.get_options()
|
||||
# cls.proxy = ProxyThread(cls.masterclass, cls.options)
|
||||
# cls.proxy.start()
|
||||
|
||||
# @classmethod
|
||||
# def teardown_class(cls):
|
||||
# # perf: we want to run tests in parallel
|
||||
# # should this ever cause an error, travis should catch it.
|
||||
# # shutil.rmtree(cls.confdir)
|
||||
# cls.proxy.shutdown()
|
||||
# cls.server.shutdown()
|
||||
# cls.server2.shutdown()
|
||||
|
||||
# def teardown(self):
|
||||
# try:
|
||||
# self.server.wait_for_silence()
|
||||
# except exceptions.Timeout:
|
||||
# # FIXME: Track down the Windows sync issues
|
||||
# if sys.platform != "win32":
|
||||
# raise
|
||||
|
||||
# def setup(self):
|
||||
# self.master.reset(self.addons())
|
||||
# self.server.clear_log()
|
||||
# self.server2.clear_log()
|
||||
|
||||
# @property
|
||||
# def master(self):
|
||||
# return self.proxy.tmaster
|
||||
|
||||
# @classmethod
|
||||
# def get_options(cls):
|
||||
# cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy")
|
||||
# return options.Options(
|
||||
# listen_port=0,
|
||||
# confdir=cls.confdir,
|
||||
# add_upstream_certs_to_client_chain=cls.add_upstream_certs_to_client_chain,
|
||||
# ssl_insecure=True,
|
||||
# )
|
||||
|
||||
# def set_addons(self, *addons):
|
||||
# self.proxy.set_addons(*addons)
|
||||
|
||||
# def addons(self):
|
||||
# """
|
||||
# Can be over-ridden to add a standard set of addons to tests.
|
||||
# """
|
||||
# return []
|
||||
|
||||
|
||||
# class LazyPathoc(pathod.pathoc.Pathoc):
|
||||
# def __init__(self, lazy_connect, *args, **kwargs):
|
||||
# self.lazy_connect = lazy_connect
|
||||
# pathod.pathoc.Pathoc.__init__(self, *args, **kwargs)
|
||||
|
||||
# def connect(self):
|
||||
# return pathod.pathoc.Pathoc.connect(self, self.lazy_connect)
|
||||
|
||||
|
||||
# class HTTPProxyTest(ProxyTestBase):
|
||||
|
||||
# def pathoc_raw(self):
|
||||
# return pathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), fp=None)
|
||||
|
||||
# def pathoc(self, sni=None):
|
||||
# """
|
||||
# Returns a connected Pathoc instance.
|
||||
# """
|
||||
# if self.ssl:
|
||||
# conn = ("127.0.0.1", self.server.port)
|
||||
# else:
|
||||
# conn = None
|
||||
# return LazyPathoc(
|
||||
# conn,
|
||||
# ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
|
||||
# )
|
||||
|
||||
# def pathod(self, spec, sni=None):
|
||||
# """
|
||||
# Constructs a pathod GET request, with the appropriate base and proxy.
|
||||
# """
|
||||
# p = self.pathoc(sni=sni)
|
||||
# if self.ssl:
|
||||
# q = "get:'/p/%s'" % spec
|
||||
# else:
|
||||
# q = f"get:'{self.server.urlbase}/p/{spec}'"
|
||||
# with p.connect():
|
||||
# return p.request(q)
|
||||
|
||||
# def app(self, page):
|
||||
# if self.ssl:
|
||||
# p = pathod.pathoc.Pathoc(
|
||||
# ("127.0.0.1", self.proxy.port), True, fp=None
|
||||
# )
|
||||
# with p.connect((self.master.options.onboarding_host, self.master.options.onbarding_port)):
|
||||
# return p.request("get:'%s'" % page)
|
||||
# else:
|
||||
# p = self.pathoc()
|
||||
# with p.connect():
|
||||
# return p.request(f"get:'http://{self.master.options.onboarding_host}{page}'")
|
||||
|
Loading…
Reference in New Issue
Block a user