Reply: remove return value

This commit is contained in:
Maximilian Hils 2020-12-28 17:35:23 +01:00
parent 1655f54817
commit 70f1d173e2
53 changed files with 224 additions and 385 deletions

View File

@ -18,4 +18,4 @@ def websocket_message(flow):
if 'FOOBAR' in message.content:
# kill the message and not send it to the other endpoint
message.kill()
message.content = ""

View File

@ -215,8 +215,6 @@ class AddonManager:
if message.reply.state == "start":
message.reply.take()
if not message.reply.has_message:
message.reply.ack()
message.reply.commit()
if isinstance(message.reply, controller.DummyReply):

View File

@ -1,4 +1,5 @@
import asyncio
import traceback
import urllib.parse
import asgiref.compatibility
@ -28,7 +29,7 @@ class ASGIApp:
assert flow.reply
return bool(
(flow.request.pretty_host, flow.request.port) == (self.host, self.port)
and not flow.reply.has_message
and flow.reply.state == "start" and not flow.error and not flow.response
and not isinstance(flow.reply, DummyReply) # ignore the HTTP flows of this app loaded from somewhere
)
@ -96,6 +97,7 @@ async def serve(app, flow: http.HTTPFlow):
scope = make_scope(flow)
done = asyncio.Event()
received_body = False
sent_response = False
async def receive():
nonlocal received_body
@ -120,18 +122,18 @@ async def serve(app, flow: http.HTTPFlow):
elif event["type"] == "http.response.body":
flow.response.content += event.get("body", b"")
if not event.get("more_body", False):
flow.reply.ack()
nonlocal sent_response
sent_response = True
else:
raise AssertionError(f"Unexpected event: {event['type']}")
try:
await app(scope, receive, send)
if not flow.reply.has_message:
if not sent_response:
raise RuntimeError(f"no response sent.")
except Exception as e:
ctx.log.error(f"Error in asgi app: {e}")
except Exception:
ctx.log.error(f"Error in asgi app:\n{traceback.format_exc(limit=-5)}")
flow.response = http.HTTPResponse.make(500, b"ASGI Error.")
flow.reply.ack(force=True)
finally:
flow.reply.commit()
done.set()

View File

@ -105,7 +105,7 @@ class MapLocal:
self.replacements.append(spec)
def request(self, flow: http.HTTPFlow) -> None:
if flow.reply and flow.reply.has_message:
if flow.response or flow.error or (flow.reply and flow.reply.state == "taken"):
return
url = flow.request.pretty_url

View File

@ -48,7 +48,7 @@ class MapRemote:
self.replacements.append(spec)
def request(self, flow: http.HTTPFlow) -> None:
if flow.reply and flow.reply.has_message:
if flow.response or flow.error or (flow.reply and flow.reply.state == "taken"):
return
for spec in self.replacements:
if spec.matches(flow):

View File

@ -31,12 +31,14 @@ class ModifyBody:
self.replacements.append(spec)
def request(self, flow):
if not flow.reply.has_message:
self.run(flow)
if flow.response or flow.error or flow.reply.state == "taken":
return
self.run(flow)
def response(self, flow):
if not flow.reply.has_message:
self.run(flow)
if flow.error or flow.reply.state == "taken":
return
self.run(flow)
def run(self, flow):
for spec in self.replacements:

View File

@ -73,12 +73,14 @@ class ModifyHeaders:
self.replacements.append(spec)
def request(self, flow):
if not flow.reply.has_message:
self.run(flow, flow.request.headers)
if flow.response or flow.error or flow.reply.state == "taken":
return
self.run(flow, flow.request.headers)
def response(self, flow):
if not flow.reply.has_message:
self.run(flow, flow.response.headers)
if flow.error or flow.reply.state == "taken":
return
self.run(flow, flow.response.headers)
def run(self, flow: http.HTTPFlow, hdrs: Headers) -> None:
# unset all specified headers

View File

@ -212,5 +212,4 @@ class ServerPlayback:
f.request.url
)
)
assert f.reply
f.reply.kill()
f.kill()

View File

@ -41,8 +41,8 @@ class StreamBodies:
expected_size = http1.expected_http_body_size(
f.request, f.response if not is_request else None
)
except exceptions.HttpException:
f.reply.kill()
except ValueError:
f.kill()
return
if expected_size and not r.raw_content and not (0 <= expected_size <= self.max_size):
# r.stream may already be a callable, which we want to preserve.

View File

@ -203,6 +203,11 @@ class TlsConfig:
tls_start.ssl_conn.set_tlsext_host_name(server.sni)
tls_start.ssl_conn.set_connect_state()
def running(self):
# FIXME: We have a weird bug where the contract for configure is not followed and it is never called with
# confdir or command_history as updated.
self.configure("confdir") # pragma: no cover
def configure(self, updated):
if "confdir" not in updated and "certs" not in updated:
return
@ -219,7 +224,7 @@ class TlsConfig:
"The mitmproxy certificate authority has expired!\n"
"Please delete all CA-related files in your ~/.mitmproxy folder.\n"
"The CA will be regenerated automatically after restarting mitmproxy.\n"
"Then make sure all your clients have the new CA installed.",
"See https://docs.mitmproxy.org/stable/concepts-certificates/ for additional help.",
)
for certspec in ctx.options.certs:

View File

