cleanup old proxy server

This commit is contained in:
Maximilian Hils 2020-12-28 14:33:10 +01:00
parent 0dbf69dfe9
commit 1655f54817
31 changed files with 211 additions and 1244 deletions

View File

@ -2,7 +2,6 @@ from mitmproxy.addons import anticache
from mitmproxy.addons import anticomp from mitmproxy.addons import anticomp
from mitmproxy.addons import block from mitmproxy.addons import block
from mitmproxy.addons import browser from mitmproxy.addons import browser
from mitmproxy.addons import check_ca
from mitmproxy.addons import clientplayback from mitmproxy.addons import clientplayback
from mitmproxy.addons import command_history from mitmproxy.addons import command_history
from mitmproxy.addons import core from mitmproxy.addons import core
@ -34,7 +33,6 @@ def default_addons():
block.Block(), block.Block(),
anticache.AntiCache(), anticache.AntiCache(),
anticomp.AntiComp(), anticomp.AntiComp(),
check_ca.CheckCA(),
clientplayback.ClientPlayback(), clientplayback.ClientPlayback(),
command_history.CommandHistory(), command_history.CommandHistory(),
cut.Cut(), cut.Cut(),

View File

@ -1,24 +0,0 @@
import mitmproxy
from mitmproxy import ctx
class CheckCA:
def __init__(self):
self.failed = False
def configure(self, updated):
has_ca = (
mitmproxy.ctx.master.server and
mitmproxy.ctx.master.server.config and
mitmproxy.ctx.master.server.config.certstore and
mitmproxy.ctx.master.server.config.certstore.default_ca
)
if has_ca:
self.failed = mitmproxy.ctx.master.server.config.certstore.default_ca.has_expired()
if self.failed:
ctx.log.warn(
"The mitmproxy certificate authority has expired!\n"
"Please delete all CA-related files in your ~/.mitmproxy folder.\n"
"The CA will be regenerated automatically after restarting mitmproxy.\n"
"Then make sure all your clients have the new CA installed.",
)

View File

@ -1,17 +0,0 @@
from mitmproxy import ctx
from mitmproxy.utils import human
"""
A tiny addon to print the proxy status to terminal. Eventually this could
also print some stats on exit.
"""
class TermStatus:
def running(self):
if ctx.master.server.bound:
ctx.log.info(
"Proxy server listening at http://{}".format(
human.format_address(ctx.master.server.address)
)
)

View File

@ -214,6 +214,14 @@ class TlsConfig:
key_size=ctx.options.key_size, key_size=ctx.options.key_size,
passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None, passphrase=ctx.options.cert_passphrase.encode("utf8") if ctx.options.cert_passphrase else None,
) )
if self.certstore.default_ca.has_expired():
ctx.log.warn(
"The mitmproxy certificate authority has expired!\n"
"Please delete all CA-related files in your ~/.mitmproxy folder.\n"
"The CA will be regenerated automatically after restarting mitmproxy.\n"
"Then make sure all your clients have the new CA installed.",
)
for certspec in ctx.options.certs: for certspec in ctx.options.certs:
parts = certspec.split("=", 1) parts = certspec.split("=", 1)
if len(parts) == 1: if len(parts) == 1:

View File

@ -1,50 +1,7 @@
import queue import queue
import asyncio
from mitmproxy import exceptions from mitmproxy import exceptions
class Channel:
"""
The only way for the proxy server to communicate with the master
is to use the channel it has been given.
"""
def __init__(self, master, loop, should_exit):
self.master = master
self.loop = loop
self.should_exit = should_exit
def ask(self, mtype, m):
"""
Decorate a message with a reply attribute, and send it to the master.
Then wait for a response.
Raises:
exceptions.Kill: All connections should be closed immediately.
"""
if not self.should_exit.is_set():
m.reply = Reply(m)
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
g = m.reply.q.get()
if g == exceptions.Kill:
raise exceptions.Kill()
return g
def tell(self, mtype, m):
"""
Decorate a message with a dummy reply attribute, send it to the master,
then return immediately.
"""
if not self.should_exit.is_set():
m.reply = DummyReply()
asyncio.run_coroutine_threadsafe(
self.master.addons.handle_lifecycle(mtype, m),
self.loop,
)
NO_REPLY = object() # special object we can distinguish from a valid "None" reply. NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
@ -53,6 +10,7 @@ class Reply:
Messages sent through a channel are decorated with a "reply" attribute. This Messages sent through a channel are decorated with a "reply" attribute. This
object is used to respond to the message through the return channel. object is used to respond to the message through the return channel.
""" """
def __init__(self, obj): def __init__(self, obj):
self.obj = obj self.obj = obj
# Spawn an event loop in the current thread # Spawn an event loop in the current thread
@ -138,6 +96,7 @@ class DummyReply(Reply):
handler so that they can be used multiple times. Useful when we need an handler so that they can be used multiple times. Useful when we need an
object to seem like it has a channel, and during testing. object to seem like it has a channel, and during testing.
""" """
def __init__(self): def __init__(self):
super().__init__(None) super().__init__(None)
self._should_reset = False self._should_reset = False

View File

@ -1,4 +1,10 @@
""" """
Edit 2020-12 @mhils:
The advice below hasn't paid off in any form. We now just use builtin exceptions and specialize where necessary.
---
We try to be very hygienic regarding the exceptions we throw: We try to be very hygienic regarding the exceptions we throw:
- Every exception that might be externally visible to users shall be a subclass - Every exception that might be externally visible to users shall be a subclass
@ -11,7 +17,6 @@ See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/
class MitmproxyException(Exception): class MitmproxyException(Exception):
""" """
Base class for all exceptions thrown by mitmproxy. Base class for all exceptions thrown by mitmproxy.
""" """
@ -21,58 +26,12 @@ class MitmproxyException(Exception):
class Kill(MitmproxyException): class Kill(MitmproxyException):
""" """
Signal that both client and server connection(s) should be killed immediately. Signal that both client and server connection(s) should be killed immediately.
""" """
pass pass
class ProtocolException(MitmproxyException):
"""
ProtocolExceptions are caused by invalid user input, unavailable network resources,
or other events that are outside of our influence.
"""
pass
class TlsProtocolException(ProtocolException):
pass
class ClientHandshakeException(TlsProtocolException):
def __init__(self, message, server):
super().__init__(message)
self.server = server
class InvalidServerCertificate(TlsProtocolException):
def __repr__(self):
# In contrast to most others, this is a user-facing error which needs to look good.
return str(self)
class Socks5ProtocolException(ProtocolException):
pass
class HttpProtocolException(ProtocolException):
pass
class Http2ProtocolException(ProtocolException):
pass
class Http2ZombieException(ProtocolException):
pass
class ServerException(MitmproxyException):
pass
class ContentViewException(MitmproxyException): class ContentViewException(MitmproxyException):
pass pass
@ -89,10 +48,6 @@ class ControlException(MitmproxyException):
pass pass
class SetServerNotAllowedException(MitmproxyException):
pass
class CommandError(Exception): class CommandError(Exception):
pass pass
@ -116,62 +71,22 @@ class TypeError(MitmproxyException):
pass pass
"""
Net-layer exceptions
"""
class NetlibException(MitmproxyException): class NetlibException(MitmproxyException):
""" """
Base class for all exceptions thrown by mitmproxy.net. Base class for all exceptions thrown by mitmproxy.net.
""" """
def __init__(self, message=None): def __init__(self, message=None):
super().__init__(message) super().__init__(message)
class SessionLoadException(MitmproxyException):
pass
class Disconnect:
"""Immediate EOF"""
class HttpException(NetlibException): class HttpException(NetlibException):
pass pass
class HttpReadDisconnect(HttpException, Disconnect):
pass
class HttpSyntaxException(HttpException): class HttpSyntaxException(HttpException):
pass pass
class TcpException(NetlibException):
pass
class TcpDisconnect(TcpException, Disconnect):
pass
class TcpReadIncomplete(TcpException):
pass
class TcpTimeout(TcpException):
pass
class TlsException(NetlibException): class TlsException(NetlibException):
pass pass
class InvalidCertificateException(TlsException):
pass
class Timeout(TcpException):
pass

