Merge pull request #3089 from cortesi/creplay

Revamp client replay
This commit is contained in:
Aldo Cortesi 2018-05-02 11:33:45 +12:00 committed by GitHub
commit 0f6072050a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 400 additions and 425 deletions

View File

@ -1,18 +1,140 @@
import queue
import typing
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy import io from mitmproxy import io
from mitmproxy import flow
from mitmproxy import command from mitmproxy import command
import mitmproxy.types import mitmproxy.types
import typing
class RequestReplayThread(basethread.BaseThread):
daemon = True
def __init__(
self,
opts: options.Options,
channel: controller.Channel,
queue: queue.Queue,
) -> None:
self.options = opts
self.channel = channel
self.queue = queue
super().__init__("RequestReplayThread")
def run(self):
while True:
f = self.queue.get()
self.replay(f)
def replay(self, f): # pragma: no cover
f.live = True
r = f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
f.response = None
# If we have a channel, run script hooks.
request_reply = self.channel.ask("request", f)
if isinstance(request_reply, http.HTTPResponse):
f.response = request_reply
if not f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(
server_address, (self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException(
"Upstream server refuses CONNECT request"
)
server.establish_tls(
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if f.server_conn:
f.server_conn.close()
f.server_conn = server
f.response = http.HTTPResponse.wrap(
http1.read_response(server.rfile, r, body_size_limit=bsl)
)
response_reply = self.channel.ask("response", f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
f.error = flow.Error(str(e))
self.channel.ask("error", f)
except exceptions.Kill:
self.channel.tell("log", log.LogEntry("Connection killed", "info"))
except Exception as e:
self.channel.tell("log", log.LogEntry(repr(e), "error"))
finally:
r.first_line_format = first_line_format_backup
f.live = False
if server.connected():
server.finish()
server.close()
class ClientPlayback: class ClientPlayback:
def __init__(self): def __init__(self):
self.flows: typing.List[flow.Flow] = [] self.q = queue.Queue()
self.current_thread = None self.thread: RequestReplayThread = None
self.configured = False
def check(self, f: http.HTTPFlow):
if f.live:
return "Can't replay live flow."
if f.intercepted:
return "Can't replay intercepted flow."
if not f.request:
return "Can't replay flow with missing request."
if f.request.raw_content is None:
return "Can't replay flow with missing content."
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
@ -20,65 +142,77 @@ class ClientPlayback:
"Replay client requests from a saved file." "Replay client requests from a saved file."
) )
def count(self) -> int: def running(self):
if self.current_thread: self.thread = RequestReplayThread(
current = 1 ctx.options,
else: ctx.master.channel,
current = 0 self.q,
return current + len(self.flows) )
self.thread.start()
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
Stop client replay.
"""
self.flows = []
ctx.log.alert("Client replay stopped.")
ctx.master.addons.trigger("update", [])
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
Replay requests from flows.
"""
for f in flows:
if f.live:
raise exceptions.CommandError("Can't replay live flow.")
self.flows = list(flows)
ctx.log.alert("Replaying %s flows." % len(self.flows))
ctx.master.addons.trigger("update", [])
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
ctx.log.alert("Replaying %s flows." % len(self.flows))
self.flows = flows
ctx.master.addons.trigger("update", [])
def configure(self, updated): def configure(self, updated):
if not self.configured and ctx.options.client_replay: if "client_replay" in updated and ctx.options.client_replay:
self.configured = True
ctx.log.info("Client Replay: {}".format(ctx.options.client_replay))
try: try:
flows = io.read_flows_from_paths(ctx.options.client_replay) flows = io.read_flows_from_paths(ctx.options.client_replay)
except exceptions.FlowReadException as e: except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e)) raise exceptions.OptionsError(str(e))
self.start_replay(flows) self.start_replay(flows)
def tick(self): @command.command("replay.client.count")
current_is_done = self.current_thread and not self.current_thread.is_alive() def count(self) -> int:
can_start_new = not self.current_thread or current_is_done """
will_start_new = can_start_new and self.flows Approximate number of flows queued for replay.
"""
return self.q.qsize()
if current_is_done: @command.command("replay.client.stop")
self.current_thread = None def stop_replay(self) -> None:
ctx.master.addons.trigger("update", []) """
if will_start_new: Clear the replay queue.
f = self.flows.pop(0) """
self.current_thread = ctx.master.replay_request(f) with self.q.mutex:
ctx.master.addons.trigger("update", [f]) lst = list(self.q.queue)
if current_is_done and not will_start_new: self.q.queue.clear()
ctx.master.addons.trigger("processing_complete") for f in lst:
f.revert()
ctx.master.addons.trigger("update", lst)
ctx.log.alert("Client replay queue cleared.")
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
Add flows to the replay queue, skipping flows that can't be replayed.
"""
lst = []
for f in flows:
hf = typing.cast(http.HTTPFlow, f)
err = self.check(hf)
if err:
ctx.log.warn(err)
continue
lst.append(hf)
# Prepare the flow for replay
hf.backup()
hf.request.is_replay = True
hf.response = None
hf.error = None
# https://github.com/mitmproxy/mitmproxy/issues/2197
if hf.request.http_version == "HTTP/2.0":
hf.request.http_version = "HTTP/1.1"
host = hf.request.headers.pop(":authority")
hf.request.headers.insert(0, "host", host)
self.q.put(hf)
ctx.master.addons.trigger("update", lst)
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
"""
Load flows from file, and add them to the replay queue.
"""
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
self.start_replay(flows)

View File

@ -204,7 +204,15 @@ class CommandManager(mitmproxy.types._CommandBase):
return parse, remhelp return parse, remhelp
def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: def call(self, path: str, *args: typing.Sequence[typing.Any]) -> typing.Any:
"""
Call a command with native arguments. May raise CommandError.
"""
if path not in self.commands:
raise exceptions.CommandError("Unknown command: %s" % path)
return self.commands[path].func(*args)
def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any:
""" """
Call a command using a list of string arguments. May raise CommandError. Call a command using a list of string arguments. May raise CommandError.
""" """
@ -212,14 +220,14 @@ class CommandManager(mitmproxy.types._CommandBase):
raise exceptions.CommandError("Unknown command: %s" % path) raise exceptions.CommandError("Unknown command: %s" % path)
return self.commands[path].call(args) return self.commands[path].call(args)
def call(self, cmdstr: str): def execute(self, cmdstr: str):
""" """
Call a command using a string. May raise CommandError. Execute a command string. May raise CommandError.
""" """
parts = list(lexer(cmdstr)) parts = list(lexer(cmdstr))
if not len(parts) >= 1: if not len(parts) >= 1:
raise exceptions.CommandError("Invalid command: %s" % cmdstr) raise exceptions.CommandError("Invalid command: %s" % cmdstr)
return self.call_args(parts[0], parts[1:]) return self.call_strings(parts[0], parts[1:])
def dump(self, out=sys.stdout) -> None: def dump(self, out=sys.stdout) -> None:
cmds = list(self.commands.values()) cmds = list(self.commands.values())

View File

@ -8,13 +8,11 @@ from mitmproxy import addonmanager
from mitmproxy import options from mitmproxy import options
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import eventsequence from mitmproxy import eventsequence
from mitmproxy import exceptions
from mitmproxy import command from mitmproxy import command
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket from mitmproxy import websocket
from mitmproxy import log from mitmproxy import log
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay
from mitmproxy.coretypes import basethread from mitmproxy.coretypes import basethread
from . import ctx as mitmproxy_ctx from . import ctx as mitmproxy_ctx
@ -164,58 +162,3 @@ class Master:
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
await self.addons.handle_lifecycle(e, o) await self.addons.handle_lifecycle(e, o)
def replay_request(
self,
f: http.HTTPFlow,
block: bool=False
) -> http_replay.RequestReplayThread:
"""
Replay a HTTP request to receive a new response from the server.
Args:
f: The flow to replay.
block: If True, this function will wait for the replay to finish.
This causes a deadlock if activated in the main thread.
Returns:
The thread object doing the replay.
Raises:
exceptions.ReplayException, if the flow is in a state
where it is ineligible for replay.
"""
if f.live:
raise exceptions.ReplayException(
"Can't replay live flow."
)
if f.intercepted:
raise exceptions.ReplayException(
"Can't replay intercepted flow."
)
if not f.request:
raise exceptions.ReplayException(
"Can't replay flow with missing request."
)
if f.request.raw_content is None:
raise exceptions.ReplayException(
"Can't replay flow with missing content."
)
f.backup()
f.request.is_replay = True
f.response = None
f.error = None
if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
f.request.http_version = "HTTP/1.1"
host = f.request.headers.pop(":authority")
f.request.headers.insert(0, "host", host)
rt = http_replay.RequestReplayThread(self.options, f, self.channel)
rt.start() # pragma: no cover
if block:
rt.join()
return rt

View File

@ -372,9 +372,8 @@ class TCPClient(_Connection):
# Make sure to close the real socket, not the SSL proxy. # Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate... # it tries to renegotiate...
if not self.connection: if self.connection:
return if isinstance(self.connection, SSL.Connection):
elif isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket) close_socket(self.connection._socket)
else: else:
close_socket(self.connection) close_socket(self.connection)

View File

@ -1,125 +0,0 @@
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
# TODO: Doesn't really belong into mitmproxy.proxy.protocol...
class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(
self,
opts: options.Options,
f: http.HTTPFlow,
channel: controller.Channel,
) -> None:
self.options = opts
self.f = f
f.live = True
self.channel = channel
super().__init__(
"RequestReplay (%s)" % f.request.url
)
self.daemon = True
def run(self):
r = self.f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
self.f.response = None
# If we have a channel, run script hooks.
if self.channel:
request_reply = self.channel.ask("request", self.f)
if isinstance(request_reply, http.HTTPResponse):
self.f.response = request_reply
if not self.f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if self.f.server_conn:
self.f.server_conn.close()
self.f.server_conn = server
self.f.response = http.HTTPResponse.wrap(
http1.read_response(
server.rfile,
r,
body_size_limit=bsl
)
)
if self.channel:
response_reply = self.channel.ask("response", self.f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
self.f.error = flow.Error(str(e))
if self.channel:
self.channel.ask("error", self.f)
except exceptions.Kill:
# Kill should only be raised if there's a channel in the
# first place.
self.channel.tell(
"log",
log.LogEntry("Connection killed", "info")
)
except Exception as e:
self.channel.tell(
"log",
log.LogEntry(repr(e), "error")
)
finally:
r.first_line_format = first_line_format_backup
self.f.live = False
if server.connected():
server.finish()

View File

@ -112,12 +112,10 @@ class context:
if addon not in self.master.addons: if addon not in self.master.addons:
self.master.addons.register(addon) self.master.addons.register(addon)
with self.options.rollback(kwargs.keys(), reraise=True): with self.options.rollback(kwargs.keys(), reraise=True):
if kwargs:
self.options.update(**kwargs) self.options.update(**kwargs)
self.master.addons.invoke_addon( else:
addon, self.master.addons.invoke_addon(addon, "configure", {})
"configure",
kwargs.keys()
)
def script(self, path): def script(self, path):
""" """

View File

@ -258,7 +258,7 @@ class ConsoleAddon:
command, then invoke another command with all occurrences of {choice} command, then invoke another command with all occurrences of {choice}
replaced by the choice the user made. replaced by the choice the user made.
""" """
choices = ctx.master.commands.call_args(choicecmd, []) choices = ctx.master.commands.call_strings(choicecmd, [])
def callback(opt): def callback(opt):
# We're now outside of the call context... # We're now outside of the call context...
@ -514,7 +514,7 @@ class ConsoleAddon:
raise exceptions.CommandError("Invalid flowview mode.") raise exceptions.CommandError("Invalid flowview mode.")
try: try:
self.master.commands.call_args( self.master.commands.call_strings(
"view.setval", "view.setval",
["@focus", "flowview_mode_%s" % idx, mode] ["@focus", "flowview_mode_%s" % idx, mode]
) )
@ -537,7 +537,7 @@ class ConsoleAddon:
if not fv: if not fv:
raise exceptions.CommandError("Not viewing a flow.") raise exceptions.CommandError("Not viewing a flow.")
idx = fv.body.tab_offset idx = fv.body.tab_offset
return self.master.commands.call_args( return self.master.commands.call_strings(
"view.getval", "view.getval",
[ [
"@focus", "@focus",

View File

@ -167,6 +167,7 @@ class StatusBar(urwid.WidgetWrap):
self.ib = urwid.WidgetWrap(urwid.Text("")) self.ib = urwid.WidgetWrap(urwid.Text(""))
self.ab = ActionBar(self.master) self.ab = ActionBar(self.master)
super().__init__(urwid.Pile([self.ib, self.ab])) super().__init__(urwid.Pile([self.ib, self.ab]))
signals.flow_change.connect(self.sig_update)
signals.update_settings.connect(self.sig_update) signals.update_settings.connect(self.sig_update)
signals.flowlist_change.connect(self.sig_update) signals.flowlist_change.connect(self.sig_update)
master.options.changed.connect(self.sig_update) master.options.changed.connect(self.sig_update)
@ -184,7 +185,7 @@ class StatusBar(urwid.WidgetWrap):
r = [] r = []
sreplay = self.master.addons.get("serverplayback") sreplay = self.master.addons.get("serverplayback")
creplay = self.master.addons.get("clientplayback") creplay = self.master.commands.call("replay.client.count")
if len(self.master.options.setheaders): if len(self.master.options.setheaders):
r.append("[") r.append("[")
@ -192,10 +193,10 @@ class StatusBar(urwid.WidgetWrap):
r.append("eaders]") r.append("eaders]")
if len(self.master.options.replacements): if len(self.master.options.replacements):
r.append("[%d replacements]" % len(self.master.options.replacements)) r.append("[%d replacements]" % len(self.master.options.replacements))
if creplay.count(): if creplay:
r.append("[") r.append("[")
r.append(("heading_key", "cplayback")) r.append(("heading_key", "cplayback"))
r.append(":%s]" % creplay.count()) r.append(":%s]" % creplay)
if sreplay.count(): if sreplay.count():
r.append("[") r.append("[")
r.append(("heading_key", "splayback")) r.append(("heading_key", "splayback"))

View File

@ -344,7 +344,7 @@ class ReplayFlow(RequestHandler):
self.view.update([self.flow]) self.view.update([self.flow])
try: try:
self.master.replay_request(self.flow) self.master.commands.call("replay.client", [self.flow])
except exceptions.ReplayException as e: except exceptions.ReplayException as e:
raise APIError(400, str(e)) raise APIError(400, str(e))

View File

@ -47,10 +47,10 @@ class Choice:
class _CommandBase: class _CommandBase:
commands: typing.MutableMapping[str, typing.Any] = {} commands: typing.MutableMapping[str, typing.Any] = {}
def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any: def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any:
raise NotImplementedError raise NotImplementedError
def call(self, cmd: str) -> typing.Any: def execute(self, cmd: str) -> typing.Any:
raise NotImplementedError raise NotImplementedError
@ -337,7 +337,7 @@ class _FlowType(_BaseFlowType):
def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow: def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow:
try: try:
flows = manager.call_args("view.resolve", [s]) flows = manager.call_strings("view.resolve", [s])
except exceptions.CommandError as e: except exceptions.CommandError as e:
raise exceptions.TypeError from e raise exceptions.TypeError from e
if len(flows) != 1: if len(flows) != 1:
@ -356,7 +356,7 @@ class _FlowsType(_BaseFlowType):
def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]: def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]:
try: try:
return manager.call_args("view.resolve", [s]) return manager.call_strings("view.resolve", [s])
except exceptions.CommandError as e: except exceptions.CommandError as e:
raise exceptions.TypeError from e raise exceptions.TypeError from e
@ -401,17 +401,17 @@ class _ChoiceType(_BaseType):
display = "choice" display = "choice"
def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]: def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]:
return manager.call(t.options_command) return manager.execute(t.options_command)
def parse(self, manager: _CommandBase, t: Choice, s: str) -> str: def parse(self, manager: _CommandBase, t: Choice, s: str) -> str:
opts = manager.call(t.options_command) opts = manager.execute(t.options_command)
if s not in opts: if s not in opts:
raise exceptions.TypeError("Invalid choice.") raise exceptions.TypeError("Invalid choice.")
return s return s
def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool: def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
try: try:
opts = manager.call(typ.options_command) opts = manager.execute(typ.options_command)
except exceptions.CommandError: except exceptions.CommandError:
return False return False
return val in opts return val in opts

View File

@ -12,11 +12,11 @@ class TestCheckCA:
async def test_check_ca(self, expired): async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!' msg = 'The mitmproxy certificate authority has expired!'
with taddons.context() as tctx: a = check_ca.CheckCA()
with taddons.context(a) as tctx:
tctx.master.server = mock.MagicMock() tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock( tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired return_value = expired
) )
a = check_ca.CheckCA()
tctx.configure(a) tctx.configure(a)
assert await tctx.master.await_log(msg) == expired assert await tctx.master.await_log(msg) == expired

View File

@ -1,13 +1,16 @@
import time
import pytest import pytest
from unittest import mock
from mitmproxy.test import tflow from mitmproxy.test import tflow, tutils
from mitmproxy import io from mitmproxy import io
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy.net import http as net_http
from mitmproxy.addons import clientplayback from mitmproxy.addons import clientplayback
from mitmproxy.test import taddons from mitmproxy.test import taddons
from .. import tservers
def tdump(path, flows): def tdump(path, flows):
with open(path, "wb") as f: with open(path, "wb") as f:
@ -21,48 +24,87 @@ class MockThread():
return False return False
class TBase(tservers.HTTPProxyTest):
@staticmethod
def wait_response(flow):
"""
Race condition: We don't want to replay the flow while it is still live.
"""
s = time.time()
while True:
if flow.response or flow.error:
break
time.sleep(0.001)
if time.time() - s > 5:
raise RuntimeError("Flow is live for too long.")
@staticmethod
def reset(f):
f.live = False
f.repsonse = False
f.error = False
def addons(self):
return [clientplayback.ClientPlayback()]
def test_replay(self):
cr = self.master.addons.get("clientplayback")
assert self.pathod("304").status_code == 304
assert len(self.master.state.flows) == 1
l = self.master.state.flows[-1]
assert l.response.status_code == 304
l.request.path = "/p/305"
cr.start_replay([l])
self.wait_response(l)
assert l.response.status_code == 305
# Disconnect error
cr.stop_replay()
self.reset(l)
l.request.path = "/p/305:d0"
cr.start_replay([l])
self.wait_response(l)
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
# # Port error
cr.stop_replay()
self.reset(l)
l.request.port = 1
# In upstream mode, we get a 502 response from the upstream proxy server.
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
cr.start_replay([l])
self.wait_response(l)
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
class TestHTTPProxy(TBase, tservers.HTTPProxyTest):
pass
class TestHTTPSProxy(TBase, tservers.HTTPProxyTest):
ssl = True
class TestUpstreamProxy(TBase, tservers.HTTPUpstreamProxyTest):
pass
class TestClientPlayback: class TestClientPlayback:
def test_playback(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp) as tctx:
assert cp.count() == 0
f = tflow.tflow(resp=True)
cp.start_replay([f])
assert cp.count() == 1
RP = "mitmproxy.proxy.protocol.http_replay.RequestReplayThread"
with mock.patch(RP) as rp:
assert not cp.current_thread
cp.tick()
assert rp.called
assert cp.current_thread
cp.flows = []
cp.current_thread.is_alive.return_value = False
assert cp.count() == 1
cp.tick()
assert cp.count() == 0
assert tctx.master.has_event("update")
assert tctx.master.has_event("processing_complete")
cp.current_thread = MockThread()
cp.tick()
assert cp.current_thread is None
cp.start_replay([f])
cp.stop_replay()
assert not cp.flows
df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True)
with pytest.raises(exceptions.CommandError, match="Can't replay live flow."):
cp.start_replay([df])
def test_load_file(self, tmpdir): def test_load_file(self, tmpdir):
cp = clientplayback.ClientPlayback() cp = clientplayback.ClientPlayback()
with taddons.context(cp): with taddons.context(cp):
fpath = str(tmpdir.join("flows")) fpath = str(tmpdir.join("flows"))
tdump(fpath, [tflow.tflow(resp=True)]) tdump(fpath, [tflow.tflow(resp=True)])
cp.load_file(fpath) cp.load_file(fpath)
assert cp.flows assert cp.count() == 1
with pytest.raises(exceptions.CommandError): with pytest.raises(exceptions.CommandError):
cp.load_file("/nonexistent") cp.load_file("/nonexistent")
@ -71,11 +113,63 @@ class TestClientPlayback:
with taddons.context(cp) as tctx: with taddons.context(cp) as tctx:
path = str(tmpdir.join("flows")) path = str(tmpdir.join("flows"))
tdump(path, [tflow.tflow()]) tdump(path, [tflow.tflow()])
assert cp.count() == 0
tctx.configure(cp, client_replay=[path]) tctx.configure(cp, client_replay=[path])
cp.configured = False assert cp.count() == 1
tctx.configure(cp, client_replay=[]) tctx.configure(cp, client_replay=[])
cp.configured = False
tctx.configure(cp)
cp.configured = False
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
tctx.configure(cp, client_replay=["nonexistent"]) tctx.configure(cp, client_replay=["nonexistent"])
def test_check(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
f = tflow.tflow(resp=True)
f.live = True
assert "live flow" in cp.check(f)
f = tflow.tflow(resp=True)
f.intercepted = True
assert "intercepted flow" in cp.check(f)
f = tflow.tflow(resp=True)
f.request = None
assert "missing request" in cp.check(f)
f = tflow.tflow(resp=True)
f.request.raw_content = None
assert "missing content" in cp.check(f)
@pytest.mark.asyncio
async def test_playback(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp) as ctx:
assert cp.count() == 0
f = tflow.tflow(resp=True)
cp.start_replay([f])
assert cp.count() == 1
cp.stop_replay()
assert cp.count() == 0
f.live = True
cp.start_replay([f])
assert cp.count() == 0
await ctx.master.await_log("live")
def test_http2(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
req = tutils.treq(
headers = net_http.Headers(
(
(b":authority", b"foo"),
(b"header", b"qvalue"),
(b"content-length", b"7")
)
)
)
f = tflow.tflow(req=req)
f.request.http_version = "HTTP/2.0"
cp.start_replay([f])
assert f.request.http_version == "HTTP/1.1"
assert ":authority" not in f.request.headers

View File

@ -42,7 +42,7 @@ def corrupt_data():
class TestReadFile: class TestReadFile:
def test_configure(self): def test_configure(self):
rf = readfile.ReadFile() rf = readfile.ReadFile()
with taddons.context() as tctx: with taddons.context(rf) as tctx:
tctx.configure(rf, readfile_filter="~q") tctx.configure(rf, readfile_filter="~q")
with pytest.raises(Exception, match="Invalid readfile filter"): with pytest.raises(Exception, match="Invalid readfile filter"):
tctx.configure(rf, readfile_filter="~~") tctx.configure(rf, readfile_filter="~~")

View File

@ -11,7 +11,7 @@ from mitmproxy.addons import view
def test_configure(tmpdir): def test_configure(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context(sa) as tctx:
with pytest.raises(exceptions.OptionsError): with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, save_stream_file=str(tmpdir)) tctx.configure(sa, save_stream_file=str(tmpdir))
with pytest.raises(Exception, match="Invalid filter"): with pytest.raises(Exception, match="Invalid filter"):
@ -32,7 +32,7 @@ def rd(p):
def test_tcp(tmpdir): def test_tcp(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo")) p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p) tctx.configure(sa, save_stream_file=p)
@ -45,7 +45,7 @@ def test_tcp(tmpdir):
def test_websocket(tmpdir): def test_websocket(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo")) p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p) tctx.configure(sa, save_stream_file=p)
@ -73,12 +73,12 @@ def test_save_command(tmpdir):
v = view.View() v = view.View()
tctx.master.addons.add(v) tctx.master.addons.add(v)
tctx.master.addons.add(sa) tctx.master.addons.add(sa)
tctx.master.commands.call_args("save.file", ["@shown", p]) tctx.master.commands.call_strings("save.file", ["@shown", p])
def test_simple(tmpdir): def test_simple(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo")) p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p) tctx.configure(sa, save_stream_file=p)

View File

@ -92,14 +92,13 @@ class TestScript:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_simple(self, tdata): async def test_simple(self, tdata):
with taddons.context() as tctx:
sc = script.Script( sc = script.Script(
tdata.path( tdata.path(
"mitmproxy/data/addonscripts/recorder/recorder.py" "mitmproxy/data/addonscripts/recorder/recorder.py"
), ),
True, True,
) )
tctx.master.addons.add(sc) with taddons.context(sc) as tctx:
tctx.configure(sc) tctx.configure(sc)
await tctx.master.await_log("recorder running") await tctx.master.await_log("recorder running")
rec = tctx.master.addons.get("recorder") rec = tctx.master.addons.get("recorder")
@ -284,7 +283,7 @@ class TestScriptLoader:
rec = tdata.path("mitmproxy/data/addonscripts/recorder") rec = tdata.path("mitmproxy/data/addonscripts/recorder")
sc = script.ScriptLoader() sc = script.ScriptLoader()
sc.is_running = True sc.is_running = True
with taddons.context() as tctx: with taddons.context(sc) as tctx:
tctx.configure( tctx.configure(
sc, sc,
scripts = [ scripts = [

View File

@ -155,7 +155,7 @@ def test_create():
def test_orders(): def test_orders():
v = view.View() v = view.View()
with taddons.context(): with taddons.context(v):
assert v.order_options() assert v.order_options()
@ -303,7 +303,7 @@ def test_setgetval():
def test_order(): def test_order():
v = view.View() v = view.View()
with taddons.context() as tctx: with taddons.context(v) as tctx:
v.request(tft(method="get", start=1)) v.request(tft(method="get", start=1))
v.request(tft(method="put", start=2)) v.request(tft(method="put", start=2))
v.request(tft(method="get", start=3)) v.request(tft(method="get", start=3))
@ -434,7 +434,7 @@ def test_signals():
def test_focus_follow(): def test_focus_follow():
v = view.View() v = view.View()
with taddons.context() as tctx: with taddons.context(v) as tctx:
console_addon = consoleaddons.ConsoleAddon(tctx.master) console_addon = consoleaddons.ConsoleAddon(tctx.master)
tctx.configure(console_addon) tctx.configure(console_addon)
tctx.configure(v, console_focus_follow=True, view_filter="~m get") tctx.configure(v, console_focus_follow=True, view_filter="~m get")
@ -553,7 +553,7 @@ def test_settings():
def test_configure(): def test_configure():
v = view.View() v = view.View()
with taddons.context() as tctx: with taddons.context(v) as tctx:
tctx.configure(v, view_filter="~q") tctx.configure(v, view_filter="~q")
with pytest.raises(Exception, match="Invalid interception filter"): with pytest.raises(Exception, match="Invalid interception filter"):
tctx.configure(v, view_filter="~~") tctx.configure(v, view_filter="~~")

View File

@ -1 +0,0 @@
# TODO: write tests

View File

@ -31,48 +31,6 @@ class CommonMixin:
def test_large(self): def test_large(self):
assert len(self.pathod("200:b@50k").content) == 1024 * 50 assert len(self.pathod("200:b@50k").content) == 1024 * 50
@staticmethod
def wait_until_not_live(flow):
"""
Race condition: We don't want to replay the flow while it is still live.
"""
s = time.time()
while flow.live:
time.sleep(0.001)
if time.time() - s > 5:
raise RuntimeError("Flow is live for too long.")
def test_replay(self):
assert self.pathod("304").status_code == 304
assert len(self.master.state.flows) == 1
l = self.master.state.flows[-1]
assert l.response.status_code == 304
l.request.path = "/p/305"
self.wait_until_not_live(l)
rt = self.master.replay_request(l, block=True)
assert l.response.status_code == 305
# Disconnect error
l.request.path = "/p/305:d0"
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
# Port error
l.request.port = 1
# In upstream mode, we get a 502 response from the upstream proxy server.
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
def test_http(self): def test_http(self):
f = self.pathod("304") f = self.pathod("304")
assert f.status_code == 304 assert f.status_code == 304

View File

@ -47,7 +47,7 @@ class AOption:
def test_command(): def test_command():
with taddons.context() as tctx: with taddons.context() as tctx:
tctx.master.addons.add(TAddon("test")) tctx.master.addons.add(TAddon("test"))
assert tctx.master.commands.call("test.command") == "here" assert tctx.master.commands.execute("test.command") == "here"
def test_halt(): def test_halt():

View File

@ -242,16 +242,19 @@ def test_simple():
a = TAddon() a = TAddon()
c.add("one.two", a.cmd1) c.add("one.two", a.cmd1)
assert c.commands["one.two"].help == "cmd1 help" assert c.commands["one.two"].help == "cmd1 help"
assert(c.call("one.two foo") == "ret foo") assert(c.execute("one.two foo") == "ret foo")
assert(c.call("one.two", "foo") == "ret foo")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute("nonexistent")
with pytest.raises(exceptions.CommandError, match="Invalid"):
c.execute("")
with pytest.raises(exceptions.CommandError, match="argument mismatch"):
c.execute("one.two too many args")
with pytest.raises(exceptions.CommandError, match="Unknown"): with pytest.raises(exceptions.CommandError, match="Unknown"):
c.call("nonexistent") c.call("nonexistent")
with pytest.raises(exceptions.CommandError, match="Invalid"):
c.call("")
with pytest.raises(exceptions.CommandError, match="argument mismatch"):
c.call("one.two too many args")
c.add("empty", a.empty) c.add("empty", a.empty)
c.call("empty") c.execute("empty")
fp = io.StringIO() fp = io.StringIO()
c.dump(fp) c.dump(fp)
@ -340,13 +343,13 @@ def test_decorator():
a = TDec() a = TDec()
c.collect_commands(a) c.collect_commands(a)
assert "cmd1" in c.commands assert "cmd1" in c.commands
assert c.call("cmd1 bar") == "ret bar" assert c.execute("cmd1 bar") == "ret bar"
assert "empty" in c.commands assert "empty" in c.commands
assert c.call("empty") is None assert c.execute("empty") is None
with taddons.context() as tctx: with taddons.context() as tctx:
tctx.master.addons.add(a) tctx.master.addons.add(a)
assert tctx.master.commands.call("cmd1 bar") == "ret bar" assert tctx.master.commands.execute("cmd1 bar") == "ret bar"
def test_verify_arg_signature(): def test_verify_arg_signature():

View File

@ -1,17 +1,14 @@
import io import io
from unittest import mock
import pytest import pytest
from mitmproxy.test import tflow, tutils, taddons from mitmproxy.test import tflow, taddons
import mitmproxy.io import mitmproxy.io
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import options from mitmproxy import options
from mitmproxy.io import tnetstring from mitmproxy.io import tnetstring
from mitmproxy.exceptions import FlowReadException, ReplayException from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import http from mitmproxy import http
from mitmproxy.net import http as net_http
from mitmproxy import master
from . import tservers from . import tservers
@ -122,34 +119,6 @@ class TestFlowMaster:
assert s.flows[1].handshake_flow == f.handshake_flow assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages) assert len(s.flows[1].messages) == len(f.messages)
def test_replay(self):
opts = options.Options()
fm = master.Master(opts)
f = tflow.tflow(resp=True)
f.request.content = None
with pytest.raises(ReplayException, match="missing"):
fm.replay_request(f)
f.request = None
with pytest.raises(ReplayException, match="request"):
fm.replay_request(f)
f.intercepted = True
with pytest.raises(ReplayException, match="intercepted"):
fm.replay_request(f)
f.live = True
with pytest.raises(ReplayException, match="live"):
fm.replay_request(f)
req = tutils.treq(headers=net_http.Headers(((b":authority", b"foo"), (b"header", b"qvalue"), (b"content-length", b"7"))))
f = tflow.tflow(req=req)
f.request.http_version = "HTTP/2.0"
with mock.patch('mitmproxy.proxy.protocol.http_replay.RequestReplayThread.run'):
rt = fm.replay_request(f)
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all(self): async def test_all(self):
opts = options.Options( opts = options.Options(

View File

@ -9,7 +9,6 @@ import tornado.testing
from tornado import httpclient from tornado import httpclient
from tornado import websocket from tornado import websocket
from mitmproxy import exceptions
from mitmproxy import options from mitmproxy import options
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy.tools.web import app from mitmproxy.tools.web import app
@ -186,13 +185,9 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
assert not f._backup assert not f._backup
def test_flow_replay(self): def test_flow_replay(self):
with mock.patch("mitmproxy.master.Master.replay_request") as replay_request: with mock.patch("mitmproxy.command.CommandManager.call") as replay_call:
assert self.fetch("/flows/42/replay", method="POST").code == 200 assert self.fetch("/flows/42/replay", method="POST").code == 200
assert replay_request.called assert replay_call.called
replay_request.side_effect = exceptions.ReplayException(
"out of replays"
)
assert self.fetch("/flows/42/replay", method="POST").code == 400
def test_flow_content(self): def test_flow_content(self):
f = self.view.get_by_id("42") f = self.view.get_by_id("42")