mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
fix bugs, fix tests
This commit is contained in:
parent
63844df343
commit
a7058e2a3c
@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap):
|
|||||||
r.append("[%s]" % (":".join(opts)))
|
r.append("[%s]" % (":".join(opts)))
|
||||||
|
|
||||||
if self.master.server.config.mode in ["reverse", "upstream"]:
|
if self.master.server.config.mode in ["reverse", "upstream"]:
|
||||||
dst = self.master.server.config.mode.dst
|
dst = self.master.server.config.upstream_server
|
||||||
scheme = "https" if dst[0] else "http"
|
r.append("[dest:%s]" % netlib.utils.unparse_url(
|
||||||
if dst[1] != dst[0]:
|
dst.scheme,
|
||||||
scheme += "2https" if dst[1] else "http"
|
dst.address.host,
|
||||||
r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:]))
|
dst.address.port
|
||||||
|
))
|
||||||
if self.master.scripts:
|
if self.master.scripts:
|
||||||
r.append("[")
|
r.append("[")
|
||||||
r.append(("heading_key", "s"))
|
r.append(("heading_key", "s"))
|
||||||
|
@ -40,6 +40,7 @@ class _HttpLayer(Layer):
|
|||||||
def send_response(self, response):
|
def send_response(self, response):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class _StreamingHttpLayer(_HttpLayer):
|
class _StreamingHttpLayer(_HttpLayer):
|
||||||
supports_streaming = True
|
supports_streaming = True
|
||||||
|
|
||||||
@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer):
|
|||||||
|
|
||||||
|
|
||||||
class Http1Layer(_StreamingHttpLayer):
|
class Http1Layer(_StreamingHttpLayer):
|
||||||
|
|
||||||
def __init__(self, ctx, mode):
|
def __init__(self, ctx, mode):
|
||||||
super(Http1Layer, self).__init__(ctx)
|
super(Http1Layer, self).__init__(ctx)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer):
|
|||||||
|
|
||||||
def send_response_headers(self, response):
|
def send_response_headers(self, response):
|
||||||
h = self.client_protocol._assemble_response_first_line(response)
|
h = self.client_protocol._assemble_response_first_line(response)
|
||||||
self.client_conn.wfile.write(h+"\r\n")
|
self.client_conn.wfile.write(h + "\r\n")
|
||||||
h = self.client_protocol._assemble_response_headers(
|
h = self.client_protocol._assemble_response_headers(
|
||||||
response,
|
response,
|
||||||
preserve_transfer_encoding=True
|
preserve_transfer_encoding=True
|
||||||
)
|
)
|
||||||
self.client_conn.send(h+"\r\n")
|
self.client_conn.send(h + "\r\n")
|
||||||
|
|
||||||
def send_response_body(self, response, chunks):
|
def send_response_body(self, response, chunks):
|
||||||
if self.client_protocol.has_chunked_encoding(response.headers):
|
if self.client_protocol.has_chunked_encoding(response.headers):
|
||||||
@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer):
|
|||||||
def __init__(self, ctx, mode):
|
def __init__(self, ctx, mode):
|
||||||
super(Http2Layer, self).__init__(ctx)
|
super(Http2Layer, self).__init__(ctx)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True,
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||||
|
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
|
|
||||||
def read_request(self):
|
def read_request(self):
|
||||||
request = HTTPRequest.from_protocol(
|
request = HTTPRequest.from_protocol(
|
||||||
@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer):
|
|||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.ctx.connect()
|
self.ctx.connect()
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||||
|
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
self.server_protocol.perform_connection_preface()
|
self.server_protocol.perform_connection_preface()
|
||||||
|
|
||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
self.ctx.reconnect()
|
self.ctx.reconnect()
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||||
|
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
self.server_protocol.perform_connection_preface()
|
self.server_protocol.perform_connection_preface()
|
||||||
|
|
||||||
def set_server(self, *args, **kwargs):
|
def set_server(self, *args, **kwargs):
|
||||||
self.ctx.set_server(*args, **kwargs)
|
self.ctx.set_server(*args, **kwargs)
|
||||||
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
|
self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
|
||||||
|
unhandled_frame_cb=self.handle_unexpected_frame)
|
||||||
self.server_protocol.perform_connection_preface()
|
self.server_protocol.perform_connection_preface()
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer):
|
|||||||
def __init__(self, ctx, connect_request):
|
def __init__(self, ctx, connect_request):
|
||||||
super(UpstreamConnectLayer, self).__init__(ctx)
|
super(UpstreamConnectLayer, self).__init__(ctx)
|
||||||
self.connect_request = connect_request
|
self.connect_request = connect_request
|
||||||
self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx)
|
self.server_conn = ConnectServerConnection(
|
||||||
|
(connect_request.host, connect_request.port),
|
||||||
|
self.ctx
|
||||||
|
)
|
||||||
|
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
layer = self.ctx.next_layer(self)
|
layer = self.ctx.next_layer(self)
|
||||||
@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer):
|
|||||||
def reconnect(self):
|
def reconnect(self):
|
||||||
self.ctx.reconnect()
|
self.ctx.reconnect()
|
||||||
self.send_request(self.connect_request)
|
self.send_request(self.connect_request)
|
||||||
|
resp = self.read_response("CONNECT")
|
||||||
|
if resp.code != 200:
|
||||||
|
raise ProtocolException("Reconnect: Upstream server refuses CONNECT request")
|
||||||
|
|
||||||
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
def set_server(self, address, server_tls=None, sni=None, depth=1):
|
||||||
if depth == 1:
|
if depth == 1:
|
||||||
@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer):
|
|||||||
self.connect_request.port = address.port
|
self.connect_request.port = address.port
|
||||||
self.server_conn.address = address
|
self.server_conn.address = address
|
||||||
else:
|
else:
|
||||||
self.ctx.set_server(address, server_tls, sni, depth-1)
|
self.ctx.set_server(address, server_tls, sni, depth - 1)
|
||||||
|
|
||||||
|
|
||||||
class HttpLayer(Layer):
|
class HttpLayer(Layer):
|
||||||
@ -413,10 +424,10 @@ class HttpLayer(Layer):
|
|||||||
# First send the headers and then transfer the response incrementally
|
# First send the headers and then transfer the response incrementally
|
||||||
self.send_response_headers(flow.response)
|
self.send_response_headers(flow.response)
|
||||||
chunks = self.read_response_body(
|
chunks = self.read_response_body(
|
||||||
flow.response.headers,
|
flow.response.headers,
|
||||||
flow.request.method,
|
flow.request.method,
|
||||||
flow.response.code,
|
flow.response.code,
|
||||||
max_chunk_size=4096
|
max_chunk_size=4096
|
||||||
)
|
)
|
||||||
if callable(flow.response.stream):
|
if callable(flow.response.stream):
|
||||||
chunks = flow.response.stream(chunks)
|
chunks = flow.response.stream(chunks)
|
||||||
@ -521,7 +532,8 @@ class HttpLayer(Layer):
|
|||||||
# If there's not TlsLayer below which could catch the exception,
|
# If there's not TlsLayer below which could catch the exception,
|
||||||
# TLS will not be established.
|
# TLS will not be established.
|
||||||
if tls and not self.server_conn.tls_established:
|
if tls and not self.server_conn.tls_established:
|
||||||
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
|
raise ProtocolException(
|
||||||
|
"Cannot upgrade to SSL, no TLS layer on the protocol stack.")
|
||||||
else:
|
else:
|
||||||
if not self.server_conn:
|
if not self.server_conn:
|
||||||
self.connect()
|
self.connect()
|
||||||
@ -542,7 +554,8 @@ class HttpLayer(Layer):
|
|||||||
|
|
||||||
def validate_request(self, request):
|
def validate_request(self, request):
|
||||||
if request.form_in == "absolute" and request.scheme != "http":
|
if request.form_in == "absolute" and request.scheme != "http":
|
||||||
self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
self.send_response(
|
||||||
|
make_error_response(400, "Invalid request scheme: %s" % request.scheme))
|
||||||
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
raise HttpException("Invalid request scheme: %s" % request.scheme)
|
||||||
|
|
||||||
expected_request_forms = {
|
expected_request_forms = {
|
||||||
@ -570,7 +583,11 @@ class HttpLayer(Layer):
|
|||||||
self.send_response(make_error_response(
|
self.send_response(make_error_response(
|
||||||
407,
|
407,
|
||||||
"Proxy Authentication Required",
|
"Proxy Authentication Required",
|
||||||
odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
|
odict.ODictCaseless(
|
||||||
|
[
|
||||||
|
[k, v] for k, v in
|
||||||
|
self.config.authenticator.auth_challenge_headers().items()
|
||||||
|
])
|
||||||
))
|
))
|
||||||
raise InvalidCredentials("Proxy Authentication Required")
|
raise InvalidCredentials("Proxy Authentication Required")
|
||||||
|
|
||||||
@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread):
|
|||||||
if r.scheme == "https":
|
if r.scheme == "https":
|
||||||
connect_request = make_connect_request((r.host, r.port))
|
connect_request = make_connect_request((r.host, r.port))
|
||||||
server.send(protocol.assemble(connect_request))
|
server.send(protocol.assemble(connect_request))
|
||||||
|
resp = protocol.read_response("CONNECT")
|
||||||
|
if resp.code != 200:
|
||||||
|
raise HttpError(502, "Upstream server refuses CONNECT request")
|
||||||
server.establish_ssl(
|
server.establish_ssl(
|
||||||
self.config.clientcerts,
|
self.config.clientcerts,
|
||||||
sni=self.flow.server_conn.sni
|
sni=self.flow.server_conn.sni
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import argparse
|
|
||||||
from libmproxy import cmdline
|
from libmproxy import cmdline
|
||||||
from libmproxy.proxy import ProxyConfig, process_proxy_options
|
from libmproxy.proxy import ProxyConfig, process_proxy_options
|
||||||
from libmproxy.proxy.connection import ServerConnection
|
from libmproxy.proxy.connection import ServerConnection
|
||||||
from libmproxy.proxy.primitives import ProxyError
|
from libmproxy.proxy.primitives import ProxyError
|
||||||
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler
|
from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2
|
||||||
import tutils
|
import tutils
|
||||||
from libpathod import test
|
from libpathod import test
|
||||||
from netlib import http, tcp
|
from netlib import http, tcp
|
||||||
@ -175,8 +174,10 @@ class TestDummyServer:
|
|||||||
class TestConnectionHandler:
|
class TestConnectionHandler:
|
||||||
def test_fatal_error(self):
|
def test_fatal_error(self):
|
||||||
config = mock.Mock()
|
config = mock.Mock()
|
||||||
config.mode.get_upstream_server.side_effect = RuntimeError
|
root_layer = mock.Mock()
|
||||||
c = ConnectionHandler(
|
root_layer.side_effect = RuntimeError
|
||||||
|
config.mode.return_value = root_layer
|
||||||
|
c = ConnectionHandler2(
|
||||||
config,
|
config,
|
||||||
mock.MagicMock(),
|
mock.MagicMock(),
|
||||||
("127.0.0.1",
|
("127.0.0.1",
|
||||||
|
@ -68,7 +68,7 @@ class CommonMixin:
|
|||||||
# SSL with the upstream proxy.
|
# SSL with the upstream proxy.
|
||||||
rt = self.master.replay_request(l, block=True)
|
rt = self.master.replay_request(l, block=True)
|
||||||
assert not rt
|
assert not rt
|
||||||
if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl:
|
if isinstance(self, tservers.HTTPUpstreamProxTest):
|
||||||
assert l.response.code == 502
|
assert l.response.code == 502
|
||||||
else:
|
else:
|
||||||
assert l.error
|
assert l.error
|
||||||
@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin):
|
|||||||
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
|
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
|
||||||
p.connect()
|
p.connect()
|
||||||
r = p.request("get:/")
|
r = p.request("get:/")
|
||||||
assert r.status_code == 400
|
assert r.status_code == 502
|
||||||
|
|
||||||
|
|
||||||
class TestProxy(tservers.HTTPProxTest):
|
class TestProxy(tservers.HTTPProxTest):
|
||||||
@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest):
|
|||||||
assert resp.headers["Transfer-Encoding"][0] == 'chunked'
|
assert resp.headers["Transfer-Encoding"][0] == 'chunked'
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
chunks = list(
|
chunks = list(protocol.read_http_body_chunked(
|
||||||
content for _, content, _ in protocol.read_http_body_chunked(
|
resp.headers, None, "GET", 200, False
|
||||||
resp.headers, None, "GET", 200, False))
|
))
|
||||||
assert chunks == ["this", "isatest", ""]
|
assert chunks == ["this", "isatest", ""]
|
||||||
|
|
||||||
connection.close()
|
connection.close()
|
||||||
@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
|
|||||||
|
|
||||||
p = self.pathoc()
|
p = self.pathoc()
|
||||||
req = p.request("get:'/p/418:b\"content\"'")
|
req = p.request("get:'/p/418:b\"content\"'")
|
||||||
|
assert req.content == "content"
|
||||||
|
assert req.status_code == 418
|
||||||
|
|
||||||
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
|
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
|
||||||
# CONNECT, failing request,
|
# CONNECT, failing request,
|
||||||
assert self.chain[0].tmaster.state.flow_count() == 4
|
assert self.chain[0].tmaster.state.flow_count() == 4
|
||||||
@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
|
|||||||
assert self.chain[1].tmaster.state.flow_count() == 2
|
assert self.chain[1].tmaster.state.flow_count() == 2
|
||||||
# (doesn't store (repeated) CONNECTs from chain[0]
|
# (doesn't store (repeated) CONNECTs from chain[0]
|
||||||
# as it is a regular proxy)
|
# as it is a regular proxy)
|
||||||
assert req.content == "content"
|
|
||||||
assert req.status_code == 418
|
|
||||||
|
|
||||||
assert not self.chain[1].tmaster.state.flows[0].response # killed
|
assert not self.chain[1].tmaster.state.flows[0].response # killed
|
||||||
assert self.chain[1].tmaster.state.flows[1].response
|
assert self.chain[1].tmaster.state.flows[1].response
|
||||||
|
@ -181,22 +181,24 @@ class TResolver:
|
|||||||
def original_addr(self, sock):
|
def original_addr(self, sock):
|
||||||
return ("127.0.0.1", self.port)
|
return ("127.0.0.1", self.port)
|
||||||
|
|
||||||
|
|
||||||
class TransparentProxTest(ProxTestBase):
|
class TransparentProxTest(ProxTestBase):
|
||||||
ssl = None
|
ssl = None
|
||||||
resolver = TResolver
|
resolver = TResolver
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@mock.patch("libmproxy.platform.resolver")
|
def setupAll(cls):
|
||||||
def setupAll(cls, _):
|
|
||||||
super(TransparentProxTest, cls).setupAll()
|
super(TransparentProxTest, cls).setupAll()
|
||||||
if cls.ssl:
|
|
||||||
ports = [cls.server.port, cls.server2.port]
|
cls._resolver = mock.patch(
|
||||||
else:
|
"libmproxy.platform.resolver",
|
||||||
ports = []
|
new=lambda: cls.resolver(cls.server.port)
|
||||||
cls.config.mode = TransparentProxyMode(
|
)
|
||||||
cls.resolver(cls.server.port),
|
cls._resolver.start()
|
||||||
ports)
|
|
||||||
|
@classmethod
|
||||||
|
def teardownAll(cls):
|
||||||
|
cls._resolver.stop()
|
||||||
|
super(TransparentProxTest, cls).teardownAll()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_proxy_config(cls):
|
def get_proxy_config(cls):
|
||||||
@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest):
|
|||||||
d["mode"] = "socks5"
|
d["mode"] = "socks5"
|
||||||
return d
|
return d
|
||||||
|
|
||||||
class SpoofModeTest(ProxTestBase):
|
|
||||||
ssl = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_proxy_config(cls):
|
|
||||||
d = ProxTestBase.get_proxy_config()
|
|
||||||
d["upstream_server"] = None
|
|
||||||
d["mode"] = "spoof"
|
|
||||||
return d
|
|
||||||
|
|
||||||
def pathoc(self, sni=None):
|
|
||||||
"""
|
|
||||||
Returns a connected Pathoc instance.
|
|
||||||
"""
|
|
||||||
p = libpathod.pathoc.Pathoc(
|
|
||||||
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
|
|
||||||
)
|
|
||||||
p.connect()
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
class SSLSpoofModeTest(ProxTestBase):
|
|
||||||
ssl = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_proxy_config(cls):
|
|
||||||
d = ProxTestBase.get_proxy_config()
|
|
||||||
d["upstream_server"] = None
|
|
||||||
d["mode"] = "sslspoof"
|
|
||||||
d["spoofed_ssl_port"] = 443
|
|
||||||
return d
|
|
||||||
|
|
||||||
def pathoc(self, sni=None):
|
|
||||||
"""
|
|
||||||
Returns a connected Pathoc instance.
|
|
||||||
"""
|
|
||||||
p = libpathod.pathoc.Pathoc(
|
|
||||||
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
|
|
||||||
)
|
|
||||||
p.connect()
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
class ChainProxTest(ProxTestBase):
|
class ChainProxTest(ProxTestBase):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user