View File

@ -1,54 +1,28 @@
import sys
import traceback
import threading
import asyncio import asyncio
import logging import sys
import threading
import traceback
from mitmproxy import addonmanager from mitmproxy import addonmanager
from mitmproxy import options from mitmproxy import command
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import eventsequence from mitmproxy import eventsequence
from mitmproxy import command
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log from mitmproxy import log
from mitmproxy import options
from mitmproxy import websocket
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.coretypes import basethread
from . import ctx as mitmproxy_ctx from . import ctx as mitmproxy_ctx
# Conclusively preventing cross-thread races on proxy shutdown turns out to be
# very hard. We could build a thread sync infrastructure for this, or we could
# wait until we ditch threads and move all the protocols into the async loop.
# Until then, silence non-critical errors.
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class ServerThread(basethread.BaseThread):
def __init__(self, server):
self.server = server
address = getattr(self.server, "address", None)
super().__init__(
"ServerThread ({})".format(repr(address))
)
def run(self):
self.server.serve_forever()
class Master: class Master:
""" """
The master handles mitmproxy's main event loop. The master handles mitmproxy's main event loop.
""" """
def __init__(self, opts): def __init__(self, opts):
self.should_exit = threading.Event() self.should_exit = threading.Event()
self.channel = controller.Channel( self.loop = asyncio.get_event_loop()
self,
asyncio.get_event_loop(),
self.should_exit,
)
self.options: options.Options = opts or options.Options() self.options: options.Options = opts or options.Options()
self.commands = command.CommandManager(self) self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self) self.addons = addonmanager.AddonManager(self)
@ -60,19 +34,8 @@ class Master:
mitmproxy_ctx.log = self.log mitmproxy_ctx.log = self.log
mitmproxy_ctx.options = self.options mitmproxy_ctx.options = self.options
@property
def server(self):
return self._server
@server.setter
def server(self, server):
server.set_channel(self.channel)
self._server = server
def start(self): def start(self):
self.should_exit.clear() self.should_exit.clear()
if self.server:
ServerThread(self.server).start()
async def running(self): async def running(self):
self.addons.trigger("running") self.addons.trigger("running")
@ -109,8 +72,6 @@ class Master:
async def _shutdown(self): async def _shutdown(self):
self.should_exit.set() self.should_exit.set()
if self.server:
self.server.shutdown()
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.stop() loop.stop()
@ -120,13 +81,13 @@ class Master:
""" """
if not self.should_exit.is_set(): if not self.should_exit.is_set():
self.should_exit.set() self.should_exit.set()
ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.channel.loop) ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.loop)
# Weird band-aid to make sure that self._shutdown() is actually executed, # Weird band-aid to make sure that self._shutdown() is actually executed,
# which otherwise hangs the process as the proxy server is threaded. # which otherwise hangs the process as the proxy server is threaded.
# This all needs to be simplified when the proxy server runs on asyncio as well. # This all needs to be simplified when the proxy server runs on asyncio as well.
if not self.channel.loop.is_running(): # pragma: no cover if not self.loop.is_running(): # pragma: no cover
try: try:
self.channel.loop.run_until_complete(asyncio.wrap_future(ret)) self.loop.run_until_complete(asyncio.wrap_future(ret))
except RuntimeError: except RuntimeError:
pass # Event loop stopped before Future completed. pass # Event loop stopped before Future completed.

View File

