mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-22 15:37:45 +00:00
[sans-io] add flow killing, tests++
This commit is contained in:
parent
34d96da876
commit
40fc542cf6
@ -3,6 +3,7 @@ from enum import Flag, auto
|
||||
from typing import List, Literal, Optional, Sequence, Union
|
||||
|
||||
from mitmproxy import certs
|
||||
from mitmproxy.flow import Error
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.options import Options
|
||||
|
||||
@ -101,3 +102,9 @@ class Context:
|
||||
ret.server = self.server
|
||||
ret.layers = self.layers.copy()
|
||||
return ret
|
||||
|
||||
|
||||
# FIXME: Move to mitmproxy.flow, adjust Flow.kill()
|
||||
class Killed(Error):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Connection killed.")
|
@ -8,7 +8,7 @@ from mitmproxy.net import server_spec
|
||||
from mitmproxy.net.http import url
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import commands, events, layer, tunnel
|
||||
from mitmproxy.proxy2.context import Connection, Context, Server
|
||||
from mitmproxy.proxy2.context import Connection, Context, Killed, Server
|
||||
from mitmproxy.proxy2.layers import tls
|
||||
from mitmproxy.proxy2.layers.http import _upstream_proxy
|
||||
from mitmproxy.proxy2.utils import expect
|
||||
@ -173,6 +173,8 @@ class HttpStream(layer.Layer):
|
||||
self.flow.request.host_header = self.context.server.address[0]
|
||||
|
||||
yield HttpRequestHeadersHook(self.flow)
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
|
||||
if self.flow.request.headers.get("expect", "").lower() == "100-continue":
|
||||
continue_response = http.HTTPResponse.make(100)
|
||||
@ -212,10 +214,14 @@ class HttpStream(layer.Layer):
|
||||
self.flow.request.data.content = self.request_body_buf
|
||||
self.request_body_buf = b""
|
||||
yield HttpRequestHook(self.flow)
|
||||
if self.flow.response:
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
elif self.flow.response:
|
||||
# response was set by an inline script.
|
||||
# we now need to emulate the responseheaders hook.
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
yield from self.send_response()
|
||||
else:
|
||||
ok = yield from self.make_server_connection()
|
||||
@ -233,7 +239,9 @@ class HttpStream(layer.Layer):
|
||||
def state_wait_for_response_headers(self, event: ResponseHeaders) -> layer.CommandGenerator[None]:
|
||||
self.flow.response = event.response
|
||||
yield HttpResponseHeadersHook(self.flow)
|
||||
if self.flow.response.stream:
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
elif self.flow.response.stream:
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
self.server_state = self.state_stream_response_body
|
||||
else:
|
||||
@ -261,8 +269,17 @@ class HttpStream(layer.Layer):
|
||||
yield from self.send_response()
|
||||
self.server_state = self.state_done
|
||||
|
||||
def check_killed(self) -> layer.CommandGenerator[bool]:
|
||||
if isinstance(self.flow.error, Killed):
|
||||
yield commands.CloseConnection(self.context.client)
|
||||
self._handle_event = self.state_errored
|
||||
return True
|
||||
return False
|
||||
|
||||
def send_response(self):
|
||||
yield HttpResponseHook(self.flow)
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
yield SendHttp(ResponseHeaders(self.stream_id, self.flow.response), self.context.client)
|
||||
if self.flow.response.raw_content:
|
||||
yield SendHttp(ResponseData(self.stream_id, self.flow.response.raw_content), self.context.client)
|
||||
@ -274,8 +291,9 @@ class HttpStream(layer.Layer):
|
||||
) -> layer.CommandGenerator[None]:
|
||||
self.flow.error = flow.Error(event.message)
|
||||
yield HttpErrorHook(self.flow)
|
||||
|
||||
if isinstance(event, ResponseProtocolError):
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
elif isinstance(event, ResponseProtocolError):
|
||||
yield SendHttp(event, self.context.client)
|
||||
|
||||
def make_server_connection(self) -> layer.CommandGenerator[bool]:
|
||||
@ -293,6 +311,8 @@ class HttpStream(layer.Layer):
|
||||
|
||||
def handle_connect(self) -> layer.CommandGenerator[None]:
|
||||
yield HttpConnectHook(self.flow)
|
||||
if (yield from self.check_killed()):
|
||||
return
|
||||
|
||||
self.context.server.address = (self.flow.request.host, self.flow.request.port)
|
||||
|
||||
|
@ -6,7 +6,7 @@ from mitmproxy.http import HTTPFlow, HTTPResponse
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
from mitmproxy.proxy2 import layer
|
||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||
from mitmproxy.proxy2.context import Server
|
||||
from mitmproxy.proxy2.context import Killed, Server
|
||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||
from mitmproxy.proxy2.layers import TCPLayer, http, tls
|
||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
|
||||
@ -630,7 +630,8 @@ def test_http_client_aborts(tctx, stream):
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 6\r\n\r\n"
|
||||
b"Content-Length: 6\r\n"
|
||||
b"\r\n"
|
||||
b"abc")
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
)
|
||||
@ -642,7 +643,8 @@ def test_http_client_aborts(tctx, stream):
|
||||
>> reply(None)
|
||||
<< SendData(server, b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 6\r\n\r\n"
|
||||
b"Content-Length: 6\r\n"
|
||||
b"\r\n"
|
||||
b"abc")
|
||||
)
|
||||
else:
|
||||
@ -656,5 +658,138 @@ def test_http_client_aborts(tctx, stream):
|
||||
|
||||
)
|
||||
|
||||
flow: Callable[[], HTTPFlow]
|
||||
assert "peer closed connection" in flow().error.msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_http_server_aborts(tctx, stream):
|
||||
"""Test handling of the case where a server aborts during response transmission."""
|
||||
server = Placeholder(Server)
|
||||
flow = Placeholder(HTTPFlow)
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
|
||||
def enable_streaming(flow: HTTPFlow):
|
||||
flow.response.stream = True
|
||||
|
||||
assert (
|
||||
playbook
|
||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n\r\n")
|
||||
<< http.HttpRequestHeadersHook(flow)
|
||||
>> reply()
|
||||
<< http.HttpRequestHook(flow)
|
||||
>> reply()
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\n"
|
||||
b"Content-Length: 6\r\n"
|
||||
b"\r\n"
|
||||
b"abc")
|
||||
<< http.HttpResponseHeadersHook(flow)
|
||||
)
|
||||
if stream:
|
||||
assert (
|
||||
playbook
|
||||
>> reply(side_effect=enable_streaming)
|
||||
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\n"
|
||||
b"Content-Length: 6\r\n"
|
||||
b"\r\n"
|
||||
b"abc")
|
||||
)
|
||||
else:
|
||||
assert playbook >> reply()
|
||||
assert (
|
||||
playbook
|
||||
>> ConnectionClosed(server)
|
||||
<< CloseConnection(server)
|
||||
<< http.HttpErrorHook(flow)
|
||||
)
|
||||
if stream:
|
||||
assert (
|
||||
playbook
|
||||
>> reply()
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
else:
|
||||
error_html = Placeholder(bytes)
|
||||
assert (
|
||||
playbook
|
||||
>> reply()
|
||||
<< SendData(tctx.client, error_html)
|
||||
<< CloseConnection(tctx.client)
|
||||
)
|
||||
assert b"502 Bad Gateway" in error_html()
|
||||
assert b"peer closed connection" in error_html()
|
||||
|
||||
assert "peer closed connection" in flow().error.msg
|
||||
|
||||
|
||||
@pytest.mark.parametrize("when", ["http_connect", "requestheaders", "request", "script-response-responseheaders",
|
||||
"responseheaders",
|
||||
"response", "error"])
|
||||
def test_kill_flow(tctx, when):
|
||||
"""Test that we properly kill flows if instructed to do so"""
|
||||
server = Placeholder(Server)
|
||||
connect_flow = Placeholder(HTTPFlow)
|
||||
flow = Placeholder(HTTPFlow)
|
||||
|
||||
def kill(flow: HTTPFlow):
|
||||
flow.error = Killed()
|
||||
|
||||
def assert_kill():
|
||||
assert (playbook
|
||||
>> reply(side_effect=kill)
|
||||
<< CloseConnection(tctx.client))
|
||||
|
||||
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||
assert (playbook
|
||||
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
||||
<< http.HttpConnectHook(connect_flow))
|
||||
if when == "http_connect":
|
||||
return assert_kill()
|
||||
assert (playbook
|
||||
>> reply()
|
||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
<< layer.NextLayerHook(Placeholder())
|
||||
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||
<< http.HttpRequestHeadersHook(flow))
|
||||
if when == "requestheaders":
|
||||
return assert_kill()
|
||||
assert (playbook
|
||||
>> reply()
|
||||
<< http.HttpRequestHook(flow))
|
||||
if when == "request":
|
||||
return assert_kill()
|
||||
if when == "script-response-responseheaders":
|
||||
assert (playbook
|
||||
>> reply(side_effect=lambda f: setattr(f, "response", HTTPResponse.make()))
|
||||
<< http.HttpResponseHeadersHook(flow))
|
||||
return assert_kill()
|
||||
assert (playbook
|
||||
>> reply()
|
||||
<< OpenConnection(server)
|
||||
>> reply(None)
|
||||
<< SendData(server, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World")
|
||||
<< http.HttpResponseHeadersHook(flow))
|
||||
if when == "responseheaders":
|
||||
return assert_kill()
|
||||
|
||||
if when == "response":
|
||||
assert (playbook
|
||||
>> reply()
|
||||
>> DataReceived(server, b"!")
|
||||
<< http.HttpResponseHook(flow))
|
||||
return assert_kill()
|
||||
elif when == "error":
|
||||
assert (playbook
|
||||
>> reply()
|
||||
>> ConnectionClosed(server)
|
||||
<< CloseConnection(server)
|
||||
<< http.HttpErrorHook(flow))
|
||||
return assert_kill()
|
||||
else:
|
||||
raise AssertionError
|
||||
|
Loading…
Reference in New Issue
Block a user