Merge pull request #3089 from cortesi/creplay

Revamp client replay
This commit is contained in:
Aldo Cortesi 2018-05-02 11:33:45 +12:00 committed by GitHub
commit 0f6072050a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 400 additions and 425 deletions

View File

@ -1,18 +1,140 @@
import queue
import typing
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
from mitmproxy import ctx
from mitmproxy import io
from mitmproxy import flow
from mitmproxy import command
import mitmproxy.types
import typing
class RequestReplayThread(basethread.BaseThread):
daemon = True
def __init__(
self,
opts: options.Options,
channel: controller.Channel,
queue: queue.Queue,
) -> None:
self.options = opts
self.channel = channel
self.queue = queue
super().__init__("RequestReplayThread")
def run(self):
while True:
f = self.queue.get()
self.replay(f)
def replay(self, f): # pragma: no cover
f.live = True
r = f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
f.response = None
# If we have a channel, run script hooks.
request_reply = self.channel.ask("request", f)
if isinstance(request_reply, http.HTTPResponse):
f.response = request_reply
if not f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(
server_address, (self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException(
"Upstream server refuses CONNECT request"
)
server.establish_tls(
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if f.server_conn:
f.server_conn.close()
f.server_conn = server
f.response = http.HTTPResponse.wrap(
http1.read_response(server.rfile, r, body_size_limit=bsl)
)
response_reply = self.channel.ask("response", f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
f.error = flow.Error(str(e))
self.channel.ask("error", f)
except exceptions.Kill:
self.channel.tell("log", log.LogEntry("Connection killed", "info"))
except Exception as e:
self.channel.tell("log", log.LogEntry(repr(e), "error"))
finally:
r.first_line_format = first_line_format_backup
f.live = False
if server.connected():
server.finish()
server.close()
class ClientPlayback:
def __init__(self):
self.flows: typing.List[flow.Flow] = []
self.current_thread = None
self.configured = False
self.q = queue.Queue()
self.thread: RequestReplayThread = None
def check(self, f: http.HTTPFlow):
if f.live:
return "Can't replay live flow."
if f.intercepted:
return "Can't replay intercepted flow."
if not f.request:
return "Can't replay flow with missing request."
if f.request.raw_content is None:
return "Can't replay flow with missing content."
def load(self, loader):
loader.add_option(
@ -20,65 +142,77 @@ class ClientPlayback:
"Replay client requests from a saved file."
)
def count(self) -> int:
if self.current_thread:
current = 1
else:
current = 0
return current + len(self.flows)
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
Stop client replay.
"""
self.flows = []
ctx.log.alert("Client replay stopped.")
ctx.master.addons.trigger("update", [])
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
Replay requests from flows.
"""
for f in flows:
if f.live:
raise exceptions.CommandError("Can't replay live flow.")
self.flows = list(flows)
ctx.log.alert("Replaying %s flows." % len(self.flows))
ctx.master.addons.trigger("update", [])
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
ctx.log.alert("Replaying %s flows." % len(self.flows))
self.flows = flows
ctx.master.addons.trigger("update", [])
def running(self):
self.thread = RequestReplayThread(
ctx.options,
ctx.master.channel,
self.q,
)
self.thread.start()
def configure(self, updated):
if not self.configured and ctx.options.client_replay:
self.configured = True
ctx.log.info("Client Replay: {}".format(ctx.options.client_replay))
if "client_replay" in updated and ctx.options.client_replay:
try:
flows = io.read_flows_from_paths(ctx.options.client_replay)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e))
self.start_replay(flows)
def tick(self):
current_is_done = self.current_thread and not self.current_thread.is_alive()
can_start_new = not self.current_thread or current_is_done
will_start_new = can_start_new and self.flows
@command.command("replay.client.count")
def count(self) -> int:
"""
Approximate number of flows queued for replay.
"""
return self.q.qsize()
if current_is_done:
self.current_thread = None
ctx.master.addons.trigger("update", [])
if will_start_new:
f = self.flows.pop(0)
self.current_thread = ctx.master.replay_request(f)
ctx.master.addons.trigger("update", [f])
if current_is_done and not will_start_new:
ctx.master.addons.trigger("processing_complete")
@command.command("replay.client.stop")
def stop_replay(self) -> None:
"""
Clear the replay queue.
"""
with self.q.mutex:
lst = list(self.q.queue)
self.q.queue.clear()
for f in lst:
f.revert()
ctx.master.addons.trigger("update", lst)
ctx.log.alert("Client replay queue cleared.")
@command.command("replay.client")
def start_replay(self, flows: typing.Sequence[flow.Flow]) -> None:
"""
Add flows to the replay queue, skipping flows that can't be replayed.
"""
lst = []
for f in flows:
hf = typing.cast(http.HTTPFlow, f)
err = self.check(hf)
if err:
ctx.log.warn(err)
continue
lst.append(hf)
# Prepare the flow for replay
hf.backup()
hf.request.is_replay = True
hf.response = None
hf.error = None
# https://github.com/mitmproxy/mitmproxy/issues/2197
if hf.request.http_version == "HTTP/2.0":
hf.request.http_version = "HTTP/1.1"
host = hf.request.headers.pop(":authority")
hf.request.headers.insert(0, "host", host)
self.q.put(hf)
ctx.master.addons.trigger("update", lst)
@command.command("replay.client.file")
def load_file(self, path: mitmproxy.types.Path) -> None:
"""
Load flows from file, and add them to the replay queue.
"""
try:
flows = io.read_flows_from_paths([path])
except exceptions.FlowReadException as e:
raise exceptions.CommandError(str(e))
self.start_replay(flows)

View File

@ -204,7 +204,15 @@ class CommandManager(mitmproxy.types._CommandBase):
return parse, remhelp
def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any:
def call(self, path: str, *args: typing.Sequence[typing.Any]) -> typing.Any:
"""
Call a command with native arguments. May raise CommandError.
"""
if path not in self.commands:
raise exceptions.CommandError("Unknown command: %s" % path)
return self.commands[path].func(*args)
def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any:
"""
Call a command using a list of string arguments. May raise CommandError.
"""
@ -212,14 +220,14 @@ class CommandManager(mitmproxy.types._CommandBase):
raise exceptions.CommandError("Unknown command: %s" % path)
return self.commands[path].call(args)
def call(self, cmdstr: str):
def execute(self, cmdstr: str):
"""
Call a command using a string. May raise CommandError.
Execute a command string. May raise CommandError.
"""
parts = list(lexer(cmdstr))
if not len(parts) >= 1:
raise exceptions.CommandError("Invalid command: %s" % cmdstr)
return self.call_args(parts[0], parts[1:])
return self.call_strings(parts[0], parts[1:])
def dump(self, out=sys.stdout) -> None:
cmds = list(self.commands.values())

View File

@ -8,13 +8,11 @@ from mitmproxy import addonmanager
from mitmproxy import options
from mitmproxy import controller
from mitmproxy import eventsequence
from mitmproxy import exceptions
from mitmproxy import command
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log
from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay
from mitmproxy.coretypes import basethread
from . import ctx as mitmproxy_ctx
@ -164,58 +162,3 @@ class Master:
f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
await self.addons.handle_lifecycle(e, o)
def replay_request(
self,
f: http.HTTPFlow,
block: bool=False
) -> http_replay.RequestReplayThread:
"""
Replay a HTTP request to receive a new response from the server.
Args:
f: The flow to replay.
block: If True, this function will wait for the replay to finish.
This causes a deadlock if activated in the main thread.
Returns:
The thread object doing the replay.
Raises:
exceptions.ReplayException, if the flow is in a state
where it is ineligible for replay.
"""
if f.live:
raise exceptions.ReplayException(
"Can't replay live flow."
)
if f.intercepted:
raise exceptions.ReplayException(
"Can't replay intercepted flow."
)
if not f.request:
raise exceptions.ReplayException(
"Can't replay flow with missing request."
)
if f.request.raw_content is None:
raise exceptions.ReplayException(
"Can't replay flow with missing content."
)
f.backup()
f.request.is_replay = True
f.response = None
f.error = None
if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
f.request.http_version = "HTTP/1.1"
host = f.request.headers.pop(":authority")
f.request.headers.insert(0, "host", host)
rt = http_replay.RequestReplayThread(self.options, f, self.channel)
rt.start() # pragma: no cover
if block:
rt.join()
return rt

View File

@ -372,12 +372,11 @@ class TCPClient(_Connection):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
if not self.connection:
return
elif isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
if self.connection:
if isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
def convert_to_tls(self, sni=None, alpn_protos=None, **sslctx_kwargs):
context = tls.create_client_context(

View File

@ -1,125 +0,0 @@
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
# TODO: Doesn't really belong into mitmproxy.proxy.protocol...
class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(
self,
opts: options.Options,
f: http.HTTPFlow,
channel: controller.Channel,
) -> None:
self.options = opts
self.f = f
f.live = True
self.channel = channel
super().__init__(
"RequestReplay (%s)" % f.request.url
)
self.daemon = True
def run(self):
r = self.f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
self.f.response = None
# If we have a channel, run script hooks.
if self.channel:
request_reply = self.channel.ask("request", self.f)
if isinstance(request_reply, http.HTTPResponse):
self.f.response = request_reply
if not self.f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if self.f.server_conn:
self.f.server_conn.close()
self.f.server_conn = server
self.f.response = http.HTTPResponse.wrap(
http1.read_response(
server.rfile,
r,
body_size_limit=bsl
)
)
if self.channel:
response_reply = self.channel.ask("response", self.f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
self.f.error = flow.Error(str(e))
if self.channel:
self.channel.ask("error", self.f)
except exceptions.Kill:
# Kill should only be raised if there's a channel in the
# first place.
self.channel.tell(
"log",
log.LogEntry("Connection killed", "info")
)
except Exception as e:
self.channel.tell(
"log",
log.LogEntry(repr(e), "error")
)
finally:
r.first_line_format = first_line_format_backup
self.f.live = False
if server.connected():
server.finish()

View File

@ -112,12 +112,10 @@ class context:
if addon not in self.master.addons:
self.master.addons.register(addon)
with self.options.rollback(kwargs.keys(), reraise=True):
self.options.update(**kwargs)
self.master.addons.invoke_addon(
addon,
"configure",
kwargs.keys()
)
if kwargs:
self.options.update(**kwargs)
else:
self.master.addons.invoke_addon(addon, "configure", {})
def script(self, path):
"""

View File

@ -258,7 +258,7 @@ class ConsoleAddon:
command, then invoke another command with all occurrences of {choice}
replaced by the choice the user made.
"""
choices = ctx.master.commands.call_args(choicecmd, [])
choices = ctx.master.commands.call_strings(choicecmd, [])
def callback(opt):
# We're now outside of the call context...
@ -514,7 +514,7 @@ class ConsoleAddon:
raise exceptions.CommandError("Invalid flowview mode.")
try:
self.master.commands.call_args(
self.master.commands.call_strings(
"view.setval",
["@focus", "flowview_mode_%s" % idx, mode]
)
@ -537,7 +537,7 @@ class ConsoleAddon:
if not fv:
raise exceptions.CommandError("Not viewing a flow.")
idx = fv.body.tab_offset
return self.master.commands.call_args(
return self.master.commands.call_strings(
"view.getval",
[
"@focus",

View File

@ -167,6 +167,7 @@ class StatusBar(urwid.WidgetWrap):
self.ib = urwid.WidgetWrap(urwid.Text(""))
self.ab = ActionBar(self.master)
super().__init__(urwid.Pile([self.ib, self.ab]))
signals.flow_change.connect(self.sig_update)
signals.update_settings.connect(self.sig_update)
signals.flowlist_change.connect(self.sig_update)
master.options.changed.connect(self.sig_update)
@ -184,7 +185,7 @@ class StatusBar(urwid.WidgetWrap):
r = []
sreplay = self.master.addons.get("serverplayback")
creplay = self.master.addons.get("clientplayback")
creplay = self.master.commands.call("replay.client.count")
if len(self.master.options.setheaders):
r.append("[")
@ -192,10 +193,10 @@ class StatusBar(urwid.WidgetWrap):
r.append("eaders]")
if len(self.master.options.replacements):
r.append("[%d replacements]" % len(self.master.options.replacements))
if creplay.count():
if creplay:
r.append("[")
r.append(("heading_key", "cplayback"))
r.append(":%s]" % creplay.count())
r.append(":%s]" % creplay)
if sreplay.count():
r.append("[")
r.append(("heading_key", "splayback"))

View File

@ -344,7 +344,7 @@ class ReplayFlow(RequestHandler):
self.view.update([self.flow])
try:
self.master.replay_request(self.flow)
self.master.commands.call("replay.client", [self.flow])
except exceptions.ReplayException as e:
raise APIError(400, str(e))

View File

@ -47,10 +47,10 @@ class Choice:
class _CommandBase:
commands: typing.MutableMapping[str, typing.Any] = {}
def call_args(self, path: str, args: typing.Sequence[str]) -> typing.Any:
def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any:
raise NotImplementedError
def call(self, cmd: str) -> typing.Any:
def execute(self, cmd: str) -> typing.Any:
raise NotImplementedError
@ -337,7 +337,7 @@ class _FlowType(_BaseFlowType):
def parse(self, manager: _CommandBase, t: type, s: str) -> flow.Flow:
try:
flows = manager.call_args("view.resolve", [s])
flows = manager.call_strings("view.resolve", [s])
except exceptions.CommandError as e:
raise exceptions.TypeError from e
if len(flows) != 1:
@ -356,7 +356,7 @@ class _FlowsType(_BaseFlowType):
def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]:
try:
return manager.call_args("view.resolve", [s])
return manager.call_strings("view.resolve", [s])
except exceptions.CommandError as e:
raise exceptions.TypeError from e
@ -401,17 +401,17 @@ class _ChoiceType(_BaseType):
display = "choice"
def completion(self, manager: _CommandBase, t: Choice, s: str) -> typing.Sequence[str]:
return manager.call(t.options_command)
return manager.execute(t.options_command)
def parse(self, manager: _CommandBase, t: Choice, s: str) -> str:
opts = manager.call(t.options_command)
opts = manager.execute(t.options_command)
if s not in opts:
raise exceptions.TypeError("Invalid choice.")
return s
def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
try:
opts = manager.call(typ.options_command)
opts = manager.execute(typ.options_command)
except exceptions.CommandError:
return False
return val in opts

View File

@ -12,11 +12,11 @@ class TestCheckCA:
async def test_check_ca(self, expired):
msg = 'The mitmproxy certificate authority has expired!'
with taddons.context() as tctx:
a = check_ca.CheckCA()
with taddons.context(a) as tctx:
tctx.master.server = mock.MagicMock()
tctx.master.server.config.certstore.default_ca.has_expired = mock.MagicMock(
return_value = expired
)
a = check_ca.CheckCA()
tctx.configure(a)
assert await tctx.master.await_log(msg) == expired

View File

@ -1,13 +1,16 @@
import time
import pytest
from unittest import mock
from mitmproxy.test import tflow
from mitmproxy.test import tflow, tutils
from mitmproxy import io
from mitmproxy import exceptions
from mitmproxy.net import http as net_http
from mitmproxy.addons import clientplayback
from mitmproxy.test import taddons
from .. import tservers
def tdump(path, flows):
with open(path, "wb") as f:
@ -21,48 +24,87 @@ class MockThread():
return False
class TBase(tservers.HTTPProxyTest):
@staticmethod
def wait_response(flow):
"""
Race condition: We don't want to replay the flow while it is still live.
"""
s = time.time()
while True:
if flow.response or flow.error:
break
time.sleep(0.001)
if time.time() - s > 5:
raise RuntimeError("Flow is live for too long.")
@staticmethod
def reset(f):
f.live = False
f.repsonse = False
f.error = False
def addons(self):
return [clientplayback.ClientPlayback()]
def test_replay(self):
cr = self.master.addons.get("clientplayback")
assert self.pathod("304").status_code == 304
assert len(self.master.state.flows) == 1
l = self.master.state.flows[-1]
assert l.response.status_code == 304
l.request.path = "/p/305"
cr.start_replay([l])
self.wait_response(l)
assert l.response.status_code == 305
# Disconnect error
cr.stop_replay()
self.reset(l)
l.request.path = "/p/305:d0"
cr.start_replay([l])
self.wait_response(l)
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
# # Port error
cr.stop_replay()
self.reset(l)
l.request.port = 1
# In upstream mode, we get a 502 response from the upstream proxy server.
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
cr.start_replay([l])
self.wait_response(l)
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
class TestHTTPProxy(TBase, tservers.HTTPProxyTest):
pass
class TestHTTPSProxy(TBase, tservers.HTTPProxyTest):
ssl = True
class TestUpstreamProxy(TBase, tservers.HTTPUpstreamProxyTest):
pass
class TestClientPlayback:
def test_playback(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp) as tctx:
assert cp.count() == 0
f = tflow.tflow(resp=True)
cp.start_replay([f])
assert cp.count() == 1
RP = "mitmproxy.proxy.protocol.http_replay.RequestReplayThread"
with mock.patch(RP) as rp:
assert not cp.current_thread
cp.tick()
assert rp.called
assert cp.current_thread
cp.flows = []
cp.current_thread.is_alive.return_value = False
assert cp.count() == 1
cp.tick()
assert cp.count() == 0
assert tctx.master.has_event("update")
assert tctx.master.has_event("processing_complete")
cp.current_thread = MockThread()
cp.tick()
assert cp.current_thread is None
cp.start_replay([f])
cp.stop_replay()
assert not cp.flows
df = tflow.DummyFlow(tflow.tclient_conn(), tflow.tserver_conn(), True)
with pytest.raises(exceptions.CommandError, match="Can't replay live flow."):
cp.start_replay([df])
def test_load_file(self, tmpdir):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
fpath = str(tmpdir.join("flows"))
tdump(fpath, [tflow.tflow(resp=True)])
cp.load_file(fpath)
assert cp.flows
assert cp.count() == 1
with pytest.raises(exceptions.CommandError):
cp.load_file("/nonexistent")
@ -71,11 +113,63 @@ class TestClientPlayback:
with taddons.context(cp) as tctx:
path = str(tmpdir.join("flows"))
tdump(path, [tflow.tflow()])
assert cp.count() == 0
tctx.configure(cp, client_replay=[path])
cp.configured = False
assert cp.count() == 1
tctx.configure(cp, client_replay=[])
cp.configured = False
tctx.configure(cp)
cp.configured = False
with pytest.raises(exceptions.OptionsError):
tctx.configure(cp, client_replay=["nonexistent"])
def test_check(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
f = tflow.tflow(resp=True)
f.live = True
assert "live flow" in cp.check(f)
f = tflow.tflow(resp=True)
f.intercepted = True
assert "intercepted flow" in cp.check(f)
f = tflow.tflow(resp=True)
f.request = None
assert "missing request" in cp.check(f)
f = tflow.tflow(resp=True)
f.request.raw_content = None
assert "missing content" in cp.check(f)
@pytest.mark.asyncio
async def test_playback(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp) as ctx:
assert cp.count() == 0
f = tflow.tflow(resp=True)
cp.start_replay([f])
assert cp.count() == 1
cp.stop_replay()
assert cp.count() == 0
f.live = True
cp.start_replay([f])
assert cp.count() == 0
await ctx.master.await_log("live")
def test_http2(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp):
req = tutils.treq(
headers = net_http.Headers(
(
(b":authority", b"foo"),
(b"header", b"qvalue"),
(b"content-length", b"7")
)
)
)
f = tflow.tflow(req=req)
f.request.http_version = "HTTP/2.0"
cp.start_replay([f])
assert f.request.http_version == "HTTP/1.1"
assert ":authority" not in f.request.headers

View File

@ -42,7 +42,7 @@ def corrupt_data():
class TestReadFile:
def test_configure(self):
rf = readfile.ReadFile()
with taddons.context() as tctx:
with taddons.context(rf) as tctx:
tctx.configure(rf, readfile_filter="~q")
with pytest.raises(Exception, match="Invalid readfile filter"):
tctx.configure(rf, readfile_filter="~~")

View File

@ -11,7 +11,7 @@ from mitmproxy.addons import view
def test_configure(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
with taddons.context(sa) as tctx:
with pytest.raises(exceptions.OptionsError):
tctx.configure(sa, save_stream_file=str(tmpdir))
with pytest.raises(Exception, match="Invalid filter"):
@ -32,7 +32,7 @@ def rd(p):
def test_tcp(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
@ -45,7 +45,7 @@ def test_tcp(tmpdir):
def test_websocket(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
@ -73,12 +73,12 @@ def test_save_command(tmpdir):
v = view.View()
tctx.master.addons.add(v)
tctx.master.addons.add(sa)
tctx.master.commands.call_args("save.file", ["@shown", p])
tctx.master.commands.call_strings("save.file", ["@shown", p])
def test_simple(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
with taddons.context(sa) as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)

View File

@ -92,14 +92,13 @@ class TestScript:
@pytest.mark.asyncio
async def test_simple(self, tdata):
with taddons.context() as tctx:
sc = script.Script(
tdata.path(
"mitmproxy/data/addonscripts/recorder/recorder.py"
),
True,
)
tctx.master.addons.add(sc)
sc = script.Script(
tdata.path(
"mitmproxy/data/addonscripts/recorder/recorder.py"
),
True,
)
with taddons.context(sc) as tctx:
tctx.configure(sc)
await tctx.master.await_log("recorder running")
rec = tctx.master.addons.get("recorder")
@ -284,7 +283,7 @@ class TestScriptLoader:
rec = tdata.path("mitmproxy/data/addonscripts/recorder")
sc = script.ScriptLoader()
sc.is_running = True
with taddons.context() as tctx:
with taddons.context(sc) as tctx:
tctx.configure(
sc,
scripts = [

View File

@ -155,7 +155,7 @@ def test_create():
def test_orders():
v = view.View()
with taddons.context():
with taddons.context(v):
assert v.order_options()
@ -303,7 +303,7 @@ def test_setgetval():
def test_order():
v = view.View()
with taddons.context() as tctx:
with taddons.context(v) as tctx:
v.request(tft(method="get", start=1))
v.request(tft(method="put", start=2))
v.request(tft(method="get", start=3))
@ -434,7 +434,7 @@ def test_signals():
def test_focus_follow():
v = view.View()
with taddons.context() as tctx:
with taddons.context(v) as tctx:
console_addon = consoleaddons.ConsoleAddon(tctx.master)
tctx.configure(console_addon)
tctx.configure(v, console_focus_follow=True, view_filter="~m get")
@ -553,7 +553,7 @@ def test_settings():
def test_configure():
v = view.View()
with taddons.context() as tctx:
with taddons.context(v) as tctx:
tctx.configure(v, view_filter="~q")
with pytest.raises(Exception, match="Invalid interception filter"):
tctx.configure(v, view_filter="~~")

View File

@ -1 +0,0 @@
# TODO: write tests

View File

@ -31,48 +31,6 @@ class CommonMixin:
def test_large(self):
assert len(self.pathod("200:b@50k").content) == 1024 * 50
@staticmethod
def wait_until_not_live(flow):
"""
Race condition: We don't want to replay the flow while it is still live.
"""
s = time.time()
while flow.live:
time.sleep(0.001)
if time.time() - s > 5:
raise RuntimeError("Flow is live for too long.")
def test_replay(self):
assert self.pathod("304").status_code == 304
assert len(self.master.state.flows) == 1
l = self.master.state.flows[-1]
assert l.response.status_code == 304
l.request.path = "/p/305"
self.wait_until_not_live(l)
rt = self.master.replay_request(l, block=True)
assert l.response.status_code == 305
# Disconnect error
l.request.path = "/p/305:d0"
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
# Port error
l.request.port = 1
# In upstream mode, we get a 502 response from the upstream proxy server.
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
def test_http(self):
f = self.pathod("304")
assert f.status_code == 304

View File

@ -47,7 +47,7 @@ class AOption:
def test_command():
with taddons.context() as tctx:
tctx.master.addons.add(TAddon("test"))
assert tctx.master.commands.call("test.command") == "here"
assert tctx.master.commands.execute("test.command") == "here"
def test_halt():

View File

@ -242,16 +242,19 @@ def test_simple():
a = TAddon()
c.add("one.two", a.cmd1)
assert c.commands["one.two"].help == "cmd1 help"
assert(c.call("one.two foo") == "ret foo")
assert(c.execute("one.two foo") == "ret foo")
assert(c.call("one.two", "foo") == "ret foo")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.execute("nonexistent")
with pytest.raises(exceptions.CommandError, match="Invalid"):
c.execute("")
with pytest.raises(exceptions.CommandError, match="argument mismatch"):
c.execute("one.two too many args")
with pytest.raises(exceptions.CommandError, match="Unknown"):
c.call("nonexistent")
with pytest.raises(exceptions.CommandError, match="Invalid"):
c.call("")
with pytest.raises(exceptions.CommandError, match="argument mismatch"):
c.call("one.two too many args")
c.add("empty", a.empty)
c.call("empty")
c.execute("empty")
fp = io.StringIO()
c.dump(fp)
@ -340,13 +343,13 @@ def test_decorator():
a = TDec()
c.collect_commands(a)
assert "cmd1" in c.commands
assert c.call("cmd1 bar") == "ret bar"
assert c.execute("cmd1 bar") == "ret bar"
assert "empty" in c.commands
assert c.call("empty") is None
assert c.execute("empty") is None
with taddons.context() as tctx:
tctx.master.addons.add(a)
assert tctx.master.commands.call("cmd1 bar") == "ret bar"
assert tctx.master.commands.execute("cmd1 bar") == "ret bar"
def test_verify_arg_signature():

View File

@ -1,17 +1,14 @@
import io
from unittest import mock
import pytest
from mitmproxy.test import tflow, tutils, taddons
from mitmproxy.test import tflow, taddons
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy import options
from mitmproxy.io import tnetstring
from mitmproxy.exceptions import FlowReadException, ReplayException
from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow
from mitmproxy import http
from mitmproxy.net import http as net_http
from mitmproxy import master
from . import tservers
@ -122,34 +119,6 @@ class TestFlowMaster:
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
def test_replay(self):
opts = options.Options()
fm = master.Master(opts)
f = tflow.tflow(resp=True)
f.request.content = None
with pytest.raises(ReplayException, match="missing"):
fm.replay_request(f)
f.request = None
with pytest.raises(ReplayException, match="request"):
fm.replay_request(f)
f.intercepted = True
with pytest.raises(ReplayException, match="intercepted"):
fm.replay_request(f)
f.live = True
with pytest.raises(ReplayException, match="live"):
fm.replay_request(f)
req = tutils.treq(headers=net_http.Headers(((b":authority", b"foo"), (b"header", b"qvalue"), (b"content-length", b"7"))))
f = tflow.tflow(req=req)
f.request.http_version = "HTTP/2.0"
with mock.patch('mitmproxy.proxy.protocol.http_replay.RequestReplayThread.run'):
rt = fm.replay_request(f)
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(

View File

@ -9,7 +9,6 @@ import tornado.testing
from tornado import httpclient
from tornado import websocket
from mitmproxy import exceptions
from mitmproxy import options
from mitmproxy.test import tflow
from mitmproxy.tools.web import app
@ -186,13 +185,9 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
assert not f._backup
def test_flow_replay(self):
with mock.patch("mitmproxy.master.Master.replay_request") as replay_request:
with mock.patch("mitmproxy.command.CommandManager.call") as replay_call:
assert self.fetch("/flows/42/replay", method="POST").code == 200
assert replay_request.called
replay_request.side_effect = exceptions.ReplayException(
"out of replays"
)
assert self.fetch("/flows/42/replay", method="POST").code == 400
assert replay_call.called
def test_flow_content(self):
f = self.view.get_by_id("42")