@ -1,7 +1,6 @@
from .read import ( from .read import (
read_request, read_request_head, read_request_head,
read_response, read_response_head, read_response_head,
read_body,
connection_close, connection_close,
expected_http_body_size, expected_http_body_size,
) )
@ -13,9 +12,8 @@ from .assemble import (
__all__ = [ __all__ = [
"read_request", "read_request_head", "read_request_head",
"read_response", "read_response_head", "read_response_head",
"read_body",
"connection_close", "connection_close",
"expected_http_body_size", "expected_http_body_size",
"assemble_request", "assemble_request_head", "assemble_request", "assemble_request_head",

View File

@ -1,13 +1,9 @@
import re import re
import sys
import time import time
import typing from typing import List, Tuple, Iterable, Optional
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net.http import headers from mitmproxy.net.http import request, response, headers, url
from mitmproxy.net.http import request
from mitmproxy.net.http import response
from mitmproxy.net.http import url
def get_header_tokens(headers, key): def get_header_tokens(headers, key):
@ -22,137 +18,6 @@ def get_header_tokens(headers, key):
return [token.strip() for token in tokens] return [token.strip() for token in tokens]
def read_request(rfile, body_size_limit=None):
request = read_request_head(rfile)
expected_body_size = expected_http_body_size(request)
request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request.timestamp_end = time.time()
return request
def read_request_head(rfile):
"""
Parse an HTTP request head (request line + headers) from an input stream
Args:
rfile: The input stream
Returns:
The HTTP request object (without body)
Raises:
exceptions.HttpReadDisconnect: No bytes can be read from rfile.
exceptions.HttpSyntaxException: The input is malformed HTTP.
exceptions.HttpException: Any other error occurred.
"""
timestamp_start = time.time()
if hasattr(rfile, "reset_timestamps"):
rfile.reset_timestamps()
host, port, method, scheme, authority, path, http_version = _read_request_line(rfile)
headers = _read_headers(rfile)
if hasattr(rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
return request.Request(
host, port, method, scheme, authority, path, http_version, headers, None, None, timestamp_start, None
)
def read_response(rfile, request, body_size_limit=None):
response = read_response_head(rfile)
expected_body_size = expected_http_body_size(request, response)
response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit))
response.timestamp_end = time.time()
return response
def read_response_head(rfile):
"""
Parse an HTTP response head (response line + headers) from an input stream
Args:
rfile: The input stream
Returns:
The HTTP request object (without body)
Raises:
exceptions.HttpReadDisconnect: No bytes can be read from rfile.
exceptions.HttpSyntaxException: The input is malformed HTTP.
exceptions.HttpException: Any other error occurred.
"""
timestamp_start = time.time()
if hasattr(rfile, "reset_timestamps"):
rfile.reset_timestamps()
http_version, status_code, message = _read_response_line(rfile)
headers = _read_headers(rfile)
if hasattr(rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = rfile.first_byte_timestamp
return response.Response(http_version, status_code, message, headers, None, None, timestamp_start, None)
def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
"""
Read an HTTP message body
Args:
rfile: The input stream
expected_size: The expected body size (see :py:meth:`expected_body_size`)
limit: Maximum body size
max_chunk_size: Maximium chunk size that gets yielded
Returns:
A generator that yields byte chunks of the content.
Raises:
exceptions.HttpException, if an error occurs
Caveats:
max_chunk_size is not considered if the transfer encoding is chunked.
"""
if not limit or limit < 0:
limit = sys.maxsize
if not max_chunk_size:
max_chunk_size = limit
if expected_size is None:
yield from _read_chunked(rfile, limit)
elif expected_size >= 0:
if limit is not None and expected_size > limit:
raise exceptions.HttpException(
"HTTP Body too large. "
"Limit is {}, content length was advertised as {}".format(limit, expected_size)
)
bytes_left = expected_size
while bytes_left:
chunk_size = min(bytes_left, max_chunk_size)
content = rfile.read(chunk_size)
if len(content) < chunk_size:
raise exceptions.HttpException("Unexpected EOF")
yield content
bytes_left -= chunk_size
else:
bytes_left = limit
while bytes_left:
chunk_size = min(bytes_left, max_chunk_size)
content = rfile.read(chunk_size)
if not content:
return
yield content
bytes_left -= chunk_size
not_done = rfile.read(1)
if not_done:
raise exceptions.HttpException(f"HTTP body too large. Limit is {limit}.")
def connection_close(http_version, headers): def connection_close(http_version, headers):
""" """
Checks the message to see if the client connection should be closed Checks the message to see if the client connection should be closed
@ -175,7 +40,7 @@ def connection_close(http_version, headers):
def expected_http_body_size( def expected_http_body_size(
request: request.Request, request: request.Request,
response: typing.Optional[response.Response] = None, response: Optional[response.Response] = None,
expect_continue_as_0: bool = True expect_continue_as_0: bool = True
): ):
""" """
@ -195,6 +60,8 @@ def expected_http_body_size(
# http://tools.ietf.org/html/rfc7230#section-3.3 # http://tools.ietf.org/html/rfc7230#section-3.3
if not response: if not response:
headers = request.headers headers = request.headers
if request.method.upper() == "CONNECT":
return 0
if expect_continue_as_0 and headers.get("expect", "").lower() == "100-continue": if expect_continue_as_0 and headers.get("expect", "").lower() == "100-continue":
return 0 return 0
else: else:
@ -227,28 +94,20 @@ def expected_http_body_size(
return -1 return -1
def _get_first_line(rfile): def _check_http_version(http_version):
try: if not re.match(br"^HTTP/\d\.\d$", http_version):
line = rfile.readline() raise exceptions.HttpSyntaxException(f"Unknown HTTP version: {http_version}")
if line == b"\r\n" or line == b"\n":
# Possible leftover from previous message
line = rfile.readline()
except (exceptions.TcpDisconnect, exceptions.TlsException):
raise exceptions.HttpReadDisconnect("Remote disconnected")
if not line:
raise exceptions.HttpReadDisconnect("Remote disconnected")
return line.strip()
def _read_request_line(rfile): def raise_if_http_version_unknown(http_version: bytes) -> None:
try: if not re.match(br"^HTTP/\d\.\d$", http_version):
line = _get_first_line(rfile) raise ValueError(f"Unknown HTTP version: {http_version!r}")
except exceptions.HttpReadDisconnect:
# We want to provide a better error message.
raise exceptions.HttpReadDisconnect("Client disconnected")
def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
try: try:
method, target, http_version = line.split() method, target, http_version = line.split()
port: Optional[int]
if target == b"*" or target.startswith(b"/"): if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target scheme, authority, path = b"", b"", target
@ -269,41 +128,29 @@ def _read_request_line(rfile):
# TODO: we can probably get rid of this check? # TODO: we can probably get rid of this check?
url.parse(target) url.parse(target)
_check_http_version(http_version) raise_if_http_version_unknown(http_version)
except ValueError: except ValueError as e:
raise exceptions.HttpSyntaxException(f"Bad HTTP request line: {line}") raise ValueError(f"Bad HTTP request line: {line!r}") from e
return host, port, method, scheme, authority, path, http_version return host, port, method, scheme, authority, path, http_version
def _read_response_line(rfile): def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]:
try:
line = _get_first_line(rfile)
except exceptions.HttpReadDisconnect:
# We want to provide a better error message.
raise exceptions.HttpReadDisconnect("Server disconnected")
try: try:
parts = line.split(None, 2) parts = line.split(None, 2)
if len(parts) == 2: # handle missing message gracefully if len(parts) == 2: # handle missing message gracefully
parts.append(b"") parts.append(b"")
http_version, status_code, message = parts http_version, status_code_str, reason = parts
status_code = int(status_code) status_code = int(status_code_str)
_check_http_version(http_version) raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP response line: {line!r}") from e
except ValueError: return http_version, status_code, reason
raise exceptions.HttpSyntaxException(f"Bad HTTP response line: {line}")
return http_version, status_code, message
def _check_http_version(http_version): def _read_headers(lines: Iterable[bytes]) -> headers.Headers:
if not re.match(br"^HTTP/\d\.\d$", http_version):
raise exceptions.HttpSyntaxException(f"Unknown HTTP version: {http_version}")
def _read_headers(rfile):
""" """
Read a set of headers. Read a set of headers.
Stop once a blank line is reached. Stop once a blank line is reached.
@ -314,15 +161,11 @@ def _read_headers(rfile):
Raises: Raises:
exceptions.HttpSyntaxException exceptions.HttpSyntaxException
""" """
ret = [] ret: List[Tuple[bytes, bytes]] = []
while True: for line in lines:
line = rfile.readline()
if not line or line == b"\r\n" or line == b"\n":
# we do have coverage of this, but coverage.py does not detect it.
break # pragma: no cover
if line[0] in b" \t": if line[0] in b" \t":
if not ret: if not ret:
raise exceptions.HttpSyntaxException("Invalid headers") raise ValueError("Invalid headers")
# continued header # continued header
ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else: else:
@ -333,40 +176,65 @@ def _read_headers(rfile):
raise ValueError() raise ValueError()
ret.append((name, value)) ret.append((name, value))
except ValueError: except ValueError:
raise exceptions.HttpSyntaxException( raise ValueError(f"Invalid header line: {line!r}")
"Invalid header line: %s" % repr(line)
)
return headers.Headers(ret) return headers.Headers(ret)
def _read_chunked(rfile, limit=sys.maxsize): def read_request_head(lines: List[bytes]) -> request.Request:
""" """
Read a HTTP body with chunked transfer encoding. Parse an HTTP request head (request line + headers) from an iterable of lines
Args: Args:
rfile: the input file lines: The input lines
limit: A positive integer
Returns:
The HTTP request object (without body)
Raises:
ValueError: The input is malformed.
""" """
total = 0 host, port, method, scheme, authority, path, http_version = _read_request_line(lines[0])
while True: headers = _read_headers(lines[1:])
line = rfile.readline(128)
if line == b"": return request.Request(
raise exceptions.HttpException("Connection closed prematurely") host=host,
if line != b"\r\n" and line != b"\n": port=port,
try: method=method,
length = int(line, 16) scheme=scheme,
except ValueError: authority=authority,
raise exceptions.HttpSyntaxException(f"Invalid chunked encoding length: {line}") path=path,
total += length http_version=http_version,
if total > limit: headers=headers,
raise exceptions.HttpException( content=None,
"HTTP Body too large. Limit is {}, " trailers=None,
"chunked content longer than {}".format(limit, total) timestamp_start=time.time(),
) timestamp_end=None
chunk = rfile.read(length) )
suffix = rfile.readline(5)
if suffix != b"\r\n":
raise exceptions.HttpSyntaxException("Malformed chunked body") def read_response_head(lines: List[bytes]) -> response.Response:
if length == 0: """
return Parse an HTTP response head (response line + headers) from an iterable of lines
yield chunk
Args:
lines: The input lines
Returns:
The HTTP response object (without body)
Raises:
ValueError: The input is malformed.
"""
http_version, status_code, reason = _read_response_line(lines[0])
headers = _read_headers(lines[1:])
return response.Response(
http_version=http_version,
status_code=status_code,
reason=reason,
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)

View File

@ -1,160 +0,0 @@
import re
import time
from typing import Iterable, List, Optional, Tuple
from mitmproxy.net.http import headers, request, response, url
from mitmproxy.net.http.http1 import read
def raise_if_http_version_unknown(http_version: bytes) -> None:
if not re.match(br"^HTTP/\d\.\d$", http_version):
raise ValueError(f"Unknown HTTP version: {http_version!r}")
def _read_request_line(line: bytes) -> Tuple[str, int, bytes, bytes, bytes, bytes, bytes]:
try:
method, target, http_version = line.split()
port: Optional[int]
if target == b"*" or target.startswith(b"/"):
scheme, authority, path = b"", b"", target
host, port = "", 0
elif method == b"CONNECT":
scheme, authority, path = b"", target, b""
host, port = url.parse_authority(authority, check=True)
if not port:
raise ValueError
else:
scheme, rest = target.split(b"://", maxsplit=1)
authority, path_ = rest.split(b"/", maxsplit=1)
path = b"/" + path_
host, port = url.parse_authority(authority, check=True)
port = port or url.default_port(scheme)
if not port:
raise ValueError
# TODO: we can probably get rid of this check?
url.parse(target)
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP request line: {line!r}") from e
return host, port, method, scheme, authority, path, http_version
def _read_response_line(line: bytes) -> Tuple[bytes, int, bytes]:
try:
parts = line.split(None, 2)
if len(parts) == 2: # handle missing message gracefully
parts.append(b"")
http_version, status_code_str, reason = parts
status_code = int(status_code_str)
raise_if_http_version_unknown(http_version)
except ValueError as e:
raise ValueError(f"Bad HTTP response line: {line!r}") from e
return http_version, status_code, reason
def _read_headers(lines: Iterable[bytes]) -> headers.Headers:
"""
Read a set of headers.
Stop once a blank line is reached.
Returns:
A headers object
Raises:
exceptions.HttpSyntaxException
"""
ret: List[Tuple[bytes, bytes]] = []
for line in lines:
if line[0] in b" \t":
if not ret:
raise ValueError("Invalid headers")
# continued header
ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
else:
try:
name, value = line.split(b":", 1)
value = value.strip()
if not name:
raise ValueError()
ret.append((name, value))
except ValueError:
raise ValueError(f"Invalid header line: {line!r}")
return headers.Headers(ret)
def read_request_head(lines: List[bytes]) -> request.Request:
"""
Parse an HTTP request head (request line + headers) from an iterable of lines
Args:
lines: The input lines
Returns:
The HTTP request object (without body)
Raises:
ValueError: The input is malformed.
"""
host, port, method, scheme, authority, path, http_version = _read_request_line(lines[0])
headers = _read_headers(lines[1:])
return request.Request(
host=host,
port=port,
method=method,
scheme=scheme,
authority=authority,
path=path,
http_version=http_version,
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None
)
def read_response_head(lines: List[bytes]) -> response.Response:
"""
Parse an HTTP response head (response line + headers) from an iterable of lines
Args:
lines: The input lines
Returns:
The HTTP response object (without body)
Raises:
ValueError: The input is malformed.
"""
http_version, status_code, reason = _read_response_line(lines[0])
headers = _read_headers(lines[1:])
return response.Response(
http_version=http_version,
status_code=status_code,
reason=reason,
headers=headers,
content=None,
trailers=None,
timestamp_start=time.time(),
timestamp_end=None,
)
def expected_http_body_size(
request: request.Request,
response: Optional[response.Response] = None,
expect_continue_as_0: bool = True,
):
"""
Like the non-sans-io version, but also treating CONNECT as content-length: 0
"""
if request.data.method.upper() == b"CONNECT":
return 0
return read.expected_http_body_size(request, response, expect_continue_as_0)

View File

@ -18,28 +18,3 @@ The most important primitives are:
- Context: The context is the connection context each layer is provided with, which is always a client connection - Context: The context is the connection context each layer is provided with, which is always a client connection
and sometimes also a server connection. and sometimes also a server connection.
""" """
from .config import ProxyConfig
class DummyServer:
bound = False
def __init__(self, config=None):
self.config = config
self.address = "dummy"
def set_channel(self, channel):
pass
def serve_forever(self):
pass
def shutdown(self):
pass
__all__ = [
"DummyServer",
"ProxyConfig",
]

View File

@ -1,92 +0,0 @@
import os
import re
import typing
from OpenSSL import crypto
from mitmproxy import certs
from mitmproxy import exceptions
from mitmproxy import options as moptions
from mitmproxy.net import server_spec
class HostMatcher:
def __init__(self, handle, patterns=tuple()):
self.handle = handle
self.patterns = list(patterns)
self.regexes = [re.compile(p, re.IGNORECASE) for p in self.patterns]
def __call__(self, address):
if not address:
return False
host = "%s:%s" % address
if self.handle in ["ignore", "tcp"]:
return any(rex.search(host) for rex in self.regexes)
else: # self.handle == "allow"
return not any(rex.search(host) for rex in self.regexes)
def __bool__(self):
return bool(self.patterns)
class ProxyConfig:
def __init__(self, options: moptions.Options) -> None:
self.options = options
self.certstore: certs.CertStore
self.check_filter: typing.Optional[HostMatcher] = None
self.check_tcp: typing.Optional[HostMatcher] = None
self.upstream_server: typing.Optional[server_spec.ServerSpec] = None
self.configure(options, set(options.keys()))
options.changed.connect(self.configure)
def configure(self, options: moptions.Options, updated: typing.Any) -> None:
if options.allow_hosts and options.ignore_hosts:
raise exceptions.OptionsError("--ignore-hosts and --allow-hosts are mutually "
"exclusive; please choose one.")
if options.ignore_hosts:
self.check_filter = HostMatcher("ignore", options.ignore_hosts)
elif options.allow_hosts:
self.check_filter = HostMatcher("allow", options.allow_hosts)
else:
self.check_filter = HostMatcher(False)
if "tcp_hosts" in updated:
self.check_tcp = HostMatcher("tcp", options.tcp_hosts)
certstore_path = os.path.expanduser(options.confdir)
if not os.path.exists(os.path.dirname(certstore_path)):
raise exceptions.OptionsError(
"Certificate Authority parent directory does not exist: %s" %
os.path.dirname(certstore_path)
)
key_size = options.key_size
passphrase = options.cert_passphrase.encode("utf-8") if options.cert_passphrase else None
self.certstore = certs.CertStore.from_store(
certstore_path,
moptions.CONF_BASENAME,
key_size,
passphrase
)
for c in options.certs:
parts = c.split("=", 1)
if len(parts) == 1:
parts = ["*", parts[0]]
cert = os.path.expanduser(parts[1])
if not os.path.exists(cert):
raise exceptions.OptionsError(
"Certificate file does not exist: %s" % cert
)
try:
self.certstore.add_cert_file(parts[0], cert, passphrase)
except crypto.Error:
raise exceptions.OptionsError(
"Invalid certificate format: %s" % cert
)
m = options.mode
if m.startswith("upstream:") or m.startswith("reverse:"):
_, spec = server_spec.parse_with_mode(options.mode)
self.upstream_server = spec

View File

@ -8,7 +8,6 @@ from h11._receivebuffer import ReceiveBuffer
from mitmproxy import exceptions, http from mitmproxy import exceptions, http
from mitmproxy.net import http as net_http from mitmproxy.net import http as net_http
from mitmproxy.net.http import http1, status_codes from mitmproxy.net.http import http1, status_codes
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.context import Connection, ConnectionState, Context from mitmproxy.proxy.context import Connection, ConnectionState, Context
from mitmproxy.proxy.layers.http._base import ReceiveHttp, StreamId from mitmproxy.proxy.layers.http._base import ReceiveHttp, StreamId
@ -148,7 +147,7 @@ class Http1Connection(HttpConnection, metaclass=abc.ABCMeta):
yield from self.make_pipe() yield from self.make_pipe()
return return
connection_done = ( connection_done = (
http1_sansio.expected_http_body_size(self.request, self.response) == -1 http1.expected_http_body_size(self.request, self.response) == -1
or http1.connection_close(self.request.http_version, self.request.headers) or http1.connection_close(self.request.http_version, self.request.headers)
or http1.connection_close(self.response.http_version, self.response.headers) or http1.connection_close(self.response.http_version, self.response.headers)
# If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request. # If we proxy HTTP/2 to HTTP/1, we only use upstream connections for one request.
@ -227,8 +226,8 @@ class Http1Server(Http1Connection):
if request_head: if request_head:
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
try: try:
self.request = http1_sansio.read_request_head(request_head) self.request = http1.read_request_head(request_head)
expected_body_size = http1_sansio.expected_http_body_size(self.request, expect_continue_as_0=False) expected_body_size = http1.expected_http_body_size(self.request, expect_continue_as_0=False)
except (ValueError, exceptions.HttpSyntaxException) as e: except (ValueError, exceptions.HttpSyntaxException) as e:
yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}") yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}")
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
@ -294,7 +293,7 @@ class Http1Client(Http1Connection):
assert self.request assert self.request
if "chunked" in self.request.headers.get("transfer-encoding", "").lower(): if "chunked" in self.request.headers.get("transfer-encoding", "").lower():
yield commands.SendData(self.conn, b"0\r\n\r\n") yield commands.SendData(self.conn, b"0\r\n\r\n")
elif http1_sansio.expected_http_body_size(self.request, self.response) == -1: elif http1.expected_http_body_size(self.request, self.response) == -1:
yield commands.CloseConnection(self.conn, half_close=True) yield commands.CloseConnection(self.conn, half_close=True)
yield from self.mark_done(request=True) yield from self.mark_done(request=True)
elif isinstance(event, RequestProtocolError): elif isinstance(event, RequestProtocolError):
@ -316,8 +315,8 @@ class Http1Client(Http1Connection):
if response_head: if response_head:
response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays
try: try:
self.response = http1_sansio.read_response_head(response_head) self.response = http1.read_response_head(response_head)
expected_size = http1_sansio.expected_http_body_size(self.request, self.response) expected_size = http1.expected_http_body_size(self.request, self.response)
except (ValueError, exceptions.HttpSyntaxException) as e: except (ValueError, exceptions.HttpSyntaxException) as e:
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}")) yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}"))

View File

@ -5,7 +5,6 @@ from h11._receivebuffer import ReceiveBuffer
from mitmproxy import http from mitmproxy import http
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy import commands, context, layer, tunnel from mitmproxy.proxy import commands, context, layer, tunnel
from mitmproxy.utils import human from mitmproxy.utils import human
@ -57,7 +56,7 @@ class HttpUpstreamProxy(tunnel.TunnelLayer):
if response_head: if response_head:
response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays response_head = [bytes(x) for x in response_head] # TODO: Make url.parse compatible with bytearrays
try: try:
response = http1_sansio.read_response_head(response_head) response = http1.read_response_head(response_head)
except ValueError as e: except ValueError as e:
yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {e}") yield commands.Log(f"{human.format_address(self.tunnel_connection.address)}: {e}")
return False, str(e) return False, str(e)

View File

@ -38,16 +38,16 @@ class RecordingMaster(mitmproxy.master.Master):
return True return True
return False return False
async def await_log(self, txt, level=None): async def await_log(self, txt, level=None, timeout=1):
# start with a sleep(0), which lets all other coroutines advance. # start with a sleep(0), which lets all other coroutines advance.
# often this is enough to not sleep at all. # often this is enough to not sleep at all.
await asyncio.sleep(0) await asyncio.sleep(0)
for i in range(20): for i in range(int(timeout / 0.001)):
if self.has_log(txt, level): if self.has_log(txt, level):
return True return True
else: else:
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
return False raise AssertionError(f"Did not find log entry {txt!r} in {self.logs}.")
def clear(self): def clear(self):
self.logs = [] self.logs = []

View File

@ -1,7 +1,7 @@
from mitmproxy import addons from mitmproxy import addons
from mitmproxy import options from mitmproxy import options
from mitmproxy import master from mitmproxy import master
from mitmproxy.addons import dumper, termlog, termstatus, keepserving, readfile from mitmproxy.addons import dumper, termlog, keepserving, readfile
class ErrorCheck: class ErrorCheck:
@ -24,7 +24,7 @@ class DumpMaster(master.Master):
super().__init__(options) super().__init__(options)
self.errorcheck = ErrorCheck() self.errorcheck = ErrorCheck()
if with_termlog: if with_termlog:
self.addons.add(termlog.TermLog(), termstatus.TermStatus()) self.addons.add(termlog.TermLog())
self.addons.add(*addons.default_addons()) self.addons.add(*addons.default_addons())
if with_dumper: if with_dumper:
self.addons.add(dumper.Dumper()) self.addons.add(dumper.Dumper())

View File

@ -1,15 +1,14 @@
import os
import sys
import asyncio
import argparse import argparse
import asyncio
import os
import signal import signal
import sys
import typing import typing
from mitmproxy.tools import cmdline
from mitmproxy import exceptions, master from mitmproxy import exceptions, master
from mitmproxy import options from mitmproxy import options
from mitmproxy import optmanager from mitmproxy import optmanager
from mitmproxy import proxy from mitmproxy.tools import cmdline
from mitmproxy.utils import debug, arg_check from mitmproxy.utils import debug, arg_check
@ -48,8 +47,6 @@ def process_options(parser, opts, args):
adict[n] = getattr(args, n) adict[n] = getattr(args, n)
opts.merge(adict) opts.merge(adict)
return proxy.config.ProxyConfig(opts)
def run( def run(
master_cls: typing.Type[master.Master], master_cls: typing.Type[master.Master],
@ -85,10 +82,7 @@ def run(
os.path.join(opts.confdir, "config.yaml"), os.path.join(opts.confdir, "config.yaml"),
os.path.join(opts.confdir, "config.yml"), os.path.join(opts.confdir, "config.yml"),
) )
pconf = process_options(parser, opts, args) process_options(parser, opts, args)
# new core initializes itself as an addon
master.server = proxy.DummyServer(pconf)
if args.options: if args.options:
print(optmanager.dump_defaults(opts)) print(optmanager.dump_defaults(opts))
@ -97,7 +91,7 @@ def run(
master.commands.dump() master.commands.dump()
sys.exit(0) sys.exit(0)
if extra: if extra:
if(args.filter_args): if args.filter_args:
master.log.info(f"Only processing flows that match \"{' & '.join(args.filter_args)}\"") master.log.info(f"Only processing flows that match \"{' & '.join(args.filter_args)}\"")
opts.update(**extra(args)) opts.update(**extra(args))

View File

@ -14,7 +14,6 @@ import tornado.websocket
import mitmproxy.flow import mitmproxy.flow
import mitmproxy.tools.web.master # noqa import mitmproxy.tools.web.master # noqa
from mitmproxy import contentviews from mitmproxy import contentviews
from mitmproxy import exceptions
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import http from mitmproxy import http
from mitmproxy import io from mitmproxy import io
@ -380,14 +379,7 @@ class RevertFlow(RequestHandler):
class ReplayFlow(RequestHandler): class ReplayFlow(RequestHandler):
def post(self, flow_id): def post(self, flow_id):
self.flow.backup() self.master.commands.call("replay.client", [self.flow])
self.flow.response = None
self.view.update([self.flow])
try:
self.master.commands.call("replay.client", [self.flow])
except exceptions.ReplayException as e:
raise APIError(400, str(e))
class FlowContent(RequestHandler): class FlowContent(RequestHandler):

View File

@ -11,7 +11,6 @@ from mitmproxy.addons import intercept
from mitmproxy.addons import readfile from mitmproxy.addons import readfile
from mitmproxy.addons import termlog from mitmproxy.addons import termlog
from mitmproxy.addons import view from mitmproxy.addons import view
from mitmproxy.addons import termstatus
from mitmproxy.tools.web import app, webaddons, static_viewer from mitmproxy.tools.web import app, webaddons, static_viewer
@ -41,7 +40,7 @@ class WebMaster(master.Master):
self.events, self.events,
) )
if with_termlog: if with_termlog:
self.addons.add(termlog.TermLog(), termstatus.TermStatus()) self.addons.add(termlog.TermLog())
self.app = app.Application( self.app = app.Application(
self, self.options.web_debug self, self.options.web_debug
) )

View File

@ -1,22 +0,0 @@
import pytest
from unittest import mock
from mitmproxy.addons import check_ca
from mitmproxy.test import taddons
class TestCheckCA:
@pytest.mark.parametrize('expired', [False, True])
@pytest.mark.asyncio
async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!'
a = check_ca.CheckCA()
with taddons.context(a) as tctx:
tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired
)
tctx.configure(a)
assert await tctx.master.await_log(msg) == expired