@ -18,7 +18,6 @@ import traceback
from typing import Dict, Optional # noqa
from typing import List # noqa
from mitmproxy import exceptions
from mitmproxy.net import http
from mitmproxy.utils import strutils
from . import (
@ -42,7 +41,7 @@ def add(view: View) -> None:
# TODO: auto-select a different name (append an integer?)
for i in views:
if i.name == view.name:
raise exceptions.ContentViewException("Duplicate view: " + view.name)
raise ValueError("Duplicate view: " + view.name)
views.append(view)

View File

@ -1,8 +1,8 @@
import queue
import asyncio
import warnings
from typing import Any
from mitmproxy import exceptions
NO_REPLY = object() # special object we can distinguish from a valid "None" reply.
from mitmproxy import exceptions, flow
class Reply:
@ -12,14 +12,10 @@ class Reply:
"""
def __init__(self, obj):
self.obj = obj
# Spawn an event loop in the current thread
self.q = queue.Queue()
self._state = "start" # "start" -> "taken" -> "committed"
# Holds the reply value. May change before things are actually committed.
self.value = NO_REPLY
self.obj: Any = obj
self.done: asyncio.Event = asyncio.Event()
self._loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
self._state: str = "start" # "start" -> "taken" -> "committed"
@property
def state(self):
@ -36,19 +32,13 @@ class Reply:
"""
return self._state
@property
def has_message(self):
return self.value != NO_REPLY
def take(self):
"""
Scripts or other parties make "take" a reply out of a normal flow.
For example, intercepted flows are taken out so that the connection thread does not proceed.
"""
if self.state != "start":
raise exceptions.ControlException(
f"Reply is {self.state}, but expected it to be start."
)
raise exceptions.ControlException(f"Reply is {self.state}, but expected it to be start.")
self._state = "taken"
def commit(self):
@ -58,35 +48,22 @@ class Reply:
called .take().
"""
if self.state != "taken":
raise exceptions.ControlException(
f"Reply is {self.state}, but expected it to be taken."
)
if not self.has_message:
raise exceptions.ControlException("There is no reply message.")
raise exceptions.ControlException(f"Reply is {self.state}, but expected it to be taken.")
self._state = "committed"
self.q.put(self.value)
try:
self._loop.call_soon_threadsafe(lambda: self.done.set())
except RuntimeError: # pragma: no cover
pass # event loop may already be closed.
def ack(self, force=False):
self.send(self.obj, force)
def kill(self, force=False):
self.send(exceptions.Kill, force)
if self._state == "taken":
self.commit()
def send(self, msg, force=False):
if self.state not in {"start", "taken"}:
raise exceptions.ControlException(
f"Reply is {self.state}, but expected it to be start or taken."
)
if self.has_message and not force:
raise exceptions.ControlException("There is already a reply message.")
self.value = msg
def kill(self, force=False): # pragma: no cover
warnings.warn("reply.kill() is deprecated, use flow.kill() or set the error attribute instead.",
DeprecationWarning, stacklevel=2)
self.obj.error = flow.Error(flow.Error.KILLED_MESSAGE)
def __del__(self):
if self.state != "committed":
# This will be ignored by the interpreter, but emit a warning
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
raise exceptions.ControlException(f"Uncommitted reply: {self.obj}")
class DummyReply(Reply):
@ -103,13 +80,12 @@ class DummyReply(Reply):
def mark_reset(self):
if self.state != "committed":
raise exceptions.ControlException("Uncommitted reply: %s" % self.obj)
raise exceptions.ControlException(f"Uncommitted reply: {self.obj}")
self._should_reset = True
def reset(self):
if self._should_reset:
self._state = "start"
self.value = NO_REPLY
def __del__(self):
pass

View File

@ -25,21 +25,6 @@ class MitmproxyException(Exception):
super().__init__(message)
class Kill(MitmproxyException):
"""
Signal that both client and server connection(s) should be killed immediately.
"""
pass
class ContentViewException(MitmproxyException):
pass
class ReplayException(MitmproxyException):
pass
class FlowReadException(MitmproxyException):
pass
@ -69,24 +54,3 @@ class AddonHalt(MitmproxyException):
class TypeError(MitmproxyException):
pass
class NetlibException(MitmproxyException):
"""
Base class for all exceptions thrown by mitmproxy.net.
"""
def __init__(self, message=None):
super().__init__(message)
class HttpException(NetlibException):
pass
class HttpSyntaxException(HttpException):
pass
class TlsException(NetlibException):
pass

View File

@ -151,16 +151,17 @@ class Flow(stateobject.StateObject):
return (
self.reply and
self.reply.state in {"start", "taken"} and
self.reply.value != exceptions.Kill
not (self.error and self.error.msg == Error.KILLED_MESSAGE)
)
def kill(self):
"""
Kill this request.
"""
if not self.killable:
raise exceptions.ControlException("Flow is not killable.")
self.error = Error(Error.KILLED_MESSAGE)
self.intercepted = False
self.reply.kill(force=True)
self.live = False
def intercept(self):
@ -182,7 +183,6 @@ class Flow(stateobject.StateObject):
self.intercepted = False
# If a flow is intercepted and then duplicated, the duplicated one is not taken.
if self.reply.state == "taken":
self.reply.ack()
self.reply.commit()
@property

View File

@ -1,4 +1,5 @@
import asyncio
import logging
import sys
import threading
import traceback
@ -14,6 +15,12 @@ from mitmproxy import websocket
from mitmproxy.net import server_spec
from . import ctx as mitmproxy_ctx
# Conclusively preventing cross-thread races on proxy shutdown turns out to be
# very hard. We could build a thread sync infrastructure for this, or we could
# wait until we ditch threads and move all the protocols into the async loop.
# Until then, silence non-critical errors.
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class Master:
"""
@ -22,7 +29,7 @@ class Master:
def __init__(self, opts):
self.should_exit = threading.Event()
self.loop = asyncio.get_event_loop()
self.event_loop = asyncio.get_event_loop()
self.options: options.Options = opts or options.Options()
self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self)
@ -81,13 +88,13 @@ class Master:
"""
if not self.should_exit.is_set():
self.should_exit.set()
ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.loop)
ret = asyncio.run_coroutine_threadsafe(self._shutdown(), loop=self.event_loop)
# Weird band-aid to make sure that self._shutdown() is actually executed,
# which otherwise hangs the process as the proxy server is threaded.
# This all needs to be simplified when the proxy server runs on asyncio as well.
if not self.loop.is_running(): # pragma: no cover
if not self.event_loop.is_running(): # pragma: no cover
try:
self.loop.run_until_complete(asyncio.wrap_future(ret))
self.event_loop.run_until_complete(asyncio.wrap_future(ret))
except RuntimeError:
pass # Event loop stopped before Future completed.

View File

@ -2,12 +2,12 @@ from mitmproxy.net.http.request import Request
from mitmproxy.net.http.response import Response
from mitmproxy.net.http.message import Message
from mitmproxy.net.http.headers import Headers, parse_content_type
from mitmproxy.net.http import http1, http2, status_codes, multipart
from mitmproxy.net.http import http1, status_codes, multipart
__all__ = [
"Request",
"Response",
"Message",
"Headers", "parse_content_type",
"http1", "http2", "status_codes", "multipart",
"http1", "status_codes", "multipart",
]

View File

@ -1,9 +1,6 @@
from mitmproxy import exceptions
def assemble_request(request):
if request.data.content is None:
raise exceptions.HttpException("Cannot assemble flow with missing content")
raise ValueError("Cannot assemble flow with missing content")
head = assemble_request_head(request)
body = b"".join(assemble_body(request.data.headers, [request.data.content], request.data.trailers))
return head + body
@ -17,7 +14,7 @@ def assemble_request_head(request):
def assemble_response(response):
if response.data.content is None:
raise exceptions.HttpException("Cannot assemble flow with missing content")
raise ValueError("Cannot assemble flow with missing content")
head = assemble_response_head(response)
body = b"".join(assemble_body(response.data.headers, [response.data.content], response.data.trailers))
return head + body
@ -40,7 +37,7 @@ def assemble_body(headers, body_chunks, trailers):
yield b"0\r\n\r\n"
else:
if trailers:
raise exceptions.HttpException("Sending HTTP/1.1 trailer headers requires transfer-encoding: chunked")
raise ValueError("Sending HTTP/1.1 trailer headers requires transfer-encoding: chunked")
for chunk in body_chunks:
yield chunk

View File

@ -2,7 +2,6 @@ import re
import time
from typing import List, Tuple, Iterable, Optional
from mitmproxy import exceptions
from mitmproxy.net.http import request, response, headers, url
@ -54,7 +53,7 @@ def expected_http_body_size(
- -1, if all data should be read until end of stream.
Raises:
exceptions.HttpSyntaxException, if the content length header is invalid
ValueError, if the content length header is invalid
"""
# Determine response size according to
# http://tools.ietf.org/html/rfc7230#section-3.3
@ -78,27 +77,19 @@ def expected_http_body_size(
if "chunked" in headers.get("transfer-encoding", "").lower():
return None
if "content-length" in headers:
try:
sizes = headers.get_all("content-length")
different_content_length_headers = any(x != sizes[0] for x in sizes)
if different_content_length_headers:
raise exceptions.HttpSyntaxException("Conflicting Content Length Headers")
size = int(sizes[0])
if size < 0:
raise ValueError()
return size
except ValueError as e:
raise exceptions.HttpSyntaxException("Unparseable Content Length") from e
sizes = headers.get_all("content-length")
different_content_length_headers = any(x != sizes[0] for x in sizes)
if different_content_length_headers:
raise ValueError("Conflicting Content Length Headers")
size = int(sizes[0])
if size < 0:
raise ValueError("Negative Content Length")
return size
if not response:
return 0
return -1
def _check_http_version(http_version):
if not re.match(br"^HTTP/\d\.\d$", http_version):
raise exceptions.HttpSyntaxException(f"Unknown HTTP version: {http_version}")
def raise_if_http_version_unknown(http_version: bytes) -> None:
if not re.match(br"^HTTP/\d\.\d$", http_version):
raise ValueError(f"Unknown HTTP version: {http_version!r}")

View File

@ -1,28 +0,0 @@
import codecs
from hyperframe.frame import Frame
from mitmproxy import exceptions
def read_frame(rfile, parse=True):
"""
Reads a full HTTP/2 frame from a file-like object.
Returns a parsed frame and the consumed bytes.
"""
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
if length == 4740180:
raise exceptions.HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1)))
body = rfile.safe_read(length)
if parse:
frame, _ = Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
else:
frame = None
return frame, b''.join([header, body])

