remove old mitmproxy.net.tcp code

this is not needed anymore with sans-io
This commit is contained in:
Maximilian Hils 2020-12-20 00:12:21 +01:00
parent cdb0cf6c0a
commit b05c13daa6
9 changed files with 17 additions and 2019 deletions

View File

@ -1,683 +0,0 @@
import os
import errno
import select
import socket
import sys
import threading
import time
import traceback
from typing import Optional # noqa
from mitmproxy.net import tls
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy import exceptions
from mitmproxy.coretypes import basethread
socket_fileobject = socket.SocketIO
# workaround for https://bugs.python.org/issue29515
# Python 3.8 for Windows is missing a constant, fixed in 3.9
IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41)
class _FileLike:
BLOCKSIZE = 1024 * 32
def __init__(self, o):
self.o = o
self._log = None
self.first_byte_timestamp = None
def set_descriptor(self, o):
self.o = o
def __getattr__(self, attr):
return getattr(self.o, attr)
def start_log(self):
"""
Starts or resets the log.
This will store all bytes read or written.
"""
self._log = []
def stop_log(self):
"""
Stops the log.
"""
self._log = None
def is_logging(self):
return self._log is not None
def get_log(self):
"""
Returns the log as a string.
"""
if not self.is_logging():
raise ValueError("Not logging!")
return b"".join(self._log)
def add_log(self, v):
if self.is_logging():
self._log.append(v)
def reset_timestamps(self):
self.first_byte_timestamp = None
class Writer(_FileLike):
def flush(self):
"""
May raise exceptions.TcpDisconnect
"""
if hasattr(self.o, "flush"):
try:
self.o.flush()
except OSError as v:
raise exceptions.TcpDisconnect(str(v))
def write(self, v):
"""
May raise exceptions.TcpDisconnect
"""
if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
try:
if hasattr(self.o, "sendall"):
self.add_log(v)
return self.o.sendall(v)
else:
r = self.o.write(v)
self.add_log(v[:r])
return r
except (SSL.Error, OSError) as e:
raise exceptions.TcpDisconnect(str(e))
class Reader(_FileLike):
def read(self, length):
"""
If length is -1, we read until connection closes.
"""
result = b''
start = time.time()
while length == -1 or length > 0:
if length == -1 or length > self.BLOCKSIZE:
rlen = self.BLOCKSIZE
else:
rlen = length
try:
data = self.o.read(rlen)
except SSL.ZeroReturnError:
# TLS connection was shut down cleanly
break
except (SSL.WantWriteError, SSL.WantReadError):
# From the OpenSSL docs:
# If the underlying BIO is non-blocking, SSL_read() will also return when the
# underlying BIO could not satisfy the needs of SSL_read() to continue the
# operation. In this case a call to SSL_get_error with the return value of
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
# 300 is OpenSSL default timeout
timeout = self.o.gettimeout() or 300
if (time.time() - start) < timeout:
time.sleep(0.1)
continue
else:
raise exceptions.TcpTimeout()
except socket.timeout:
raise exceptions.TcpTimeout()
except OSError as e:
raise exceptions.TcpDisconnect(str(e))
except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'):
break
raise exceptions.TlsException(str(e))
except SSL.Error as e:
raise exceptions.TlsException(str(e))
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data:
break
result += data
if length != -1:
length -= len(data)
self.add_log(result)
return result
def readline(self, size=None):
result = b''
bytes_read = 0
while True:
if size is not None and bytes_read >= size:
break
ch = self.read(1)
bytes_read += 1
if not ch:
break
else:
result += ch
if ch == b'\n':
break
return result
def safe_read(self, length):
"""
Like .read, but is guaranteed to either return length bytes, or
raise an exception.
"""
result = self.read(length)
if length != -1 and len(result) != length:
if not result:
raise exceptions.TcpDisconnect()
else:
raise exceptions.TcpReadIncomplete(
"Expected {} bytes, got {}".format(length, len(result))
)
return result
def peek(self, length):
"""
Tries to peek into the underlying file object.
Returns:
Up to the next N bytes if peeking is successful.
Raises:
exceptions.TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
"""
if isinstance(self.o, socket_fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except OSError as e:
raise exceptions.TcpException(repr(e))
elif isinstance(self.o, SSL.Connection):
try:
return self.o.recv(length, socket.MSG_PEEK)
except SSL.Error as e:
raise exceptions.TlsException(str(e))
else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
def ssl_read_select(rlist, timeout):
"""
This is a wrapper around select.select() which also works for SSL.Connections
by taking ssl_connection.pending() into account.
Caveats:
If .pending() > 0 for any of the connections in rlist, we avoid the select syscall
and **will not include any other connections which may or may not be ready**.
Args:
rlist: wait until ready for reading
Returns:
subset of rlist which is ready for reading.
"""
return [
conn for conn in rlist
if isinstance(conn, SSL.Connection) and conn.pending() > 0
] or select.select(rlist, (), (), timeout)[0]
def close_socket(sock):
"""
Does a hard close of a socket, without emitting a RST.
"""
try:
# We already indicate that we close our end.
# may raise "Transport endpoint is not connected" on Linux
sock.shutdown(socket.SHUT_WR)
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
# readable data could lead to an immediate RST being sent (which is the
# case on Windows).
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
#
# This in turn results in the following issue: If we send an error page
# to the client and then close the socket, the RST may be received by
# the client before the error page and the users sees a connection
# error rather than the error page. Thus, we try to empty the read
# buffer on Windows first. (see
# https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
#
if os.name == "nt": # pragma: no cover
# We cannot rely on the shutdown()-followed-by-read()-eof technique
# proposed by the page above: Some remote machines just don't send
# a TCP FIN, which would leave us in the unfortunate situation that
# recv() would block infinitely. As a workaround, we set a timeout
# here even if we are in blocking mode.
sock.settimeout(sock.gettimeout() or 20)
# limit at a megabyte so that we don't read infinitely
for _ in range(1024 ** 3 // 4096):
# may raise a timeout/disconnect exception.
if not sock.recv(4096):
break
# Now we can close the other half as well.
sock.shutdown(socket.SHUT_RD)
except OSError:
pass
sock.close()
class _Connection:
rbufsize = -1
wbufsize = -1
def _makefile(self):
"""
Set up .rfile and .wfile attributes from .connection
"""
# Ideally, we would use the Buffered IO in Python 3 by default.
# Unfortunately, the implementation of .peek() is broken for n>1 bytes,
# as it may just return what's left in the buffer and not all the bytes we want.
# As a workaround, we just use unbuffered sockets directly.
# https://mail.python.org/pipermail/python-dev/2009-June/089986.html
self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
def __init__(self, connection):
if connection:
self.connection = connection
self.ip_address = connection.getpeername()
self._makefile()
else:
self.connection = None
self.ip_address = None
self.rfile = None
self.wfile = None
self.tls_established = False
self.finished = False
def get_current_cipher(self):
if not self.tls_established:
return None
name = self.connection.get_cipher_name()
bits = self.connection.get_cipher_bits()
version = self.connection.get_cipher_version()
return name, bits, version
def finish(self):
self.finished = True
# If we have an SSL connection, wfile.close == connection.close
# (We call _FileLike.set_descriptor(conn))
# Closing the socket is not our task, therefore we don't call close
# then.
if not isinstance(self.connection, SSL.Connection):
if not getattr(self.wfile, "closed", False):
try:
self.wfile.flush()
self.wfile.close()
except exceptions.TcpDisconnect:
pass
self.rfile.close()
else:
try:
self.connection.shutdown()
except SSL.Error:
pass
class ConnectionCloser:
def __init__(self, conn):
self.conn = conn
self._canceled = False
def pop(self):
"""
Cancel the current closer, and return a fresh one.
"""
self._canceled = True
return ConnectionCloser(self.conn)
def __enter__(self):
return self
def __exit__(self, *args):
if not self._canceled:
self.conn.close()
class TCPClient(_Connection):
def __init__(self, address, source_address=None, spoof_source_address=None):
super().__init__(None)
self.address = address
self.source_address = source_address
self.cert = None
self.server_certs = []
self.sni = None
self.spoof_source_address = spoof_source_address
@property
def ssl_verification_error(self) -> Optional[exceptions.InvalidCertificateException]:
return getattr(self.connection, "cert_error", None)
def close(self):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
if self.connection:
if isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs):
context = tls.create_client_context(
alpn_protos=alpn_protos,
sni=sni,
**sslctx_kwargs
)
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
self.connection.set_tlsext_host_name(sni.encode("idna"))
self.connection.set_connect_state()
try:
self.connection.do_handshake()
except SSL.Error as v:
if self.ssl_verification_error:
raise self.ssl_verification_error
else:
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.cert = certs.Cert(self.connection.get_peer_certificate())
# Keep all server certificates in a list
for i in self.connection.get_peer_cert_chain():
self.server_certs.append(certs.Cert(i))
self.tls_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
def makesocket(self, family, type, proto):
# some parties (cuckoo sandbox) need to hook this
return socket.socket(family, type, proto)
def create_connection(self, timeout=None):
# Based on the official socket.create_connection implementation of Python 3.6.
# https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py
err = None
for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = self.makesocket(af, socktype, proto)
if timeout:
sock.settimeout(timeout)
if self.source_address:
sock.bind(self.source_address)
if self.spoof_source_address:
try:
if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT):
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover
except Exception as e:
# socket.IP_TRANSPARENT might not be available on every OS and Python version
if sock is not None:
sock.close()
raise exceptions.TcpException(
"Failed to spoof the source address: " + str(e)
)
sock.connect(sa)
return sock
except OSError as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise OSError("getaddrinfo returns an empty list") # pragma: no cover
def connect(self):
try:
connection = self.create_connection()
except OSError as err:
raise exceptions.TcpException(
'Error connecting to "%s": %s' %
(self.address[0], err)
)
self.connection = connection
self.source_address = connection.getsockname()
self.ip_address = connection.getpeername()
self._makefile()
return ConnectionCloser(self)
def settimeout(self, n):
self.connection.settimeout(n)
def gettimeout(self):
return self.connection.gettimeout()
def get_alpn_proto_negotiated(self):
if self.tls_established:
return self.connection.get_alpn_proto_negotiated()
else:
return b""
class BaseHandler(_Connection):
"""
The instantiator is expected to call the handle() and finish() methods.
"""
def __init__(self, connection, address, server):
super().__init__(connection)
self.address = address
self.server = server
self.clientcert = None
def convert_to_tls(self, cert, key, **sslctx_kwargs):
"""
Convert connection to SSL.
For a list of parameters, see tls.create_server_context(...)
"""
context = tls.create_server_context(
cert=cert,
key=key,
**sslctx_kwargs)
self.connection = SSL.Connection(context, self.connection)
self.connection.set_accept_state()
try:
self.connection.do_handshake()
except SSL.Error as v:
raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
self.tls_established = True
cert = self.connection.get_peer_certificate()
if cert:
self.clientcert = certs.Cert(cert)
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)
def handle(self): # pragma: no cover
raise NotImplementedError
def settimeout(self, n):
self.connection.settimeout(n)
def get_alpn_proto_negotiated(self):
if self.tls_established:
return self.connection.get_alpn_proto_negotiated()
else:
return b""
class Counter:
def __init__(self):
self._count = 0
self._lock = threading.Lock()
@property
def count(self):
with self._lock:
return self._count
def __enter__(self):
with self._lock:
self._count += 1
def __exit__(self, *args):
with self._lock:
self._count -= 1
class TCPServer:
def __init__(self, address):
self.address = address
self.__is_shut_down = threading.Event()
self.__is_shut_down.set()
self.__shutdown_request = False
if self.address[0] == 'localhost':
raise OSError("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.")
self.socket = None
try:
# First try to bind an IPv6 socket, attempting to enable IPv4 support if the OS supports it.
# This allows us to accept connections for ::1 and 127.0.0.1 on the same socket.
# Only works if self.address == ""
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
self.socket.bind(self.address)
except OSError:
if self.socket:
self.socket.close()
self.socket = None
if not self.socket:
try:
# Binding to an IPv6 + IPv4 socket failed, lets fall back to IPv4 only.
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.bind(self.address)
except OSError:
if self.socket:
self.socket.close()
self.socket = None
if not self.socket:
# Binding to an IPv4 only socket failed, lets fall back to IPv6 only.
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket.bind(self.address)
self.address = self.socket.getsockname()
self.socket.listen()
self.handler_counter = Counter()
def connection_thread(self, connection, client_address):
with self.handler_counter:
try:
self.handle_client_connection(connection, client_address)
except OSError as e: # pragma: no cover
# This catches situations where the underlying connection is
# closed beneath us. Syscalls on the connection object at this
# point returns EINVAL. If this happens, we close the socket and
# move on.
if not e.errno == errno.EINVAL:
raise
except:
self.handle_error(connection, client_address)
finally:
close_socket(connection)
def serve_forever(self, poll_interval=0.1):
self.__is_shut_down.clear()
try:
while not self.__shutdown_request:
r, w_, e_ = select.select([self.socket], [], [], poll_interval)
if self.socket in r:
connection, client_address = self.socket.accept()
t = basethread.BaseThread(
"TCPConnectionHandler ({}: {}:{} -> {}:{})".format(
self.__class__.__name__,
client_address[0],
client_address[1],
self.address[0],
self.address[1],
),
target=self.connection_thread,
args=(connection, client_address),
)
t.setDaemon(1)
try:
t.start()
except threading.ThreadError:
self.handle_error(connection, client_address)
connection.close()
finally:
self.__shutdown_request = False
self.__is_shut_down.set()
def shutdown(self):
self.__shutdown_request = True
self.__is_shut_down.wait()
self.socket.close()
self.handle_shutdown()
def handle_error(self, connection_, client_address, fp=sys.stderr):
"""
Called when handle_client_connection raises an exception.
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
exc = str(traceback.format_exc())
print('-' * 40, file=fp)
print(
"Error in processing of request from %s" % repr(client_address), file=fp)
print(exc, file=fp)
print('-' * 40, file=fp)
def handle_client_connection(self, conn, client_address): # pragma: no cover
"""
Called after client connection.
"""
raise NotImplementedError
def handle_shutdown(self):
"""
Called after server shutdown.
"""
def wait_for_silence(self, timeout=5):
start = time.time()
while 1:
if time.time() - start >= timeout:
raise exceptions.Timeout(
"%s service threads still alive" %
self.handler_counter.count
)
if self.handler_counter.count == 0:
return

View File

@ -4,126 +4,6 @@ Spec: https://tools.ietf.org/html/rfc6455
""" """
import base64
import hashlib
import os
import struct
from wsproto.utilities import ACCEPT_GUID
from wsproto.handshake import WEBSOCKET_VERSION
from wsproto.frame_protocol import RsvBits, Header, Frame, XorMaskerSimple, XorMaskerNull
from mitmproxy.net import http
from mitmproxy.utils import bits, strutils
def read_frame(rfile, parse=True):
"""
Reads a full WebSocket frame from a file-like object.
Returns a parsed frame header, parsed frame, and the consumed bytes.
"""
consumed_bytes = b''
def consume(len):
nonlocal consumed_bytes
d = rfile.safe_read(len)
consumed_bytes += d
return d
first_byte, second_byte = consume(2)
fin = bits.getbit(first_byte, 7)
rsv1 = bits.getbit(first_byte, 6)
rsv2 = bits.getbit(first_byte, 5)
rsv3 = bits.getbit(first_byte, 4)
opcode = first_byte & 0xF
mask_bit = bits.getbit(second_byte, 7)
length_code = second_byte & 0x7F
# payload_len > 125 indicates you need to read more bytes
# to get the actual payload length
if length_code <= 125:
payload_len = length_code
elif length_code == 126:
payload_len, = struct.unpack("!H", consume(2))
else: # length_code == 127:
payload_len, = struct.unpack("!Q", consume(8))
# masking key only present if mask bit set
if mask_bit == 1:
masking_key = consume(4)
masker = XorMaskerSimple(masking_key)
else:
masking_key = None
masker = XorMaskerNull()
masked_payload = consume(payload_len)
if parse:
header = Header(
fin=fin,
rsv=RsvBits(rsv1, rsv2, rsv3),
opcode=opcode,
payload_len=payload_len,
masking_key=masking_key,
)
frame = Frame(
opcode=opcode,
payload=masker.process(masked_payload),
frame_finished=fin,
message_finished=fin
)
else:
header = None
frame = None
return header, frame, consumed_bytes
def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
"""
Create the headers for a valid HTTP upgrade request. If Key is not
specified, it is generated, and can be found in sec-websocket-key in
the returned header set.
Returns an instance of http.Headers
"""
if version is None:
version = WEBSOCKET_VERSION
if key is None:
key = base64.b64encode(os.urandom(16)).decode('ascii')
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_version=version,
sec_websocket_key=key,
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def server_handshake_headers(client_key, protocol=None, extensions=None):
"""
The server response is a valid HTTP 101 response.
Returns an instance of http.Headers
"""
h = http.Headers(
connection="upgrade",
upgrade="websocket",
sec_websocket_accept=create_server_nonce(client_key),
)
if protocol is not None:
h['sec-websocket-protocol'] = protocol
if extensions is not None:
h['sec-websocket-extensions'] = extensions
return h
def check_handshake(headers): def check_handshake(headers):
return ( return (
"upgrade" in headers.get("connection", "").lower() and "upgrade" in headers.get("connection", "").lower() and
@ -132,14 +12,6 @@ def check_handshake(headers):
) )
def create_server_nonce(client_nonce):
return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + ACCEPT_GUID).digest())
def check_client_version(headers):
return headers.get("sec-websocket-version", "") == WEBSOCKET_VERSION
def get_extensions(headers): def get_extensions(headers):
return headers.get("sec-websocket-extensions", None) return headers.get("sec-websocket-extensions", None)

View File

@ -1,17 +1,6 @@
from io import BytesIO
from mitmproxy.net import tcp
from mitmproxy.net import http from mitmproxy.net import http
def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
fp = BytesIO(bytes)
return tcp.Reader(fp)
def treq(**kwargs) -> http.Request: def treq(**kwargs) -> http.Request:
""" """
Returns: Returns:

View File

@ -1,11 +1,21 @@
import ipaddress import ipaddress
from io import BytesIO from io import BytesIO
import pytest import pytest
from mitmproxy.net import socks from mitmproxy.net import socks
from mitmproxy.test import tutils from mitmproxy.test import tutils
# this is a temporary placeholder here, we remove the file-based API when we transition socks proxying to sans-io.
class tutils: # noqa
@staticmethod
def treader(data: bytes):
io = BytesIO(data)
io.safe_read = io.read
return io
def test_client_greeting(): def test_client_greeting():
raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") raw = tutils.treader(b"\x05\x02\x00\xBE\xEF")
out = BytesIO() out = BytesIO()

View File

@ -1,813 +0,0 @@
from io import BytesIO
import re
import queue
import time
import socket
import random
import threading
import pytest
from unittest import mock
from OpenSSL import SSL
from mitmproxy import certs
from mitmproxy.net import tcp
from mitmproxy import exceptions
from mitmproxy.utils import data
from ...conftest import skip_no_ipv6
from . import tservers
cdata = data.Data(__name__)
class EchoHandler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
v = self.rfile.readline()
self.wfile.write(v)
self.wfile.flush()
class ClientCipherListHandler(tcp.BaseHandler):
sni = None
def handle(self):
self.wfile.write(f"{self.connection.get_cipher_list()}\n".encode())
self.wfile.flush()
class HangHandler(tcp.BaseHandler):
def handle(self):
# Hang as long as the client connection is alive
while True:
try:
self.connection.setblocking(0)
ret = self.connection.recv(1)
# Client connection is dead...
if ret == "" or ret == b"":
return
except OSError:
pass
except SSL.WantReadError:
pass
except Exception:
return
time.sleep(0.1)
class ALPNHandler(tcp.BaseHandler):
sni = None
def handle(self):
alp = self.get_alpn_proto_negotiated()
if alp:
self.wfile.write(alp)
else:
self.wfile.write(b"NONE")
self.wfile.flush()
class TestServer(tservers.ServerTestBase):
handler = EchoHandler
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
def test_thread_start_error(self):
with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m:
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert not c.rfile.read(1)
assert m.called
assert "nonewthread" in self.q.get_nowait()
self.test_echo()
class TestServerBind(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# We may get an ipv4-mapped ipv6 address here, e.g. ::ffff:127.0.0.1.
# Those still appear as "127.0.0.1" in the table, so we need to strip the prefix.
peername = self.connection.getpeername()
address = re.sub(r"^::ffff:(?=\d+.\d+.\d+.\d+$)", "", peername[0])
port = peername[1]
self.wfile.write(str((address, port)).encode())
self.wfile.flush()
def test_bind(self):
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
for i in range(20):
random_port = random.randrange(1024, 65535)
try:
c = tcp.TCPClient(
("127.0.0.1", self.port), source_address=(
"127.0.0.1", random_port))
with c.connect():
assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
return
except exceptions.TcpException: # port probably already in use
pass
@skip_no_ipv6
class TestServerIPv6(tservers.ServerTestBase):
handler = EchoHandler
addr = ("::1", 0)
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("::1", self.port))
with c.connect():
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
class TestEcho(tservers.ServerTestBase):
handler = EchoHandler
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("localhost", self.port))
with c.connect():
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
class HardDisconnectHandler(tcp.BaseHandler):
def handle(self):
self.connection.close()
class TestFinishFail(tservers.ServerTestBase):
"""
This tests a difficult-to-trigger exception in the .finish() method of
the handler.
"""
handler = EchoHandler
def test_disconnect_in_finish(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.wfile.write(b"foo\n")
c.wfile.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
c.finish()
class TestServerSSL(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
cipher_list="AES256-SHA",
chain_file=cdata.path("data/server.crt")
)
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(sni="foo.com", options=SSL.OP_ALL)
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
def test_get_current_cipher(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert not c.get_current_cipher()
c.convert_to_tls(sni="foo.com")
ret = c.get_current_cipher()
assert ret
assert "AES" in ret[0]
class TestSSLv3Only(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
request_client_cert=False,
v3_only=True
)
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.TlsException):
c.convert_to_tls(sni="foo.com")
class TestInvalidTrustFile(tservers.ServerTestBase):
def test_invalid_trust_file_should_fail(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.TlsException):
c.convert_to_tls(
sni="example.mitmproxy.org",
verify=SSL.VERIFY_PEER,
ca_pemfile=cdata.path("data/verificationcerts/generate.py")
)
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
cert=cdata.path("data/verificationcerts/self-signed.crt"),
key=cdata.path("data/verificationcerts/self-signed.key")
)
def test_mode_default_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
# Verification errors should be saved even if connection isn't aborted
# aborted
assert c.ssl_verification_error
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
def test_mode_none_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(verify=SSL.VERIFY_NONE)
# Verification errors should be saved even if connection isn't aborted
assert c.ssl_verification_error
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
def test_mode_strict_should_fail(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.InvalidCertificateException):
c.convert_to_tls(
sni="example.mitmproxy.org",
verify=SSL.VERIFY_PEER,
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
)
assert c.ssl_verification_error
# Unknown issuing certificate authority for first certificate
assert "errno: 18" in str(c.ssl_verification_error)
assert "depth: 0" in str(c.ssl_verification_error)
class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
cert=cdata.path("data/verificationcerts/trusted-leaf.crt"),
key=cdata.path("data/verificationcerts/trusted-leaf.key")
)
def test_should_fail_without_sni(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.TlsException):
c.convert_to_tls(
verify=SSL.VERIFY_PEER,
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
)
def test_mode_none_should_pass_without_sni(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(
verify=SSL.VERIFY_NONE,
ca_path=cdata.path("data/verificationcerts/")
)
assert "Cannot validate hostname, SNI missing." in str(c.ssl_verification_error)
def test_should_fail(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.InvalidCertificateException):
c.convert_to_tls(
sni="mitmproxy.org",
verify=SSL.VERIFY_PEER,
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
)
assert c.ssl_verification_error
class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
cert=cdata.path("data/verificationcerts/trusted-leaf.crt"),
key=cdata.path("data/verificationcerts/trusted-leaf.key")
)
def test_mode_strict_w_pemfile_should_pass(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(
sni="example.mitmproxy.org",
verify=SSL.VERIFY_PEER,
ca_pemfile=cdata.path("data/verificationcerts/trusted-root.crt")
)
assert c.ssl_verification_error is None
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
def test_mode_strict_w_confdir_should_pass(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(
sni="example.mitmproxy.org",
verify=SSL.VERIFY_PEER,
ca_path=cdata.path("data/verificationcerts/")
)
assert c.ssl_verification_error is None
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
class TestSSLClientCert(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write(b"%d\n" % self.clientcert.serial)
self.wfile.flush()
ssl = dict(
request_client_cert=True,
v3_only=False
)
def test_clientcert(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(
cert=cdata.path("data/clientcert/client.pem"))
assert c.rfile.readline().strip() == b"1"
def test_clientcert_err(self, tdata):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(exceptions.TlsException):
c.convert_to_tls(cert=cdata.path("data/clientcert/make"))
class TestSNI(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
sni = None
def handle_sni(self, connection):
self.sni = connection.get_servername()
def handle(self):
self.wfile.write(self.sni)
self.wfile.flush()
ssl = True
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(sni="foo.com")
assert c.sni == "foo.com"
assert c.rfile.readline() == b"foo.com"
def test_idn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(sni="mitmproxyäöüß.example.com")
assert c.tls_established
assert "doesn't match" not in str(c.ssl_verification_error)
class TestServerCipherList(tservers.ServerTestBase):
handler = ClientCipherListHandler
ssl = dict(
cipher_list='AES256-GCM-SHA384'
)
@pytest.mark.xfail
def test_echo(self):
# Not working for OpenSSL 1.1.1, see
# https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(sni="foo.com")
expected = b"['TLS_AES_256_GCM_SHA384']"
assert c.rfile.readline() == expected
class TestServerCurrentCipher(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
sni = None
def handle(self):
self.wfile.write(str(self.get_current_cipher()).encode())
self.wfile.flush()
ssl = dict(
cipher_list='AES256-GCM-SHA384'
)
@pytest.mark.xfail
def test_echo(self):
# Not working for OpenSSL 1.1.1, see
# https://github.com/pyca/pyopenssl/blob/fc802df5c10f0d1cd9749c94887d652fa26db6fb/src/OpenSSL/SSL.py#L1192-L1196
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(sni="foo.com")
assert b'AES256-GCM-SHA384' in c.rfile.readline()
class TestServerCipherListError(tservers.ServerTestBase):
handler = ClientCipherListHandler
ssl = dict(
cipher_list=b'bogus'
)
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(Exception, match="handshake error"):
c.convert_to_tls(sni="foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
handler = ClientCipherListHandler
ssl = dict(
cipher_list='RC4-SHA'
)
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
with pytest.raises(Exception, match="cipher specification"):
c.convert_to_tls(sni="foo.com", cipher_list="bogus")
class TestSSLDisconnect(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.finish()
ssl = True
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
# Exercise SSL.ZeroReturnError
c.rfile.read(10)
c.close()
with pytest.raises(exceptions.TcpDisconnect):
c.wfile.write(b"foo")
with pytest.raises(queue.Empty):
self.q.get_nowait()
class TestSSLHardDisconnect(tservers.ServerTestBase):
handler = HardDisconnectHandler
ssl = True
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
# Exercise SSL.SysCallError
c.rfile.read(10)
c.close()
with pytest.raises(exceptions.TcpDisconnect):
c.wfile.write(b"foo")
class TestDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.rfile.read(10)
c.wfile.write(b"foo")
c.close()
c.close()
class TestServerTimeOut(tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.timeout = False
self.settimeout(0.01)
try:
self.rfile.read(10)
except exceptions.TcpTimeout:
self.timeout = True
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
time.sleep(0.3)
assert self.last_handler.timeout
class TestTimeOut(tservers.ServerTestBase):
handler = HangHandler
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.settimeout(0.1)
assert c.gettimeout() == 0.1
with pytest.raises(exceptions.TcpTimeout):
c.rfile.read(10)
class TestALPNClient(tservers.ServerTestBase):
handler = ALPNHandler
ssl = dict(
alpn_select=b"bar"
)
@pytest.mark.parametrize('alpn_protos, expected_negotiated, expected_response', [
([b"foo", b"bar", b"fasel"], b'bar', b'bar'),
([], b'', b'NONE'),
(None, b'', b'NONE'),
])
def test_alpn(self, monkeypatch, alpn_protos, expected_negotiated, expected_response):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(alpn_protos=alpn_protos)
assert c.get_alpn_proto_negotiated() == expected_negotiated
assert c.rfile.readline().strip() == expected_response
class TestNoSSLNoALPNClient(tservers.ServerTestBase):
handler = ALPNHandler
def test_no_ssl_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert c.get_alpn_proto_negotiated() == b""
assert c.rfile.readline().strip() == b"NONE"
class TestSSLTimeOut(tservers.ServerTestBase):
handler = HangHandler
ssl = True
def test_timeout_client(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls()
c.settimeout(0.1)
with pytest.raises(exceptions.TcpTimeout):
c.rfile.read(10)
class TestDHParams(tservers.ServerTestBase):
handler = HangHandler
ssl = dict(
dhparams=certs.CertStore.load_dhparam(
cdata.path("data/dhparam.pem"),
),
cipher_list="DHE-RSA-AES256-SHA"
)
def test_dhparams(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_tls(method=SSL.TLSv1_2_METHOD)
ret = c.get_current_cipher()
assert ret[0] == "DHE-RSA-AES256-SHA"
class TestTCPClient(tservers.ServerTestBase):
def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0))
with pytest.raises(exceptions.TcpException, match="Error connecting"):
c.connect()
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.create_connection(timeout=20) as conn:
assert conn.gettimeout() == 20
def test_spoof_address(self):
c = tcp.TCPClient(("127.0.0.1", self.port), spoof_source_address=("127.0.0.1", 0))
with pytest.raises(exceptions.TcpException, match="Failed to spoof"):
c.connect()
class TestTCPServer:
def test_binderr(self):
with pytest.raises(socket.error, match="prohibited"):
tcp.TCPServer(("localhost", 8080))
def test_wait_for_silence(self):
s = tcp.TCPServer(("127.0.0.1", 0))
with s.handler_counter:
with pytest.raises(exceptions.Timeout):
s.wait_for_silence()
s.shutdown()
class TestFileLike:
def test_blocksize(self):
s = BytesIO(b"1234567890abcdefghijklmnopqrstuvwxyz")
s = tcp.Reader(s)
s.BLOCKSIZE = 2
assert s.read(1) == b"1"
assert s.read(2) == b"23"
assert s.read(3) == b"456"
assert s.read(4) == b"7890"
d = s.read(-1)
assert d.startswith(b"abc") and d.endswith(b"xyz")
def test_wrap(self):
s = BytesIO(b"foobar\nfoobar")
s.flush()
s = tcp.Reader(s)
assert s.readline() == b"foobar\n"
assert s.readline() == b"foobar"
# Test __getattr__
assert s.isatty
def test_limit(self):
s = BytesIO(b"foobar\nfoobar")
s = tcp.Reader(s)
assert s.readline(3) == b"foo"
def test_limitless(self):
s = BytesIO(b"f" * (50 * 1024))
s = tcp.Reader(s)
ret = s.read(-1)
assert len(ret) == 50 * 1024
def test_readlog(self):
s = BytesIO(b"foobar\nfoobar")
s = tcp.Reader(s)
assert not s.is_logging()
s.start_log()
assert s.is_logging()
s.readline()
assert s.get_log() == b"foobar\n"
s.read(1)
assert s.get_log() == b"foobar\nf"
s.start_log()
assert s.get_log() == b""
s.read(1)
assert s.get_log() == b"o"
s.stop_log()
with pytest.raises(ValueError):
s.get_log()
def test_writelog(self):
s = BytesIO()
s = tcp.Writer(s)
s.start_log()
assert s.is_logging()
s.write(b"x")
assert s.get_log() == b"x"
s.write(b"x")
assert s.get_log() == b"xx"
def test_writer_flush_error(self):
s = BytesIO()
s = tcp.Writer(s)
o = mock.MagicMock()
o.flush = mock.MagicMock(side_effect=socket.error)
s.o = o
with pytest.raises(exceptions.TcpDisconnect):
s.flush()
def test_reader_read_error(self):
s = BytesIO(b"foobar\nfoobar")
s = tcp.Reader(s)
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s.o = o
with pytest.raises(exceptions.TcpDisconnect):
s.read(10)
def test_reset_timestamps(self):
s = BytesIO(b"foobar\nfoobar")
s = tcp.Reader(s)
s.first_byte_timestamp = 500
s.reset_timestamps()
assert not s.first_byte_timestamp
def test_first_byte_timestamp_updated_on_read(self):
s = BytesIO(b"foobar\nfoobar")
s = tcp.Reader(s)
s.read(1)
assert s.first_byte_timestamp
expected = s.first_byte_timestamp
s.read(5)
assert s.first_byte_timestamp == expected
def test_first_byte_timestamp_updated_on_readline(self):
s = BytesIO(b"foobar\nfoobar\nfoobar")
s = tcp.Reader(s)
s.readline()
assert s.first_byte_timestamp
expected = s.first_byte_timestamp
s.readline()
assert s.first_byte_timestamp == expected
def test_read_ssl_error(self):
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.Error())
s = tcp.Reader(s)
with pytest.raises(exceptions.TlsException):
s.read(1)
def test_read_syscall_ssl_error(self):
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
s = tcp.Reader(s)
with pytest.raises(exceptions.TlsException):
s.read(1)
def test_reader_readline_disconnect(self):
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s = tcp.Reader(o)
with pytest.raises(exceptions.TcpDisconnect):
s.readline(10)
def test_reader_incomplete_error(self):
s = BytesIO(b"foobar")
s = tcp.Reader(s)
with pytest.raises(exceptions.TcpReadIncomplete):
s.safe_read(10)
class TestPeek(tservers.ServerTestBase):
handler = EchoHandler
def _connect(self, c):
return c.connect()
def test_peek(self):
testval = b"peek!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
with self._connect(c):
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.peek(4) == b"peek"
assert c.rfile.peek(6) == b"peek!\n"
assert c.rfile.readline() == testval
c.close()
with pytest.raises(exceptions.NetlibException):
c.rfile.peek(1)
class TestPeekSSL(TestPeek):
ssl = True
def _connect(self, c):
with c.connect() as conn:
c.convert_to_tls()
return conn.pop()

View File

@ -4,9 +4,6 @@ import pytest
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net import tls from mitmproxy.net import tls
from mitmproxy.net.tcp import TCPClient
from test.mitmproxy.net.test_tcp import EchoHandler
from . import tservers
CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex( CLIENT_HELLO_NO_EXTENSIONS = bytes.fromhex(
"03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637"
@ -19,7 +16,7 @@ FULL_CLIENT_HELLO_NO_EXTENSIONS = (
CLIENT_HELLO_NO_EXTENSIONS CLIENT_HELLO_NO_EXTENSIONS
) )
"""
class TestMasterSecretLogger(tservers.ServerTestBase): class TestMasterSecretLogger(tservers.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
@ -52,6 +49,7 @@ class TestMasterSecretLogger(tservers.ServerTestBase):
tls.MasterSecretLogger.create_logfun("test"), tls.MasterSecretLogger.create_logfun("test"),
tls.MasterSecretLogger) tls.MasterSecretLogger)
assert not tls.MasterSecretLogger.create_logfun(False) assert not tls.MasterSecretLogger.create_logfun(False)
"""
class TestTLSInvalid: class TestTLSInvalid:

View File

@ -1,77 +1,4 @@
import pytest from mitmproxy.net import websocket
from io import BytesIO
from unittest import mock
from wsproto.frame_protocol import Opcode, RsvBits, Header, Frame
from mitmproxy.net import http, websocket
@pytest.mark.parametrize("input,masking_key,payload_length", [
(b'\x01\rserver-foobar', None, 13),
(b'\x01\x8dasdf\x12\x16\x16\x10\x04\x01I\x00\x0e\x1c\x06\x07\x13', b'asdf', 13),
(b'\x01~\x04\x00server-foobar', None, 1024),
(b'\x01\x7f\x00\x00\x00\x00\x00\x02\x00\x00server-foobar', None, 131072),
])
def test_read_frame(input, masking_key, payload_length):
bio = BytesIO(input)
bio.safe_read = bio.read
header, frame, consumed_bytes = websocket.read_frame(bio)
assert header == \
Header(
fin=False,
rsv=RsvBits(rsv1=False, rsv2=False, rsv3=False),
opcode=Opcode.TEXT,
payload_len=payload_length,
masking_key=masking_key,
)
assert frame == \
Frame(
opcode=Opcode.TEXT,
payload=b'server-foobar',
frame_finished=False,
message_finished=False,
)
assert consumed_bytes == input
bio = BytesIO(input)
bio.safe_read = bio.read
header, frame, consumed_bytes = websocket.read_frame(bio, False)
assert header is None
assert frame is None
assert consumed_bytes == input
@mock.patch('os.urandom', return_value=b'pumpkinspumpkins')
def test_client_handshake_headers(_):
assert websocket.client_handshake_headers() == \
http.Headers([
(b'connection', b'upgrade'),
(b'upgrade', b'websocket'),
(b'sec-websocket-version', b'13'),
(b'sec-websocket-key', b'cHVtcGtpbnNwdW1wa2lucw=='),
])
assert websocket.client_handshake_headers(b"13", b"foobar", b"foo", b"bar") == \
http.Headers([
(b'connection', b'upgrade'),
(b'upgrade', b'websocket'),
(b'sec-websocket-version', b'13'),
(b'sec-websocket-key', b'foobar'),
(b'sec-websocket-protocol', b'foo'),
(b'sec-websocket-extensions', b'bar')
])
def test_server_handshake_headers():
assert websocket.server_handshake_headers("foobar", "foo", "bar") == \
http.Headers([
(b'connection', b'upgrade'),
(b'upgrade', b'websocket'),
(b'sec-websocket-accept', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
(b'sec-websocket-protocol', b'foo'),
(b'sec-websocket-extensions', b'bar'),
])
def test_check_handshake(): def test_check_handshake():
@ -92,16 +19,6 @@ def test_check_handshake():
}) })
def test_create_server_nonce():
assert websocket.create_server_nonce(b"foobar") == b"AzhRPA4TNwR6I/riJheN0TfR7+I="
def test_check_client_version():
assert not websocket.check_client_version({})
assert not websocket.check_client_version({"sec-websocket-version": b"42"})
assert websocket.check_client_version({"sec-websocket-version": b"13"})
def test_get_extensions(): def test_get_extensions():
assert websocket.get_extensions({}) is None assert websocket.get_extensions({}) is None
assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo" assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo"

View File

@ -1,113 +0,0 @@
import threading
import queue
import io
import OpenSSL
from mitmproxy.net import tcp
from mitmproxy.utils import data
cdata = data.Data(__name__)
class _ServerThread(threading.Thread):
def __init__(self, server):
self.server = server
threading.Thread.__init__(self)
def run(self):
self.server.serve_forever()
class _TServer(tcp.TCPServer):
def __init__(self, ssl, q, handler_klass, addr, **kwargs):
"""
ssl: A dictionary of SSL parameters:
cert, key, request_client_cert, cipher_list,
dhparams, v3_only
"""
tcp.TCPServer.__init__(self, addr)
if ssl is True:
self.ssl = dict()
elif isinstance(ssl, dict):
self.ssl = ssl
else:
self.ssl = None
self.q = q
self.handler_klass = handler_klass
if self.handler_klass is not None:
self.handler_klass.kwargs = kwargs
self.last_handler = None
def handle_client_connection(self, request, client_address):
h = self.handler_klass(request, client_address, self)
self.last_handler = h
if self.ssl is not None:
cert = self.ssl.get(
"cert",
cdata.path("data/server.crt"))
raw_key = self.ssl.get(
"key",
cdata.path("data/server.key"))
with open(raw_key) as f:
raw_key = f.read()
key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key)
if self.ssl.get("v3_only", False):
method = OpenSSL.SSL.SSLv3_METHOD
options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1
else:
method = OpenSSL.SSL.SSLv23_METHOD
options = None
h.convert_to_tls(
cert,
key,
method=method,
options=options,
handle_sni=getattr(h, "handle_sni", None),
request_client_cert=self.ssl.get("request_client_cert", None),
cipher_list=self.ssl.get("cipher_list", None),
dhparams=self.ssl.get("dhparams", None),
chain_file=self.ssl.get("chain_file", None),
alpn_select=self.ssl.get("alpn_select", None)
)
h.handle()
h.finish()
def handle_error(self, connection, client_address, fp=None):
s = io.StringIO()
tcp.TCPServer.handle_error(self, connection, client_address, s)
self.q.put(s.getvalue())
class ServerTestBase:
ssl = None
handler = None
addr = ("127.0.0.1", 0)
@classmethod
def setup_class(cls, **kwargs):
cls.q = queue.Queue()
s = cls.makeserver(**kwargs)
cls.port = s.address[1]
cls.server = _ServerThread(s)
cls.server.start()
@classmethod
def makeserver(cls, **kwargs):
ssl = kwargs.pop('ssl', cls.ssl)
return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs)
@classmethod
def teardown_class(cls):
cls.server.server.shutdown()
def teardown(self):
self.server.server.wait_for_silence()
@property
def last_handler(self):
return self.server.server.last_handler

View File

@ -1,8 +1,8 @@
from unittest import mock from unittest import mock
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import io
from mitmproxy import eventsequence from mitmproxy import eventsequence
from mitmproxy import io
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.test import tutils from mitmproxy.test import tutils
@ -30,182 +30,3 @@ class MasterTest:
fw = io.FlowWriter(f) fw = io.FlowWriter(f)
t = tflow.tflow(resp=True) t = tflow.tflow(resp=True)
fw.add(t) fw.add(t)
# class TestMaster(taddons.RecordingMaster):
# def __init__(self, opts):
# super().__init__(opts)
# config = ProxyConfig(opts)
# self.server = ProxyServer(config)
# def clear_addons(self, addons):
# self.addons.clear()
# self.state = TestState()
# self.addons.add(self.state)
# self.addons.add(*addons)
# self.addons.trigger("configure", self.options.keys())
# self.addons.trigger("running")
# def reset(self, addons):
# self.clear_addons(addons)
# self.clear()
# class ProxyThread(threading.Thread):
# def __init__(self, masterclass, options):
# threading.Thread.__init__(self)
# self.masterclass = masterclass
# self.options = options
# self.tmaster = None
# self.event_loop = None
# controller.should_exit = False
# @property
# def port(self):
# return self.tmaster.server.address[1]
# @property
# def tlog(self):
# return self.tmaster.logs
# def shutdown(self):
# self.tmaster.shutdown()
# def run(self):
# self.event_loop = asyncio.new_event_loop()
# asyncio.set_event_loop(self.event_loop)
# self.tmaster = self.masterclass(self.options)
# self.tmaster.addons.add(core.Core())
# self.name = "ProxyThread (%s)" % human.format_address(self.tmaster.server.address)
# self.tmaster.run()
# def set_addons(self, *addons):
# self.tmaster.reset(addons)
# def start(self):
# super().start()
# while True:
# if self.tmaster:
# break
# time.sleep(0.01)
# class ProxyTestBase:
# # Test Configuration
# ssl = None
# ssloptions = False
# masterclass = TestMaster
# add_upstream_certs_to_client_chain = False
# @classmethod
# def setup_class(cls):
# cls.server = pathod.test.Daemon(
# ssl=cls.ssl,
# ssloptions=cls.ssloptions)
# cls.server2 = pathod.test.Daemon(
# ssl=cls.ssl,
# ssloptions=cls.ssloptions)
# cls.options = cls.get_options()
# cls.proxy = ProxyThread(cls.masterclass, cls.options)
# cls.proxy.start()
# @classmethod
# def teardown_class(cls):
# # perf: we want to run tests in parallel
# # should this ever cause an error, travis should catch it.
# # shutil.rmtree(cls.confdir)
# cls.proxy.shutdown()
# cls.server.shutdown()
# cls.server2.shutdown()
# def teardown(self):
# try:
# self.server.wait_for_silence()
# except exceptions.Timeout:
# # FIXME: Track down the Windows sync issues
# if sys.platform != "win32":
# raise
# def setup(self):
# self.master.reset(self.addons())
# self.server.clear_log()
# self.server2.clear_log()
# @property
# def master(self):
# return self.proxy.tmaster
# @classmethod
# def get_options(cls):
# cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy")
# return options.Options(
# listen_port=0,
# confdir=cls.confdir,
# add_upstream_certs_to_client_chain=cls.add_upstream_certs_to_client_chain,
# ssl_insecure=True,
# )
# def set_addons(self, *addons):
# self.proxy.set_addons(*addons)
# def addons(self):
# """
# Can be over-ridden to add a standard set of addons to tests.
# """
# return []
# class LazyPathoc(pathod.pathoc.Pathoc):
# def __init__(self, lazy_connect, *args, **kwargs):
# self.lazy_connect = lazy_connect
# pathod.pathoc.Pathoc.__init__(self, *args, **kwargs)
# def connect(self):
# return pathod.pathoc.Pathoc.connect(self, self.lazy_connect)
# class HTTPProxyTest(ProxyTestBase):
# def pathoc_raw(self):
# return pathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), fp=None)
# def pathoc(self, sni=None):
# """
# Returns a connected Pathoc instance.
# """
# if self.ssl:
# conn = ("127.0.0.1", self.server.port)
# else:
# conn = None
# return LazyPathoc(
# conn,
# ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
# )
# def pathod(self, spec, sni=None):
# """
# Constructs a pathod GET request, with the appropriate base and proxy.
# """
# p = self.pathoc(sni=sni)
# if self.ssl:
# q = "get:'/p/%s'" % spec
# else:
# q = f"get:'{self.server.urlbase}/p/{spec}'"
# with p.connect():
# return p.request(q)
# def app(self, page):
# if self.ssl:
# p = pathod.pathoc.Pathoc(
# ("127.0.0.1", self.proxy.port), True, fp=None
# )
# with p.connect((self.master.options.onboarding_host, self.master.options.onbarding_port)):
# return p.request("get:'%s'" % page)
# else:
# p = self.pathoc()
# with p.connect():
# return p.request(f"get:'http://{self.master.options.onboarding_host}{page}'")