View File

@ -1,9 +1,6 @@
import io
from mitmproxy.addons import disable_h2c from mitmproxy.addons import disable_h2c
from mitmproxy.exceptions import Kill from mitmproxy.exceptions import Kill
from mitmproxy.net.http import http1 from mitmproxy.test import taddons, tutils
from mitmproxy.test import taddons
from mitmproxy.test import tflow from mitmproxy.test import tflow
@ -28,9 +25,12 @@ class TestDisableH2CleartextUpgrade:
a = disable_h2c.DisableH2C() a = disable_h2c.DisableH2C()
tctx.configure(a) tctx.configure(a)
b = io.BytesIO(b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n")
f = tflow.tflow() f = tflow.tflow()
f.request = http1.read_request(b) f.request = tutils.treq(
method=b"PRI",
path=b"*",
http_version=b"HTTP/2.0",
)
f.intercept() f.intercept()
a.request(f) a.request(f)

View File

@ -1,18 +0,0 @@
import pytest
from mitmproxy import proxy
from mitmproxy.addons import termstatus
from mitmproxy.test import taddons
@pytest.mark.asyncio
async def test_configure():
ts = termstatus.TermStatus()
with taddons.context() as ctx:
ctx.master.server = proxy.DummyServer()
ctx.master.server.bound = True
ctx.configure(ts, server=False)
ts.running()
ctx.configure(ts, server=True)
ts.running()
await ctx.master.await_log("server listening")

View File

@ -219,3 +219,11 @@ class TestTlsConfig:
assert self.do_handshake(tssl_client, tssl_server) assert self.do_handshake(tssl_client, tssl_server)
assert tssl_server.obj.getpeercert() assert tssl_server.obj.getpeercert()
@pytest.mark.asyncio
async def test_ca_expired(self, monkeypatch):
monkeypatch.setattr(SSL.X509, "has_expired", lambda self: True)
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ta.configure(["confdir"])
await tctx.master.await_log("The mitmproxy certificate authority has expired", "warn")

View File

@ -1,14 +1,12 @@
from io import BytesIO
from unittest.mock import Mock
import pytest import pytest
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net.http import Headers from mitmproxy.net.http import Headers
from mitmproxy.net.http.http1.read import ( from mitmproxy.net.http.http1.read import (
read_request, read_response, read_request_head, read_request_head,
read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, read_response_head, connection_close, expected_http_body_size,
_read_request_line, _read_response_line, _check_http_version, _read_request_line, _read_response_line, _check_http_version,
_read_headers, _read_chunked, get_header_tokens _read_headers, get_header_tokens
) )
from mitmproxy.test.tutils import treq, tresp from mitmproxy.test.tutils import treq, tresp
@ -24,124 +22,6 @@ def test_get_header_tokens():
assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
@pytest.mark.parametrize("input", [
b"GET / HTTP/1.1\r\n\r\nskip",
b"GET / HTTP/1.1\r\n\r\nskip",
b"GET / HTTP/1.1\r\n\r\nskip",
b"GET / HTTP/1.1 \r\n\r\nskip",
])
def test_read_request(input):
rfile = BytesIO(input)
r = read_request(rfile)
assert r.method == "GET"
assert r.content == b""
assert r.http_version == "HTTP/1.1"
assert r.timestamp_end
assert rfile.read() == b"skip"
@pytest.mark.parametrize("input", [
b"CONNECT :0 0",
])
def test_read_request_error(input):
rfile = BytesIO(input)
with pytest.raises(exceptions.HttpException):
read_request(rfile)
def test_read_request_head():
rfile = BytesIO(
b"GET / HTTP/1.1\r\n"
b"Content-Length: 4\r\n"
b"\r\n"
b"skip"
)
rfile.reset_timestamps = Mock()
rfile.first_byte_timestamp = 42
r = read_request_head(rfile)
assert r.method == "GET"
assert r.headers["Content-Length"] == "4"
assert r.content is None
assert rfile.reset_timestamps.called
assert r.timestamp_start == 42
assert rfile.read() == b"skip"
@pytest.mark.parametrize("input", [
b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
b"HTTP/1.1 418 I'm a teapot \r\n\r\nbody",
])
def test_read_response(input):
req = treq()
rfile = BytesIO(input)
r = read_response(rfile, req)
assert r.http_version == "HTTP/1.1"
assert r.status_code == 418
assert r.reason == "I'm a teapot"
assert r.content == b"body"
assert r.timestamp_end
def test_read_response_head():
rfile = BytesIO(
b"HTTP/1.1 418 I'm a teapot\r\n"
b"Content-Length: 4\r\n"
b"\r\n"
b"skip"
)
rfile.reset_timestamps = Mock()
rfile.first_byte_timestamp = 42
r = read_response_head(rfile)
assert r.status_code == 418
assert r.headers["Content-Length"] == "4"
assert r.content is None
assert rfile.reset_timestamps.called
assert r.timestamp_start == 42
assert rfile.read() == b"skip"
class TestReadBody:
def test_chunked(self):
rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar")
body = b"".join(read_body(rfile, None))
assert body == b"foo"
assert rfile.read() == b"bar"
def test_known_size(self):
rfile = BytesIO(b"foobar")
body = b"".join(read_body(rfile, 3))
assert body == b"foo"
assert rfile.read() == b"bar"
def test_known_size_limit(self):
rfile = BytesIO(b"foobar")
with pytest.raises(exceptions.HttpException):
b"".join(read_body(rfile, 3, 2))
def test_known_size_too_short(self):
rfile = BytesIO(b"foo")
with pytest.raises(exceptions.HttpException):
b"".join(read_body(rfile, 6))
def test_unknown_size(self):
rfile = BytesIO(b"foobar")
body = b"".join(read_body(rfile, -1))
assert body == b"foobar"
def test_unknown_size_limit(self):
rfile = BytesIO(b"foobar")
with pytest.raises(exceptions.HttpException):
b"".join(read_body(rfile, -1, 3))
def test_max_chunk_size(self):
rfile = BytesIO(b"123456")
assert list(read_body(rfile, -1, max_chunk_size=None)) == [b"123456"]
rfile = BytesIO(b"123456")
assert list(read_body(rfile, -1, max_chunk_size=1)) == [b"1", b"2", b"3", b"4", b"5", b"6"]
def test_connection_close(): def test_connection_close():
headers = Headers() headers = Headers()
assert connection_close(b"HTTP/1.0", headers) assert connection_close(b"HTTP/1.0", headers)
@ -159,6 +39,41 @@ def test_connection_close():
assert not connection_close(b"HTTP/1.1", headers) assert not connection_close(b"HTTP/1.1", headers)
def test_check_http_version():
_check_http_version(b"HTTP/0.9")
_check_http_version(b"HTTP/1.0")
_check_http_version(b"HTTP/1.1")
_check_http_version(b"HTTP/2.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"WTF/1.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.10")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.b")
def test_read_request_head():
rfile = [
b"GET / HTTP/1.1\r\n",
b"Content-Length: 4\r\n",
]
r = read_request_head(rfile)
assert r.method == "GET"
assert r.headers["Content-Length"] == "4"
assert r.content is None
def test_read_response_head():
rfile = [
b"HTTP/1.1 418 I'm a teapot\r\n",
b"Content-Length: 4\r\n",
]
r = read_response_head(rfile)
assert r.status_code == 418
assert r.headers["Content-Length"] == "4"
assert r.content is None
def test_expected_http_body_size(): def test_expected_http_body_size():
# Expect: 100-continue # Expect: 100-continue
assert expected_http_body_size( assert expected_http_body_size(
@ -176,6 +91,10 @@ def test_expected_http_body_size():
treq(method=b"HEAD"), treq(method=b"HEAD"),
tresp(headers=Headers(content_length="42")) tresp(headers=Headers(content_length="42"))
) == 0 ) == 0
assert expected_http_body_size(
treq(method=b"CONNECT"),
None,
) == 0
assert expected_http_body_size( assert expected_http_body_size(
treq(method=b"CONNECT"), treq(method=b"CONNECT"),
tresp() tresp()
@ -221,26 +140,9 @@ def test_expected_http_body_size():
) == -1 ) == -1
def test_get_first_line():
rfile = BytesIO(b"foo\r\nbar")
assert _get_first_line(rfile) == b"foo"
rfile = BytesIO(b"\r\nfoo\r\nbar")
assert _get_first_line(rfile) == b"foo"
with pytest.raises(exceptions.HttpReadDisconnect):
rfile = BytesIO(b"")
_get_first_line(rfile)
with pytest.raises(exceptions.HttpReadDisconnect):
rfile = Mock()
rfile.readline.side_effect = exceptions.TcpDisconnect
_get_first_line(rfile)
def test_read_request_line(): def test_read_request_line():
def t(b): def t(b):
return _read_request_line(BytesIO(b)) return _read_request_line(b)
assert (t(b"GET / HTTP/1.1") == assert (t(b"GET / HTTP/1.1") ==
("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1")) ("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1"))
@ -251,21 +153,21 @@ def test_read_request_line():
assert (t(b"GET http://foo:42/bar HTTP/1.1") == assert (t(b"GET http://foo:42/bar HTTP/1.1") ==
("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1")) ("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1"))
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"GET / WTF/1.1") t(b"GET / WTF/1.1")
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"CONNECT example.com HTTP/1.1") # port missing t(b"CONNECT example.com HTTP/1.1") # port missing
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"GET ws://example.com/ HTTP/1.1") # port missing t(b"GET ws://example.com/ HTTP/1.1") # port missing
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"this is not http") t(b"this is not http")
with pytest.raises(exceptions.HttpReadDisconnect): with pytest.raises(ValueError):
t(b"") t(b"")
def test_read_response_line(): def test_read_response_line():
def t(b): def t(b):
return _read_response_line(BytesIO(b)) return _read_response_line(b)
assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK")
assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"")
@ -273,40 +175,26 @@ def test_read_response_line():
# https://github.com/mitmproxy/mitmproxy/issues/784 # https://github.com/mitmproxy/mitmproxy/issues/784
assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9") assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9")
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
assert t(b"HTTP/1.1") assert t(b"HTTP/1.1")
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"HTTP/1.1 OK OK") t(b"HTTP/1.1 OK OK")
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
t(b"WTF/1.1 200 OK") t(b"WTF/1.1 200 OK")
with pytest.raises(exceptions.HttpReadDisconnect): with pytest.raises(ValueError):
t(b"") t(b"")
def test_check_http_version():
_check_http_version(b"HTTP/0.9")
_check_http_version(b"HTTP/1.0")
_check_http_version(b"HTTP/1.1")
_check_http_version(b"HTTP/2.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"WTF/1.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.10")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.b")
class TestReadHeaders: class TestReadHeaders:
@staticmethod @staticmethod
def _read(data): def _read(data):
return _read_headers(BytesIO(data)) return _read_headers(data.splitlines(keepends=True))
def test_read_simple(self): def test_read_simple(self):
data = ( data = (
b"Header: one\r\n" b"Header: one\r\n"
b"Header2: two\r\n" b"Header2: two\r\n"
b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two")) assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two"))
@ -315,7 +203,6 @@ class TestReadHeaders:
data = ( data = (
b"Header: one\r\n" b"Header: one\r\n"
b"Header: two\r\n" b"Header: two\r\n"
b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == ((b"Header", b"one"), (b"Header", b"two")) assert headers.fields == ((b"Header", b"one"), (b"Header", b"two"))
@ -325,58 +212,26 @@ class TestReadHeaders:
b"Header: one\r\n" b"Header: one\r\n"
b"\ttwo\r\n" b"\ttwo\r\n"
b"Header2: three\r\n" b"Header2: three\r\n"
b"\r\n"
) )
headers = self._read(data) headers = self._read(data)
assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three")) assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three"))
def test_read_continued_err(self): def test_read_continued_err(self):
data = b"\tfoo: bar\r\n" data = b"\tfoo: bar\r\n"
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
self._read(data) self._read(data)
def test_read_err(self): def test_read_err(self):
data = b"foo" data = b"foo"
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
self._read(data) self._read(data)
def test_read_empty_name(self): def test_read_empty_name(self):
data = b":foo" data = b":foo"
with pytest.raises(exceptions.HttpSyntaxException): with pytest.raises(ValueError):
self._read(data) self._read(data)
def test_read_empty_value(self): def test_read_empty_value(self):
data = b"bar:" data = b"bar:"
headers = self._read(data) headers = self._read(data)
assert headers.fields == ((b"bar", b""),) assert headers.fields == ((b"bar", b""),)
def test_read_chunked():
req = treq(content=None)
req.headers["Transfer-Encoding"] = "chunked"
data = b"1\r\na\r\n0\r\n"
with pytest.raises(exceptions.HttpSyntaxException):
b"".join(_read_chunked(BytesIO(data)))
data = b"1\r\na\r\n0\r\n\r\n"
assert b"".join(_read_chunked(BytesIO(data))) == b"a"
data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n"
assert b"".join(_read_chunked(BytesIO(data))) == b"ab"
data = b"\r\n"
with pytest.raises(Exception, match="closed prematurely"):
b"".join(_read_chunked(BytesIO(data)))
data = b"1\r\nfoo"
with pytest.raises(Exception, match="Malformed chunked body"):
b"".join(_read_chunked(BytesIO(data)))
data = b"foo\r\nfoo"
with pytest.raises(exceptions.HttpSyntaxException):
b"".join(_read_chunked(BytesIO(data)))
data = b"5\r\naaaaa\r\n0\r\n\r\n"
with pytest.raises(Exception, match="too large"):
b"".join(_read_chunked(BytesIO(data), limit=2))

