client replay: move all client replay-related code into addon

This commit is contained in:
Aldo Cortesi 2018-04-27 16:34:56 +12:00
parent bc3ace6082
commit 28d53d5a24
9 changed files with 260 additions and 269 deletions

View File

@ -1,19 +1,194 @@
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
from mitmproxy import ctx
from mitmproxy import io
from mitmproxy import flow
from mitmproxy import command
import mitmproxy.types
import typing
class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(
self,
opts: options.Options,
f: http.HTTPFlow,
channel: controller.Channel,
) -> None:
self.options = opts
self.f = f
f.live = True
self.channel = channel
super().__init__(
"RequestReplay (%s)" % f.request.url
)
self.daemon = True
def run(self):
r = self.f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
self.f.response = None
# If we have a channel, run script hooks.
if self.channel:
request_reply = self.channel.ask("request", self.f)
if isinstance(request_reply, http.HTTPResponse):
self.f.response = request_reply
if not self.f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if self.f.server_conn:
self.f.server_conn.close()
self.f.server_conn = server
self.f.response = http.HTTPResponse.wrap(
http1.read_response(
server.rfile,
r,
body_size_limit=bsl
)
)
if self.channel:
response_reply = self.channel.ask("response", self.f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
self.f.error = flow.Error(str(e))
if self.channel:
self.channel.ask("error", self.f)
except exceptions.Kill:
# Kill should only be raised if there's a channel in the
# first place.
self.channel.tell(
"log",
log.LogEntry("Connection killed", "info")
)
except Exception as e:
self.channel.tell(
"log",
log.LogEntry(repr(e), "error")
)
finally:
r.first_line_format = first_line_format_backup
self.f.live = False
if server.connected():
server.finish()
class ClientPlayback:
def __init__(self):
self.flows: typing.List[flow.Flow] = []
self.current_thread = None
self.configured = False
def replay_request(
self,
f: http.HTTPFlow,
block: bool=False
) -> RequestReplayThread:
"""
Replay a HTTP request to receive a new response from the server.
Args:
f: The flow to replay.
block: If True, this function will wait for the replay to finish.
This causes a deadlock if activated in the main thread.
Returns:
The thread object doing the replay.
Raises:
exceptions.ReplayException, if the flow is in a state
where it is ineligible for replay.
"""
if f.live:
raise exceptions.ReplayException(
"Can't replay live flow."
)
if f.intercepted:
raise exceptions.ReplayException(
"Can't replay intercepted flow."
)
if not f.request:
raise exceptions.ReplayException(
"Can't replay flow with missing request."
)
if f.request.raw_content is None:
raise exceptions.ReplayException(
"Can't replay flow with missing content."
)
f.backup()
f.request.is_replay = True
f.response = None
f.error = None
if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
f.request.http_version = "HTTP/1.1"
host = f.request.headers.pop(":authority")
f.request.headers.insert(0, "host", host)
rt = RequestReplayThread(ctx.master.options, f, ctx.master.channel)
rt.start() # pragma: no cover
if block:
rt.join()
return rt
def load(self, loader):
loader.add_option(
"client_replay", typing.Sequence[str], [],
@ -78,7 +253,7 @@ class ClientPlayback:
ctx.master.addons.trigger("update", [])
if will_start_new:
f = self.flows.pop(0)
self.current_thread = ctx.master.replay_request(f)
self.current_thread = self.replay_request(f)
ctx.master.addons.trigger("update", [f])
if current_is_done and not will_start_new:
ctx.master.addons.trigger("processing_complete")

View File

@ -8,13 +8,11 @@ from mitmproxy import addonmanager
from mitmproxy import options
from mitmproxy import controller
from mitmproxy import eventsequence
from mitmproxy import exceptions
from mitmproxy import command
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log
from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay
from mitmproxy.coretypes import basethread
from . import ctx as mitmproxy_ctx
@ -164,58 +162,3 @@ class Master:
f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
await self.addons.handle_lifecycle(e, o)
def replay_request(
self,
f: http.HTTPFlow,
block: bool=False
) -> http_replay.RequestReplayThread:
"""
Replay a HTTP request to receive a new response from the server.
Args:
f: The flow to replay.
block: If True, this function will wait for the replay to finish.
This causes a deadlock if activated in the main thread.
Returns:
The thread object doing the replay.
Raises:
exceptions.ReplayException, if the flow is in a state
where it is ineligible for replay.
"""
if f.live:
raise exceptions.ReplayException(
"Can't replay live flow."
)
if f.intercepted:
raise exceptions.ReplayException(
"Can't replay intercepted flow."
)
if not f.request:
raise exceptions.ReplayException(
"Can't replay flow with missing request."
)
if f.request.raw_content is None:
raise exceptions.ReplayException(
"Can't replay flow with missing content."
)
f.backup()
f.request.is_replay = True
f.response = None
f.error = None
if f.request.http_version == "HTTP/2.0": # https://github.com/mitmproxy/mitmproxy/issues/2197
f.request.http_version = "HTTP/1.1"
host = f.request.headers.pop(":authority")
f.request.headers.insert(0, "host", host)
rt = http_replay.RequestReplayThread(self.options, f, self.channel)
rt.start() # pragma: no cover
if block:
rt.join()
return rt

View File

@ -1,125 +0,0 @@
from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import options
from mitmproxy import connections
from mitmproxy.net import server_spec, tls
from mitmproxy.net.http import http1
from mitmproxy.coretypes import basethread
from mitmproxy.utils import human
# TODO: Doesn't really belong into mitmproxy.proxy.protocol...
class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(
self,
opts: options.Options,
f: http.HTTPFlow,
channel: controller.Channel,
) -> None:
self.options = opts
self.f = f
f.live = True
self.channel = channel
super().__init__(
"RequestReplay (%s)" % f.request.url
)
self.daemon = True
def run(self):
r = self.f.request
bsl = human.parse_size(self.options.body_size_limit)
first_line_format_backup = r.first_line_format
server = None
try:
self.f.response = None
# If we have a channel, run script hooks.
if self.channel:
request_reply = self.channel.ask("request", self.f)
if isinstance(request_reply, http.HTTPResponse):
self.f.response = request_reply
if not self.f.response:
# In all modes, we directly connect to the server displayed
if self.options.mode.startswith("upstream:"):
server_address = server_spec.parse_with_mode(self.options.mode)[1].address
server = connections.ServerConnection(server_address, (self.options.listen_host, 0))
server.connect()
if r.scheme == "https":
connect_request = http.make_connect_request((r.data.host, r.port))
server.wfile.write(http1.assemble_request(connect_request))
server.wfile.flush()
resp = http1.read_response(
server.rfile,
connect_request,
body_size_limit=bsl
)
if resp.status_code != 200:
raise exceptions.ReplayException("Upstream server refuses CONNECT request")
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
else:
r.first_line_format = "absolute"
else:
server_address = (r.host, r.port)
server = connections.ServerConnection(
server_address,
(self.options.listen_host, 0)
)
server.connect()
if r.scheme == "https":
server.establish_tls(
sni=self.f.server_conn.sni,
**tls.client_arguments_from_options(self.options)
)
r.first_line_format = "relative"
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
if self.f.server_conn:
self.f.server_conn.close()
self.f.server_conn = server
self.f.response = http.HTTPResponse.wrap(
http1.read_response(
server.rfile,
r,
body_size_limit=bsl
)
)
if self.channel:
response_reply = self.channel.ask("response", self.f)
if response_reply == exceptions.Kill:
raise exceptions.Kill()
except (exceptions.ReplayException, exceptions.NetlibException) as e:
self.f.error = flow.Error(str(e))
if self.channel:
self.channel.ask("error", self.f)
except exceptions.Kill:
# Kill should only be raised if there's a channel in the
# first place.
self.channel.tell(
"log",
log.LogEntry("Connection killed", "info")
)
except Exception as e:
self.channel.tell(
"log",
log.LogEntry(repr(e), "error")
)
finally:
r.first_line_format = first_line_format_backup
self.f.live = False
if server.connected():
server.finish()

View File

@ -344,6 +344,7 @@ class ReplayFlow(RequestHandler):
self.view.update([self.flow])
try:
self.master.command.call
self.master.replay_request(self.flow)
except exceptions.ReplayException as e:
raise APIError(400, str(e))

View File

@ -22,6 +22,76 @@ class MockThread():
class TestClientPlayback:
# @staticmethod
# def wait_until_not_live(flow):
# """
# Race condition: We don't want to replay the flow while it is still live.
# """
# s = time.time()
# while flow.live:
# time.sleep(0.001)
# if time.time() - s > 5:
# raise RuntimeError("Flow is live for too long.")
# def test_replay(self):
# assert self.pathod("304").status_code == 304
# assert len(self.master.state.flows) == 1
# l = self.master.state.flows[-1]
# assert l.response.status_code == 304
# l.request.path = "/p/305"
# self.wait_until_not_live(l)
# rt = self.master.replay_request(l, block=True)
# assert l.response.status_code == 305
# # Disconnect error
# l.request.path = "/p/305:d0"
# rt = self.master.replay_request(l, block=True)
# assert rt
# if isinstance(self, tservers.HTTPUpstreamProxyTest):
# assert l.response.status_code == 502
# else:
# assert l.error
# # Port error
# l.request.port = 1
# # In upstream mode, we get a 502 response from the upstream proxy server.
# # In upstream mode with ssl, the replay will fail as we cannot establish
# # SSL with the upstream proxy.
# rt = self.master.replay_request(l, block=True)
# assert rt
# if isinstance(self, tservers.HTTPUpstreamProxyTest):
# assert l.response.status_code == 502
# else:
# assert l.error
# def test_replay(self):
# opts = options.Options()
# fm = master.Master(opts)
# f = tflow.tflow(resp=True)
# f.request.content = None
# with pytest.raises(ReplayException, match="missing"):
# fm.replay_request(f)
# f.request = None
# with pytest.raises(ReplayException, match="request"):
# fm.replay_request(f)
# f.intercepted = True
# with pytest.raises(ReplayException, match="intercepted"):
# fm.replay_request(f)
# f.live = True
# with pytest.raises(ReplayException, match="live"):
# fm.replay_request(f)
# req = tutils.treq(headers=net_http.Headers(((b":authority", b"foo"), (b"header", b"qvalue"), (b"content-length", b"7"))))
# f = tflow.tflow(req=req)
# f.request.http_version = "HTTP/2.0"
# with mock.patch('mitmproxy.proxy.protocol.http_replay.RequestReplayThread.run'):
# rt = fm.replay_request(f)
# assert rt.f.request.http_version == "HTTP/1.1"
# assert ":authority" not in rt.f.request.headers
def test_playback(self):
cp = clientplayback.ClientPlayback()
with taddons.context(cp) as tctx:
@ -29,7 +99,7 @@ class TestClientPlayback:
f = tflow.tflow(resp=True)
cp.start_replay([f])
assert cp.count() == 1
RP = "mitmproxy.proxy.protocol.http_replay.RequestReplayThread"
RP = "mitmproxy.addons.clientplayback.RequestReplayThread"
with mock.patch(RP) as rp:
assert not cp.current_thread
cp.tick()

View File

@ -1 +0,0 @@
# TODO: write tests

View File

@ -31,48 +31,6 @@ class CommonMixin:
def test_large(self):
assert len(self.pathod("200:b@50k").content) == 1024 * 50
@staticmethod
def wait_until_not_live(flow):
"""
Race condition: We don't want to replay the flow while it is still live.
"""
s = time.time()
while flow.live:
time.sleep(0.001)
if time.time() - s > 5:
raise RuntimeError("Flow is live for too long.")
def test_replay(self):
assert self.pathod("304").status_code == 304
assert len(self.master.state.flows) == 1
l = self.master.state.flows[-1]
assert l.response.status_code == 304
l.request.path = "/p/305"
self.wait_until_not_live(l)
rt = self.master.replay_request(l, block=True)
assert l.response.status_code == 305
# Disconnect error
l.request.path = "/p/305:d0"
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
# Port error
l.request.port = 1
# In upstream mode, we get a 502 response from the upstream proxy server.
# In upstream mode with ssl, the replay will fail as we cannot establish
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
assert rt
if isinstance(self, tservers.HTTPUpstreamProxyTest):
assert l.response.status_code == 502
else:
assert l.error
def test_http(self):
f = self.pathod("304")
assert f.status_code == 304

View File

@ -1,17 +1,14 @@
import io
from unittest import mock
import pytest
from mitmproxy.test import tflow, tutils, taddons
from mitmproxy.test import tflow, taddons
import mitmproxy.io
from mitmproxy import flowfilter
from mitmproxy import options
from mitmproxy.io import tnetstring
from mitmproxy.exceptions import FlowReadException, ReplayException
from mitmproxy.exceptions import FlowReadException
from mitmproxy import flow
from mitmproxy import http
from mitmproxy.net import http as net_http
from mitmproxy import master
from . import tservers
@ -122,34 +119,6 @@ class TestFlowMaster:
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
def test_replay(self):
opts = options.Options()
fm = master.Master(opts)
f = tflow.tflow(resp=True)
f.request.content = None
with pytest.raises(ReplayException, match="missing"):
fm.replay_request(f)
f.request = None
with pytest.raises(ReplayException, match="request"):
fm.replay_request(f)
f.intercepted = True
with pytest.raises(ReplayException, match="intercepted"):
fm.replay_request(f)
f.live = True
with pytest.raises(ReplayException, match="live"):
fm.replay_request(f)
req = tutils.treq(headers=net_http.Headers(((b":authority", b"foo"), (b"header", b"qvalue"), (b"content-length", b"7"))))
f = tflow.tflow(req=req)
f.request.http_version = "HTTP/2.0"
with mock.patch('mitmproxy.proxy.protocol.http_replay.RequestReplayThread.run'):
rt = fm.replay_request(f)
assert rt.f.request.http_version == "HTTP/1.1"
assert ":authority" not in rt.f.request.headers
@pytest.mark.asyncio
async def test_all(self):
opts = options.Options(

View File

@ -185,14 +185,15 @@ class TestApp(tornado.testing.AsyncHTTPTestCase):
self.fetch("/flows/42/revert", method="POST")
assert not f._backup
def test_flow_replay(self):
with mock.patch("mitmproxy.master.Master.replay_request") as replay_request:
assert self.fetch("/flows/42/replay", method="POST").code == 200
assert replay_request.called
replay_request.side_effect = exceptions.ReplayException(
"out of replays"
)
assert self.fetch("/flows/42/replay", method="POST").code == 400
# FIXME
# def test_flow_replay(self):
# with mock.patch("mitmproxy.master.Master.replay_request") as replay_request:
# assert self.fetch("/flows/42/replay", method="POST").code == 200
# assert replay_request.called
# replay_request.side_effect = exceptions.ReplayException(
# "out of replays"
# )
# assert self.fetch("/flows/42/replay", method="POST").code == 400
def test_flow_content(self):
f = self.view.get_by_id("42")