[sans-io] upstream proxy tests and fixes

This commit is contained in:
Maximilian Hils 2020-01-04 02:29:19 +01:00
parent 605da3afb6
commit 549eb8df4b
5 changed files with 150 additions and 32 deletions

View File

@ -54,7 +54,7 @@ class Client(Connection):
class Server(Connection): class Server(Connection):
sni = True sni = True
"""True: client SNI, False: no SNI, bytes: custom value""" """True: client SNI, False: no SNI, bytes: custom value"""
via: Sequence["Server"] = () via: Sequence[server_spec.ServerSpec] = ()
def __init__(self, address: Optional[tuple]): def __init__(self, address: Optional[tuple]):
self.address = address self.address = address

View File

@ -8,7 +8,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, tunnel from mitmproxy.proxy2 import commands, events, layer, tunnel
from mitmproxy.proxy2.context import Connection, Context, Server from mitmproxy.proxy2.context import Connection, Context, Server
from mitmproxy.proxy2.layers import tls from mitmproxy.proxy2.layers import tls
from mitmproxy.proxy2.layers.http import upstream_proxy from mitmproxy.proxy2.layers.http import _upstream_proxy
from mitmproxy.proxy2.utils import expect from mitmproxy.proxy2.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
@ -29,7 +29,7 @@ class GetHttpConnection(HttpCommand):
tls: bool tls: bool
via: typing.Sequence[server_spec.ServerSpec] via: typing.Sequence[server_spec.ServerSpec]
def __init__(self, address: typing.Tuple[str, int], tls: bool, via: typing.Sequence[str]): def __init__(self, address: typing.Tuple[str, int], tls: bool, via: typing.Sequence[server_spec.ServerSpec]):
self.address = address self.address = address
self.tls = tls self.tls = tls
self.via = tuple(via) self.via = tuple(via)
@ -116,38 +116,33 @@ class HttpStream(layer.Layer):
) )
self.flow.request = event.request self.flow.request = event.request
if self.flow.request.first_line_format == "authority":
yield from self.handle_connect()
return
if self.flow.request.headers.get("expect", "").lower() == "100-continue": if self.flow.request.headers.get("expect", "").lower() == "100-continue":
raise NotImplementedError("expect nothing") raise NotImplementedError("expect nothing")
# self.send_response(http.expect_continue_response) # self.send_response(http.expect_continue_response)
# request.headers.pop("expect") # request.headers.pop("expect")
# set first line format to relative in regular mode, if self.flow.request.first_line_format == "authority":
# see https://github.com/mitmproxy/mitmproxy/issues/1759 yield from self.handle_connect()
if self.mode is HTTPMode.regular and self.flow.request.first_line_format == "absolute": return
self.flow.request.first_line_format = "relative"
# update host header in reverse proxy mode # Determine .scheme, .host and .port attributes for relative-form requests
if self.context.options.mode.startswith("reverse:") and not self.context.options.keep_host_header: if self.flow.request.first_line_format in "relative":
self.flow.request.host_header = self.context.server.address[0] # Setting request.host also updates the host header, which we want to preserve
# Determine .scheme, .host and .port attributes for inline scripts. For
# absolute-form requests, they are directly given in the request. For
# authority-form requests, we only need to determine the request
# scheme. For relative-form requests, we need to determine host and
# port as well.
if self.mode is HTTPMode.transparent:
# Setting request.host also updates the host header, which we want
# to preserve
host_header = self.flow.request.host_header host_header = self.flow.request.host_header
self.flow.request.host = self.context.server.address[0] self.flow.request.host = self.context.server.address[0]
self.flow.request.port = self.context.server.address[1] self.flow.request.port = self.context.server.address[1]
self.flow.request.host_header = host_header # set again as .host overwrites this. self.flow.request.host_header = host_header # set again as .host overwrites this.
self.flow.request.scheme = "https" if self.context.server.tls else "http" self.flow.request.scheme = "https" if self.context.server.tls else "http"
# set first line format to relative in regular mode,
# see https://github.com/mitmproxy/mitmproxy/issues/1759
if self.context.options.mode == "regular" and self.flow.request.first_line_format == "absolute":
self.flow.request.first_line_format = "relative"
# update host header in reverse proxy mode
if self.context.options.mode.startswith("reverse:") and not self.context.options.keep_host_header:
self.flow.request.host_header = self.context.server.address[0]
self.flow.request.via = [] # FIXME: Make this an official attribute. self.flow.request.via = [] # FIXME: Make this an official attribute.
if self.context.options.mode.startswith("upstream:"): if self.context.options.mode.startswith("upstream:"):
self.flow.request.via.append( self.flow.request.via.append(
@ -272,7 +267,13 @@ class HttpStream(layer.Layer):
yield HttpConnectHook(self.flow) yield HttpConnectHook(self.flow)
self.context.server = Server((self.flow.request.host, self.flow.request.port)) self.context.server = Server((self.flow.request.host, self.flow.request.port))
if self.context.options.connection_strategy == "eager":
# We must not connect to the actual destination in upstream mode.
connect_now = (
self.context.options.connection_strategy == "eager"
and not self.context.options.mode.startswith("upstream:")
)
if connect_now:
err = yield commands.OpenConnection(self.context.server) err = yield commands.OpenConnection(self.context.server)
if err: if err:
self.flow.response = http.HTTPResponse.make( self.flow.response = http.HTTPResponse.make(
@ -426,13 +427,17 @@ class HttpLayer(layer.Layer):
if not can_reuse_context_connection: if not can_reuse_context_connection:
context.server = Server(event.address) context.server = Server(event.address)
context.server.via = event.via
if context.options.http2: if context.options.http2:
context.server.alpn_offers = tls.HTTP_ALPNS context.server.alpn_offers = tls.HTTP_ALPNS
else: else:
context.server.alpn_offers = tls.HTTP1_ALPNS context.server.alpn_offers = tls.HTTP1_ALPNS
for via in reversed(event.via): for via in reversed(event.via):
stack /= upstream_proxy.HttpUpstreamProxy(context, via.address) needs_http_connect = (
self.mode != HTTPMode.regular or via != event.via[-1]
)
stack /= _upstream_proxy.HttpUpstreamProxy(context, via.address, needs_http_connect)
if event.tls: if event.tls:
stack /= tls.ServerTLSLayer(context) stack /= tls.ServerTLSLayer(context)

View File

@ -5,25 +5,33 @@ from h11._receivebuffer import ReceiveBuffer
from mitmproxy import http from mitmproxy import http
from mitmproxy.net.http import http1 from mitmproxy.net.http import http1
from mitmproxy.net.http.http1 import read_sansio as http1_sansio from mitmproxy.net.http.http1 import read_sansio as http1_sansio
from mitmproxy.proxy2 import commands, context, events, layer, tunnel from mitmproxy.proxy2 import commands, context, layer, tunnel
from mitmproxy.utils import human from mitmproxy.utils import human
class HttpUpstreamProxy(tunnel.TunnelLayer): class HttpUpstreamProxy(tunnel.TunnelLayer):
buf: ReceiveBuffer buf: ReceiveBuffer
send_connect: bool
def __init__(self, ctx: context.Context, address: tuple): def __init__(self, ctx: context.Context, address: tuple, send_connect: bool):
s = context.Server(address) super().__init__(
ctx.server.via = (*ctx.server.via, s) ctx,
super().__init__(ctx, tunnel_connection=s, conn=ctx.server) tunnel_connection=context.Server(address),
conn=ctx.server
)
self.buf = ReceiveBuffer() self.buf = ReceiveBuffer()
self.send_connect = send_connect
def start_handshake(self) -> layer.CommandGenerator[None]: def start_handshake(self) -> layer.CommandGenerator[None]:
if not self.send_connect:
return (yield from super().start_handshake())
req = http.make_connect_request(self.conn.address) req = http.make_connect_request(self.conn.address)
raw = http1.assemble_request(req) raw = http1.assemble_request(req)
yield commands.SendData(self.tunnel_connection, raw) yield commands.SendData(self.tunnel_connection, raw)
def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]: def receive_handshake_data(self, data: bytes) -> layer.CommandGenerator[Tuple[bool, Optional[str]]]:
if not self.send_connect:
return (yield from super().receive_handshake_data(data))
self.buf += data self.buf += data
response_head = self.buf.maybe_extract_lines() response_head = self.buf.maybe_extract_lines()
if response_head: if response_head:

View File

@ -49,6 +49,7 @@ class TunnelLayer(layer.Layer):
if self.tunnel_state is TunnelState.ESTABLISHING: if self.tunnel_state is TunnelState.ESTABLISHING:
done, err = yield from self.receive_handshake_data(event.data) done, err = yield from self.receive_handshake_data(event.data)
if done: if done:
self.conn.state = context.ConnectionState.OPEN
self.tunnel_state = TunnelState.OPEN self.tunnel_state = TunnelState.OPEN
if err: if err:
self.tunnel_state = TunnelState.CLOSED self.tunnel_state = TunnelState.CLOSED
@ -84,7 +85,6 @@ class TunnelLayer(layer.Layer):
if err: if err:
yield from self.event_to_child(events.OpenConnectionReply(command, err)) yield from self.event_to_child(events.OpenConnectionReply(command, err))
else: else:
self.conn.state = context.ConnectionState.OPEN
self.command_to_reply_to = command self.command_to_reply_to = command
yield from self.start_handshake() yield from self.start_handshake()
else: else:

View File

@ -1,6 +1,7 @@
import pytest import pytest
from mitmproxy.http import HTTPFlow, HTTPResponse from mitmproxy.http import HTTPFlow, HTTPResponse
from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import layer from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
@ -345,3 +346,107 @@ def test_server_aborts(tctx, data):
) )
assert flow().error assert flow().error
assert b"502 Bad Gateway" in err() assert b"502 Bad Gateway" in err()
@pytest.mark.parametrize("redirect", [None, "proxy", "destination"])
@pytest.mark.parametrize("scheme", ["http", "https"])
@pytest.mark.parametrize("strategy", ["eager", "lazy"])
def test_upstream_proxy(tctx, redirect, scheme, strategy):
"""Test that an upstream HTTP proxy is used."""
server = Placeholder()
server2 = Placeholder()
flow = Placeholder()
tctx.options.mode = "upstream:http://proxy:8080"
tctx.options.connection_strategy = strategy
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
if scheme == "http":
playbook >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
playbook << OpenConnection(server)
playbook >> reply(None)
# FIXME: We really shouldn't have the port here.
playbook << SendData(server, b"GET http://example.com:80/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
else:
playbook >> DataReceived(tctx.client, b"CONNECT example.com:443 HTTP/1.1\r\n\r\n")
playbook << SendData(tctx.client, b"HTTP/1.1 200 Connection established\r\n\r\n")
playbook >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
playbook << layer.NextLayerHook(Placeholder())
playbook >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
playbook << OpenConnection(server)
playbook >> reply(None)
playbook << SendData(server, b"CONNECT example.com:443 HTTP/1.1\r\n\r\n")
playbook >> DataReceived(server, b"HTTP/1.1 200 Connection established\r\n\r\n")
playbook << SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
playbook >> DataReceived(server, b"HTTP/1.1 418 OK\r\nContent-Length: 0\r\n\r\n")
playbook << SendData(tctx.client, b"HTTP/1.1 418 OK\r\nContent-Length: 0\r\n\r\n")
assert playbook
assert server().address == ("proxy", 8080)
if scheme == "http":
playbook >> DataReceived(tctx.client, b"GET http://example.com/two HTTP/1.1\r\nHost: example.com\r\n\r\n")
else:
playbook >> DataReceived(tctx.client, b"GET /two HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert (playbook << http.HttpRequestHook(flow))
if redirect == "proxy":
flow().request.via = [server_spec.ServerSpec("http", ("other-proxy", 1234))]
elif redirect == "destination":
flow().request.host = "other-server"
flow().request.host_header = "example.com"
playbook >> reply()
if redirect:
# Protocol-wise we wouldn't need to open a new connection for plain http host redirects,
# but we disregard this edge case to simplify implementation.
playbook << OpenConnection(server2)
playbook >> reply(None)
else:
server2 = server
if scheme == "http":
if redirect == "destination":
playbook << SendData(server2, b"GET http://other-server:80/two HTTP/1.1\r\nHost: example.com\r\n\r\n")
else:
playbook << SendData(server2, b"GET http://example.com:80/two HTTP/1.1\r\nHost: example.com\r\n\r\n")
else:
if redirect:
if redirect == "destination":
playbook << SendData(server2, b"CONNECT other-server:443 HTTP/1.1\r\n\r\n")
else:
playbook << SendData(server2, b"CONNECT example.com:443 HTTP/1.1\r\n\r\n")
playbook >> DataReceived(server2, b"HTTP/1.1 200 Connection established\r\n\r\n")
playbook << SendData(server2, b"GET /two HTTP/1.1\r\nHost: example.com\r\n\r\n")
playbook >> DataReceived(server2, b"HTTP/1.1 418 OK\r\nContent-Length: 0\r\n\r\n")
playbook << SendData(tctx.client, b"HTTP/1.1 418 OK\r\nContent-Length: 0\r\n\r\n")
assert playbook
if redirect == "proxy":
assert server2().address == ("other-proxy", 1234)
else:
assert server2().address == ("proxy", 8080)
assert (
playbook
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.client)
)
@pytest.mark.xfail(reason="h11 enforces host headers by default")
def test_no_headers(tctx):
"""Test that we can correctly reassemble requests/responses with no headers."""
server = Placeholder()
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\n\r\n")
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 204 No Content\r\n\r\n")
<< SendData(tctx.client, b"HTTP/1.1 204 No Content\r\n\r\n")
)
assert server().address == ("example.com", 80)