View File

@ -1,192 +0,0 @@
import pytest
from mitmproxy import exceptions
from mitmproxy.net.http import Headers
from mitmproxy.net.http.http1.read_sansio import (
read_request_head,
read_response_head, expected_http_body_size,
_read_request_line, _read_response_line,
_read_headers,
)
from mitmproxy.test.tutils import treq, tresp
def test_read_request_head():
rfile = [
b"GET / HTTP/1.1\r\n",
b"Content-Length: 4\r\n",
]
r = read_request_head(rfile)
assert r.method == "GET"
assert r.headers["Content-Length"] == "4"
assert r.content is None
def test_read_response_head():
rfile = [
b"HTTP/1.1 418 I'm a teapot\r\n",
b"Content-Length: 4\r\n",
]
r = read_response_head(rfile)
assert r.status_code == 418
assert r.headers["Content-Length"] == "4"
assert r.content is None
def test_expected_http_body_size():
# Expect: 100-continue
assert expected_http_body_size(
treq(headers=Headers(expect="100-continue", content_length="42")),
expect_continue_as_0=True
) == 0
# Expect: 100-continue
assert expected_http_body_size(
treq(headers=Headers(expect="100-continue", content_length="42")),
expect_continue_as_0=False
) == 42
# http://tools.ietf.org/html/rfc7230#section-3.3
assert expected_http_body_size(
treq(method=b"HEAD"),
tresp(headers=Headers(content_length="42"))
) == 0
assert expected_http_body_size(
treq(method=b"CONNECT"),
tresp()
) == 0
for code in (100, 204, 304):
assert expected_http_body_size(
treq(),
tresp(status_code=code)
) == 0
# chunked
assert expected_http_body_size(
treq(headers=Headers(transfer_encoding="chunked")),
) is None
# explicit length
for val in (b"foo", b"-7"):
with pytest.raises(exceptions.HttpSyntaxException):
expected_http_body_size(
treq(headers=Headers(content_length=val))
)
assert expected_http_body_size(
treq(headers=Headers(content_length="42"))
) == 42
# more than 1 content-length headers with same value
assert expected_http_body_size(
treq(headers=Headers([(b'content-length', b'42'), (b'content-length', b'42')]))
) == 42
# more than 1 content-length headers with conflicting value
with pytest.raises(exceptions.HttpSyntaxException):
expected_http_body_size(
treq(headers=Headers([(b'content-length', b'42'), (b'content-length', b'45')]))
)
# no length
assert expected_http_body_size(
treq(headers=Headers())
) == 0
assert expected_http_body_size(
treq(headers=Headers()), tresp(headers=Headers())
) == -1
def test_read_request_line():
def t(b):
return _read_request_line(b)
assert (t(b"GET / HTTP/1.1") ==
("", 0, b"GET", b"", b"", b"/", b"HTTP/1.1"))
assert (t(b"OPTIONS * HTTP/1.1") ==
("", 0, b"OPTIONS", b"", b"", b"*", b"HTTP/1.1"))
assert (t(b"CONNECT foo:42 HTTP/1.1") ==
("foo", 42, b"CONNECT", b"", b"foo:42", b"", b"HTTP/1.1"))
assert (t(b"GET http://foo:42/bar HTTP/1.1") ==
("foo", 42, b"GET", b"http", b"foo:42", b"/bar", b"HTTP/1.1"))
with pytest.raises(ValueError):
t(b"GET / WTF/1.1")
with pytest.raises(ValueError):
t(b"CONNECT example.com HTTP/1.1") # port missing
with pytest.raises(ValueError):
t(b"GET ws://example.com/ HTTP/1.1") # port missing
with pytest.raises(ValueError):
t(b"this is not http")
with pytest.raises(ValueError):
t(b"")
def test_read_response_line():
def t(b):
return _read_response_line(b)
assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK")
assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"")
# https://github.com/mitmproxy/mitmproxy/issues/784
assert t(b"HTTP/1.1 200 Non-Autoris\xc3\xa9") == (b"HTTP/1.1", 200, b"Non-Autoris\xc3\xa9")
with pytest.raises(ValueError):
assert t(b"HTTP/1.1")
with pytest.raises(ValueError):
t(b"HTTP/1.1 OK OK")
with pytest.raises(ValueError):
t(b"WTF/1.1 200 OK")
with pytest.raises(ValueError):
t(b"")
class TestReadHeaders:
@staticmethod
def _read(data):
return _read_headers(data.splitlines(keepends=True))
def test_read_simple(self):
data = (
b"Header: one\r\n"
b"Header2: two\r\n"
)
headers = self._read(data)
assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two"))
def test_read_multi(self):
data = (
b"Header: one\r\n"
b"Header: two\r\n"
)
headers = self._read(data)
assert headers.fields == ((b"Header", b"one"), (b"Header", b"two"))
def test_read_continued(self):
data = (
b"Header: one\r\n"
b"\ttwo\r\n"
b"Header2: three\r\n"
)
headers = self._read(data)
assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three"))
def test_read_continued_err(self):
data = b"\tfoo: bar\r\n"
with pytest.raises(ValueError):
self._read(data)
def test_read_err(self):
data = b"foo"
with pytest.raises(ValueError):
self._read(data)
def test_read_empty_name(self):
data = b":foo"
with pytest.raises(ValueError):
self._read(data)
def test_read_empty_value(self):
data = b"bar:"
headers = self._read(data)
assert headers.fields == ((b"bar", b""),)

