diff --git a/mitmproxy/addons/__init__.py b/mitmproxy/addons/__init__.py index fd330a71a..2ed87d7ec 100644 --- a/mitmproxy/addons/__init__.py +++ b/mitmproxy/addons/__init__.py @@ -2,7 +2,6 @@ from mitmproxy.addons import anticache from mitmproxy.addons import anticomp from mitmproxy.addons import block from mitmproxy.addons import browser -from mitmproxy.addons import check_ca from mitmproxy.addons import clientplayback from mitmproxy.addons import command_history from mitmproxy.addons import core @@ -34,7 +33,6 @@ def default_addons(): block.Block(), anticache.AntiCache(), anticomp.AntiComp(), - check_ca.CheckCA(), clientplayback.ClientPlayback(), command_history.CommandHistory(), cut.Cut(), diff --git a/mitmproxy/addons/check_ca.py b/mitmproxy/addons/check_ca.py deleted file mode 100644 index 447ba64db..000000000 --- a/mitmproxy/addons/check_ca.py +++ /dev/null @@ -1,24 +0,0 @@ -import mitmproxy -from mitmproxy import ctx - - -class CheckCA: - def __init__(self): - self.failed = False - - def configure(self, updated): - has_ca = ( - mitmproxy.ctx.master.server and - mitmproxy.ctx.master.server.config and - mitmproxy.ctx.master.server.config.certstore and - mitmproxy.ctx.master.server.config.certstore.default_ca - ) - if has_ca: - self.failed = mitmproxy.ctx.master.server.config.certstore.default_ca.has_expired() - if self.failed: - ctx.log.warn( - "The mitmproxy certificate authority has expired!\n" - "Please delete all CA-related files in your ~/.mitmproxy folder.\n" - "The CA will be regenerated automatically after restarting mitmproxy.\n" - "Then make sure all your clients have the new CA installed.", - ) diff --git a/mitmproxy/addons/termstatus.py b/mitmproxy/addons/termstatus.py deleted file mode 100644 index da436d01d..000000000 --- a/mitmproxy/addons/termstatus.py +++ /dev/null @@ -1,17 +0,0 @@ -from mitmproxy import ctx -from mitmproxy.utils import human - -""" - A tiny addon to print the proxy status to terminal. Eventually this could - also print some stats on exit. -""" - - -class TermStatus: - def running(self): - if ctx.master.server.bound: - ctx.log.info( - "Proxy server listening at http://{}".format( - human.format_address(ctx.master.server.address) - ) - ) diff --git a/mitmproxy/addons/tlsconfig.py b/mitmproxy/addons/tlsconfig.py index fde669f05..0964478c6 100644 --- a/mitmproxy/addons/tlsconfig.py +++ b/mitmproxy/addons/tlsconfig.py @@ -214,6 +214,14 @@ class TlsConfig: key_size=ctx.options.key_size, passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None, ) + if self.certstore.default_ca.has_expired(): + ctx.log.warn( + "The mitmproxy certificate authority has expired!\n" + "Please delete all CA-related files in your ~/.mitmproxy folder.\n" + "The CA will be regenerated automatically after restarting mitmproxy.\n" + "Then make sure all your clients have the new CA installed.", + ) + for certspec in ctx.options.certs: parts = certspec.split("=", 1) if len(parts) == 1: diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 8d2373e33..4367dfd08 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -1,50 +1,7 @@ import queue -import asyncio + from mitmproxy import exceptions - -class Channel: - """ - The only way for the proxy server to communicate with the master - is to use the channel it has been given. - """ - def __init__(self, master, loop, should_exit): - self.master = master - self.loop = loop - self.should_exit = should_exit - - def ask(self, mtype, m): - """ - Decorate a message with a reply attribute, and send it to the master. - Then wait for a response. - - Raises: - exceptions.Kill: All connections should be closed immediately. - """ - if not self.should_exit.is_set(): - m.reply = Reply(m) - asyncio.run_coroutine_threadsafe( - self.master.addons.handle_lifecycle(mtype, m), - self.loop, - ) - g = m.reply.q.get() - if g == exceptions.Kill: - raise exceptions.Kill() - return g - - def tell(self, mtype, m): - """ - Decorate a message with a dummy reply attribute, send it to the master, - then return immediately. - """ - if not self.should_exit.is_set(): - m.reply = DummyReply() - asyncio.run_coroutine_threadsafe( - self.master.addons.handle_lifecycle(mtype, m), - self.loop, - ) - - NO_REPLY = object() # special object we can distinguish from a valid "None" reply. @@ -53,6 +10,7 @@ class Reply: Messages sent through a channel are decorated with a "reply" attribute. This object is used to respond to the message through the return channel. """ + def __init__(self, obj): self.obj = obj # Spawn an event loop in the current thread @@ -138,6 +96,7 @@ class DummyReply(Reply): handler so that they can be used multiple times. Useful when we need an object to seem like it has a channel, and during testing. """ + def __init__(self): super().__init__(None) self._should_reset = False diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index 9f0a8c303..d81b09fc0 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -1,4 +1,10 @@ """ + +Edit 2020-12 @mhils: + The advice below hasn't paid off in any form. We now just use builtin exceptions and specialize where necessary. + +--- + We try to be very hygienic regarding the exceptions we throw: - Every exception that might be externally visible to users shall be a subclass @@ -11,7 +17,6 @@ See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ class MitmproxyException(Exception): - """ Base class for all exceptions thrown by mitmproxy. """ @@ -21,58 +26,12 @@ class MitmproxyException(Exception): class Kill(MitmproxyException): - """ Signal that both client and server connection(s) should be killed immediately. """ pass -class ProtocolException(MitmproxyException): - """ - ProtocolExceptions are caused by invalid user input, unavailable network resources, - or other events that are outside of our influence. - """ - pass - - -class TlsProtocolException(ProtocolException): - pass - - -class ClientHandshakeException(TlsProtocolException): - - def __init__(self, message, server): - super().__init__(message) - self.server = server - - -class InvalidServerCertificate(TlsProtocolException): - def __repr__(self): - # In contrast to most others, this is a user-facing error which needs to look good. - return str(self) - - -class Socks5ProtocolException(ProtocolException): - pass - - -class HttpProtocolException(ProtocolException): - pass - - -class Http2ProtocolException(ProtocolException): - pass - - -class Http2ZombieException(ProtocolException): - pass - - -class ServerException(MitmproxyException): - pass - - class ContentViewException(MitmproxyException): pass @@ -89,10 +48,6 @@ class ControlException(MitmproxyException): pass -class SetServerNotAllowedException(MitmproxyException): - pass - - class CommandError(Exception): pass @@ -116,62 +71,22 @@ class TypeError(MitmproxyException): pass -""" - Net-layer exceptions -""" - - class NetlibException(MitmproxyException): """ Base class for all exceptions thrown by mitmproxy.net. """ + def __init__(self, message=None): super().__init__(message) -class SessionLoadException(MitmproxyException): - pass - - -class Disconnect: - """Immediate EOF""" - - class HttpException(NetlibException): pass -class HttpReadDisconnect(HttpException, Disconnect): - pass - - class HttpSyntaxException(HttpException): pass -class TcpException(NetlibException): - pass - - -class TcpDisconnect(TcpException, Disconnect): - pass - - -class TcpReadIncomplete(TcpException): - pass - - -class TcpTimeout(TcpException): - pass - - class TlsException(NetlibException): pass - - -class InvalidCertificateException(TlsException): - pass - - -class Timeout(TcpException): - pass diff --git a/mitmproxy/master.py b/mitmproxy/master.py index e88c79a93..64e271903 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -1,54 +1,28 @@ -import sys -import traceback -import threading import asyncio -import logging +import sys +import threading +import traceback from mitmproxy import addonmanager -from mitmproxy import options +from mitmproxy import command from mitmproxy import controller from mitmproxy import eventsequence -from mitmproxy import command from mitmproxy import http -from mitmproxy import websocket from mitmproxy import log +from mitmproxy import options +from mitmproxy import websocket from mitmproxy.net import server_spec -from mitmproxy.coretypes import basethread - from . import ctx as mitmproxy_ctx -# Conclusively preventing cross-thread races on proxy shutdown turns out to be -# very hard. We could build a thread sync infrastructure for this, or we could -# wait until we ditch threads and move all the protocols into the async loop. -# Until then, silence non-critical errors. -logging.getLogger('asyncio').setLevel(logging.CRITICAL) - - -class ServerThread(basethread.BaseThread): - def __init__(self, server): - self.server = server - address = getattr(self.server, "address", None) - super().__init__( - "ServerThread ({})".format(repr(address)) - ) - - def run(self): - self.server.serve_forever() - - class Master: """ The master handles mitmproxy's main event loop. """ + def __init__(self, opts): self.should_exit = threading.Event() - self.channel = controller.Channel( - self, - asyncio.get_event_loop(), - self.should_exit, - ) - + self.loop = asyncio.get_event_loop() self.options: options.Options = opts or options.Options() self.commands = command.CommandManager(self) self.addons = addonmanager.AddonManager(self) @@ -60,19 +34,8 @@ class Master: mitmproxy_ctx.log = self.log mitmproxy_ctx.options = self.options - @property - def server(self): - return self._server - - @server.setter - def server(self, server): - server.set_channel(self.channel) - self._server = server - def start(self): self.should_exit.clear() - if self.server: - ServerThread(self.server).start() async def running(self): self.addons.trigger("running") @@ -109,8 +72,6 @@ class Master: async def _shutdown(self): self.should_exit.set() - if self.server: - self.server.shutdown() loop = asyncio.get_event_loop() loop.stop() @@ -120,13 +81,13 @@ class Master: """ if not self.should_exit.is_set(): self.should_exit.set() - ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.channel.loop) + ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.loop) # Weird band-aid to make sure that self._shutdown() is actually executed, # which otherwise hangs the process as the proxy server is threaded. # This all needs to be simplified when the proxy server runs on asyncio as well. - if not self.channel.loop.is_running(): # pragma: no cover + if not self.loop.is_running(): # pragma: no cover try: - self.channel.loop.run_until_complete(asyncio.wrap_future(ret)) + self.loop.run_until_complete(asyncio.wrap_future(ret)) except RuntimeError: pass # Event loop stopped before Future completed. diff --git a/mitmproxy/net/http/http1/__init__.py b/mitmproxy/net/http/http1/__init__.py index e4bf01c5c..d05ef2fdc 100644 --- a/mitmproxy/net/http/http1/__init__.py +++ b/mitmproxy/net/http/http1/__init__.py @@ -1,7 +1,6 @@ from .read import ( - read_request, read_request_head, - read_response, read_response_head, - read_body, + read_request_head, + read_response_head, connection_close, expected_http_body_size, ) @@ -13,9 +12,8 @@ from .assemble import ( __all__ = [ - "read_request", "read_request_head", - "read_response", "read_response_head", - "read_body", + "read_request_head", + "read_response_head", "connection_close", "expected_http_body_size", "assemble_request", "assemble_request_head", diff --git a/mitmproxy/net/http/http1/read.py b/mitmproxy/net/http/http1/read.py index 42b9fe530..c5e4b6e2b 100644 --- a/mitmproxy/net/http/http1/read.py +++ b/mitmproxy/net/http/http1/read.py @@ -1,13 +1,9 @@ import re -import sys import time -import typing +from typing import List, Tuple, Iterable, Optional from mitmproxy import exceptions -from mitmproxy.net.http import headers -from mitmproxy.net.http import request -from mitmproxy.net.http import response -from mitmproxy.net.http import url +from mitmproxy.net.http import request, response, headers, url def get_header_tokens(headers, key): @@ -22,137 +18,6 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] -def read_request(rfile, body_size_limit=None): - request = read_request_head(rfile) - expected_body_size = expected_http_body_size(request) - request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) - request.timestamp_end = time.time() - return request - - -def read_request_head(rfile): - """ - Parse an HTTP request head (request line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - exceptions.HttpReadDisconnect: No bytes can be read from rfile. - exceptions.HttpSyntaxException: The input is malformed HTTP. - exceptions.HttpException: Any other error occurred. - """ - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - host, port, method, scheme, authority, path, http_version = _read_request_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return request.Request( - host, port, method, scheme, authority, path, http_version, headers, None, None, timestamp_start, None - ) - - -def read_response(rfile, request, body_size_limit=None): - response = read_response_head(rfile) - expected_body_size = expected_http_body_size(request, response) - response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) - response.timestamp_end = time.time() - return response - - -def read_response_head(rfile): - """ - Parse an HTTP response head (response line + headers) from an input stream - - Args: - rfile: The input stream - - Returns: - The HTTP request object (without body) - - Raises: - exceptions.HttpReadDisconnect: No bytes can be read from rfile. - exceptions.HttpSyntaxException: The input is malformed HTTP. - exceptions.HttpException: Any other error occurred. - """ - - timestamp_start = time.time() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - http_version, status_code, message = _read_response_line(rfile) - headers = _read_headers(rfile) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - return response.Response(http_version, status_code, message, headers, None, None, timestamp_start, None) - - -def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): - """ - Read an HTTP message body - - Args: - rfile: The input stream - expected_size: The expected body size (see :py:meth:`expected_body_size`) - limit: Maximum body size - max_chunk_size: Maximium chunk size that gets yielded - - Returns: - A generator that yields byte chunks of the content. - - Raises: - exceptions.HttpException, if an error occurs - - Caveats: - max_chunk_size is not considered if the transfer encoding is chunked. - """ - if not limit or limit < 0: - limit = sys.maxsize - if not max_chunk_size: - max_chunk_size = limit - - if expected_size is None: - yield from _read_chunked(rfile, limit) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise exceptions.HttpException( - "HTTP Body too large. " - "Limit is {}, content length was advertised as {}".format(limit, expected_size) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if len(content) < chunk_size: - raise exceptions.HttpException("Unexpected EOF") - yield content - bytes_left -= chunk_size - else: - bytes_left = limit - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise exceptions.HttpException(f"HTTP body too large. Limit is {limit}.") - - def connection_close(http_version, headers): """ Checks the message to see if the client connection should be closed @@ -175,7 +40,7 @@ def connection_close(http_version, headers): def expected_http_body_size( request: request.Request, - response: typing.Optional[response.Response] = None, + response: Optional[response.Response] = None, expect_continue_as_0: bool = True ): """ @@ -195,6 +60,8 @@ def expected_http_body_size( # http://tools.ietf.org/html/rfc7230#section-3.3 if not response: headers = request.headers + if request.method.upper() == "CONNECT": + return 0 if expect_continue_as_0 and headers.get("expect", "").lower() == "100-continue": return 0 else: @@ -227,28 +94,20 @@ def expected_http_body_size( return -1 -def _get_first_line(rfile): - try: - line = rfile.readline() - if line == b"\r\n" or line == b"\n": - # Possible leftover from previous message - line = rfile.readline() - except (exceptions.TcpDisconnect, exceptions.TlsException): - raise exceptions.HttpReadDisconnect("Remote disconnected") - if not line: - raise exceptions.HttpReadDisconnect("Remote disconnected") - return line.strip() +def _check_http_version(http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): + raise exceptions.HttpSyntaxException(f"Unknown HTTP version: {http_version}") -def _read_request_line(rfile): - try: - line = _get_first_line(rfile) - except exceptions.HttpReadDisconnect: - # We want to provide a better error message. - raise exceptions.HttpReadDisconnect("Client disconnected") +def raise_if_http_version_unknown(http_version: bytes) -> None: + if not re.match(br"^HTTP/\d\.\d$", http_version): + raise ValueError(f"Unknown HTTP version: {http_version!r}") + +def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]: try: method, target, http_version = line.split() + port: Optional[int] if target == b"*" or target.startswith(b"/"): scheme, authority, path = b"", b"", target @@ -269,41 +128,29 @@ def _read_request_line(rfile): # TODO: we can probably get rid of this check? url.parse(target) - _check_http_version(http_version) - except ValueError: - raise exceptions.HttpSyntaxException(f"Bad HTTP request line: {line}") + raise_if_http_version_unknown(http_version) + except ValueError as e: + raise ValueError(f"Bad HTTP request line: {line!r}") from e return host, port, method, scheme, authority, path, http_version -def _read_response_line(rfile): - try: - line = _get_first_line(rfile) - except exceptions.HttpReadDisconnect: - # We want to provide a better error message. - raise exceptions.HttpReadDisconnect("Server disconnected") - +def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]: try: parts = line.split(None, 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") - http_version, status_code, message = parts - status_code = int(status_code) - _check_http_version(http_version) + http_version, status_code_str, reason = parts + status_code = int(status_code_str) + raise_if_http_version_unknown(http_version) + except ValueError as e: + raise ValueError(f"Bad HTTP response line: {line!r}") from e - except ValueError: - raise exceptions.HttpSyntaxException(f"Bad HTTP response line: {line}") - - return http_version, status_code, message + return http_version, status_code, reason -def _check_http_version(http_version): - if not re.match(br"^HTTP/\d\.\d$", http_version): - raise exceptions.HttpSyntaxException(f"Unknown HTTP version: {http_version}") - - -def _read_headers(rfile): +def _read_headers(lines: Iterable[bytes]) -> headers.Headers: """ Read a set of headers. Stop once a blank line is reached. @@ -314,15 +161,11 @@ def _read_headers(rfile): Raises: exceptions.HttpSyntaxException """ - ret = [] - while True: - line = rfile.readline() - if not line or line == b"\r\n" or line == b"\n": - # we do have coverage of this, but coverage.py does not detect it. - break # pragma: no cover + ret: List[Tuple[bytes, bytes]] = [] + for line in lines: if line[0] in b" \t": if not ret: - raise exceptions.HttpSyntaxException("Invalid headers") + raise ValueError("Invalid headers") # continued header ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) else: @@ -333,40 +176,65 @@ def _read_headers(rfile): raise ValueError() ret.append((name, value)) except ValueError: - raise exceptions.HttpSyntaxException( - "Invalid header line: %s" % repr(line) - ) + raise ValueError(f"Invalid header line: {line!r}") return headers.Headers(ret) -def _read_chunked(rfile, limit=sys.maxsize): +def read_request_head(lines: List[bytes]) -> request.Request: """ - Read a HTTP body with chunked transfer encoding. + Parse an HTTP request head (request line + headers) from an iterable of lines Args: - rfile: the input file - limit: A positive integer + lines: The input lines + + Returns: + The HTTP request object (without body) + + Raises: + ValueError: The input is malformed. """ - total = 0 - while True: - line = rfile.readline(128) - if line == b"": - raise exceptions.HttpException("Connection closed prematurely") - if line != b"\r\n" and line != b"\n": - try: - length = int(line, 16) - except ValueError: - raise exceptions.HttpSyntaxException(f"Invalid chunked encoding length: {line}") - total += length - if total > limit: - raise exceptions.HttpException( - "HTTP Body too large. Limit is {}, " - "chunked content longer than {}".format(limit, total) - ) - chunk = rfile.read(length) - suffix = rfile.readline(5) - if suffix != b"\r\n": - raise exceptions.HttpSyntaxException("Malformed chunked body") - if length == 0: - return - yield chunk + host, port, method, scheme, authority, path, http_version = _read_request_line(lines[0]) + headers = _read_headers(lines[1:]) + + return request.Request( + host=host, + port=port, + method=method, + scheme=scheme, + authority=authority, + path=path, + http_version=http_version, + headers=headers, + content=None, + trailers=None, + timestamp_start=time.time(), + timestamp_end=None + ) + + +def read_response_head(lines: List[bytes]) -> response.Response: + """ + Parse an HTTP response head (response line + headers) from an iterable of lines + + Args: + lines: The input lines + + Returns: + The HTTP response object (without body) + + Raises: + ValueError: The input is malformed. + """ + http_version, status_code, reason = _read_response_line(lines[0]) + headers = _read_headers(lines[1:]) + + return response.Response( + http_version=http_version, + status_code=status_code, + reason=reason, + headers=headers, + content=None, + trailers=None, + timestamp_start=time.time(), + timestamp_end=None, + ) diff --git a/mitmproxy/net/http/http1/read_sansio.py b/mitmproxy/net/http/http1/read_sansio.py deleted file mode 100644 index fe9253382..000000000 --- a/mitmproxy/net/http/http1/read_sansio.py +++ /dev/null @@ -1,160 +0,0 @@ -import re -import time -from typing import Iterable, List, Optional, Tuple - -from mitmproxy.net.http import headers, request, response, url -from mitmproxy.net.http.http1 import read - - -def raise_if_http_version_unknown(http_version: bytes) -> None: - if not re.match(br"^HTTP/\d\.\d$", http_version): - raise ValueError(f"Unknown HTTP version: {http_version!r}") - - -def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]: - try: - method, target, http_version = line.split() - port: Optional[int] - - if target == b"*" or target.startswith(b"/"): - scheme, authority, path = b"", b"", target - host, port = "", 0 - elif method == b"CONNECT": - scheme, authority, path = b"", target, b"" - host, port = url.parse_authority(authority, check=True) - if not port: - raise ValueError - else: - scheme, rest = target.split(b"://", maxsplit=1) - authority, path_ = rest.split(b"/", maxsplit=1) - path = b"/" + path_ - host, port = url.parse_authority(authority, check=True) - port = port or url.default_port(scheme) - if not port: - raise ValueError - # TODO: we can probably get rid of this check? - url.parse(target) - - raise_if_http_version_unknown(http_version) - except ValueError as e: - raise ValueError(f"Bad HTTP request line: {line!r}") from e - - return host, port, method, scheme, authority, path, http_version - - -def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]: - try: - parts = line.split(None, 2) - if len(parts) == 2: # handle missing message gracefully - parts.append(b"") - - http_version, status_code_str, reason = parts - status_code = int(status_code_str) - raise_if_http_version_unknown(http_version) - except ValueError as e: - raise ValueError(f"Bad HTTP response line: {line!r}") from e - - return http_version, status_code, reason - - -def _read_headers(lines: Iterable[bytes]) -> headers.Headers: - """ - Read a set of headers. - Stop once a blank line is reached. - - Returns: - A headers object - - Raises: - exceptions.HttpSyntaxException - """ - ret: List[Tuple[bytes, bytes]] = [] - for line in lines: - if line[0] in b" \t": - if not ret: - raise ValueError("Invalid headers") - # continued header - ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) - else: - try: - name, value = line.split(b":", 1) - value = value.strip() - if not name: - raise ValueError() - ret.append((name, value)) - except ValueError: - raise ValueError(f"Invalid header line: {line!r}") - return headers.Headers(ret) - - -def read_request_head(lines: List[bytes]) -> request.Request: - """ - Parse an HTTP request head (request line + headers) from an iterable of lines - - Args: - lines: The input lines - - Returns: - The HTTP request object (without body) - - Raises: - ValueError: The input is malformed. - """ - host, port, method, scheme, authority, path, http_version = _read_request_line(lines[0]) - headers = _read_headers(lines[1:]) - - return request.Request( - host=host, - port=port, - method=method, - scheme=scheme, - authority=authority, - path=path, - http_version=http_version, - headers=headers, - content=None, - trailers=None, - timestamp_start=time.time(), - timestamp_end=None - ) - - -def read_response_head(lines: List[bytes]) -> response.Response: - """ - Parse an HTTP response head (response line + headers) from an iterable of lines - - Args: - lines: The input lines - - Returns: - The HTTP response object (without body) - - Raises: - ValueError: The input is malformed. - """ - http_version, status_code, reason = _read_response_line(lines[0]) - headers = _read_headers(lines[1:]) - - return response.Response( - http_version=http_version, - status_code=status_code, - reason=reason, - headers=headers, - content=None, - trailers=None, - timestamp_start=time.time(), - timestamp_end=None, - ) - - -def expected_http_body_size( - request: request.Request, - response: Optional[response.Response] = None, - expect_continue_as_0: bool = True, -): - """ - Like the non-sans-io version, but also treating CONNECT as content-length: 0 - """ - if request.data.method.upper() == b"CONNECT": - return 0 - return read.expected_http_body_size(request, response, expect_continue_as_0) diff --git a/mitmproxy/proxy/__init__.py b/mitmproxy/proxy/__init__.py index e2da76da2..acb219868 100644 --- a/mitmproxy/proxy/__init__.py +++ b/mitmproxy/proxy/__init__.py @@ -18,28 +18,3 @@ The most important primitives are: - Context: The context is the connection context each layer is provided with, which is always a client connection and sometimes also a server connection. """ - -from .config import ProxyConfig - - -class DummyServer: - bound = False - - def __init__(self, config=None): - self.config = config - self.address = "dummy" - - def set_channel(self, channel): - pass - - def serve_forever(self): - pass - - def shutdown(self): - pass - - -__all__ = [ - "DummyServer", - "ProxyConfig", -] diff --git a/mitmproxy/proxy/config.py b/mitmproxy/proxy/config.py deleted file mode 100644 index 0ba46ee58..000000000 --- a/mitmproxy/proxy/config.py +++ /dev/null @@ -1,92 +0,0 @@ -import os -import re -import typing - -from OpenSSL import crypto - -from mitmproxy import certs -from mitmproxy import exceptions -from mitmproxy import options as moptions -from mitmproxy.net import server_spec - - -class HostMatcher: - def __init__(self, handle, patterns=tuple()): - self.handle = handle - self.patterns = list(patterns) - self.regexes = [re.compile(p, re.IGNORECASE) for p in self.patterns] - - def __call__(self, address): - if not address: - return False - host = "%s:%s" % address - if self.handle in ["ignore", "tcp"]: - return any(rex.search(host) for rex in self.regexes) - else: # self.handle == "allow" - return not any(rex.search(host) for rex in self.regexes) - - def __bool__(self): - return bool(self.patterns) - - -class ProxyConfig: - - def __init__(self, options: moptions.Options) -> None: - self.options = options - - self.certstore: certs.CertStore - self.check_filter: typing.Optional[HostMatcher] = None - self.check_tcp: typing.Optional[HostMatcher] = None - self.upstream_server: typing.Optional[server_spec.ServerSpec] = None - self.configure(options, set(options.keys())) - options.changed.connect(self.configure) - - def configure(self, options: moptions.Options, updated: typing.Any) -> None: - if options.allow_hosts and options.ignore_hosts: - raise exceptions.OptionsError("--ignore-hosts and --allow-hosts are mutually " - "exclusive; please choose one.") - - if options.ignore_hosts: - self.check_filter = HostMatcher("ignore", options.ignore_hosts) - elif options.allow_hosts: - self.check_filter = HostMatcher("allow", options.allow_hosts) - else: - self.check_filter = HostMatcher(False) - if "tcp_hosts" in updated: - self.check_tcp = HostMatcher("tcp", options.tcp_hosts) - - certstore_path = os.path.expanduser(options.confdir) - if not os.path.exists(os.path.dirname(certstore_path)): - raise exceptions.OptionsError( - "Certificate Authority parent directory does not exist: %s" % - os.path.dirname(certstore_path) - ) - key_size = options.key_size - passphrase = options.cert_passphrase.encode("utf-8") if options.cert_passphrase else None - self.certstore = certs.CertStore.from_store( - certstore_path, - moptions.CONF_BASENAME, - key_size, - passphrase - ) - - for c in options.certs: - parts = c.split("=", 1) - if len(parts) == 1: - parts = ["*", parts[0]] - - cert = os.path.expanduser(parts[1]) - if not os.path.exists(cert): - raise exceptions.OptionsError( - "Certificate file does not exist: %s" % cert - ) - try: - self.certstore.add_cert_file(parts[0], cert, passphrase) - except crypto.Error: - raise exceptions.OptionsError( - "Invalid certificate format: %s" % cert - ) - m = options.mode - if m.startswith("upstream:") or m.startswith("reverse:"): - _, spec = server_spec.parse_with_mode(options.mode) - self.upstream_server = spec diff --git a/mitmproxy/proxy/layers/http/_http1.py b/mitmproxy/proxy/layers/http/_http1.py index 8b388d2fb..dc330c1dd 100644 --- a/mitmproxy/proxy/layers/http/_http1.py +++ b/mitmproxy/proxy/layers/http/_http1.py @@ -8,7 +8,6 @@ from h11._receivebuffer import ReceiveBuffer from mitmproxy import exceptions, http from mitmproxy.net import http as net_http from mitmproxy.net.http import http1, status_codes -from mitmproxy.net.http.http1 import read_sansio as http1_sansio from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy.context import Connection, ConnectionState, Context from mitmproxy.proxy.layers.http._base import ReceiveHttp, StreamId @@ -148,7 +147,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta): yield from self.make_pipe() return connection_done = ( - http1_sansio.expected_http_body_size(self.request, self.response) == -1 + http1.expected_http_body_size(self.request, self.response) == -1 or http1.connection_close(self.request.http_version, self.request.headers) or http1.connection_close(self.response.http_version, self.response.headers) # If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request. @@ -227,8 +226,8 @@ class Http1Server(Http1Connection): if request_head: request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays try: - self.request = http1_sansio.read_request_head(request_head) - expected_body_size = http1_sansio.expected_http_body_size(self.request, expect_continue_as_0=False) + self.request = http1.read_request_head(request_head) + expected_body_size = http1.expected_http_body_size(self.request, expect_continue_as_0=False) except (ValueError, exceptions.HttpSyntaxException) as e: yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}") yield commands.CloseConnection(self.conn) @@ -294,7 +293,7 @@ class Http1Client(Http1Connection): assert self.request if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): yield commands.SendData(self.conn, b"0\r\n\r\n") - elif http1_sansio.expected_http_body_size(self.request, self.response) == -1: + elif http1.expected_http_body_size(self.request, self.response) == -1: yield commands.CloseConnection(self.conn, half_close=True) yield from self.mark_done(request=True) elif isinstance(event, RequestProtocolError): @@ -316,8 +315,8 @@ class Http1Client(Http1Connection): if response_head: response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays try: - self.response = http1_sansio.read_response_head(response_head) - expected_size = http1_sansio.expected_http_body_size(self.request, self.response) + self.response = http1.read_response_head(response_head) + expected_size = http1.expected_http_body_size(self.request, self.response) except (ValueError, exceptions.HttpSyntaxException) as e: yield commands.CloseConnection(self.conn) yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}")) diff --git a/mitmproxy/proxy/layers/http/_upstream_proxy.py b/mitmproxy/proxy/layers/http/_upstream_proxy.py index 759be314f..58a3af973 100644 --- a/mitmproxy/proxy/layers/http/_upstream_proxy.py +++ b/mitmproxy/proxy/layers/http/_upstream_proxy.py @@ -5,7 +5,6 @@ from h11._receivebuffer import ReceiveBuffer from mitmproxy import http from mitmproxy.net import server_spec from mitmproxy.net.http import http1 -from mitmproxy.net.http.http1 import read_sansio as http1_sansio from mitmproxy.proxy import commands, context, layer, tunnel from mitmproxy.utils import human @@ -57,7 +56,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer): if response_head: response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays try: - response = http1_sansio.read_response_head(response_head) + response = http1.read_response_head(response_head) except ValueError as e: yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {e}") return False, str(e) diff --git a/mitmproxy/proxy/protocol/__init__.py b/mitmproxy/proxy/protocol/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/mitmproxy/test/taddons.py b/mitmproxy/test/taddons.py index 86f9f4637..fa4ca74ed 100644 --- a/mitmproxy/test/taddons.py +++ b/mitmproxy/test/taddons.py @@ -38,16 +38,16 @@ class RecordingMaster(mitmproxy.master.Master): return True return False - async def await_log(self, txt, level=None): + async def await_log(self, txt, level=None, timeout=1): # start with a sleep(0), which lets all other coroutines advance. # often this is enough to not sleep at all. await asyncio.sleep(0) - for i in range(20): + for i in range(int(timeout / 0.001)): if self.has_log(txt, level): return True else: await asyncio.sleep(0.001) - return False + raise AssertionError(f"Did not find log entry {txt!r} in {self.logs}.") def clear(self): self.logs = [] diff --git a/mitmproxy/tools/dump.py b/mitmproxy/tools/dump.py index af04f8a39..16291922c 100644 --- a/mitmproxy/tools/dump.py +++ b/mitmproxy/tools/dump.py @@ -1,7 +1,7 @@ from mitmproxy import addons from mitmproxy import options from mitmproxy import master -from mitmproxy.addons import dumper, termlog, termstatus, keepserving, readfile +from mitmproxy.addons import dumper, termlog, keepserving, readfile class ErrorCheck: @@ -24,7 +24,7 @@ class DumpMaster(master.Master): super().__init__(options) self.errorcheck = ErrorCheck() if with_termlog: - self.addons.add(termlog.TermLog(), termstatus.TermStatus()) + self.addons.add(termlog.TermLog()) self.addons.add(*addons.default_addons()) if with_dumper: self.addons.add(dumper.Dumper()) diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py index 9bcfb011a..7d8992787 100644 --- a/mitmproxy/tools/main.py +++ b/mitmproxy/tools/main.py @@ -1,15 +1,14 @@ -import os -import sys -import asyncio import argparse +import asyncio +import os import signal +import sys import typing -from mitmproxy.tools import cmdline from mitmproxy import exceptions, master from mitmproxy import options from mitmproxy import optmanager -from mitmproxy import proxy +from mitmproxy.tools import cmdline from mitmproxy.utils import debug, arg_check @@ -48,8 +47,6 @@ def process_options(parser, opts, args): adict[n] = getattr(args, n) opts.merge(adict) - return proxy.config.ProxyConfig(opts) - def run( master_cls: typing.Type[master.Master], @@ -85,10 +82,7 @@ def run( os.path.join(opts.confdir, "config.yaml"), os.path.join(opts.confdir, "config.yml"), ) - pconf = process_options(parser, opts, args) - - # new core initializes itself as an addon - master.server = proxy.DummyServer(pconf) + process_options(parser, opts, args) if args.options: print(optmanager.dump_defaults(opts)) @@ -97,7 +91,7 @@ def run( master.commands.dump() sys.exit(0) if extra: - if(args.filter_args): + if args.filter_args: master.log.info(f"Only processing flows that match \"{' & '.join(args.filter_args)}\"") opts.update(**extra(args)) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 35009d41c..54942fc38 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -14,7 +14,6 @@ import tornado.websocket import mitmproxy.flow import mitmproxy.tools.web.master # noqa from mitmproxy import contentviews -from mitmproxy import exceptions from mitmproxy import flowfilter from mitmproxy import http from mitmproxy import io @@ -380,14 +379,7 @@ class RevertFlow(RequestHandler): class ReplayFlow(RequestHandler): def post(self, flow_id): - self.flow.backup() - self.flow.response = None - self.view.update([self.flow]) - - try: - self.master.commands.call("replay.client", [self.flow]) - except exceptions.ReplayException as e: - raise APIError(400, str(e)) + self.master.commands.call("replay.client", [self.flow]) class FlowContent(RequestHandler): diff --git a/mitmproxy/tools/web/master.py b/mitmproxy/tools/web/master.py index 2f1608800..23f932041 100644 --- a/mitmproxy/tools/web/master.py +++ b/mitmproxy/tools/web/master.py @@ -11,7 +11,6 @@ from mitmproxy.addons import intercept from mitmproxy.addons import readfile from mitmproxy.addons import termlog from mitmproxy.addons import view -from mitmproxy.addons import termstatus from mitmproxy.tools.web import app, webaddons, static_viewer @@ -41,7 +40,7 @@ class WebMaster(master.Master): self.events, ) if with_termlog: - self.addons.add(termlog.TermLog(), termstatus.TermStatus()) + self.addons.add(termlog.TermLog()) self.app = app.Application( self, self.options.web_debug ) diff --git a/test/mitmproxy/addons/test_check_ca.py b/test/mitmproxy/addons/test_check_ca.py deleted file mode 100644 index 27e6f7e68..000000000 --- a/test/mitmproxy/addons/test_check_ca.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest -from unittest import mock - -from mitmproxy.addons import check_ca -from mitmproxy.test import taddons - - -class TestCheckCA: - - @pytest.mark.parametrize('expired', [False, True]) - @pytest.mark.asyncio - async def test_check_ca(self, expired): - msg = 'The mitmproxy certificate authority has expired!' - - a = check_ca.CheckCA() - with taddons.context(a) as tctx: - tctx.master.server = mock.MagicMock() - tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock( - return_value = expired - ) - tctx.configure(a) - assert await tctx.master.await_log(msg) == expired diff --git a/test/mitmproxy/addons/test_disable_h2c.py b/test/mitmproxy/addons/test_disable_h2c.py index a26d28a77..1b416583b 100644 --- a/test/mitmproxy/addons/test_disable_h2c.py +++ b/test/mitmproxy/addons/test_disable_h2c.py @@ -1,9 +1,6 @@ -import io - from mitmproxy.addons import disable_h2c from mitmproxy.exceptions import Kill -from mitmproxy.net.http import http1 -from mitmproxy.test import taddons +from mitmproxy.test import taddons, tutils from mitmproxy.test import tflow @@ -28,9 +25,12 @@ class TestDisableH2CleartextUpgrade: a = disable_h2c.DisableH2C() tctx.configure(a) - b = io.BytesIO(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") f = tflow.tflow() - f.request = http1.read_request(b) + f.request = tutils.treq( + method=b"PRI", + path=b"*", + http_version=b"HTTP/2.0", + ) f.intercept() a.request(f) diff --git a/test/mitmproxy/addons/test_termstatus.py b/test/mitmproxy/addons/test_termstatus.py deleted file mode 100644 index 96221e4b3..000000000 --- a/test/mitmproxy/addons/test_termstatus.py +++ /dev/null @@ -1,18 +0,0 @@ -import pytest - -from mitmproxy import proxy -from mitmproxy.addons import termstatus -from mitmproxy.test import taddons - - -@pytest.mark.asyncio -async def test_configure(): - ts = termstatus.TermStatus() - with taddons.context() as ctx: - ctx.master.server = proxy.DummyServer() - ctx.master.server.bound = True - ctx.configure(ts, server=False) - ts.running() - ctx.configure(ts, server=True) - ts.running() - await ctx.master.await_log("server listening") diff --git a/test/mitmproxy/addons/test_tlsconfig.py b/test/mitmproxy/addons/test_tlsconfig.py index a3f45ae47..51cb2a4b4 100644 --- a/test/mitmproxy/addons/test_tlsconfig.py +++ b/test/mitmproxy/addons/test_tlsconfig.py @@ -219,3 +219,11 @@ class TestTlsConfig: assert self.do_handshake(tssl_client, tssl_server) assert tssl_server.obj.getpeercert() + + @pytest.mark.asyncio + async def test_ca_expired(self, monkeypatch): + monkeypatch.setattr(SSL.X509, "has_expired", lambda self: True) + ta = tlsconfig.TlsConfig() + with taddons.context(ta) as tctx: + ta.configure(["confdir"]) + await tctx.master.await_log("The mitmproxy certificate authority has expired", "warn") \ No newline at end of file diff --git a/test/mitmproxy/net/http/http1/test_read.py b/test/mitmproxy/net/http/http1/test_read.py index be2197a85..5f2246f53 100644 --- a/test/mitmproxy/net/http/http1/test_read.py +++ b/test/mitmproxy/net/http/http1/test_read.py @@ -1,14 +1,12 @@ -from io import BytesIO -from unittest.mock import Mock import pytest from mitmproxy import exceptions from mitmproxy.net.http import Headers from mitmproxy.net.http.http1.read import ( - read_request, read_response, read_request_head, - read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, + read_request_head, + read_response_head, connection_close, expected_http_body_size, _read_request_line, _read_response_line, _check_http_version, - _read_headers, _read_chunked, get_header_tokens + _read_headers, get_header_tokens ) from mitmproxy.test.tutils import treq, tresp @@ -24,124 +22,6 @@ def test_get_header_tokens(): assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] -@pytest.mark.parametrize("input", [ - b"GET / HTTP/1.1\r\n\r\nskip", - b"GET / HTTP/1.1\r\n\r\nskip", - b"GET / HTTP/1.1\r\n\r\nskip", - b"GET / HTTP/1.1 \r\n\r\nskip", -]) -def test_read_request(input): - rfile = BytesIO(input) - r = read_request(rfile) - assert r.method == "GET" - assert r.content == b"" - assert r.http_version == "HTTP/1.1" - assert r.timestamp_end - assert rfile.read() == b"skip" - - -@pytest.mark.parametrize("input", [ - b"CONNECT :0 0", -]) -def test_read_request_error(input): - rfile = BytesIO(input) - with pytest.raises(exceptions.HttpException): - read_request(rfile) - - -def test_read_request_head(): - rfile = BytesIO( - b"GET / HTTP/1.1\r\n" - b"Content-Length: 4\r\n" - b"\r\n" - b"skip" - ) - rfile.reset_timestamps = Mock() - rfile.first_byte_timestamp = 42 - r = read_request_head(rfile) - assert r.method == "GET" - assert r.headers["Content-Length"] == "4" - assert r.content is None - assert rfile.reset_timestamps.called - assert r.timestamp_start == 42 - assert rfile.read() == b"skip" - - -@pytest.mark.parametrize("input", [ - b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", - b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", - b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", - b"HTTP/1.1 418 I'm a teapot \r\n\r\nbody", -]) -def test_read_response(input): - req = treq() - rfile = BytesIO(input) - r = read_response(rfile, req) - assert r.http_version == "HTTP/1.1" - assert r.status_code == 418 - assert r.reason == "I'm a teapot" - assert r.content == b"body" - assert r.timestamp_end - - -def test_read_response_head(): - rfile = BytesIO( - b"HTTP/1.1 418 I'm a teapot\r\n" - b"Content-Length: 4\r\n" - b"\r\n" - b"skip" - ) - rfile.reset_timestamps = Mock() - rfile.first_byte_timestamp = 42 - r = read_response_head(rfile) - assert r.status_code == 418 - assert r.headers["Content-Length"] == "4" - assert r.content is None - assert rfile.reset_timestamps.called - assert r.timestamp_start == 42 - assert rfile.read() == b"skip" - - -class TestReadBody: - def test_chunked(self): - rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar") - body = b"".join(read_body(rfile, None)) - assert body == b"foo" - assert rfile.read() == b"bar" - - def test_known_size(self): - rfile = BytesIO(b"foobar") - body = b"".join(read_body(rfile, 3)) - assert body == b"foo" - assert rfile.read() == b"bar" - - def test_known_size_limit(self): - rfile = BytesIO(b"foobar") - with pytest.raises(exceptions.HttpException): - b"".join(read_body(rfile, 3, 2)) - - def test_known_size_too_short(self): - rfile = BytesIO(b"foo") - with pytest.raises(exceptions.HttpException): - b"".join(read_body(rfile, 6)) - - def test_unknown_size(self): - rfile = BytesIO(b"foobar") - body = b"".join(read_body(rfile, -1)) - assert body == b"foobar" - - def test_unknown_size_limit(self): - rfile = BytesIO(b"foobar") - with pytest.raises(exceptions.HttpException): - b"".join(read_body(rfile, -1, 3)) - - def test_max_chunk_size(self): - rfile = BytesIO(b"123456") - assert list(read_body(rfile, -1, max_chunk_size=None)) == [b"123456"] - rfile = BytesIO(b"123456") - assert list(read_body(rfile, -1, max_chunk_size=1)) == [b"1", b"2", b"3", b"4", b"5", b"6"] - - def test_connection_close(): headers = Headers() assert connection_close(b"HTTP/1.0", headers) @@ -159,6 +39,41 @@ def test_connection_close(): assert not connection_close(b"HTTP/1.1", headers) +def test_check_http_version(): + _check_http_version(b"HTTP/0.9") + _check_http_version(b"HTTP/1.0") + _check_http_version(b"HTTP/1.1") + _check_http_version(b"HTTP/2.0") + with pytest.raises(exceptions.HttpSyntaxException): + _check_http_version(b"WTF/1.0") + with pytest.raises(exceptions.HttpSyntaxException): + _check_http_version(b"HTTP/1.10") + with pytest.raises(exceptions.HttpSyntaxException): + _check_http_version(b"HTTP/1.b") + + +def test_read_request_head(): + rfile = [ + b"GET / HTTP/1.1\r\n", + b"Content-Length: 4\r\n", + ] + r = read_request_head(rfile) + assert r.method == "GET" + assert r.headers["Content-Length"] == "4" + assert r.content is None + + +def test_read_response_head(): + rfile = [ + b"HTTP/1.1 418 I'm a teapot\r\n", + b"Content-Length: 4\r\n", + ] + r = read_response_head(rfile) + assert r.status_code == 418 + assert r.headers["Content-Length"] == "4" + assert r.content is None + + def test_expected_http_body_size(): # Expect: 100-continue assert expected_http_body_size( @@ -176,6 +91,10 @@ def test_expected_http_body_size(): treq(method=b"HEAD"), tresp(headers=Headers(content_length="42")) ) == 0 + assert expected_http_body_size( + treq(method=b"CONNECT"), + None, + ) == 0 assert expected_http_body_size( treq(method=b"CONNECT"), tresp() @@ -221,26 +140,9 @@ def test_expected_http_body_size(): ) == -1 -def test_get_first_line(): - rfile = BytesIO(b"foo\r\nbar") - assert _get_first_line(rfile) == b"foo" - - rfile = BytesIO(b"\r\nfoo\r\nbar") - assert _get_first_line(rfile) == b"foo" - - with pytest.raises(exceptions.HttpReadDisconnect): - rfile = BytesIO(b"") - _get_first_line(rfile) - - with pytest.raises(exceptions.HttpReadDisconnect): - rfile = Mock() - rfile.readline.side_effect = exceptions.TcpDisconnect - _get_first_line(rfile) - - def test_read_request_line(): def t(b): - return _read_request_line(BytesIO(b)) + return _read_request_line(b) assert (t(b"GET / HTTP/1.1") == ("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1")) @@ -251,21 +153,21 @@ def test_read_request_line(): assert (t(b"GET http://foo:42/bar HTTP/1.1") == ("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1")) - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"GET / WTF/1.1") - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"CONNECT example.com HTTP/1.1") # port missing - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"GET ws://example.com/ HTTP/1.1") # port missing - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"this is not http") - with pytest.raises(exceptions.HttpReadDisconnect): + with pytest.raises(ValueError): t(b"") def test_read_response_line(): def t(b): - return _read_response_line(BytesIO(b)) + return _read_response_line(b) assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") @@ -273,40 +175,26 @@ def test_read_response_line(): # https://github.com/mitmproxy/mitmproxy/issues/784 assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9") - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): assert t(b"HTTP/1.1") - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"HTTP/1.1 OK OK") - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): t(b"WTF/1.1 200 OK") - with pytest.raises(exceptions.HttpReadDisconnect): + with pytest.raises(ValueError): t(b"") -def test_check_http_version(): - _check_http_version(b"HTTP/0.9") - _check_http_version(b"HTTP/1.0") - _check_http_version(b"HTTP/1.1") - _check_http_version(b"HTTP/2.0") - with pytest.raises(exceptions.HttpSyntaxException): - _check_http_version(b"WTF/1.0") - with pytest.raises(exceptions.HttpSyntaxException): - _check_http_version(b"HTTP/1.10") - with pytest.raises(exceptions.HttpSyntaxException): - _check_http_version(b"HTTP/1.b") - - class TestReadHeaders: @staticmethod def _read(data): - return _read_headers(BytesIO(data)) + return _read_headers(data.splitlines(keepends=True)) def test_read_simple(self): data = ( b"Header: one\r\n" b"Header2: two\r\n" - b"\r\n" ) headers = self._read(data) assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) @@ -315,7 +203,6 @@ class TestReadHeaders: data = ( b"Header: one\r\n" b"Header: two\r\n" - b"\r\n" ) headers = self._read(data) assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) @@ -325,58 +212,26 @@ class TestReadHeaders: b"Header: one\r\n" b"\ttwo\r\n" b"Header2: three\r\n" - b"\r\n" ) headers = self._read(data) assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) def test_read_continued_err(self): data = b"\tfoo: bar\r\n" - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): self._read(data) def test_read_err(self): data = b"foo" - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): self._read(data) def test_read_empty_name(self): data = b":foo" - with pytest.raises(exceptions.HttpSyntaxException): + with pytest.raises(ValueError): self._read(data) def test_read_empty_value(self): data = b"bar:" headers = self._read(data) assert headers.fields == ((b"bar", b""),) - - -def test_read_chunked(): - req = treq(content=None) - req.headers["Transfer-Encoding"] = "chunked" - - data = b"1\r\na\r\n0\r\n" - with pytest.raises(exceptions.HttpSyntaxException): - b"".join(_read_chunked(BytesIO(data))) - - data = b"1\r\na\r\n0\r\n\r\n" - assert b"".join(_read_chunked(BytesIO(data))) == b"a" - - data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" - assert b"".join(_read_chunked(BytesIO(data))) == b"ab" - - data = b"\r\n" - with pytest.raises(Exception, match="closed prematurely"): - b"".join(_read_chunked(BytesIO(data))) - - data = b"1\r\nfoo" - with pytest.raises(Exception, match="Malformed chunked body"): - b"".join(_read_chunked(BytesIO(data))) - - data = b"foo\r\nfoo" - with pytest.raises(exceptions.HttpSyntaxException): - b"".join(_read_chunked(BytesIO(data))) - - data = b"5\r\naaaaa\r\n0\r\n\r\n" - with pytest.raises(Exception, match="too large"): - b"".join(_read_chunked(BytesIO(data), limit=2)) diff --git a/test/mitmproxy/net/http/http1/test_read_sansio.py b/test/mitmproxy/net/http/http1/test_read_sansio.py deleted file mode 100644 index a204fed06..000000000 --- a/test/mitmproxy/net/http/http1/test_read_sansio.py +++ /dev/null @@ -1,192 +0,0 @@ -import pytest - -from mitmproxy import exceptions -from mitmproxy.net.http import Headers -from mitmproxy.net.http.http1.read_sansio import ( - read_request_head, - read_response_head, expected_http_body_size, - _read_request_line, _read_response_line, - _read_headers, -) -from mitmproxy.test.tutils import treq, tresp - - -def test_read_request_head(): - rfile = [ - b"GET / HTTP/1.1\r\n", - b"Content-Length: 4\r\n", - ] - r = read_request_head(rfile) - assert r.method == "GET" - assert r.headers["Content-Length"] == "4" - assert r.content is None - - -def test_read_response_head(): - rfile = [ - b"HTTP/1.1 418 I'm a teapot\r\n", - b"Content-Length: 4\r\n", - ] - r = read_response_head(rfile) - assert r.status_code == 418 - assert r.headers["Content-Length"] == "4" - assert r.content is None - - -def test_expected_http_body_size(): - # Expect: 100-continue - assert expected_http_body_size( - treq(headers=Headers(expect="100-continue", content_length="42")), - expect_continue_as_0=True - ) == 0 - # Expect: 100-continue - assert expected_http_body_size( - treq(headers=Headers(expect="100-continue", content_length="42")), - expect_continue_as_0=False - ) == 42 - - # http://tools.ietf.org/html/rfc7230#section-3.3 - assert expected_http_body_size( - treq(method=b"HEAD"), - tresp(headers=Headers(content_length="42")) - ) == 0 - assert expected_http_body_size( - treq(method=b"CONNECT"), - tresp() - ) == 0 - for code in (100, 204, 304): - assert expected_http_body_size( - treq(), - tresp(status_code=code) - ) == 0 - - # chunked - assert expected_http_body_size( - treq(headers=Headers(transfer_encoding="chunked")), - ) is None - - # explicit length - for val in (b"foo", b"-7"): - with pytest.raises(exceptions.HttpSyntaxException): - expected_http_body_size( - treq(headers=Headers(content_length=val)) - ) - assert expected_http_body_size( - treq(headers=Headers(content_length="42")) - ) == 42 - - # more than 1 content-length headers with same value - assert expected_http_body_size( - treq(headers=Headers([(b'content-length', b'42'), (b'content-length', b'42')])) - ) == 42 - - # more than 1 content-length headers with conflicting value - with pytest.raises(exceptions.HttpSyntaxException): - expected_http_body_size( - treq(headers=Headers([(b'content-length', b'42'), (b'content-length', b'45')])) - ) - - # no length - assert expected_http_body_size( - treq(headers=Headers()) - ) == 0 - assert expected_http_body_size( - treq(headers=Headers()), tresp(headers=Headers()) - ) == -1 - - -def test_read_request_line(): - def t(b): - return _read_request_line(b) - - assert (t(b"GET / HTTP/1.1") == - ("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1")) - assert (t(b"OPTIONS * HTTP/1.1") == - ("", 0, b"OPTIONS", b"", b"", b"*", b"HTTP/1.1")) - assert (t(b"CONNECT foo:42 HTTP/1.1") == - ("foo", 42, b"CONNECT", b"", b"foo:42", b"", b"HTTP/1.1")) - assert (t(b"GET http://foo:42/bar HTTP/1.1") == - ("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1")) - - with pytest.raises(ValueError): - t(b"GET / WTF/1.1") - with pytest.raises(ValueError): - t(b"CONNECT example.com HTTP/1.1") # port missing - with pytest.raises(ValueError): - t(b"GET ws://example.com/ HTTP/1.1") # port missing - with pytest.raises(ValueError): - t(b"this is not http") - with pytest.raises(ValueError): - t(b"") - - -def test_read_response_line(): - def t(b): - return _read_response_line(b) - - assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") - assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") - - # https://github.com/mitmproxy/mitmproxy/issues/784 - assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9") - - with pytest.raises(ValueError): - assert t(b"HTTP/1.1") - - with pytest.raises(ValueError): - t(b"HTTP/1.1 OK OK") - with pytest.raises(ValueError): - t(b"WTF/1.1 200 OK") - with pytest.raises(ValueError): - t(b"") - - -class TestReadHeaders: - @staticmethod - def _read(data): - return _read_headers(data.splitlines(keepends=True)) - - def test_read_simple(self): - data = ( - b"Header: one\r\n" - b"Header2: two\r\n" - ) - headers = self._read(data) - assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) - - def test_read_multi(self): - data = ( - b"Header: one\r\n" - b"Header: two\r\n" - ) - headers = self._read(data) - assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) - - def test_read_continued(self): - data = ( - b"Header: one\r\n" - b"\ttwo\r\n" - b"Header2: three\r\n" - ) - headers = self._read(data) - assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) - - def test_read_continued_err(self): - data = b"\tfoo: bar\r\n" - with pytest.raises(ValueError): - self._read(data) - - def test_read_err(self): - data = b"foo" - with pytest.raises(ValueError): - self._read(data) - - def test_read_empty_name(self): - data = b":foo" - with pytest.raises(ValueError): - self._read(data) - - def test_read_empty_value(self): - data = b"bar:" - headers = self._read(data) - assert headers.fields == ((b"bar", b""),) diff --git a/test/mitmproxy/proxy/test_config.py b/test/mitmproxy/proxy/test_config.py deleted file mode 100644 index 38a6e1ade..000000000 --- a/test/mitmproxy/proxy/test_config.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest - -from mitmproxy import options -from mitmproxy import exceptions -from mitmproxy.proxy.config import ProxyConfig - - -class TestProxyConfig: - def test_invalid_confdir(self): - opts = options.Options() - opts.confdir = "foo" - with pytest.raises(exceptions.OptionsError, match="parent directory does not exist"): - ProxyConfig(opts) - - def test_invalid_certificate(self, tdata): - opts = options.Options() - opts.certs = [tdata.path("mitmproxy/data/dumpfile-011.bin")] - with pytest.raises(exceptions.OptionsError, match="Invalid certificate format"): - ProxyConfig(opts) - - def test_cannot_set_both_allow_and_filter_options(self): - opts = options.Options() - opts.ignore_hosts = ["foo"] - opts.allow_hosts = ["bar"] - with pytest.raises(exceptions.OptionsError, match="--ignore-hosts and --allow-hosts are " - "mutually exclusive; please choose " - "one."): - ProxyConfig(opts) diff --git a/test/mitmproxy/test_addonmanager.py b/test/mitmproxy/test_addonmanager.py index 660db549d..ef479d8e4 100644 --- a/test/mitmproxy/test_addonmanager.py +++ b/test/mitmproxy/test_addonmanager.py @@ -138,7 +138,8 @@ async def test_simple(): tctx.master.clear() a.get("one").response = addons a.trigger("response") - assert not await tctx.master.await_log("not callable") + with pytest.raises(AssertionError): + await tctx.master.await_log("not callable") a.remove(a.get("one")) assert not a.get("one") diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index 104440e08..29b415b4e 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -3,7 +3,6 @@ import argparse import pytest from mitmproxy import options -from mitmproxy.proxy import DummyServer from mitmproxy.tools import cmdline from mitmproxy.tools import main @@ -41,11 +40,3 @@ class TestProcessProxyOptions: self.assert_noerr( "--cert", tdata.path("mitmproxy/data/testkey.pem")) - - -class TestDummyServer: - - def test_simple(self): - d = DummyServer(None) - d.set_channel(None) - d.shutdown() diff --git a/test/mitmproxy/test_taddons.py b/test/mitmproxy/test_taddons.py index 53091bc14..e3717a282 100644 --- a/test/mitmproxy/test_taddons.py +++ b/test/mitmproxy/test_taddons.py @@ -19,7 +19,7 @@ async def test_recordingmaster(): async def test_dumplog(): with taddons.context() as tctx: ctx.log.info("testing") - await ctx.master.await_log("testing") + assert await ctx.master.await_log("testing") s = io.StringIO() tctx.master.dump_log(s) assert s.getvalue() diff --git a/test/mitmproxy/tools/test_cmdline.py b/test/mitmproxy/tools/test_cmdline.py index e247dc1db..d644f594d 100644 --- a/test/mitmproxy/tools/test_cmdline.py +++ b/test/mitmproxy/tools/test_cmdline.py @@ -10,7 +10,7 @@ def test_common(): opts = options.Options() cmdline.common_options(parser, opts) args = parser.parse_args(args=[]) - assert main.process_options(parser, opts, args) + main.process_options(parser, opts, args) def test_mitmproxy():