diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py deleted file mode 100644 index d4cd0c018..000000000 --- a/mitmproxy/net/tcp.py +++ /dev/null @@ -1,683 +0,0 @@ -import os -import errno -import select -import socket -import sys -import threading -import time -import traceback - -from typing import Optional # noqa - -from mitmproxy.net import tls - -from OpenSSL import SSL - -from mitmproxy import certs -from mitmproxy import exceptions -from mitmproxy.coretypes import basethread - -socket_fileobject = socket.SocketIO - -# workaround for https://bugs.python.org/issue29515 -# Python 3.8 for Windows is missing a constant, fixed in 3.9 -IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) - - -class _FileLike: - BLOCKSIZE = 1024 * 32 - - def __init__(self, o): - self.o = o - self._log = None - self.first_byte_timestamp = None - - def set_descriptor(self, o): - self.o = o - - def __getattr__(self, attr): - return getattr(self.o, attr) - - def start_log(self): - """ - Starts or resets the log. - - This will store all bytes read or written. - """ - self._log = [] - - def stop_log(self): - """ - Stops the log. - """ - self._log = None - - def is_logging(self): - return self._log is not None - - def get_log(self): - """ - Returns the log as a string. - """ - if not self.is_logging(): - raise ValueError("Not logging!") - return b"".join(self._log) - - def add_log(self, v): - if self.is_logging(): - self._log.append(v) - - def reset_timestamps(self): - self.first_byte_timestamp = None - - -class Writer(_FileLike): - - def flush(self): - """ - May raise exceptions.TcpDisconnect - """ - if hasattr(self.o, "flush"): - try: - self.o.flush() - except OSError as v: - raise exceptions.TcpDisconnect(str(v)) - - def write(self, v): - """ - May raise exceptions.TcpDisconnect - """ - if v: - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - try: - if hasattr(self.o, "sendall"): - self.add_log(v) - return self.o.sendall(v) - else: - r = self.o.write(v) - self.add_log(v[:r]) - return r - except (SSL.Error, OSError) as e: - raise exceptions.TcpDisconnect(str(e)) - - -class Reader(_FileLike): - - def read(self, length): - """ - If length is -1, we read until connection closes. - """ - result = b'' - start = time.time() - while length == -1 or length > 0: - if length == -1 or length > self.BLOCKSIZE: - rlen = self.BLOCKSIZE - else: - rlen = length - try: - data = self.o.read(rlen) - except SSL.ZeroReturnError: - # TLS connection was shut down cleanly - break - except (SSL.WantWriteError, SSL.WantReadError): - # From the OpenSSL docs: - # If the underlying BIO is non-blocking, SSL_read() will also return when the - # underlying BIO could not satisfy the needs of SSL_read() to continue the - # operation. In this case a call to SSL_get_error with the return value of - # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. - # 300 is OpenSSL default timeout - timeout = self.o.gettimeout() or 300 - if (time.time() - start) < timeout: - time.sleep(0.1) - continue - else: - raise exceptions.TcpTimeout() - except socket.timeout: - raise exceptions.TcpTimeout() - except OSError as e: - raise exceptions.TcpDisconnect(str(e)) - except SSL.SysCallError as e: - if e.args == (-1, 'Unexpected EOF'): - break - raise exceptions.TlsException(str(e)) - except SSL.Error as e: - raise exceptions.TlsException(str(e)) - self.first_byte_timestamp = self.first_byte_timestamp or time.time() - if not data: - break - result += data - if length != -1: - length -= len(data) - self.add_log(result) - return result - - def readline(self, size=None): - result = b'' - bytes_read = 0 - while True: - if size is not None and bytes_read >= size: - break - ch = self.read(1) - bytes_read += 1 - if not ch: - break - else: - result += ch - if ch == b'\n': - break - return result - - def safe_read(self, length): - """ - Like .read, but is guaranteed to either return length bytes, or - raise an exception. - """ - result = self.read(length) - if length != -1 and len(result) != length: - if not result: - raise exceptions.TcpDisconnect() - else: - raise exceptions.TcpReadIncomplete( - "Expected {} bytes, got {}".format(length, len(result)) - ) - return result - - def peek(self, length): - """ - Tries to peek into the underlying file object. - - Returns: - Up to the next N bytes if peeking is successful. - - Raises: - exceptions.TcpException if there was an error with the socket - TlsException if there was an error with pyOpenSSL. - NotImplementedError if the underlying file object is not a [pyOpenSSL] socket - """ - if isinstance(self.o, socket_fileobject): - try: - return self.o._sock.recv(length, socket.MSG_PEEK) - except OSError as e: - raise exceptions.TcpException(repr(e)) - elif isinstance(self.o, SSL.Connection): - try: - return self.o.recv(length, socket.MSG_PEEK) - except SSL.Error as e: - raise exceptions.TlsException(str(e)) - else: - raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") - - -def ssl_read_select(rlist, timeout): - """ - This is a wrapper around select.select() which also works for SSL.Connections - by taking ssl_connection.pending() into account. - - Caveats: - If .pending() > 0 for any of the connections in rlist, we avoid the select syscall - and **will not include any other connections which may or may not be ready**. - - Args: - rlist: wait until ready for reading - - Returns: - subset of rlist which is ready for reading. - """ - return [ - conn for conn in rlist - if isinstance(conn, SSL.Connection) and conn.pending() > 0 - ] or select.select(rlist, (), (), timeout)[0] - - -def close_socket(sock): - """ - Does a hard close of a socket, without emitting a RST. - """ - try: - # We already indicate that we close our end. - # may raise "Transport endpoint is not connected" on Linux - sock.shutdown(socket.SHUT_WR) - - # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending - # readable data could lead to an immediate RST being sent (which is the - # case on Windows). - # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html - # - # This in turn results in the following issue: If we send an error page - # to the client and then close the socket, the RST may be received by - # the client before the error page and the users sees a connection - # error rather than the error page. Thus, we try to empty the read - # buffer on Windows first. (see - # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) - # - - if os.name == "nt": # pragma: no cover - # We cannot rely on the shutdown()-followed-by-read()-eof technique - # proposed by the page above: Some remote machines just don't send - # a TCP FIN, which would leave us in the unfortunate situation that - # recv() would block infinitely. As a workaround, we set a timeout - # here even if we are in blocking mode. - sock.settimeout(sock.gettimeout() or 20) - - # limit at a megabyte so that we don't read infinitely - for _ in range(1024 ** 3 // 4096): - # may raise a timeout/disconnect exception. - if not sock.recv(4096): - break - - # Now we can close the other half as well. - sock.shutdown(socket.SHUT_RD) - - except OSError: - pass - - sock.close() - - -class _Connection: - - rbufsize = -1 - wbufsize = -1 - - def _makefile(self): - """ - Set up .rfile and .wfile attributes from .connection - """ - # Ideally, we would use the Buffered IO in Python 3 by default. - # Unfortunately, the implementation of .peek() is broken for n>1 bytes, - # as it may just return what's left in the buffer and not all the bytes we want. - # As a workaround, we just use unbuffered sockets directly. - # https://mail.python.org/pipermail/python-dev/2009-June/089986.html - self.rfile = Reader(socket.SocketIO(self.connection, "rb")) - self.wfile = Writer(socket.SocketIO(self.connection, "wb")) - - def __init__(self, connection): - if connection: - self.connection = connection - self.ip_address = connection.getpeername() - self._makefile() - else: - self.connection = None - self.ip_address = None - self.rfile = None - self.wfile = None - - self.tls_established = False - self.finished = False - - def get_current_cipher(self): - if not self.tls_established: - return None - - name = self.connection.get_cipher_name() - bits = self.connection.get_cipher_bits() - version = self.connection.get_cipher_version() - return name, bits, version - - def finish(self): - self.finished = True - # If we have an SSL connection, wfile.close == connection.close - # (We call _FileLike.set_descriptor(conn)) - # Closing the socket is not our task, therefore we don't call close - # then. - if not isinstance(self.connection, SSL.Connection): - if not getattr(self.wfile, "closed", False): - try: - self.wfile.flush() - self.wfile.close() - except exceptions.TcpDisconnect: - pass - - self.rfile.close() - else: - try: - self.connection.shutdown() - except SSL.Error: - pass - - -class ConnectionCloser: - def __init__(self, conn): - self.conn = conn - self._canceled = False - - def pop(self): - """ - Cancel the current closer, and return a fresh one. - """ - self._canceled = True - return ConnectionCloser(self.conn) - - def __enter__(self): - return self - - def __exit__(self, *args): - if not self._canceled: - self.conn.close() - - -class TCPClient(_Connection): - - def __init__(self, address, source_address=None, spoof_source_address=None): - super().__init__(None) - self.address = address - self.source_address = source_address - self.cert = None - self.server_certs = [] - self.sni = None - self.spoof_source_address = spoof_source_address - - @property - def ssl_verification_error(self) -> Optional[exceptions.InvalidCertificateException]: - return getattr(self.connection, "cert_error", None) - - def close(self): - # Make sure to close the real socket, not the SSL proxy. - # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, - # it tries to renegotiate... - if self.connection: - if isinstance(self.connection, SSL.Connection): - close_socket(self.connection._socket) - else: - close_socket(self.connection) - - def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs): - context = tls.create_client_context( - alpn_protos=alpn_protos, - sni=sni, - **sslctx_kwargs - ) - self.connection = SSL.Connection(context, self.connection) - if sni: - self.sni = sni - self.connection.set_tlsext_host_name(sni.encode("idna")) - self.connection.set_connect_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - if self.ssl_verification_error: - raise self.ssl_verification_error - else: - raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - - self.cert = certs.Cert(self.connection.get_peer_certificate()) - - # Keep all server certificates in a list - for i in self.connection.get_peer_cert_chain(): - self.server_certs.append(certs.Cert(i)) - - self.tls_established = True - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def makesocket(self, family, type, proto): - # some parties (cuckoo sandbox) need to hook this - return socket.socket(family, type, proto) - - def create_connection(self, timeout=None): - # Based on the official socket.create_connection implementation of Python 3.6. - # https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py - - err = None - for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = None - try: - sock = self.makesocket(af, socktype, proto) - if timeout: - sock.settimeout(timeout) - if self.source_address: - sock.bind(self.source_address) - if self.spoof_source_address: - try: - if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT): - sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover - except Exception as e: - # socket.IP_TRANSPARENT might not be available on every OS and Python version - if sock is not None: - sock.close() - raise exceptions.TcpException( - "Failed to spoof the source address: " + str(e) - ) - sock.connect(sa) - return sock - - except OSError as _: - err = _ - if sock is not None: - sock.close() - - if err is not None: - raise err - else: - raise OSError("getaddrinfo returns an empty list") # pragma: no cover - - def connect(self): - try: - connection = self.create_connection() - except OSError as err: - raise exceptions.TcpException( - 'Error connecting to "%s": %s' % - (self.address[0], err) - ) - self.connection = connection - self.source_address = connection.getsockname() - self.ip_address = connection.getpeername() - self._makefile() - return ConnectionCloser(self) - - def settimeout(self, n): - self.connection.settimeout(n) - - def gettimeout(self): - return self.connection.gettimeout() - - def get_alpn_proto_negotiated(self): - if self.tls_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class BaseHandler(_Connection): - - """ - The instantiator is expected to call the handle() and finish() methods. - """ - - def __init__(self, connection, address, server): - super().__init__(connection) - self.address = address - self.server = server - self.clientcert = None - - def convert_to_tls(self, cert, key, **sslctx_kwargs): - """ - Convert connection to SSL. - For a list of parameters, see tls.create_server_context(...) - """ - - context = tls.create_server_context( - cert=cert, - key=key, - **sslctx_kwargs) - self.connection = SSL.Connection(context, self.connection) - self.connection.set_accept_state() - try: - self.connection.do_handshake() - except SSL.Error as v: - raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) - self.tls_established = True - cert = self.connection.get_peer_certificate() - if cert: - self.clientcert = certs.Cert(cert) - self.rfile.set_descriptor(self.connection) - self.wfile.set_descriptor(self.connection) - - def handle(self): # pragma: no cover - raise NotImplementedError - - def settimeout(self, n): - self.connection.settimeout(n) - - def get_alpn_proto_negotiated(self): - if self.tls_established: - return self.connection.get_alpn_proto_negotiated() - else: - return b"" - - -class Counter: - def __init__(self): - self._count = 0 - self._lock = threading.Lock() - - @property - def count(self): - with self._lock: - return self._count - - def __enter__(self): - with self._lock: - self._count += 1 - - def __exit__(self, *args): - with self._lock: - self._count -= 1 - - -class TCPServer: - - def __init__(self, address): - self.address = address - self.__is_shut_down = threading.Event() - self.__is_shut_down.set() - self.__shutdown_request = False - - if self.address[0] == 'localhost': - raise OSError("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.") - - self.socket = None - - try: - # First try to bind an IPv6 socket, attempting to enable IPv4 support if the OS supports it. - # This allows us to accept connections for ::1 and 127.0.0.1 on the same socket. - # Only works if self.address == "" - self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - self.socket.bind(self.address) - except OSError: - if self.socket: - self.socket.close() - self.socket = None - - if not self.socket: - try: - # Binding to an IPv6 + IPv4 socket failed, lets fall back to IPv4 only. - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self.socket.bind(self.address) - except OSError: - if self.socket: - self.socket.close() - self.socket = None - - if not self.socket: - # Binding to an IPv4 only socket failed, lets fall back to IPv6 only. - self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self.socket.bind(self.address) - - self.address = self.socket.getsockname() - self.socket.listen() - self.handler_counter = Counter() - - def connection_thread(self, connection, client_address): - with self.handler_counter: - try: - self.handle_client_connection(connection, client_address) - except OSError as e: # pragma: no cover - # This catches situations where the underlying connection is - # closed beneath us. Syscalls on the connection object at this - # point returns EINVAL. If this happens, we close the socket and - # move on. - if not e.errno == errno.EINVAL: - raise - except: - self.handle_error(connection, client_address) - finally: - close_socket(connection) - - def serve_forever(self, poll_interval=0.1): - self.__is_shut_down.clear() - try: - while not self.__shutdown_request: - r, w_, e_ = select.select([self.socket], [], [], poll_interval) - if self.socket in r: - connection, client_address = self.socket.accept() - t = basethread.BaseThread( - "TCPConnectionHandler ({}: {}:{} -> {}:{})".format( - self.__class__.__name__, - client_address[0], - client_address[1], - self.address[0], - self.address[1], - ), - target=self.connection_thread, - args=(connection, client_address), - ) - t.setDaemon(1) - try: - t.start() - except threading.ThreadError: - self.handle_error(connection, client_address) - connection.close() - finally: - self.__shutdown_request = False - self.__is_shut_down.set() - - def shutdown(self): - self.__shutdown_request = True - self.__is_shut_down.wait() - self.socket.close() - self.handle_shutdown() - - def handle_error(self, connection_, client_address, fp=sys.stderr): - """ - Called when handle_client_connection raises an exception. - """ - # If a thread has persisted after interpreter exit, the module might be - # none. - if traceback: - exc = str(traceback.format_exc()) - print('-' * 40, file=fp) - print( - "Error in processing of request from %s" % repr(client_address), file=fp) - print(exc, file=fp) - print('-' * 40, file=fp) - - def handle_client_connection(self, conn, client_address): # pragma: no cover - """ - Called after client connection. - """ - raise NotImplementedError - - def handle_shutdown(self): - """ - Called after server shutdown. - """ - - def wait_for_silence(self, timeout=5): - start = time.time() - while 1: - if time.time() - start >= timeout: - raise exceptions.Timeout( - "%s service threads still alive" % - self.handler_counter.count - ) - if self.handler_counter.count == 0: - return diff --git a/mitmproxy/net/websocket.py b/mitmproxy/net/websocket.py index 930b33d4c..4758db0ca 100644 --- a/mitmproxy/net/websocket.py +++ b/mitmproxy/net/websocket.py @@ -4,142 +4,14 @@ Spec: https://tools.ietf.org/html/rfc6455 """ -import base64 -import hashlib -import os -import struct - -from wsproto.utilities import ACCEPT_GUID -from wsproto.handshake import WEBSOCKET_VERSION -from wsproto.frame_protocol import RsvBits, Header, Frame, XorMaskerSimple, XorMaskerNull - -from mitmproxy.net import http -from mitmproxy.utils import bits, strutils - - -def read_frame(rfile, parse=True): - """ - Reads a full WebSocket frame from a file-like object. - - Returns a parsed frame header, parsed frame, and the consumed bytes. - """ - - consumed_bytes = b'' - - def consume(len): - nonlocal consumed_bytes - d = rfile.safe_read(len) - consumed_bytes += d - return d - - first_byte, second_byte = consume(2) - fin = bits.getbit(first_byte, 7) - rsv1 = bits.getbit(first_byte, 6) - rsv2 = bits.getbit(first_byte, 5) - rsv3 = bits.getbit(first_byte, 4) - opcode = first_byte & 0xF - mask_bit = bits.getbit(second_byte, 7) - length_code = second_byte & 0x7F - - # payload_len > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_len = length_code - elif length_code == 126: - payload_len, = struct.unpack("!H", consume(2)) - else: # length_code == 127: - payload_len, = struct.unpack("!Q", consume(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = consume(4) - masker = XorMaskerSimple(masking_key) - else: - masking_key = None - masker = XorMaskerNull() - - masked_payload = consume(payload_len) - - if parse: - header = Header( - fin=fin, - rsv=RsvBits(rsv1, rsv2, rsv3), - opcode=opcode, - payload_len=payload_len, - masking_key=masking_key, - ) - frame = Frame( - opcode=opcode, - payload=masker.process(masked_payload), - frame_finished=fin, - message_finished=fin - ) - else: - header = None - frame = None - - return header, frame, consumed_bytes - - -def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of http.Headers - """ - if version is None: - version = WEBSOCKET_VERSION - if key is None: - key = base64.b64encode(os.urandom(16)).decode('ascii') - h = http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_version=version, - sec_websocket_key=key, - ) - if protocol is not None: - h['sec-websocket-protocol'] = protocol - if extensions is not None: - h['sec-websocket-extensions'] = extensions - return h - - -def server_handshake_headers(client_key, protocol=None, extensions=None): - """ - The server response is a valid HTTP 101 response. - - Returns an instance of http.Headers - """ - h = http.Headers( - connection="upgrade", - upgrade="websocket", - sec_websocket_accept=create_server_nonce(client_key), - ) - if protocol is not None: - h['sec-websocket-protocol'] = protocol - if extensions is not None: - h['sec-websocket-extensions'] = extensions - return h - - def check_handshake(headers): return ( - "upgrade" in headers.get("connection", "").lower() and - headers.get("upgrade", "").lower() == "websocket" and - (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) + "upgrade" in headers.get("connection", "").lower() and + headers.get("upgrade", "").lower() == "websocket" and + (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) ) -def create_server_nonce(client_nonce): - return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + ACCEPT_GUID).digest()) - - -def check_client_version(headers): - return headers.get("sec-websocket-version", "") == WEBSOCKET_VERSION - - def get_extensions(headers): return headers.get("sec-websocket-extensions", None) diff --git a/mitmproxy/test/tutils.py b/mitmproxy/test/tutils.py index 79751060e..6ace1056e 100644 --- a/mitmproxy/test/tutils.py +++ b/mitmproxy/test/tutils.py @@ -1,17 +1,6 @@ -from io import BytesIO - -from mitmproxy.net import tcp from mitmproxy.net import http -def treader(bytes): - """ - Construct a tcp.Read object from bytes. - """ - fp = BytesIO(bytes) - return tcp.Reader(fp) - - def treq(**kwargs) -> http.Request: """ Returns: diff --git a/test/mitmproxy/net/test_socks.py b/test/mitmproxy/net/test_socks.py index 65fd85855..c6e2d1530 100644 --- a/test/mitmproxy/net/test_socks.py +++ b/test/mitmproxy/net/test_socks.py @@ -1,11 +1,21 @@ import ipaddress from io import BytesIO + import pytest from mitmproxy.net import socks from mitmproxy.test import tutils +# this is a temporary placeholder here, we remove the file-based API when we transition socks proxying to sans-io. +class tutils: # noqa + @staticmethod + def treader(data: bytes): + io = BytesIO(data) + io.safe_read = io.read + return io + + def test_client_greeting(): raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") out = BytesIO() diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py deleted file mode 100644 index b96a88811..000000000 --- a/test/mitmproxy/net/test_tcp.py +++ /dev/null @@ -1,813 +0,0 @@ -from io import BytesIO -import re -import queue -import time -import socket -import random -import threading -import pytest -from unittest import mock -from OpenSSL import SSL - -from mitmproxy import certs -from mitmproxy.net import tcp -from mitmproxy import exceptions -from mitmproxy.utils import data -from ...conftest import skip_no_ipv6 - -from . import tservers - - -cdata = data.Data(__name__) - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - v = self.rfile.readline() - self.wfile.write(v) - self.wfile.flush() - - -class ClientCipherListHandler(tcp.BaseHandler): - sni = None - - def handle(self): - self.wfile.write(f"{self.connection.get_cipher_list()}\n".encode()) - self.wfile.flush() - - -class HangHandler(tcp.BaseHandler): - - def handle(self): - # Hang as long as the client connection is alive - while True: - try: - self.connection.setblocking(0) - ret = self.connection.recv(1) - # Client connection is dead... - if ret == "" or ret == b"": - return - except OSError: - pass - except SSL.WantReadError: - pass - except Exception: - return - time.sleep(0.1) - - -class ALPNHandler(tcp.BaseHandler): - sni = None - - def handle(self): - alp = self.get_alpn_proto_negotiated() - if alp: - self.wfile.write(alp) - else: - self.wfile.write(b"NONE") - self.wfile.flush() - - -class TestServer(tservers.ServerTestBase): - handler = EchoHandler - - def test_echo(self): - testval = b"echo!\n" - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_thread_start_error(self): - with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m: - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - assert not c.rfile.read(1) - assert m.called - assert "nonewthread" in self.q.get_nowait() - self.test_echo() - - -class TestServerBind(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - - def handle(self): - # We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1. - # Those still appear as "127.0.0.1" in the table, so we need to strip the prefix. - peername = self.connection.getpeername() - address = re.sub(r"^::ffff:(?=\d+.\d+.\d+.\d+$)", "", peername[0]) - port = peername[1] - - self.wfile.write(str((address, port)).encode()) - self.wfile.flush() - - def test_bind(self): - """ Test to bind to a given random port. Try again if the random port turned out to be blocked. """ - for i in range(20): - random_port = random.randrange(1024, 65535) - try: - c = tcp.TCPClient( - ("127.0.0.1", self.port), source_address=( - "127.0.0.1", random_port)) - with c.connect(): - assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode() - return - except exceptions.TcpException: # port probably already in use - pass - - -@skip_no_ipv6 -class TestServerIPv6(tservers.ServerTestBase): - handler = EchoHandler - addr = ("::1", 0) - - def test_echo(self): - testval = b"echo!\n" - c = tcp.TCPClient(("::1", self.port)) - with c.connect(): - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - -class TestEcho(tservers.ServerTestBase): - handler = EchoHandler - - def test_echo(self): - testval = b"echo!\n" - c = tcp.TCPClient(("localhost", self.port)) - with c.connect(): - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - -class HardDisconnectHandler(tcp.BaseHandler): - - def handle(self): - self.connection.close() - - -class TestFinishFail(tservers.ServerTestBase): - - """ - This tests a difficult-to-trigger exception in the .finish() method of - the handler. - """ - handler = EchoHandler - - def test_disconnect_in_finish(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.wfile.write(b"foo\n") - c.wfile.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - c.finish() - - -class TestServerSSL(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - cipher_list="AES256-SHA", - chain_file=cdata.path("data/server.crt") - ) - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL) - testval = b"echo!\n" - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_get_current_cipher(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - assert not c.get_current_cipher() - c.convert_to_tls(sni="foo.com") - ret = c.get_current_cipher() - assert ret - assert "AES" in ret[0] - - -class TestSSLv3Only(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - request_client_cert=False, - v3_only=True - ) - - def test_failure(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.TlsException): - c.convert_to_tls(sni="foo.com") - - -class TestInvalidTrustFile(tservers.ServerTestBase): - def test_invalid_trust_file_should_fail(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.TlsException): - c.convert_to_tls( - sni="example.mitmproxy.org", - verify=SSL.VERIFY_PEER, - ca_pemfile=cdata.path("data/verificationcerts/generate.py") - ) - - -class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): - handler = EchoHandler - - ssl = dict( - cert=cdata.path("data/verificationcerts/self-signed.crt"), - key=cdata.path("data/verificationcerts/self-signed.key") - ) - - def test_mode_default_should_pass(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls() - - # Verification errors should be saved even if connection isn't aborted - # aborted - assert c.ssl_verification_error - - testval = b"echo!\n" - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_mode_none_should_pass(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(verify=SSL.VERIFY_NONE) - - # Verification errors should be saved even if connection isn't aborted - assert c.ssl_verification_error - - testval = b"echo!\n" - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_mode_strict_should_fail(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_tls( - sni="example.mitmproxy.org", - verify=SSL.VERIFY_PEER, - ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt") - ) - - assert c.ssl_verification_error - - # Unknown issuing certificate authority for first certificate - assert "errno: 18" in str(c.ssl_verification_error) - assert "depth: 0" in str(c.ssl_verification_error) - - -class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): - handler = EchoHandler - - ssl = dict( - cert=cdata.path("data/verificationcerts/trusted-leaf.crt"), - key=cdata.path("data/verificationcerts/trusted-leaf.key") - ) - - def test_should_fail_without_sni(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.TlsException): - c.convert_to_tls( - verify=SSL.VERIFY_PEER, - ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt") - ) - - def test_mode_none_should_pass_without_sni(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls( - verify=SSL.VERIFY_NONE, - ca_path=cdata.path("data/verificationcerts/") - ) - - assert "Cannot validate hostname, SNI missing." in str(c.ssl_verification_error) - - def test_should_fail(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.InvalidCertificateException): - c.convert_to_tls( - sni="mitmproxy.org", - verify=SSL.VERIFY_PEER, - ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt") - ) - assert c.ssl_verification_error - - -class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): - handler = EchoHandler - - ssl = dict( - cert=cdata.path("data/verificationcerts/trusted-leaf.crt"), - key=cdata.path("data/verificationcerts/trusted-leaf.key") - ) - - def test_mode_strict_w_pemfile_should_pass(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls( - sni="example.mitmproxy.org", - verify=SSL.VERIFY_PEER, - ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt") - ) - - assert c.ssl_verification_error is None - - testval = b"echo!\n" - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - def test_mode_strict_w_confdir_should_pass(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls( - sni="example.mitmproxy.org", - verify=SSL.VERIFY_PEER, - ca_path=cdata.path("data/verificationcerts/") - ) - - assert c.ssl_verification_error is None - - testval = b"echo!\n" - c.wfile.write(testval) - c.wfile.flush() - assert c.rfile.readline() == testval - - -class TestSSLClientCert(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - sni = None - - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write(b"%d\n" % self.clientcert.serial) - self.wfile.flush() - - ssl = dict( - request_client_cert=True, - v3_only=False - ) - - def test_clientcert(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls( - cert=cdata.path("data/clientcert/client.pem")) - assert c.rfile.readline().strip() == b"1" - - def test_clientcert_err(self, tdata): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(exceptions.TlsException): - c.convert_to_tls(cert=cdata.path("data/clientcert/make")) - - -class TestSNI(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - sni = None - - def handle_sni(self, connection): - self.sni = connection.get_servername() - - def handle(self): - self.wfile.write(self.sni) - self.wfile.flush() - - ssl = True - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(sni="foo.com") - assert c.sni == "foo.com" - assert c.rfile.readline() == b"foo.com" - - def test_idn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(sni="mitmproxyäöüß.example.com") - assert c.tls_established - assert "doesn't match" not in str(c.ssl_verification_error) - - -class TestServerCipherList(tservers.ServerTestBase): - handler = ClientCipherListHandler - ssl = dict( - cipher_list='AES256-GCM-SHA384' - ) - - @pytest.mark.xfail - def test_echo(self): - # Not working for OpenSSL 1.1.1, see - # https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196 - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(sni="foo.com") - expected = b"['TLS_AES_256_GCM_SHA384']" - assert c.rfile.readline() == expected - - -class TestServerCurrentCipher(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - sni = None - - def handle(self): - self.wfile.write(str(self.get_current_cipher()).encode()) - self.wfile.flush() - - ssl = dict( - cipher_list='AES256-GCM-SHA384' - ) - - @pytest.mark.xfail - def test_echo(self): - # Not working for OpenSSL 1.1.1, see - # https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196 - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(sni="foo.com") - assert b'AES256-GCM-SHA384' in c.rfile.readline() - - -class TestServerCipherListError(tservers.ServerTestBase): - handler = ClientCipherListHandler - ssl = dict( - cipher_list=b'bogus' - ) - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(Exception, match="handshake error"): - c.convert_to_tls(sni="foo.com") - - -class TestClientCipherListError(tservers.ServerTestBase): - handler = ClientCipherListHandler - ssl = dict( - cipher_list='RC4-SHA' - ) - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - with pytest.raises(Exception, match="cipher specification"): - c.convert_to_tls(sni="foo.com", cipher_list="bogus") - - -class TestSSLDisconnect(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - - def handle(self): - self.finish() - - ssl = True - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls() - # Exercise SSL.ZeroReturnError - c.rfile.read(10) - c.close() - with pytest.raises(exceptions.TcpDisconnect): - c.wfile.write(b"foo") - with pytest.raises(queue.Empty): - self.q.get_nowait() - - -class TestSSLHardDisconnect(tservers.ServerTestBase): - handler = HardDisconnectHandler - ssl = True - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls() - # Exercise SSL.SysCallError - c.rfile.read(10) - c.close() - with pytest.raises(exceptions.TcpDisconnect): - c.wfile.write(b"foo") - - -class TestDisconnect(tservers.ServerTestBase): - - def test_echo(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.rfile.read(10) - c.wfile.write(b"foo") - c.close() - c.close() - - -class TestServerTimeOut(tservers.ServerTestBase): - - class handler(tcp.BaseHandler): - - def handle(self): - self.timeout = False - self.settimeout(0.01) - try: - self.rfile.read(10) - except exceptions.TcpTimeout: - self.timeout = True - - def test_timeout(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - time.sleep(0.3) - assert self.last_handler.timeout - - -class TestTimeOut(tservers.ServerTestBase): - handler = HangHandler - - def test_timeout(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.settimeout(0.1) - assert c.gettimeout() == 0.1 - with pytest.raises(exceptions.TcpTimeout): - c.rfile.read(10) - - -class TestALPNClient(tservers.ServerTestBase): - handler = ALPNHandler - ssl = dict( - alpn_select=b"bar" - ) - - @pytest.mark.parametrize('alpn_protos, expected_negotiated, expected_response', [ - ([b"foo", b"bar", b"fasel"], b'bar', b'bar'), - ([], b'', b'NONE'), - (None, b'', b'NONE'), - ]) - def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(alpn_protos=alpn_protos) - assert c.get_alpn_proto_negotiated() == expected_negotiated - assert c.rfile.readline().strip() == expected_response - - -class TestNoSSLNoALPNClient(tservers.ServerTestBase): - handler = ALPNHandler - - def test_no_ssl_no_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - assert c.get_alpn_proto_negotiated() == b"" - assert c.rfile.readline().strip() == b"NONE" - - -class TestSSLTimeOut(tservers.ServerTestBase): - handler = HangHandler - ssl = True - - def test_timeout_client(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls() - c.settimeout(0.1) - with pytest.raises(exceptions.TcpTimeout): - c.rfile.read(10) - - -class TestDHParams(tservers.ServerTestBase): - handler = HangHandler - ssl = dict( - dhparams=certs.CertStore.load_dhparam( - cdata.path("data/dhparam.pem"), - ), - cipher_list="DHE-RSA-AES256-SHA" - ) - - def test_dhparams(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_tls(method=SSL.TLSv1_2_METHOD) - ret = c.get_current_cipher() - assert ret[0] == "DHE-RSA-AES256-SHA" - - -class TestTCPClient(tservers.ServerTestBase): - - def test_conerr(self): - c = tcp.TCPClient(("127.0.0.1", 0)) - with pytest.raises(exceptions.TcpException, match="Error connecting"): - c.connect() - - def test_timeout(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.create_connection(timeout=20) as conn: - assert conn.gettimeout() == 20 - - def test_spoof_address(self): - c = tcp.TCPClient(("127.0.0.1", self.port), spoof_source_address=("127.0.0.1", 0)) - with pytest.raises(exceptions.TcpException, match="Failed to spoof"): - c.connect() - - -class TestTCPServer: - - def test_binderr(self): - with pytest.raises(socket.error, match="prohibited"): - tcp.TCPServer(("localhost", 8080)) - - def test_wait_for_silence(self): - s = tcp.TCPServer(("127.0.0.1", 0)) - with s.handler_counter: - with pytest.raises(exceptions.Timeout): - s.wait_for_silence() - s.shutdown() - - -class TestFileLike: - - def test_blocksize(self): - s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz") - s = tcp.Reader(s) - s.BLOCKSIZE = 2 - assert s.read(1) == b"1" - assert s.read(2) == b"23" - assert s.read(3) == b"456" - assert s.read(4) == b"7890" - d = s.read(-1) - assert d.startswith(b"abc") and d.endswith(b"xyz") - - def test_wrap(self): - s = BytesIO(b"foobar\nfoobar") - s.flush() - s = tcp.Reader(s) - assert s.readline() == b"foobar\n" - assert s.readline() == b"foobar" - # Test __getattr__ - assert s.isatty - - def test_limit(self): - s = BytesIO(b"foobar\nfoobar") - s = tcp.Reader(s) - assert s.readline(3) == b"foo" - - def test_limitless(self): - s = BytesIO(b"f" * (50 * 1024)) - s = tcp.Reader(s) - ret = s.read(-1) - assert len(ret) == 50 * 1024 - - def test_readlog(self): - s = BytesIO(b"foobar\nfoobar") - s = tcp.Reader(s) - assert not s.is_logging() - s.start_log() - assert s.is_logging() - s.readline() - assert s.get_log() == b"foobar\n" - s.read(1) - assert s.get_log() == b"foobar\nf" - s.start_log() - assert s.get_log() == b"" - s.read(1) - assert s.get_log() == b"o" - s.stop_log() - with pytest.raises(ValueError): - s.get_log() - - def test_writelog(self): - s = BytesIO() - s = tcp.Writer(s) - s.start_log() - assert s.is_logging() - s.write(b"x") - assert s.get_log() == b"x" - s.write(b"x") - assert s.get_log() == b"xx" - - def test_writer_flush_error(self): - s = BytesIO() - s = tcp.Writer(s) - o = mock.MagicMock() - o.flush = mock.MagicMock(side_effect=socket.error) - s.o = o - with pytest.raises(exceptions.TcpDisconnect): - s.flush() - - def test_reader_read_error(self): - s = BytesIO(b"foobar\nfoobar") - s = tcp.Reader(s) - o = mock.MagicMock() - o.read = mock.MagicMock(side_effect=socket.error) - s.o = o - with pytest.raises(exceptions.TcpDisconnect): - s.read(10) - - def test_reset_timestamps(self): - s = BytesIO(b"foobar\nfoobar") - s = tcp.Reader(s) - s.first_byte_timestamp = 500 - s.reset_timestamps() - assert not s.first_byte_timestamp - - def test_first_byte_timestamp_updated_on_read(self): - s = BytesIO(b"foobar\nfoobar") - s = tcp.Reader(s) - s.read(1) - assert s.first_byte_timestamp - expected = s.first_byte_timestamp - s.read(5) - assert s.first_byte_timestamp == expected - - def test_first_byte_timestamp_updated_on_readline(self): - s = BytesIO(b"foobar\nfoobar\nfoobar") - s = tcp.Reader(s) - s.readline() - assert s.first_byte_timestamp - expected = s.first_byte_timestamp - s.readline() - assert s.first_byte_timestamp == expected - - def test_read_ssl_error(self): - s = mock.MagicMock() - s.read = mock.MagicMock(side_effect=SSL.Error()) - s = tcp.Reader(s) - with pytest.raises(exceptions.TlsException): - s.read(1) - - def test_read_syscall_ssl_error(self): - s = mock.MagicMock() - s.read = mock.MagicMock(side_effect=SSL.SysCallError()) - s = tcp.Reader(s) - with pytest.raises(exceptions.TlsException): - s.read(1) - - def test_reader_readline_disconnect(self): - o = mock.MagicMock() - o.read = mock.MagicMock(side_effect=socket.error) - s = tcp.Reader(o) - with pytest.raises(exceptions.TcpDisconnect): - s.readline(10) - - def test_reader_incomplete_error(self): - s = BytesIO(b"foobar") - s = tcp.Reader(s) - with pytest.raises(exceptions.TcpReadIncomplete): - s.safe_read(10) - - -class TestPeek(tservers.ServerTestBase): - handler = EchoHandler - - def _connect(self, c): - return c.connect() - - def test_peek(self): - testval = b"peek!\n" - c = tcp.TCPClient(("127.0.0.1", self.port)) - with self._connect(c): - c.wfile.write(testval) - c.wfile.flush() - - assert c.rfile.peek(4) == b"peek" - assert c.rfile.peek(6) == b"peek!\n" - assert c.rfile.readline() == testval - - c.close() - with pytest.raises(exceptions.NetlibException): - c.rfile.peek(1) - - -class TestPeekSSL(TestPeek): - ssl = True - - def _connect(self, c): - with c.connect() as conn: - c.convert_to_tls() - return conn.pop() diff --git a/test/mitmproxy/net/test_tls.py b/test/mitmproxy/net/test_tls.py index 951573fe0..d0573ce56 100644 --- a/test/mitmproxy/net/test_tls.py +++ b/test/mitmproxy/net/test_tls.py @@ -4,9 +4,6 @@ import pytest from mitmproxy import exceptions from mitmproxy.net import tls -from mitmproxy.net.tcp import TCPClient -from test.mitmproxy.net.test_tcp import EchoHandler -from . import tservers CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" @@ -19,7 +16,7 @@ FULL_CLIENT_HELLO_NO_EXTENSIONS = ( CLIENT_HELLO_NO_EXTENSIONS ) - +""" class TestMasterSecretLogger(tservers.ServerTestBase): handler = EchoHandler ssl = dict( @@ -52,6 +49,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase): tls.MasterSecretLogger.create_logfun("test"), tls.MasterSecretLogger) assert not tls.MasterSecretLogger.create_logfun(False) +""" class TestTLSInvalid: diff --git a/test/mitmproxy/net/test_websocket.py b/test/mitmproxy/net/test_websocket.py index c38f9375d..06ea6581a 100644 --- a/test/mitmproxy/net/test_websocket.py +++ b/test/mitmproxy/net/test_websocket.py @@ -1,77 +1,4 @@ -import pytest -from io import BytesIO -from unittest import mock - -from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame - -from mitmproxy.net import http, websocket - - -@pytest.mark.parametrize("input,masking_key,payload_length", [ - (b'\x01\rserver-foobar', None, 13), - (b'\x01\x8dasdf\x12\x16\x16\x10\x04\x01I\x00\x0e\x1c\x06\x07\x13', b'asdf', 13), - (b'\x01~\x04\x00server-foobar', None, 1024), - (b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072), -]) -def test_read_frame(input, masking_key, payload_length): - bio = BytesIO(input) - bio.safe_read = bio.read - - header, frame, consumed_bytes = websocket.read_frame(bio) - assert header == \ - Header( - fin=False, - rsv=RsvBits(rsv1=False, rsv2=False, rsv3=False), - opcode=Opcode.TEXT, - payload_len=payload_length, - masking_key=masking_key, - ) - assert frame == \ - Frame( - opcode=Opcode.TEXT, - payload=b'server-foobar', - frame_finished=False, - message_finished=False, - ) - assert consumed_bytes == input - - bio = BytesIO(input) - bio.safe_read = bio.read - header, frame, consumed_bytes = websocket.read_frame(bio, False) - assert header is None - assert frame is None - assert consumed_bytes == input - - -@mock.patch('os.urandom', return_value=b'pumpkinspumpkins') -def test_client_handshake_headers(_): - assert websocket.client_handshake_headers() == \ - http.Headers([ - (b'connection', b'upgrade'), - (b'upgrade', b'websocket'), - (b'sec-websocket-version', b'13'), - (b'sec-websocket-key', b'cHVtcGtpbnNwdW1wa2lucw=='), - ]) - assert websocket.client_handshake_headers(b"13", b"foobar", b"foo", b"bar") == \ - http.Headers([ - (b'connection', b'upgrade'), - (b'upgrade', b'websocket'), - (b'sec-websocket-version', b'13'), - (b'sec-websocket-key', b'foobar'), - (b'sec-websocket-protocol', b'foo'), - (b'sec-websocket-extensions', b'bar') - ]) - - -def test_server_handshake_headers(): - assert websocket.server_handshake_headers("foobar", "foo", "bar") == \ - http.Headers([ - (b'connection', b'upgrade'), - (b'upgrade', b'websocket'), - (b'sec-websocket-accept', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), - (b'sec-websocket-protocol', b'foo'), - (b'sec-websocket-extensions', b'bar'), - ]) +from mitmproxy.net import websocket def test_check_handshake(): @@ -92,16 +19,6 @@ def test_check_handshake(): }) -def test_create_server_nonce(): - assert websocket.create_server_nonce(b"foobar") == b"AzhRPA4TNwR6I/riJheN0TfR7+I=" - - -def test_check_client_version(): - assert not websocket.check_client_version({}) - assert not websocket.check_client_version({"sec-websocket-version": b"42"}) - assert websocket.check_client_version({"sec-websocket-version": b"13"}) - - def test_get_extensions(): assert websocket.get_extensions({}) is None assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo" diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py deleted file mode 100644 index fea4a73a9..000000000 --- a/test/mitmproxy/net/tservers.py +++ /dev/null @@ -1,113 +0,0 @@ -import threading -import queue -import io -import OpenSSL - -from mitmproxy.net import tcp -from mitmproxy.utils import data - -cdata = data.Data(__name__) - - -class _ServerThread(threading.Thread): - - def __init__(self, server): - self.server = server - threading.Thread.__init__(self) - - def run(self): - self.server.serve_forever() - - -class _TServer(tcp.TCPServer): - - def __init__(self, ssl, q, handler_klass, addr, **kwargs): - """ - ssl: A dictionary of SSL parameters: - - cert, key, request_client_cert, cipher_list, - dhparams, v3_only - """ - tcp.TCPServer.__init__(self, addr) - - if ssl is True: - self.ssl = dict() - elif isinstance(ssl, dict): - self.ssl = ssl - else: - self.ssl = None - - self.q = q - self.handler_klass = handler_klass - if self.handler_klass is not None: - self.handler_klass.kwargs = kwargs - self.last_handler = None - - def handle_client_connection(self, request, client_address): - h = self.handler_klass(request, client_address, self) - self.last_handler = h - if self.ssl is not None: - cert = self.ssl.get( - "cert", - cdata.path("data/server.crt")) - raw_key = self.ssl.get( - "key", - cdata.path("data/server.key")) - with open(raw_key) as f: - raw_key = f.read() - key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key) - if self.ssl.get("v3_only", False): - method = OpenSSL.SSL.SSLv3_METHOD - options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 - else: - method = OpenSSL.SSL.SSLv23_METHOD - options = None - h.convert_to_tls( - cert, - key, - method=method, - options=options, - handle_sni=getattr(h, "handle_sni", None), - request_client_cert=self.ssl.get("request_client_cert", None), - cipher_list=self.ssl.get("cipher_list", None), - dhparams=self.ssl.get("dhparams", None), - chain_file=self.ssl.get("chain_file", None), - alpn_select=self.ssl.get("alpn_select", None) - ) - h.handle() - h.finish() - - def handle_error(self, connection, client_address, fp=None): - s = io.StringIO() - tcp.TCPServer.handle_error(self, connection, client_address, s) - self.q.put(s.getvalue()) - - -class ServerTestBase: - ssl = None - handler = None - addr = ("127.0.0.1", 0) - - @classmethod - def setup_class(cls, **kwargs): - cls.q = queue.Queue() - s = cls.makeserver(**kwargs) - cls.port = s.address[1] - cls.server = _ServerThread(s) - cls.server.start() - - @classmethod - def makeserver(cls, **kwargs): - ssl = kwargs.pop('ssl', cls.ssl) - return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs) - - @classmethod - def teardown_class(cls): - cls.server.server.shutdown() - - def teardown(self): - self.server.server.wait_for_silence() - - @property - def last_handler(self): - return self.server.server.last_handler diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 5387f8a6f..e2d3c481b 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -1,8 +1,8 @@ from unittest import mock from mitmproxy import controller -from mitmproxy import io from mitmproxy import eventsequence +from mitmproxy import io from mitmproxy.test import tflow from mitmproxy.test import tutils @@ -30,182 +30,3 @@ class MasterTest: fw = io.FlowWriter(f) t = tflow.tflow(resp=True) fw.add(t) - - -# class TestMaster(taddons.RecordingMaster): - -# def __init__(self, opts): -# super().__init__(opts) -# config = ProxyConfig(opts) -# self.server = ProxyServer(config) - -# def clear_addons(self, addons): -# self.addons.clear() -# self.state = TestState() -# self.addons.add(self.state) -# self.addons.add(*addons) -# self.addons.trigger("configure", self.options.keys()) -# self.addons.trigger("running") - -# def reset(self, addons): -# self.clear_addons(addons) -# self.clear() - - -# class ProxyThread(threading.Thread): - -# def __init__(self, masterclass, options): -# threading.Thread.__init__(self) -# self.masterclass = masterclass -# self.options = options -# self.tmaster = None -# self.event_loop = None -# controller.should_exit = False - -# @property -# def port(self): -# return self.tmaster.server.address[1] - -# @property -# def tlog(self): -# return self.tmaster.logs - -# def shutdown(self): -# self.tmaster.shutdown() - -# def run(self): -# self.event_loop = asyncio.new_event_loop() -# asyncio.set_event_loop(self.event_loop) -# self.tmaster = self.masterclass(self.options) -# self.tmaster.addons.add(core.Core()) -# self.name = "ProxyThread (%s)" % human.format_address(self.tmaster.server.address) -# self.tmaster.run() - -# def set_addons(self, *addons): -# self.tmaster.reset(addons) - -# def start(self): -# super().start() -# while True: -# if self.tmaster: -# break -# time.sleep(0.01) - - -# class ProxyTestBase: -# # Test Configuration -# ssl = None -# ssloptions = False -# masterclass = TestMaster - -# add_upstream_certs_to_client_chain = False - -# @classmethod -# def setup_class(cls): -# cls.server = pathod.test.Daemon( -# ssl=cls.ssl, -# ssloptions=cls.ssloptions) -# cls.server2 = pathod.test.Daemon( -# ssl=cls.ssl, -# ssloptions=cls.ssloptions) - -# cls.options = cls.get_options() -# cls.proxy = ProxyThread(cls.masterclass, cls.options) -# cls.proxy.start() - -# @classmethod -# def teardown_class(cls): -# # perf: we want to run tests in parallel -# # should this ever cause an error, travis should catch it. -# # shutil.rmtree(cls.confdir) -# cls.proxy.shutdown() -# cls.server.shutdown() -# cls.server2.shutdown() - -# def teardown(self): -# try: -# self.server.wait_for_silence() -# except exceptions.Timeout: -# # FIXME: Track down the Windows sync issues -# if sys.platform != "win32": -# raise - -# def setup(self): -# self.master.reset(self.addons()) -# self.server.clear_log() -# self.server2.clear_log() - -# @property -# def master(self): -# return self.proxy.tmaster - -# @classmethod -# def get_options(cls): -# cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy") -# return options.Options( -# listen_port=0, -# confdir=cls.confdir, -# add_upstream_certs_to_client_chain=cls.add_upstream_certs_to_client_chain, -# ssl_insecure=True, -# ) - -# def set_addons(self, *addons): -# self.proxy.set_addons(*addons) - -# def addons(self): -# """ -# Can be over-ridden to add a standard set of addons to tests. -# """ -# return [] - - -# class LazyPathoc(pathod.pathoc.Pathoc): -# def __init__(self, lazy_connect, *args, **kwargs): -# self.lazy_connect = lazy_connect -# pathod.pathoc.Pathoc.__init__(self, *args, **kwargs) - -# def connect(self): -# return pathod.pathoc.Pathoc.connect(self, self.lazy_connect) - - -# class HTTPProxyTest(ProxyTestBase): - -# def pathoc_raw(self): -# return pathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), fp=None) - -# def pathoc(self, sni=None): -# """ -# Returns a connected Pathoc instance. -# """ -# if self.ssl: -# conn = ("127.0.0.1", self.server.port) -# else: -# conn = None -# return LazyPathoc( -# conn, -# ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None -# ) - -# def pathod(self, spec, sni=None): -# """ -# Constructs a pathod GET request, with the appropriate base and proxy. -# """ -# p = self.pathoc(sni=sni) -# if self.ssl: -# q = "get:'/p/%s'" % spec -# else: -# q = f"get:'{self.server.urlbase}/p/{spec}'" -# with p.connect(): -# return p.request(q) - -# def app(self, page): -# if self.ssl: -# p = pathod.pathoc.Pathoc( -# ("127.0.0.1", self.proxy.port), True, fp=None -# ) -# with p.connect((self.master.options.onboarding_host, self.master.options.onbarding_port)): -# return p.request("get:'%s'" % page) -# else: -# p = self.pathoc() -# with p.connect(): -# return p.request(f"get:'http://{self.master.options.onboarding_host}{page}'")