View File

@ -1,28 +0,0 @@
import pytest
from mitmproxy import options
from mitmproxy import exceptions
from mitmproxy.proxy.config import ProxyConfig
class TestProxyConfig:
def test_invalid_confdir(self):
opts = options.Options()
opts.confdir = "foo"
with pytest.raises(exceptions.OptionsError, match="parent directory does not exist"):
ProxyConfig(opts)
def test_invalid_certificate(self, tdata):
opts = options.Options()
opts.certs = [tdata.path("mitmproxy/data/dumpfile-011.bin")]
with pytest.raises(exceptions.OptionsError, match="Invalid certificate format"):
ProxyConfig(opts)
def test_cannot_set_both_allow_and_filter_options(self):
opts = options.Options()
opts.ignore_hosts = ["foo"]
opts.allow_hosts = ["bar"]
with pytest.raises(exceptions.OptionsError, match="--ignore-hosts and --allow-hosts are "
"mutually exclusive; please choose "
"one."):
ProxyConfig(opts)

View File

@ -138,7 +138,8 @@ async def test_simple():
tctx.master.clear() tctx.master.clear()
a.get("one").response = addons a.get("one").response = addons
a.trigger("response") a.trigger("response")
assert not await tctx.master.await_log("not callable") with pytest.raises(AssertionError):
await tctx.master.await_log("not callable")
a.remove(a.get("one")) a.remove(a.get("one"))
assert not a.get("one") assert not a.get("one")