View File

@ -9,7 +9,7 @@ import certifi
from kaitaistruct import KaitaiStream
from OpenSSL import SSL
from mitmproxy import certs, exceptions
from mitmproxy import certs
from mitmproxy.contrib.kaitaistruct import tls_client_hello
from mitmproxy.net import check
@ -101,7 +101,7 @@ def _create_ssl_context(
ok = SSL._lib.SSL_CTX_set_min_proto_version(context._context, min_version.value)
ok += SSL._lib.SSL_CTX_set_max_proto_version(context._context, max_version.value)
if ok != 2:
raise exceptions.TlsException(
raise RuntimeError(
f"Error setting TLS versions ({min_version=}, {max_version=}). "
"The version you specified may be unavailable in your libssl."
)
@ -113,8 +113,8 @@ def _create_ssl_context(
if cipher_list is not None:
try:
context.set_cipher_list(b":".join(x.encode() for x in cipher_list))
except SSL.Error as v:
raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
except SSL.Error as e:
raise RuntimeError("SSL cipher specification error: {e}") from e
# SSLKEYLOGFILE
if log_master_secret:
@ -143,7 +143,7 @@ def create_proxy_server_context(
)
if verify is not Verify.VERIFY_NONE and sni is None:
raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
raise ValueError("Cannot validate certificate hostname without SNI")
context.set_verify(verify.value, None)
if sni is not None:
@ -164,16 +164,16 @@ def create_proxy_server_context(
ca_pemfile = certifi.where()
try:
context.load_verify_locations(ca_pemfile, ca_path)
except SSL.Error:
raise exceptions.TlsException(f"Cannot load trusted certificates ({ca_pemfile=}, {ca_path=}).")
except SSL.Error as e:
raise RuntimeError(f"Cannot load trusted certificates ({ca_pemfile=}, {ca_path=}).") from e
# Client Certs
if client_cert:
try:
context.use_privatekey_file(client_cert)
context.use_certificate_chain_file(client_cert)
except SSL.Error as v:
raise exceptions.TlsException(f"TLS client certificate error: {v}")
except SSL.Error as e:
raise RuntimeError(f"Cannot load TLS client certificate: {e}") from e
if alpn_protos is not None:
# advertise application layer protocols
@ -206,8 +206,8 @@ def create_client_proxy_context(
context.use_privatekey(key)
try:
context.load_verify_locations(chain_file, None)
except SSL.Error:
raise exceptions.TlsException(f"Cannot load certificate chain ({chain_file}).")
except SSL.Error as e:
raise RuntimeError(f"Cannot load certificate chain ({chain_file}).") from e
if alpn_select_callback is not None:
assert callable(alpn_select_callback)

View File

@ -367,6 +367,7 @@ class HttpStream(layer.Layer):
is_client_error_but_we_already_talk_upstream = (
isinstance(event, RequestProtocolError)
and self.client_state in (self.state_stream_request_body, self.state_done)
and self.server_state != self.state_errored
)
need_error_hook = not (
self.client_state in (self.state_wait_for_request_headers, self.state_errored)

View File

@ -5,7 +5,7 @@ import h11
from h11._readers import ChunkedReader, ContentLengthReader, Http10Reader
from h11._receivebuffer import ReceiveBuffer
from mitmproxy import exceptions, http
from mitmproxy import http
from mitmproxy.net import http as net_http
from mitmproxy.net.http import http1, status_codes
from mitmproxy.proxy import commands, events, layer
@ -228,7 +228,7 @@ class Http1Server(Http1Connection):
try:
self.request = http1.read_request_head(request_head)
expected_body_size = http1.expected_http_body_size(self.request, expect_continue_as_0=False)
except (ValueError, exceptions.HttpSyntaxException) as e:
except ValueError as e:
yield commands.Log(f"{human.format_address(self.conn.peername)}: {e}")
yield commands.CloseConnection(self.conn)
self.state = self.done
@ -317,7 +317,7 @@ class Http1Client(Http1Connection):
try:
self.response = http1.read_response_head(response_head)
expected_size = http1.expected_http_body_size(self.request, self.response)
except (ValueError, exceptions.HttpSyntaxException) as e:
except ValueError as e:
yield commands.CloseConnection(self.conn)
yield ReceiveHttp(ResponseProtocolError(self.stream_id, f"Cannot parse HTTP response: {e}"))
return

View File

@ -140,6 +140,9 @@ class _TLSLayer(tunnel.TunnelLayer):
tls_start = TlsStartData(self.conn, self.context)
yield TlsStartHook(tls_start)
if not tls_start.ssl_conn:
yield commands.Log("No TLS context was provided, failing connection.", "error")
yield commands.CloseConnection(self.conn)
assert tls_start.ssl_conn
self.tls = tls_start.ssl_conn

View File

@ -26,12 +26,10 @@ def concurrent(fn):
def run():
fn(*args)
if obj.reply.state == "taken":
if not obj.reply.has_message:
obj.reply.ack()
obj.reply.commit()
obj.reply.take()
ScriptThread(
"script.concurrent (%s)" % fn.__name__,
f"script.concurrent {fn.__name__}",
target=run
).start()

View File

@ -1,14 +1,13 @@
import asyncio
import json
import pytest
import flask
import pytest
from flask import request
from mitmproxy.addons import asgiapp
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.addons import next_layer
from mitmproxy.addons.proxyserver import Proxyserver
from mitmproxy.test import taddons
tapp = flask.Flask(__name__)
@ -53,7 +52,7 @@ async def test_asgi_full():
tctx.master.addons.add(next_layer.NextLayer())
tctx.configure(ps, listen_host="127.0.0.1", listen_port=0)
ps.running()
assert await tctx.master.await_log("Proxy server listening", level="info")
await tctx.master.await_log("Proxy server listening", level="info")
proxy_addr = ps.server.sockets[0].getsockname()[:2]
reader, writer = await asyncio.open_connection(*proxy_addr)

View File

@ -17,7 +17,7 @@ async def test_browser():
b.start()
b.browser.poll = lambda: None
b.start()
assert await tctx.master.await_log("already running")
await tctx.master.await_log("already running")
b.done()
assert not b.browser
@ -30,4 +30,4 @@ async def test_no_browser():
b = browser.Browser()
with taddons.context() as tctx:
b.start()
assert await tctx.master.await_log("platform is not supported")
await tctx.master.await_log("platform is not supported")

View File

@ -76,7 +76,7 @@ async def test_playback_crash(monkeypatch):
with taddons.context(cp) as tctx:
cp.running()
cp.start_replay([tflow.tflow()])
assert await tctx.master.await_log("Client replay has crashed!", level="error")
await tctx.master.await_log("Client replay has crashed!", level="error")
assert cp.count() == 0
@ -110,7 +110,7 @@ async def test_start_stop(tdata):
assert cp.count() == 1
cp.start_replay([tflow.twebsocketflow()])
assert await tctx.master.await_log("Can only replay HTTP flows.", level="warn")
await tctx.master.await_log("Can only replay HTTP flows.", level="warn")
assert cp.count() == 1
cp.stop_replay()

View File

@ -36,7 +36,7 @@ class TestCommandHistory:
ch.history.append('cmd3')
tctx.options.confdir = '/non/existent/path/foobar1234/'
ch.done()
assert await tctx.master.await_log(f"Failed writing to {ch.history_file}")
await tctx.master.await_log(f"Failed writing to {ch.history_file}")
def test_add_command(self):
ch = command_history.CommandHistory()
@ -54,7 +54,7 @@ class TestCommandHistory:
with taddons.context(ch) as tctx:
tctx.options.confdir = '/non/existent/path/foobar1234/'
ch.add_command('cmd1')
assert await tctx.master.await_log(f"Failed writing to {ch.history_file}")
await tctx.master.await_log(f"Failed writing to {ch.history_file}")
def test_get_next_and_prev(self, tmpdir):
ch = command_history.CommandHistory()
@ -168,7 +168,7 @@ class TestCommandHistory:
with patch.object(Path, 'unlink') as mock_unlink:
mock_unlink.side_effect = IOError()
ch.clear_history()
assert await tctx.master.await_log(f"Failed deleting {ch.history_file}")
await tctx.master.await_log(f"Failed deleting {ch.history_file}")
def test_filter(self, tmpdir):
ch = command_history.CommandHistory()

View File

@ -94,7 +94,7 @@ async def test_cut_clip():
"copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message)
tctx.command(c.clip, "@all", "request.method")
assert await tctx.master.await_log(log_message, level="error")
await tctx.master.await_log(log_message, level="error")
def test_cut_save(tmpdir):
@ -136,7 +136,7 @@ async def test_cut_save_open(exception, log_message, tmpdir):
with mock.patch("mitmproxy.addons.cut.open") as m:
m.side_effect = exception(log_message)
tctx.command(c.save, "@all", "request.method", f)
assert await tctx.master.await_log(log_message, level="error")
await tctx.master.await_log(log_message, level="error")
def test_cut():

View File

@ -1,5 +1,5 @@
from mitmproxy import flow
from mitmproxy.addons import disable_h2c
from mitmproxy.exceptions import Kill
from mitmproxy.test import taddons, tutils
from mitmproxy.test import tflow
@ -35,4 +35,4 @@ class TestDisableH2CleartextUpgrade:
a.request(f)
assert not f.killable
assert f.reply.value == Kill
assert f.error.msg == flow.Error.KILLED_MESSAGE

View File

@ -193,14 +193,14 @@ class TestContentView:
@pytest.mark.asyncio
async def test_contentview(self):
with mock.patch("mitmproxy.contentviews.auto.ViewAuto.__call__") as va:
va.side_effect = exceptions.ContentViewException("")
va.side_effect = ValueError("")
sio = io.StringIO()
sio_err = io.StringIO()
d = dumper.Dumper(sio, sio_err)
with taddons.context(d) as ctx:
ctx.configure(d, flow_detail=4)
with taddons.context(d) as tctx:
tctx.configure(d, flow_detail=4)
d.response(tflow.tflow())
assert await ctx.master.await_log("content viewer failed")
await tctx.master.await_log("content viewer failed")
def test_tcp():

View File

@ -232,7 +232,7 @@ async def test_export_open(exception, log_message, tmpdir):
with mock.patch("mitmproxy.addons.export.open") as m:
m.side_effect = exception(log_message)
e.file("raw_request", tflow.tflow(resp=True), f)
assert await tctx.master.await_log(log_message, level="error")
await tctx.master.await_log(log_message, level="error")
@pytest.mark.asyncio
@ -263,4 +263,4 @@ async def test_clip(tmpdir):
"copy/paste mechanism for your system."
pc.side_effect = pyperclip.PyperclipException(log_message)
e.clip("raw_request", tflow.tflow(resp=True))
assert await tctx.master.await_log(log_message, level="error")
await tctx.master.await_log(log_message, level="error")

View File

@ -151,7 +151,7 @@ class TestMapLocal:
f.request.url = b"https://example.org/css/nonexistent"
ml.request(f)
assert f.response.status_code == 404
assert await tctx.master.await_log("None of the local file candidates exist")
await tctx.master.await_log("None of the local file candidates exist")
tmpfile = tmpdir.join("foo.jpg")
tmpfile.write("foo")
@ -166,7 +166,7 @@ class TestMapLocal:
f = tflow.tflow()
f.request.url = b"https://example.org/images/foo.jpg"
ml.request(f)
assert await tctx.master.await_log("could not read file")
await tctx.master.await_log("could not read file")
def test_has_reply(self, tmpdir):
ml = MapLocal()
@ -181,6 +181,6 @@ class TestMapLocal:
)
f = tflow.tflow()
f.request.url = b"https://example.org/images/foo.jpg"
f.kill()
f.reply.take()
ml.request(f)
assert not f.response

View File

@ -34,6 +34,6 @@ class TestMapRemote:
tctx.configure(mr, map_remote=[":example.org:mitmproxy.org"])
f = tflow.tflow()
f.request.url = b"https://example.org/images/test.jpg"
f.kill()
f.reply.take()
mr.request(f)
assert f.request.url == "https://example.org/images/test.jpg"

View File

@ -33,6 +33,25 @@ class TestModifyBody:
mb.response(f)
assert f.response.content == b"bar"
@pytest.mark.parametrize("take", [True, False])
def test_taken(self, take):
mb = modifybody.ModifyBody()
with taddons.context(mb) as tctx:
tctx.configure(mb, modify_body=["/foo/bar"])
f = tflow.tflow()
f.request.content = b"foo"
if take:
f.reply.take()
mb.request(f)
assert (f.request.content == b"bar") ^ take
f = tflow.tflow(resp=True)
f.response.content = b"foo"
if take:
f.reply.take()
mb.response(f)
assert (f.response.content == b"bar") ^ take
def test_order(self):
mb = modifybody.ModifyBody()
with taddons.context(mb) as tctx:
@ -86,4 +105,4 @@ class TestModifyBodyFile:
f = tflow.tflow()
f.request.content = b"foo"
mb.request(f)
assert await tctx.master.await_log("could not read")
await tctx.master.await_log("could not read")

View File

@ -1,9 +1,8 @@
import pytest
from mitmproxy.test import tflow
from mitmproxy.test import taddons
from mitmproxy.addons.modifyheaders import parse_modify_spec, ModifyHeaders
from mitmproxy.test import taddons
from mitmproxy.test import tflow
def test_parse_modify_spec():
@ -114,6 +113,23 @@ class TestModifyHeaders:
mh.response(f)
assert "one" not in f.response.headers
@pytest.mark.parametrize("take", [True, False])
def test_taken(self, take):
mh = ModifyHeaders()
with taddons.context(mh) as tctx:
tctx.configure(mh, modify_headers=["/content-length/42"])
f = tflow.tflow()
if take:
f.reply.take()
mh.request(f)
assert (f.request.headers["content-length"] == "42") ^ take
f = tflow.tflow(resp=True)
if take:
f.reply.take()
mh.response(f)
assert (f.response.headers["content-length"] == "42") ^ take
class TestModifyHeadersFile:
def test_simple(self, tmpdir):
@ -150,4 +166,4 @@ class TestModifyHeadersFile:
f = tflow.tflow()
f.request.content = b"foo"
mh.request(f)
assert await tctx.master.await_log("could not read")
await tctx.master.await_log("could not read")

View File

@ -23,20 +23,20 @@ class TestApp:
@pytest.mark.parametrize("ext", ["pem", "p12", "cer"])
@pytest.mark.asyncio
async def test_cert(self, client, ext):
async def test_cert(self, client, ext, tdata):
ob = onboarding.Onboarding()
with taddons.context(ob) as tctx:
tctx.configure(ob)
tctx.configure(ob, confdir=tdata.path("mitmproxy/data/confdir"))
resp = client.get(f"/cert/{ext}")
assert resp.status_code == 200
assert resp.data
@pytest.mark.parametrize("ext", ["pem", "p12", "cer"])
@pytest.mark.asyncio
async def test_head(self, client, ext):
async def test_head(self, client, ext, tdata):
ob = onboarding.Onboarding()
with taddons.context(ob) as tctx:
tctx.configure(ob)
tctx.configure(ob, confdir=tdata.path("mitmproxy/data/confdir"))
resp = client.head(f"http://{tctx.options.onboarding_host}/cert/{ext}")
assert resp.status_code == 200
assert "Content-Length" in resp.headers

View File

@ -51,7 +51,7 @@ async def test_start_stop():
tctx.configure(ps, listen_host="127.0.0.1", listen_port=0)
assert not ps.server
ps.running()
assert await tctx.master.await_log("Proxy server listening", level="info")
await tctx.master.await_log("Proxy server listening", level="info")
assert ps.server
proxy_addr = ps.server.sockets[0].getsockname()[:2]
@ -61,7 +61,7 @@ async def test_start_stop():
assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n"
tctx.configure(ps, server=False)
assert await tctx.master.await_log("Stopping server", level="info")
await tctx.master.await_log("Stopping server", level="info")
assert not ps.server
assert state.flows
assert state.flows[0].request.path == "/hello"
@ -78,6 +78,6 @@ async def test_warn_no_nextlayer():
with taddons.context(ps) as tctx:
tctx.configure(ps, listen_host="127.0.0.1", listen_port=0)
ps.running()
assert await tctx.master.await_log("Proxy server listening at", level="info")
await tctx.master.await_log("Proxy server listening at", level="info")
assert tctx.master.has_log("Warning: Running proxyserver without nextlayer addon!", level="warn")
await ps.shutdown_server()

View File

@ -69,7 +69,7 @@ class TestReadFile:
tf.write(corrupt_data.getvalue())
tctx.configure(rf, rfile=str(tf))
rf.running()
assert await tctx.master.await_log("corrupted")
await tctx.master.await_log("corrupted")
@pytest.mark.asyncio
async def test_corrupt(self, corrupt_data):
@ -81,7 +81,7 @@ class TestReadFile:
tctx.master.clear()
with pytest.raises(exceptions.FlowReadException):
await rf.load_flows(corrupt_data)
assert await tctx.master.await_log("file corrupted")
await tctx.master.await_log("file corrupted")
@pytest.mark.asyncio
async def test_nonexistent_file(self):
@ -89,7 +89,7 @@ class TestReadFile:
with taddons.context(rf) as tctx:
with pytest.raises(exceptions.FlowReadException):
await rf.load_flows_from_path("nonexistent")
assert await tctx.master.await_log("nonexistent")
await tctx.master.await_log("nonexistent")
class TestReadFileStdin:

View File

@ -29,14 +29,14 @@ async def test_load_script(tdata):
script.load_script(
"nonexistent"
)
assert await tctx.master.await_log("No such file or directory")
await tctx.master.await_log("No such file or directory")
script.load_script(
tdata.path(
"mitmproxy/data/addonscripts/recorder/error.py"
)
)
assert await tctx.master.await_log("invalid syntax")
await tctx.master.await_log("invalid syntax")
def test_load_fullname(tdata):
@ -108,7 +108,7 @@ class TestScript:
f.write("\n")
sc = script.Script(str(f), True)
tctx.configure(sc)
assert await tctx.master.await_log("Loading")
await tctx.master.await_log("Loading")
tctx.master.clear()
for i in range(20):
@ -133,8 +133,8 @@ class TestScript:
f = tflow.tflow(resp=True)
tctx.master.addons.trigger("request", f)
assert await tctx.master.await_log("ValueError: Error!")
assert await tctx.master.await_log("error.py")
await tctx.master.await_log("ValueError: Error!")
await tctx.master.await_log("error.py")
@pytest.mark.asyncio
async def test_optionexceptions(self, tdata):
@ -145,7 +145,7 @@ class TestScript:
)
tctx.master.addons.add(sc)
tctx.configure(sc)
assert await tctx.master.await_log("Options Error")
await tctx.master.await_log("Options Error")
@pytest.mark.asyncio
async def test_addon(self, tdata):
@ -201,7 +201,7 @@ class TestScriptLoader:
sc = script.ScriptLoader()
with taddons.context(sc) as tctx:
sc.script_run([tflow.tflow(resp=True)], "/")
assert await tctx.master.await_log("No such script")
await tctx.master.await_log("No such script")
def test_simple(self, tdata):
sc = script.ScriptLoader()
@ -258,10 +258,10 @@ class TestScriptLoader:
tb = True
with taddons.context() as tctx:
script.script_error_handler(path, exc, msg, tb)
assert await tctx.master.await_log("/sample/path/example.py")
assert await tctx.master.await_log("Error raised")
assert await tctx.master.await_log("lineno")
assert await tctx.master.await_log("NoneType")
await tctx.master.await_log("/sample/path/example.py")
await tctx.master.await_log("Error raised")
await tctx.master.await_log("lineno")
await tctx.master.await_log("NoneType")
@pytest.mark.asyncio
async def test_order(self, tdata):

View File

@ -356,7 +356,7 @@ def test_server_playback_kill():
f = tflow.tflow()
f.request.host = "nonexistent"
tctx.cycle(s, f)
assert f.reply.value == exceptions.Kill
assert f.error
def test_server_playback_response_deleted():

View File

@ -226,4 +226,4 @@ class TestTlsConfig:
ta = tlsconfig.TlsConfig()
with taddons.context(ta) as tctx:
ta.configure(["confdir"])
await tctx.master.await_log("The mitmproxy certificate authority has expired", "warn")
await tctx.master.await_log("The mitmproxy certificate authority has expired", "warn")

View File

@ -217,7 +217,7 @@ async def test_load(tmpdir):
with open(path, "wb") as f:
f.write(b"invalidflows")
v.load_file(path)
assert await tctx.master.await_log("Invalid data format.")
await tctx.master.await_log("Invalid data format.")
def test_resolve():

View File

@ -1,11 +1,11 @@
from unittest import mock
import pytest
from mitmproxy import contentviews
from mitmproxy.exceptions import ContentViewException
from mitmproxy.net.http import Headers
from mitmproxy.test import tutils
from mitmproxy.test import tflow
from mitmproxy.test import tutils
class TestContentView(contentviews.View):
@ -19,7 +19,7 @@ def test_add_remove():
assert tcv in contentviews.views
# repeated addition causes exception
with pytest.raises(ContentViewException, match="Duplicate view"):
with pytest.raises(ValueError, match="Duplicate view"):
contentviews.add(tcv)
contentviews.remove(tcv)

View File

@ -1,6 +1,5 @@
import pytest
from mitmproxy import exceptions
from mitmproxy.net.http import Headers
from mitmproxy.net.http.http1.assemble import (
assemble_request, assemble_request_head, assemble_response,
@ -19,7 +18,7 @@ def test_assemble_request():
b"content"
)
with pytest.raises(exceptions.HttpException):
with pytest.raises(ValueError):
assemble_request(treq(content=None))
@ -56,7 +55,7 @@ def test_assemble_response():
b"my-little-trailer: foobar\r\n\r\n"
)
with pytest.raises(exceptions.HttpException):
with pytest.raises(ValueError):
assemble_response(tresp(content=None))
@ -80,7 +79,7 @@ def test_assemble_body():
c = list(assemble_body(Headers(transfer_encoding="chunked"), [b"123456789a"], Headers(trailer="trailer")))
assert c == [b"a\r\n123456789a\r\n", b"0\r\ntrailer: trailer\r\n\r\n"]
with pytest.raises(exceptions.HttpException):
with pytest.raises(ValueError):
list(assemble_body(Headers(), [b"body"], Headers(trailer="trailer")))

View File

@ -1,12 +1,10 @@
import pytest
from mitmproxy import exceptions
from mitmproxy.net.http import Headers
from mitmproxy.net.http.http1.read import (
read_request_head,
read_response_head, connection_close, expected_http_body_size,
_read_request_line, _read_response_line, _check_http_version,
_read_headers, get_header_tokens
_read_request_line, _read_response_line, _read_headers, get_header_tokens
)
from mitmproxy.test.tutils import treq, tresp
@ -39,19 +37,6 @@ def test_connection_close():
assert not connection_close(b"HTTP/1.1", headers)
def test_check_http_version():
_check_http_version(b"HTTP/0.9")
_check_http_version(b"HTTP/1.0")
_check_http_version(b"HTTP/1.1")
_check_http_version(b"HTTP/2.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"WTF/1.0")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.10")
with pytest.raises(exceptions.HttpSyntaxException):
_check_http_version(b"HTTP/1.b")
def test_read_request_head():
rfile = [
b"GET / HTTP/1.1\r\n",
@ -112,7 +97,7 @@ def test_expected_http_body_size():
# explicit length
for val in (b"foo", b"-7"):
with pytest.raises(exceptions.HttpSyntaxException):
with pytest.raises(ValueError):
expected_http_body_size(
treq(headers=Headers(content_length=val))
)
@ -126,7 +111,7 @@ def test_expected_http_body_size():
) == 42
# more than 1 content-length headers with conflicting value
with pytest.raises(exceptions.HttpSyntaxException):
with pytest.raises(ValueError):
expected_http_body_size(
treq(headers=Headers([(b'content-length', b'42'), (b'content-length', b'45')]))
)

View File

@ -1,37 +0,0 @@
import pytest
import codecs
from io import BytesIO
import hyperframe
from mitmproxy import exceptions
from mitmproxy.net.http import http2
def test_read_frame():
raw = codecs.decode('000006000101234567666f6f626172', 'hex_codec')
bio = BytesIO(raw)
bio.safe_read = bio.read
frame, consumed_bytes = http2.read_frame(bio)
assert isinstance(frame, hyperframe.frame.DataFrame)
assert frame.stream_id == 19088743
assert 'END_STREAM' in frame.flags
assert len(frame.flags) == 1
assert frame.data == b'foobar'
assert consumed_bytes == raw
bio = BytesIO(raw)
bio.safe_read = bio.read
frame, consumed_bytes = http2.read_frame(bio, False)
assert frame is None
assert consumed_bytes == raw
def test_read_frame_failed():
raw = codecs.decode('485454000000000000', 'hex_codec')
bio = BytesIO(raw)
bio.safe_read = bio.read
with pytest.raises(exceptions.HttpException):
_ = http2.read_frame(bio, False)

View File

@ -39,7 +39,7 @@ class TestConcurrent:
"mitmproxy/data/addonscripts/concurrent_decorator_err.py"
)
)
assert await tctx.master.await_log("decorator not supported")
await tctx.master.await_log("decorator not supported")
def test_concurrent_class(self, tdata):
with taddons.context() as tctx:

View File

@ -129,17 +129,17 @@ async def test_simple():
a.add(TAddon("one"))
a.trigger("nonexistent")
assert await tctx.master.await_log("unknown event")
await tctx.master.await_log("unknown event")
a.trigger("running")
a.trigger("response")
assert await tctx.master.await_log("not callable")
await tctx.master.await_log("not callable")
tctx.master.clear()
a.get("one").response = addons
a.trigger("response")
with pytest.raises(AssertionError):
await tctx.master.await_log("not callable")
await tctx.master.await_log("not callable", timeout=0.01)
a.remove(a.get("one"))
assert not a.get("one")

View File

@ -1,110 +1,55 @@
import asyncio
import queue
import pytest
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy import controller
from mitmproxy.test import taddons
import mitmproxy.ctx
from mitmproxy import controller
from mitmproxy.exceptions import ControlException
from mitmproxy.test import taddons
@pytest.mark.asyncio
async def test_master():
class tAddon:
def log(self, _):
ctx.master.should_exit.set()
mitmproxy.ctx.master.should_exit.set()
with taddons.context(tAddon()) as ctx:
assert not ctx.master.should_exit.is_set()
with taddons.context(tAddon()) as tctx:
assert not tctx.master.should_exit.is_set()
async def test():
mitmproxy.ctx.log("test")
asyncio.ensure_future(test())
assert await ctx.master.await_log("test")
assert ctx.master.should_exit.is_set()
await tctx.master.await_log("test")
assert tctx.master.should_exit.is_set()
class TestReply:
def test_simple(self):
@pytest.mark.asyncio
async def test_simple(self):
reply = controller.Reply(42)
assert reply.state == "start"
reply.send("foo")
assert reply.value == "foo"
reply.take()
assert reply.state == "taken"
with pytest.raises(queue.Empty):
reply.q.get_nowait()
assert not reply.done.is_set()
reply.commit()
assert reply.state == "committed"
assert reply.q.get() == "foo"
assert await asyncio.wait_for(reply.done.wait(), 1)
def test_kill(self):
reply = controller.Reply(43)
reply.kill()
def test_double_commit(self):
reply = controller.Reply(47)
reply.take()
reply.commit()
assert reply.q.get() == Kill
def test_ack(self):
reply = controller.Reply(44)
reply.ack()
reply.take()
reply.commit()
assert reply.q.get() == 44
def test_reply_none(self):
reply = controller.Reply(45)
reply.send(None)
reply.take()
reply.commit()
assert reply.q.get() is None
def test_commit_no_reply(self):
reply = controller.Reply(46)
reply.take()
with pytest.raises(ControlException):
reply.commit()
reply.ack()
reply.commit()
def test_double_send(self):
reply = controller.Reply(47)
reply.send(1)
with pytest.raises(ControlException):
reply.send(2)
reply.take()
reply.commit()
def test_state_transitions(self):
states = {"start", "taken", "committed"}
accept = {
"take": {"start"},
"commit": {"taken"},
"ack": {"start", "taken"},
}
for fn, ok in accept.items():
for state in states:
r = controller.Reply(48)
r._state = state
if fn == "commit":
r.value = 49
if state in ok:
getattr(r, fn)()
else:
with pytest.raises(ControlException):
getattr(r, fn)()
r._state = "committed" # hide warnings on deletion
def test_del(self):
reply = controller.Reply(47)
with pytest.raises(ControlException):
reply.__del__()
reply.ack()
reply.take()
reply.commit()
@ -113,7 +58,6 @@ class TestDummyReply:
def test_simple(self):
reply = controller.DummyReply()
for _ in range(2):
reply.ack()
reply.take()
reply.commit()
reply.mark_reset()
@ -122,7 +66,6 @@ class TestDummyReply:
def test_reset(self):
reply = controller.DummyReply()
reply.ack()
reply.take()
with pytest.raises(ControlException):
reply.mark_reset()

View File

@ -1,12 +1,12 @@
import pytest
from mitmproxy.test import tflow
from mitmproxy.net.http import Headers
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy import flow
from mitmproxy import flowfilter
from mitmproxy import http
from mitmproxy.exceptions import ControlException
from mitmproxy.net.http import Headers
from mitmproxy.test import tflow
class TestHTTPRequest:
@ -169,7 +169,7 @@ class TestHTTPFlow:
assert f.killable
f.kill()
assert not f.killable
assert f.reply.value == Kill
assert f.error.msg == flow.Error.KILLED_MESSAGE
def test_intercept(self):
f = tflow.tflow()

View File

@ -12,14 +12,14 @@ async def test_recordingmaster():
assert not tctx.master.has_log("nonexistent")
ctx.log.error("foo")
assert not tctx.master.has_log("foo", level="debug")
assert await tctx.master.await_log("foo", level="error")
await tctx.master.await_log("foo", level="error")
@pytest.mark.asyncio
async def test_dumplog():
with taddons.context() as tctx:
ctx.log.info("testing")
assert await ctx.master.await_log("testing")
await tctx.master.await_log("testing")
s = io.StringIO()
tctx.master.dump_log(s)
assert s.getvalue()

View File

@ -3,7 +3,7 @@ import io
import pytest
from mitmproxy import flowfilter
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy.exceptions import ControlException
from mitmproxy.io import tnetstring
from mitmproxy.test import tflow
@ -56,7 +56,6 @@ class TestWebSocketFlow:
assert f.killable
f.kill()
assert not f.killable
assert f.reply.value == Kill
def test_match(self):
f = tflow.twebsocketflow()