View File

@ -3,7 +3,6 @@ import argparse
import pytest import pytest
from mitmproxy import options from mitmproxy import options
from mitmproxy.proxy import DummyServer
from mitmproxy.tools import cmdline from mitmproxy.tools import cmdline
from mitmproxy.tools import main from mitmproxy.tools import main
@ -41,11 +40,3 @@ class TestProcessProxyOptions:
self.assert_noerr( self.assert_noerr(
"--cert", "--cert",
tdata.path("mitmproxy/data/testkey.pem")) tdata.path("mitmproxy/data/testkey.pem"))
class TestDummyServer:
def test_simple(self):
d = DummyServer(None)
d.set_channel(None)
d.shutdown()

View File

@ -19,7 +19,7 @@ async def test_recordingmaster():
async def test_dumplog(): async def test_dumplog():
with taddons.context() as tctx: with taddons.context() as tctx:
ctx.log.info("testing") ctx.log.info("testing")
await ctx.master.await_log("testing") assert await ctx.master.await_log("testing")
s = io.StringIO() s = io.StringIO()
tctx.master.dump_log(s) tctx.master.dump_log(s)
assert s.getvalue() assert s.getvalue()

View File

@ -10,7 +10,7 @@ def test_common():
opts = options.Options() opts = options.Options()
cmdline.common_options(parser, opts) cmdline.common_options(parser, opts)
args = parser.parse_args(args=[]) args = parser.parse_args(args=[])
assert main.process_options(parser, opts, args) main.process_options(parser, opts, args)
def test_mitmproxy(): def test_mitmproxy():