mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] fixes, fixes, fixes
This commit is contained in:
parent
b2060356b6
commit
a30a6758f3
@ -86,12 +86,11 @@ class NextLayer:
|
|||||||
|
|
||||||
def next_layer(self, nextlayer: layer.NextLayer):
|
def next_layer(self, nextlayer: layer.NextLayer):
|
||||||
nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client())
|
nextlayer.layer = self._next_layer(nextlayer.context, nextlayer.data_client())
|
||||||
|
# nextlayer.layer.debug = " " * len(nextlayer.context.layers)
|
||||||
|
|
||||||
def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]:
|
def _next_layer(self, context: context.Context, data_client: bytes) -> typing.Optional[layer.Layer]:
|
||||||
if len(context.layers) == 0:
|
if len(context.layers) == 0:
|
||||||
return self.make_top_layer(context)
|
return self.make_top_layer(context)
|
||||||
if len(context.layers) == 1:
|
|
||||||
return layers.ServerTLSLayer(context)
|
|
||||||
|
|
||||||
if len(data_client) < 3:
|
if len(data_client) < 3:
|
||||||
return
|
return
|
||||||
@ -113,21 +112,14 @@ class NextLayer:
|
|||||||
if isinstance(top_layer, layers.ServerTLSLayer):
|
if isinstance(top_layer, layers.ServerTLSLayer):
|
||||||
return layers.ClientTLSLayer(context)
|
return layers.ClientTLSLayer(context)
|
||||||
else:
|
else:
|
||||||
if s(modes.HttpProxy):
|
|
||||||
# A "Secure Web Proxy" (https://www.chromium.org/developers/design-documents/secure-web-proxy)
|
|
||||||
# This does not imply TLS on the server side.
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# In all other cases, client TLS implies TLS for both ends.
|
|
||||||
context.server.tls = True
|
|
||||||
return layers.ServerTLSLayer(context)
|
return layers.ServerTLSLayer(context)
|
||||||
|
|
||||||
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
|
# 3. Setup the HTTP layer for a regular HTTP proxy or an upstream proxy.
|
||||||
if any([
|
if any([
|
||||||
s(modes.HttpProxy, layers.ServerTLSLayer),
|
s(modes.HttpProxy),
|
||||||
s(modes.HttpProxy, layers.ServerTLSLayer, layers.ClientTLSLayer),
|
s(modes.HttpProxy, layers.ClientTLSLayer),
|
||||||
]):
|
]):
|
||||||
return layers.HTTPLayer(context, HTTPMode.regular)
|
return layers.HttpLayer(context, HTTPMode.regular)
|
||||||
if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer,
|
if ctx.options.mode.startswith("upstream:") and len(context.layers) <= 3 and isinstance(top_layer,
|
||||||
layers.ServerTLSLayer):
|
layers.ServerTLSLayer):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@ -153,7 +145,7 @@ class NextLayer:
|
|||||||
return layers.TCPLayer(context)
|
return layers.TCPLayer(context)
|
||||||
|
|
||||||
# 6. Assume HTTP by default.
|
# 6. Assume HTTP by default.
|
||||||
return layers.HTTPLayer(context, HTTPMode.transparent)
|
return layers.HttpLayer(context, HTTPMode.transparent)
|
||||||
|
|
||||||
def make_top_layer(self, context: context.Context) -> layer.Layer:
|
def make_top_layer(self, context: context.Context) -> layer.Layer:
|
||||||
if ctx.options.mode == "regular":
|
if ctx.options.mode == "regular":
|
||||||
|
@ -13,9 +13,11 @@ from mitmproxy.proxy2.layers import tls
|
|||||||
|
|
||||||
def alpn_select_callback(conn: SSL.Connection, options):
|
def alpn_select_callback(conn: SSL.Connection, options):
|
||||||
server_alpn = conn.get_app_data()["server_alpn"]
|
server_alpn = conn.get_app_data()["server_alpn"]
|
||||||
|
http2 = conn.get_app_data()["http2"]
|
||||||
if server_alpn and server_alpn in options:
|
if server_alpn and server_alpn in options:
|
||||||
return server_alpn
|
return server_alpn
|
||||||
for alpn in tls.HTTP_ALPNS:
|
http_alpns = tls.HTTP_ALPNS if http2 else tls.HTTP1_ALPNS
|
||||||
|
for alpn in http_alpns:
|
||||||
if alpn in options:
|
if alpn in options:
|
||||||
return alpn
|
return alpn
|
||||||
else:
|
else:
|
||||||
@ -76,11 +78,11 @@ class TlsConfig:
|
|||||||
)
|
)
|
||||||
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
tls_start.ssl_conn = SSL.Connection(ssl_ctx)
|
||||||
tls_start.ssl_conn.set_app_data({
|
tls_start.ssl_conn.set_app_data({
|
||||||
"server_alpn": tls_start.context.server.alpn
|
"server_alpn": tls_start.context.server.alpn,
|
||||||
|
"http2": ctx.options.http2,
|
||||||
})
|
})
|
||||||
tls_start.ssl_conn.set_accept_state()
|
tls_start.ssl_conn.set_accept_state()
|
||||||
|
|
||||||
|
|
||||||
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
|
def create_proxy_server_ssl_conn(self, tls_start: tls.TlsStartData) -> None:
|
||||||
client = tls_start.context.client
|
client = tls_start.context.client
|
||||||
server: context.Server = tls_start.conn
|
server: context.Server = tls_start.conn
|
||||||
@ -120,7 +122,7 @@ class TlsConfig:
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
client_cert = path
|
client_cert = path
|
||||||
|
|
||||||
args["cipher_list"] = ':'.join(server.cipher_list) if server.cipher_list else None
|
args["cipher_list"] = b':'.join(server.cipher_list) if server.cipher_list else None
|
||||||
ssl_ctx = net_tls.create_client_context(
|
ssl_ctx = net_tls.create_client_context(
|
||||||
cert=client_cert,
|
cert=client_cert,
|
||||||
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
|
sni=server.sni.decode("idna"), # FIXME: Should pass-through here.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from enum import Flag, auto
|
from enum import Flag, auto
|
||||||
from typing import List, Optional, Sequence, Union
|
from typing import List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
from mitmproxy import certs
|
from mitmproxy import certs
|
||||||
from mitmproxy.options import Options
|
from mitmproxy.options import Options
|
||||||
@ -16,7 +16,7 @@ class Connection:
|
|||||||
"""
|
"""
|
||||||
Connections exposed to the layers only contain metadata, no socket objects.
|
Connections exposed to the layers only contain metadata, no socket objects.
|
||||||
"""
|
"""
|
||||||
address: tuple
|
address: Optional[tuple]
|
||||||
state: ConnectionState
|
state: ConnectionState
|
||||||
tls: bool = False
|
tls: bool = False
|
||||||
tls_established: bool = False
|
tls_established: bool = False
|
||||||
@ -25,7 +25,7 @@ class Connection:
|
|||||||
alpn_offers: Sequence[bytes] = ()
|
alpn_offers: Sequence[bytes] = ()
|
||||||
cipher_list: Sequence[bytes] = ()
|
cipher_list: Sequence[bytes] = ()
|
||||||
tls_version: Optional[str] = None
|
tls_version: Optional[str] = None
|
||||||
sni: Union[bytes, bool, None]
|
sni: Union[bytes, Literal[True], None]
|
||||||
|
|
||||||
timestamp_tls_setup: Optional[float] = None
|
timestamp_tls_setup: Optional[float] = None
|
||||||
|
|
||||||
@ -33,21 +33,17 @@ class Connection:
|
|||||||
def connected(self):
|
def connected(self):
|
||||||
return self.state is ConnectionState.OPEN
|
return self.state is ConnectionState.OPEN
|
||||||
|
|
||||||
@connected.setter
|
|
||||||
def connected(self, val: bool) -> None:
|
|
||||||
# We should really set .state, but verdict is still due if we even want to keep .state around.
|
|
||||||
# We allow setting .connected while we figure that out.
|
|
||||||
if val:
|
|
||||||
self.state = ConnectionState.OPEN
|
|
||||||
else:
|
|
||||||
self.state = ConnectionState.CLOSED
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{type(self).__name__}({repr(self.__dict__)})"
|
attrs = repr({
|
||||||
|
k: {"cipher_list": lambda: f"<{len(v)} ciphers>"}.get(k,lambda: v)()
|
||||||
|
for k, v in self.__dict__.items()
|
||||||
|
})
|
||||||
|
return f"{type(self).__name__}({attrs})"
|
||||||
|
|
||||||
|
|
||||||
class Client(Connection):
|
class Client(Connection):
|
||||||
sni: Optional[bytes] = None
|
sni: Union[bytes, None] = None
|
||||||
|
address: tuple
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
self.address = address
|
self.address = address
|
||||||
@ -55,9 +51,9 @@ class Client(Connection):
|
|||||||
|
|
||||||
|
|
||||||
class Server(Connection):
|
class Server(Connection):
|
||||||
sni: Union[bytes, bool] = True
|
sni = True
|
||||||
"""True: client SNI, False: no SNI, bytes: custom value"""
|
"""True: client SNI, False: no SNI, bytes: custom value"""
|
||||||
address: Optional[tuple]
|
via: Optional["Server"] = None
|
||||||
|
|
||||||
def __init__(self, address: Optional[tuple]):
|
def __init__(self, address: Optional[tuple]):
|
||||||
self.address = address
|
self.address = address
|
||||||
|
@ -5,7 +5,7 @@ The counterpart to events are commands.
|
|||||||
"""
|
"""
|
||||||
import socket
|
import socket
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
|
|
||||||
from mitmproxy.proxy2 import commands
|
from mitmproxy.proxy2 import commands
|
||||||
from mitmproxy.proxy2.context import Connection
|
from mitmproxy.proxy2.context import Connection
|
||||||
@ -55,7 +55,6 @@ class DataReceived(ConnectionEvent):
|
|||||||
return f"DataReceived({target}, {self.data})"
|
return f"DataReceived({target}, {self.data})"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommandReply(Event):
|
class CommandReply(Event):
|
||||||
"""
|
"""
|
||||||
Emitted when a command has been finished, e.g.
|
Emitted when a command has been finished, e.g.
|
||||||
@ -65,6 +64,7 @@ class CommandReply(Event):
|
|||||||
reply: typing.Any
|
reply: typing.Any
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
|
assert is_dataclass(cls)
|
||||||
if cls is CommandReply:
|
if cls is CommandReply:
|
||||||
raise TypeError("CommandReply may not be instantiated directly.")
|
raise TypeError("CommandReply may not be instantiated directly.")
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
|
@ -5,6 +5,7 @@ import collections
|
|||||||
import textwrap
|
import textwrap
|
||||||
import typing
|
import typing
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy import log
|
from mitmproxy import log
|
||||||
from mitmproxy.proxy2 import commands, events
|
from mitmproxy.proxy2 import commands, events
|
||||||
@ -149,11 +150,14 @@ class NextLayer(Layer):
|
|||||||
events: typing.List[mevents.Event]
|
events: typing.List[mevents.Event]
|
||||||
"""All events that happened before a decision was made."""
|
"""All events that happened before a decision was made."""
|
||||||
|
|
||||||
def __init__(self, context: Context) -> None:
|
_ask_on_start: bool
|
||||||
|
|
||||||
|
def __init__(self, context: Context, ask_on_start: bool = False) -> None:
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.context.layers.remove(self)
|
self.context.layers.remove(self)
|
||||||
self.events = []
|
|
||||||
self.layer = None
|
self.layer = None
|
||||||
|
self.events = []
|
||||||
|
self._ask_on_start = ask_on_start
|
||||||
self._handle = None
|
self._handle = None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -169,11 +173,13 @@ class NextLayer(Layer):
|
|||||||
self.events.append(event)
|
self.events.append(event)
|
||||||
|
|
||||||
# We receive new data. Let's find out if we can determine the next layer now?
|
# We receive new data. Let's find out if we can determine the next layer now?
|
||||||
if isinstance(event, mevents.DataReceived):
|
if self._ask_on_start and isinstance(event, events.Start):
|
||||||
|
yield from self._ask()
|
||||||
|
elif isinstance(event, mevents.DataReceived):
|
||||||
# For now, we only ask if we have received new data to reduce hook noise.
|
# For now, we only ask if we have received new data to reduce hook noise.
|
||||||
yield from self.ask_now()
|
yield from self._ask()
|
||||||
|
|
||||||
def ask_now(self):
|
def _ask(self):
|
||||||
"""
|
"""
|
||||||
Manually trigger a next_layer hook.
|
Manually trigger a next_layer hook.
|
||||||
The only use at the moment is to make sure that the top layer is initialized.
|
The only use at the moment is to make sure that the top layer is initialized.
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from . import modes
|
from . import modes
|
||||||
from .http import HTTPLayer
|
from .http import HttpLayer
|
||||||
from .tcp import TCPLayer
|
from .tcp import TCPLayer
|
||||||
from .tls import ClientTLSLayer, ServerTLSLayer
|
from .tls import ClientTLSLayer, ServerTLSLayer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"modes",
|
"modes",
|
||||||
"HTTPLayer",
|
"HttpLayer",
|
||||||
"TCPLayer",
|
"TCPLayer",
|
||||||
"ClientTLSLayer", "ServerTLSLayer",
|
"ClientTLSLayer", "ServerTLSLayer",
|
||||||
]
|
]
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import collections
|
import collections
|
||||||
import typing
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy import flow, http
|
from mitmproxy import flow, http
|
||||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||||
@ -8,7 +9,7 @@ from mitmproxy.proxy2.context import Connection, Context, Server
|
|||||||
from mitmproxy.proxy2.layers import tls
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
from mitmproxy.utils import human
|
from mitmproxy.utils import human
|
||||||
from ._base import HttpConnection, StreamId, HttpCommand, ReceiveHttp
|
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
from ._hooks import HttpConnectHook, HttpErrorHook, HttpRequestHeadersHook, HttpRequestHook, HttpResponseHeadersHook, \
|
||||||
@ -37,6 +38,7 @@ class GetHttpConnection(HttpCommand):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class GetHttpConnectionReply(events.CommandReply):
|
class GetHttpConnectionReply(events.CommandReply):
|
||||||
command: GetHttpConnection
|
command: GetHttpConnection
|
||||||
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
reply: typing.Tuple[typing.Optional[Connection], typing.Optional[str]]
|
||||||
@ -67,8 +69,6 @@ class SendHttp(HttpCommand):
|
|||||||
return f"Send({self.event})"
|
return f"Send({self.event})"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class HttpStream(layer.Layer):
|
class HttpStream(layer.Layer):
|
||||||
request_body_buf: bytes
|
request_body_buf: bytes
|
||||||
response_body_buf: bytes
|
response_body_buf: bytes
|
||||||
@ -78,7 +78,7 @@ class HttpStream(layer.Layer):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def mode(self):
|
def mode(self):
|
||||||
parent: HTTPLayer = self.context.layers[-2]
|
parent: HttpLayer = self.context.layers[-2]
|
||||||
return parent.mode
|
return parent.mode
|
||||||
|
|
||||||
def __init__(self, context: Context):
|
def __init__(self, context: Context):
|
||||||
@ -309,7 +309,7 @@ class HttpStream(layer.Layer):
|
|||||||
yield from ()
|
yield from ()
|
||||||
|
|
||||||
|
|
||||||
class HTTPLayer(layer.Layer):
|
class HttpLayer(layer.Layer):
|
||||||
"""
|
"""
|
||||||
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
ConnectionEvent: We have received b"GET /\r\n\r\n" from the client.
|
||||||
HttpEvent: We have received request headers
|
HttpEvent: We have received request headers
|
||||||
@ -339,22 +339,22 @@ class HTTPLayer(layer.Layer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"HTTPLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
return f"HttpLayer(conns: {len(self.connections)}, queue: {[type(e).__name__ for e in self.command_queue]})"
|
||||||
|
|
||||||
def _handle_event(self, event: events.Event):
|
def _handle_event(self, event: events.Event):
|
||||||
if isinstance(event, events.Start):
|
if isinstance(event, events.Start):
|
||||||
return
|
return
|
||||||
elif isinstance(event, events.CommandReply):
|
elif isinstance(event, events.CommandReply):
|
||||||
try:
|
stream = self.stream_by_command.pop(event.command)
|
||||||
stream = self.stream_by_command.pop(event.command)
|
|
||||||
except KeyError:
|
|
||||||
raise
|
|
||||||
self.event_to_child(stream, event)
|
self.event_to_child(stream, event)
|
||||||
elif isinstance(event, events.ConnectionEvent):
|
elif isinstance(event, events.ConnectionEvent):
|
||||||
if event.connection == self.context.server and self.context.server not in self.connections:
|
if event.connection == self.context.server and self.context.server not in self.connections:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
handler = self.connections[event.connection]
|
try:
|
||||||
|
handler = self.connections[event.connection]
|
||||||
|
except KeyError:
|
||||||
|
raise
|
||||||
self.event_to_child(handler, event)
|
self.event_to_child(handler, event)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
@ -388,7 +388,7 @@ class HTTPLayer(layer.Layer):
|
|||||||
|
|
||||||
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
|
def get_connection(self, event: GetHttpConnection, *, reuse: bool = True):
|
||||||
# Do we already have a connection we can re-use?
|
# Do we already have a connection we can re-use?
|
||||||
for connection, layer in self.connections.items():
|
for connection in self.connections:
|
||||||
connection_suitable = (
|
connection_suitable = (
|
||||||
reuse and
|
reuse and
|
||||||
event.connection_spec_matches(connection) and
|
event.connection_spec_matches(connection) and
|
||||||
@ -402,7 +402,7 @@ class HTTPLayer(layer.Layer):
|
|||||||
self.waiting_for_establishment[connection].append(event)
|
self.waiting_for_establishment[connection].append(event)
|
||||||
else:
|
else:
|
||||||
stream = self.stream_by_command.pop(event)
|
stream = self.stream_by_command.pop(event)
|
||||||
self.event_to_child(stream, GetHttpConnectionReply(event, (layer, None)))
|
self.event_to_child(stream, GetHttpConnectionReply(event, (connection, None)))
|
||||||
return
|
return
|
||||||
|
|
||||||
can_reuse_context_connection = (
|
can_reuse_context_connection = (
|
||||||
@ -414,8 +414,11 @@ class HTTPLayer(layer.Layer):
|
|||||||
layer = HttpClient(context)
|
layer = HttpClient(context)
|
||||||
if not can_reuse_context_connection:
|
if not can_reuse_context_connection:
|
||||||
context.server = Server(event.address)
|
context.server = Server(event.address)
|
||||||
|
if context.options.http2:
|
||||||
|
context.server.alpn_offers = tls.HTTP_ALPNS
|
||||||
|
else:
|
||||||
|
context.server.alpn_offers = tls.HTTP1_ALPNS
|
||||||
if event.tls:
|
if event.tls:
|
||||||
context.server.tls = True
|
|
||||||
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
|
# TODO: This is a bit ugly, let's make up nicer syntax, e.g. using __truediv__
|
||||||
orig = layer
|
orig = layer
|
||||||
layer = tls.ServerTLSLayer(context)
|
layer = tls.ServerTLSLayer(context)
|
||||||
@ -432,7 +435,6 @@ class HTTPLayer(layer.Layer):
|
|||||||
|
|
||||||
if command.err:
|
if command.err:
|
||||||
reply = (None, command.err)
|
reply = (None, command.err)
|
||||||
self.connections.pop(command.connection)
|
|
||||||
else:
|
else:
|
||||||
reply = (command.connection, None)
|
reply = (command.connection, None)
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from mitmproxy import http
|
from mitmproxy import http
|
||||||
from mitmproxy.proxy2 import commands
|
from mitmproxy.proxy2 import commands
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from mitmproxy.proxy2 import commands, events, layer
|
|||||||
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
from mitmproxy.proxy2.context import Client, Connection, Context, Server
|
||||||
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
|
from mitmproxy.proxy2.layers.http._base import ReceiveHttp, StreamId
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
|
from mitmproxy.utils import human
|
||||||
from ._base import HttpConnection
|
from ._base import HttpConnection
|
||||||
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
|
||||||
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
|
||||||
@ -131,7 +132,7 @@ class Http1Server(Http1Connection):
|
|||||||
elif isinstance(event, ResponseEndOfMessage):
|
elif isinstance(event, ResponseEndOfMessage):
|
||||||
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
if "chunked" in self.response.headers.get("transfer-encoding", "").lower():
|
||||||
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
yield commands.SendData(self.conn, b"0\r\n\r\n")
|
||||||
elif http1.expected_http_body_size(self.request, self.response) == -1:
|
elif http1.expected_http_body_size(self.request, self.response) == -1 or self.request.first_line_format == "authority":
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
yield from self.mark_done(response=True)
|
yield from self.mark_done(response=True)
|
||||||
elif isinstance(event, ResponseProtocolError):
|
elif isinstance(event, ResponseProtocolError):
|
||||||
@ -162,7 +163,13 @@ class Http1Server(Http1Connection):
|
|||||||
request_head = self.buf.maybe_extract_lines()
|
request_head = self.buf.maybe_extract_lines()
|
||||||
if request_head:
|
if request_head:
|
||||||
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
request_head = [bytes(x) for x in request_head] # TODO: Make url.parse compatible with bytearrays
|
||||||
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
try:
|
||||||
|
self.request = http.HTTPRequest.wrap(http1_sansio.read_request_head(request_head))
|
||||||
|
except ValueError as e:
|
||||||
|
yield commands.Log(f"{human.format_address(self.conn.address)}: {e}")
|
||||||
|
yield commands.CloseConnection(self.conn)
|
||||||
|
self.state = self.wait
|
||||||
|
return
|
||||||
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
|
yield ReceiveHttp(RequestHeaders(self.stream_id, self.request))
|
||||||
|
|
||||||
if self.request.first_line_format == "authority":
|
if self.request.first_line_format == "authority":
|
||||||
@ -181,6 +188,7 @@ class Http1Server(Http1Connection):
|
|||||||
elif isinstance(event, events.ConnectionClosed):
|
elif isinstance(event, events.ConnectionClosed):
|
||||||
if bytes(self.buf).strip():
|
if bytes(self.buf).strip():
|
||||||
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
|
yield commands.Log(f"Client closed connection before sending request headers: {bytes(self.buf)}")
|
||||||
|
yield commands.Log(f"Receive Buffer: {bytes(self.buf)}", level="debug")
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected event: {event}")
|
raise ValueError(f"Unexpected event: {event}")
|
||||||
|
@ -2,6 +2,7 @@ from mitmproxy import platform
|
|||||||
from mitmproxy.net import server_spec
|
from mitmproxy.net import server_spec
|
||||||
from mitmproxy.proxy2 import commands, events, layer
|
from mitmproxy.proxy2 import commands, events, layer
|
||||||
from mitmproxy.proxy2.context import Server
|
from mitmproxy.proxy2.context import Server
|
||||||
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.proxy2.utils import expect
|
from mitmproxy.proxy2.utils import expect
|
||||||
|
|
||||||
|
|
||||||
@ -11,10 +12,11 @@ class ReverseProxy(layer.Layer):
|
|||||||
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
spec = server_spec.parse_with_mode(self.context.options.mode)[1]
|
||||||
self.context.server = Server(spec.address)
|
self.context.server = Server(spec.address)
|
||||||
if spec.scheme not in ("http", "tcp"):
|
if spec.scheme not in ("http", "tcp"):
|
||||||
self.context.server.tls = True
|
|
||||||
if not self.context.options.keep_host_header:
|
if not self.context.options.keep_host_header:
|
||||||
self.context.server.sni = spec.address[0]
|
self.context.server.sni = spec.address[0]
|
||||||
child_layer = layer.NextLayer(self.context)
|
child_layer = tls.ServerTLSLayer(self.context)
|
||||||
|
else:
|
||||||
|
child_layer = layer.NextLayer(self.context)
|
||||||
self._handle_event = child_layer.handle_event
|
self._handle_event = child_layer.handle_event
|
||||||
yield from child_layer.handle_event(event)
|
yield from child_layer.handle_event(event)
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from mitmproxy import flow, tcp
|
from mitmproxy import flow, tcp
|
||||||
|
@ -90,7 +90,8 @@ def parse_client_hello(data: bytes) -> Optional[net_tls.ClientHello]:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
HTTP_ALPNS = (b"h2", b"http/1.1", b"http/1.0", b"http/0.9")
|
HTTP1_ALPNS = (b"http/1.1", b"http/1.0", b"http/0.9")
|
||||||
|
HTTP_ALPNS = (b"h2",) + HTTP1_ALPNS
|
||||||
|
|
||||||
|
|
||||||
# We need these classes as hooks can only have one argument at the moment.
|
# We need these classes as hooks can only have one argument at the moment.
|
||||||
@ -122,10 +123,16 @@ class _TLSLayer(layer.Layer):
|
|||||||
tls: SSL.Connection = None
|
tls: SSL.Connection = None
|
||||||
"""The OpenSSL connection object"""
|
"""The OpenSSL connection object"""
|
||||||
child_layer: layer.Layer
|
child_layer: layer.Layer
|
||||||
|
errored: bool = False
|
||||||
|
"""Have we errored yet?"""
|
||||||
|
|
||||||
def __init__(self, context: context.Context, conn: context.Connection):
|
def __init__(self, context: context.Context, conn: context.Connection):
|
||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
|
|
||||||
|
assert not conn.tls
|
||||||
|
conn.tls = True
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
|
self.child_layer = layer.NextLayer(self.context)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
if not self.tls:
|
if not self.tls:
|
||||||
@ -138,8 +145,6 @@ class _TLSLayer(layer.Layer):
|
|||||||
|
|
||||||
def start_tls(self, initial_data: bytes = b""):
|
def start_tls(self, initial_data: bytes = b""):
|
||||||
assert not self.tls
|
assert not self.tls
|
||||||
assert self.conn.connected
|
|
||||||
self.conn.tls = True
|
|
||||||
|
|
||||||
tls_start = TlsStartData(self.conn, self.context)
|
tls_start = TlsStartData(self.conn, self.context)
|
||||||
yield TlsStartHook(tls_start)
|
yield TlsStartHook(tls_start)
|
||||||
@ -227,7 +232,7 @@ class _TLSLayer(layer.Layer):
|
|||||||
)
|
)
|
||||||
if close:
|
if close:
|
||||||
self.conn.state &= ~context.ConnectionState.CAN_READ
|
self.conn.state &= ~context.ConnectionState.CAN_READ
|
||||||
yield commands.Log(f"TLS close_notify {self.conn}")
|
yield commands.Log(f"TLS close_notify {self.conn}", level="debug")
|
||||||
yield from self.event_to_child(
|
yield from self.event_to_child(
|
||||||
events.ConnectionClosed(self.conn)
|
events.ConnectionClosed(self.conn)
|
||||||
)
|
)
|
||||||
@ -252,12 +257,13 @@ class _TLSLayer(layer.Layer):
|
|||||||
pass # We have already dispatched a ConnectionClosed to the child layer.
|
pass # We have already dispatched a ConnectionClosed to the child layer.
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
else:
|
elif not self.errored:
|
||||||
yield from self.on_handshake_error("connection closed without notice")
|
yield from self.on_handshake_error("connection closed without notice")
|
||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
def on_handshake_error(self, err: str) -> layer.CommandGenerator[None]:
|
||||||
|
self.errored = True
|
||||||
yield commands.CloseConnection(self.conn)
|
yield commands.CloseConnection(self.conn)
|
||||||
|
|
||||||
|
|
||||||
@ -269,13 +275,12 @@ class ServerTLSLayer(_TLSLayer):
|
|||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
super().__init__(context, context.server)
|
super().__init__(context, context.server)
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
|
||||||
|
|
||||||
@expect(events.Start)
|
@expect(events.Start)
|
||||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
def state_start(self, event) -> layer.CommandGenerator[None]:
|
||||||
self.context.server.tls = True
|
|
||||||
if self.context.server.connected:
|
if self.context.server.connected:
|
||||||
yield from self.start_tls()
|
yield from self.start_tls()
|
||||||
|
yield from self.event_to_child(event)
|
||||||
self._handle_event = super()._handle_event
|
self._handle_event = super()._handle_event
|
||||||
|
|
||||||
_handle_event = state_start
|
_handle_event = state_start
|
||||||
@ -327,20 +332,12 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
recv_buffer: bytearray
|
recv_buffer: bytearray
|
||||||
|
server_tls_available: bool
|
||||||
|
|
||||||
def __init__(self, context: context.Context):
|
def __init__(self, context: context.Context):
|
||||||
assert isinstance(context.layers[-1], ServerTLSLayer)
|
super().__init__(context, context.client)
|
||||||
super().__init__(context, self.context.client)
|
self.server_tls_available = isinstance(self.context.layers[-2], ServerTLSLayer)
|
||||||
self.recv_buffer = bytearray()
|
self.recv_buffer = bytearray()
|
||||||
self.child_layer = layer.NextLayer(self.context)
|
|
||||||
|
|
||||||
@expect(events.Start)
|
|
||||||
def state_start(self, _) -> layer.CommandGenerator[None]:
|
|
||||||
self.context.client.tls = True
|
|
||||||
self._handle_event = self.state_wait_for_clienthello
|
|
||||||
yield from ()
|
|
||||||
|
|
||||||
_handle_event = state_start
|
|
||||||
|
|
||||||
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def state_wait_for_clienthello(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
if isinstance(event, events.DataReceived) and event.connection == self.conn:
|
||||||
@ -361,8 +358,8 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established:
|
if tls_clienthello.establish_server_tls_first and not self.context.server.tls_established:
|
||||||
err = yield from self.start_server_tls()
|
err = yield from self.start_server_tls()
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log("Unable to establish TLS connection with server. "
|
yield commands.Log(f"Unable to establish TLS connection with server ({err}). "
|
||||||
"Trying to establish TLS with client anyway.")
|
f"Trying to establish TLS with client anyway.")
|
||||||
|
|
||||||
yield from self.start_tls(bytes(self.recv_buffer))
|
yield from self.start_tls(bytes(self.recv_buffer))
|
||||||
self.recv_buffer.clear()
|
self.recv_buffer.clear()
|
||||||
@ -373,14 +370,17 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
else:
|
else:
|
||||||
yield from self.event_to_child(event)
|
yield from self.event_to_child(event)
|
||||||
|
|
||||||
|
_handle_event = state_wait_for_clienthello
|
||||||
|
|
||||||
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
|
def start_server_tls(self) -> layer.CommandGenerator[Optional[str]]:
|
||||||
"""
|
"""
|
||||||
We often need information from the upstream connection to establish TLS with the client.
|
We often need information from the upstream connection to establish TLS with the client.
|
||||||
For example, we need to check if the client does ALPN or not.
|
For example, we need to check if the client does ALPN or not.
|
||||||
"""
|
"""
|
||||||
|
if not self.server_tls_available:
|
||||||
|
return "No server TLS available."
|
||||||
err = yield commands.OpenConnection(self.context.server)
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.Log(f"Cannot establish server connection: {err}")
|
|
||||||
return err
|
return err
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
@ -400,3 +400,14 @@ class ClientTLSLayer(_TLSLayer):
|
|||||||
level="warn"
|
level="warn"
|
||||||
)
|
)
|
||||||
yield from super().on_handshake_error(err)
|
yield from super().on_handshake_error(err)
|
||||||
|
|
||||||
|
|
||||||
|
class MockTLSLayer(_TLSLayer):
|
||||||
|
"""Mock layer to disable actual TLS and use cleartext in tests.
|
||||||
|
|
||||||
|
Use like so:
|
||||||
|
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ctx: context.Context):
|
||||||
|
super().__init__(ctx, context.Server(None))
|
||||||
|
@ -15,6 +15,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
@ -27,11 +28,6 @@ from mitmproxy.utils import human
|
|||||||
from mitmproxy.utils.data import pkg_data
|
from mitmproxy.utils.data import pkg_data
|
||||||
|
|
||||||
|
|
||||||
class StreamIO(typing.NamedTuple):
|
|
||||||
r: asyncio.StreamReader
|
|
||||||
w: asyncio.StreamWriter
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutWatchdog:
|
class TimeoutWatchdog:
|
||||||
last_activity: float
|
last_activity: float
|
||||||
CONNECTION_TIMEOUT = 10 * 60
|
CONNECTION_TIMEOUT = 10 * 60
|
||||||
@ -72,8 +68,15 @@ class TimeoutWatchdog:
|
|||||||
self.can_timeout.set()
|
self.can_timeout.set()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectionIO:
|
||||||
|
handler: typing.Optional[asyncio.Task] = None
|
||||||
|
reader: typing.Optional[asyncio.StreamReader] = None
|
||||||
|
writer: typing.Optional[asyncio.StreamWriter] = None
|
||||||
|
|
||||||
|
|
||||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||||
transports: typing.MutableMapping[Connection, StreamIO]
|
transports: typing.MutableMapping[Connection, ConnectionIO]
|
||||||
timeout_watchdog: TimeoutWatchdog
|
timeout_watchdog: TimeoutWatchdog
|
||||||
|
|
||||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||||
@ -81,92 +84,85 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
|
|
||||||
self.client = Client(addr)
|
self.client = Client(addr)
|
||||||
self.context = Context(self.client, options)
|
self.context = Context(self.client, options)
|
||||||
self.layer = layer.NextLayer(self.context)
|
self.transports = {
|
||||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
self.client: ConnectionIO(handler=None, reader=reader, writer=writer)
|
||||||
|
}
|
||||||
|
|
||||||
# Ask for the first layer right away.
|
# Ask for the first layer right away.
|
||||||
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
||||||
# on protocols that start with a server greeting.
|
# on protocols that start with a server greeting.
|
||||||
self.layer.ask_now()
|
self.layer = layer.NextLayer(self.context, ask_on_start=True)
|
||||||
|
|
||||||
self.transports = {
|
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||||
self.client: StreamIO(reader, writer)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def handle_client(self) -> None:
|
async def handle_client(self) -> None:
|
||||||
# FIXME: Work around log suppression in core.
|
# Hack: Work around log suppression in core.
|
||||||
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
||||||
|
asyncio.get_event_loop().set_debug(True)
|
||||||
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
|
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
|
||||||
|
|
||||||
self.log("[sans-io] clientconnect")
|
self.log("[sans-io] clientconnect")
|
||||||
|
|
||||||
|
handler = asyncio.create_task(
|
||||||
|
self.handle_connection(self.client)
|
||||||
|
)
|
||||||
|
self.transports[self.client].handler = handler
|
||||||
self.server_event(events.Start())
|
self.server_event(events.Start())
|
||||||
await self.handle_connection(self.client)
|
await handler
|
||||||
|
|
||||||
self.log("[sans-io] clientdisconnected")
|
self.log("[sans-io] clientdisconnected")
|
||||||
watch.cancel()
|
watch.cancel()
|
||||||
|
|
||||||
if self.transports:
|
if self.transports:
|
||||||
self.log("[sans-io] closing transports...")
|
self.log("[sans-io] closing transports...")
|
||||||
await asyncio.wait([
|
for x in self.transports.values():
|
||||||
self.close_connection(x)
|
x.handler.cancel()
|
||||||
for x in self.transports
|
await asyncio.wait([x.handler for x in self.transports.values()])
|
||||||
])
|
|
||||||
self.log("[sans-io] transports closed!")
|
self.log("[sans-io] transports closed!")
|
||||||
|
|
||||||
async def on_timeout(self) -> None:
|
|
||||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
|
||||||
await self.close_connection(self.client)
|
|
||||||
|
|
||||||
async def close_connection(self, connection: Connection) -> None:
|
|
||||||
self.log(f"closing {connection}", "debug")
|
|
||||||
connection.state = ConnectionState.CLOSED
|
|
||||||
io = self.transports.pop(connection)
|
|
||||||
io.w.close()
|
|
||||||
await io.w.wait_closed()
|
|
||||||
|
|
||||||
async def shutdown_connection(self, connection: Connection) -> None:
|
|
||||||
assert connection.state & ConnectionState.CAN_WRITE
|
|
||||||
io = self.transports[connection]
|
|
||||||
self.log(f"shutting down {connection}", "debug")
|
|
||||||
|
|
||||||
io.w.write_eof()
|
|
||||||
connection.state &= ~ConnectionState.CAN_WRITE
|
|
||||||
|
|
||||||
async def handle_connection(self, connection: Connection) -> None:
|
|
||||||
reader, writer = self.transports[connection]
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
data = await reader.read(65535)
|
|
||||||
except socket.error:
|
|
||||||
data = b""
|
|
||||||
if data:
|
|
||||||
self.server_event(events.DataReceived(connection, data))
|
|
||||||
else:
|
|
||||||
if connection.state is ConnectionState.CAN_READ:
|
|
||||||
await self.close_connection(connection)
|
|
||||||
else:
|
|
||||||
connection.state &= ~ConnectionState.CAN_READ
|
|
||||||
self.server_event(events.ConnectionClosed(connection))
|
|
||||||
break
|
|
||||||
|
|
||||||
async def open_connection(self, command: commands.OpenConnection) -> None:
|
async def open_connection(self, command: commands.OpenConnection) -> None:
|
||||||
if not command.connection.address:
|
if not command.connection.address:
|
||||||
raise ValueError("Cannot open connection, no hostname given.")
|
raise ValueError("Cannot open connection, no hostname given.")
|
||||||
assert command.connection not in self.transports
|
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.open_connection(
|
reader, writer = await asyncio.open_connection(*command.connection.address)
|
||||||
*command.connection.address
|
except (IOError, asyncio.CancelledError) as e:
|
||||||
)
|
|
||||||
except IOError as e:
|
|
||||||
self.server_event(events.OpenConnectionReply(command, str(e)))
|
self.server_event(events.OpenConnectionReply(command, str(e)))
|
||||||
else:
|
else:
|
||||||
self.log("serverconnect")
|
self.log(f"serverconnect {command.connection.address}")
|
||||||
self.transports[command.connection] = StreamIO(reader, writer)
|
self.transports[command.connection].reader = reader
|
||||||
|
self.transports[command.connection].writer = writer
|
||||||
command.connection.state = ConnectionState.OPEN
|
command.connection.state = ConnectionState.OPEN
|
||||||
self.server_event(events.OpenConnectionReply(command, None))
|
self.server_event(events.OpenConnectionReply(command, None))
|
||||||
await self.handle_connection(command.connection)
|
try:
|
||||||
self.log("serverdisconnected")
|
await self.handle_connection(command.connection)
|
||||||
|
finally:
|
||||||
|
self.log("serverdisconnected")
|
||||||
|
|
||||||
|
async def handle_connection(self, connection: Connection) -> None:
|
||||||
|
reader = self.transports[connection].reader
|
||||||
|
assert reader
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = await reader.read(65535)
|
||||||
|
except (socket.error, asyncio.CancelledError):
|
||||||
|
data = b""
|
||||||
|
if data:
|
||||||
|
self.server_event(events.DataReceived(connection, data))
|
||||||
|
else:
|
||||||
|
connection.state &= ~ConnectionState.CAN_READ
|
||||||
|
self.server_event(events.ConnectionClosed(connection))
|
||||||
|
if connection.state is ConnectionState.CLOSED:
|
||||||
|
self.transports[connection].handler.cancel()
|
||||||
|
await asyncio.Event().wait() # wait for cancellation
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
connection.state = ConnectionState.CLOSED
|
||||||
|
io = self.transports.pop(connection)
|
||||||
|
io.writer.close()
|
||||||
|
|
||||||
|
async def on_timeout(self) -> None:
|
||||||
|
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||||
|
self.transports[self.client].handler.cancel()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||||
@ -180,31 +176,24 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
try:
|
try:
|
||||||
layer_commands = self.layer.handle_event(event)
|
layer_commands = self.layer.handle_event(event)
|
||||||
for command in layer_commands:
|
for command in layer_commands:
|
||||||
|
|
||||||
if isinstance(command, commands.OpenConnection):
|
if isinstance(command, commands.OpenConnection):
|
||||||
asyncio.ensure_future(
|
assert command.connection not in self.transports
|
||||||
|
handler = asyncio.create_task(
|
||||||
self.open_connection(command)
|
self.open_connection(command)
|
||||||
)
|
)
|
||||||
|
self.transports[command.connection] = ConnectionIO(handler=handler)
|
||||||
|
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
|
||||||
|
return # The connection has already been closed.
|
||||||
elif isinstance(command, commands.SendData):
|
elif isinstance(command, commands.SendData):
|
||||||
try:
|
self.transports[command.connection].writer.write(command.data)
|
||||||
io = self.transports[command.connection]
|
|
||||||
except KeyError:
|
|
||||||
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
|
|
||||||
else:
|
|
||||||
io.w.write(command.data)
|
|
||||||
elif isinstance(command, commands.CloseConnection):
|
elif isinstance(command, commands.CloseConnection):
|
||||||
if command.connection == self.client:
|
self.close_our_end(command.connection)
|
||||||
asyncio.ensure_future(
|
|
||||||
self.close_connection(command.connection)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
asyncio.ensure_future(
|
|
||||||
self.shutdown_connection(command.connection)
|
|
||||||
)
|
|
||||||
elif isinstance(command, commands.GetSocket):
|
elif isinstance(command, commands.GetSocket):
|
||||||
socket = self.transports[command.connection].w.get_extra_info("socket")
|
socket = self.transports[command.connection].writer.get_extra_info("socket")
|
||||||
self.server_event(events.GetSocketReply(command, socket))
|
self.server_event(events.GetSocketReply(command, socket))
|
||||||
elif isinstance(command, commands.Hook):
|
elif isinstance(command, commands.Hook):
|
||||||
asyncio.ensure_future(
|
asyncio.create_task(
|
||||||
self.handle_hook(command)
|
self.handle_hook(command)
|
||||||
)
|
)
|
||||||
elif isinstance(command, commands.Log):
|
elif isinstance(command, commands.Log):
|
||||||
@ -214,6 +203,22 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
|||||||
except Exception:
|
except Exception:
|
||||||
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
||||||
|
|
||||||
|
def close_our_end(self, connection):
|
||||||
|
assert connection.state & ConnectionState.CAN_WRITE
|
||||||
|
self.log(f"shutting down {connection}", "debug")
|
||||||
|
try:
|
||||||
|
self.transports[connection].writer.write_eof()
|
||||||
|
except socket.error:
|
||||||
|
connection.state = ConnectionState.CLOSED
|
||||||
|
connection.state &= ~ConnectionState.CAN_WRITE
|
||||||
|
|
||||||
|
# if we are closing the client connection, we should destroy everything.
|
||||||
|
if connection == self.client:
|
||||||
|
self.transports[connection].handler.cancel()
|
||||||
|
# If we have already received a close, let's finish everything.
|
||||||
|
elif connection.state is ConnectionState.CLOSED:
|
||||||
|
self.transports[connection].handler.cancel()
|
||||||
|
|
||||||
|
|
||||||
class SimpleConnectionHandler(ConnectionHandler):
|
class SimpleConnectionHandler(ConnectionHandler):
|
||||||
"""Simple handler that does not really process any hooks."""
|
"""Simple handler that does not really process any hooks."""
|
||||||
@ -253,10 +258,10 @@ if __name__ == "__main__":
|
|||||||
async def handle(reader, writer):
|
async def handle(reader, writer):
|
||||||
layer_stack = [
|
layer_stack = [
|
||||||
lambda ctx: layers.ServerTLSLayer(ctx),
|
lambda ctx: layers.ServerTLSLayer(ctx),
|
||||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.regular),
|
lambda ctx: layers.HttpLayer(ctx, HTTPMode.regular),
|
||||||
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
lambda ctx: setattr(ctx.server, "tls", True) or layers.ServerTLSLayer(ctx),
|
||||||
lambda ctx: layers.ClientTLSLayer(ctx),
|
lambda ctx: layers.ClientTLSLayer(ctx),
|
||||||
lambda ctx: layers.HTTPLayer(ctx, HTTPMode.transparent)
|
lambda ctx: layers.HttpLayer(ctx, HTTPMode.transparent)
|
||||||
]
|
]
|
||||||
|
|
||||||
def next_layer(nl: layer.NextLayer):
|
def next_layer(nl: layer.NextLayer):
|
||||||
|
@ -6,13 +6,15 @@ import pytest
|
|||||||
from mitmproxy.net.websockets import Frame, OPCODE
|
from mitmproxy.net.websockets import Frame, OPCODE
|
||||||
from mitmproxy.proxy2 import commands, events
|
from mitmproxy.proxy2 import commands, events
|
||||||
from mitmproxy.proxy2.layers.old import websocket
|
from mitmproxy.proxy2.layers.old import websocket
|
||||||
|
|
||||||
|
from mitmproxy.proxy2.context import ConnectionState
|
||||||
from mitmproxy.test import tflow
|
from mitmproxy.test import tflow
|
||||||
from .. import tutils
|
from .. import tutils
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ws_playbook(tctx):
|
def ws_playbook(tctx):
|
||||||
tctx.server.connected = True
|
tctx.server.state = ConnectionState.OPEN
|
||||||
playbook = tutils.Playbook(
|
playbook = tutils.Playbook(
|
||||||
websocket.WebsocketLayer(
|
websocket.WebsocketLayer(
|
||||||
tctx,
|
tctx,
|
||||||
|
@ -6,7 +6,7 @@ from mitmproxy.proxy2 import layer
|
|||||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||||
from mitmproxy.proxy2.layers import http, tls
|
from mitmproxy.proxy2.layers import http, tls
|
||||||
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_establish_server_tls, reply_next_layer
|
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
|
||||||
|
|
||||||
|
|
||||||
def test_http_proxy(tctx):
|
def test_http_proxy(tctx):
|
||||||
@ -14,7 +14,7 @@ def test_http_proxy(tctx):
|
|||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< http.HttpRequestHeadersHook(flow)
|
<< http.HttpRequestHeadersHook(flow)
|
||||||
>> reply()
|
>> reply()
|
||||||
@ -39,7 +39,7 @@ def test_https_proxy(strategy, tctx):
|
|||||||
"""Test a CONNECT request, followed by a HTTP GET /"""
|
"""Test a CONNECT request, followed by a HTTP GET /"""
|
||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular))
|
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
|
||||||
tctx.options.connection_strategy = strategy
|
tctx.options.connection_strategy = strategy
|
||||||
|
|
||||||
(playbook
|
(playbook
|
||||||
@ -54,7 +54,7 @@ def test_https_proxy(strategy, tctx):
|
|||||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||||
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET /foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< layer.NextLayerHook(Placeholder())
|
<< layer.NextLayerHook(Placeholder())
|
||||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||||
<< http.HttpRequestHeadersHook(flow)
|
<< http.HttpRequestHeadersHook(flow)
|
||||||
>> reply()
|
>> reply()
|
||||||
<< http.HttpRequestHook(flow)
|
<< http.HttpRequestHook(flow)
|
||||||
@ -77,12 +77,15 @@ def test_https_proxy(strategy, tctx):
|
|||||||
@pytest.mark.parametrize("https_client", [False, True])
|
@pytest.mark.parametrize("https_client", [False, True])
|
||||||
@pytest.mark.parametrize("https_server", [False, True])
|
@pytest.mark.parametrize("https_server", [False, True])
|
||||||
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
|
@pytest.mark.parametrize("strategy", ["lazy", "eager"])
|
||||||
def test_redirect(strategy, https_server, https_client, tctx):
|
def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
|
||||||
"""Test redirects between http:// and https:// in regular proxy mode."""
|
"""Test redirects between http:// and https:// in regular proxy mode."""
|
||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
tctx.options.connection_strategy = strategy
|
tctx.options.connection_strategy = strategy
|
||||||
p = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
|
||||||
|
if https_server:
|
||||||
|
monkeypatch.setattr(tls, "ServerTLSLayer", tls.MockTLSLayer)
|
||||||
|
|
||||||
def redirect(flow: HTTPFlow):
|
def redirect(flow: HTTPFlow):
|
||||||
if https_server:
|
if https_server:
|
||||||
@ -98,16 +101,13 @@ def test_redirect(strategy, https_server, https_client, tctx):
|
|||||||
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
p << SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||||
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
p >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
p << layer.NextLayerHook(Placeholder())
|
p << layer.NextLayerHook(Placeholder())
|
||||||
p >> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
p >> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||||
else:
|
else:
|
||||||
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
p >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
p << http.HttpRequestHook(flow)
|
p << http.HttpRequestHook(flow)
|
||||||
p >> reply(side_effect=redirect)
|
p >> reply(side_effect=redirect)
|
||||||
p << OpenConnection(server)
|
p << OpenConnection(server)
|
||||||
p >> reply(None)
|
p >> reply(None)
|
||||||
if https_server:
|
|
||||||
pass # p << tls.EstablishServerTLS(server)
|
|
||||||
# p >> reply_establish_server_tls()
|
|
||||||
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
p << SendData(server, b"GET / HTTP/1.1\r\nHost: redirected.site\r\n\r\n")
|
||||||
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
p >> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||||
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
p << SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
|
||||||
@ -123,7 +123,7 @@ def test_multiple_server_connections(tctx):
|
|||||||
"""Test multiple requests being rewritten to different targets."""
|
"""Test multiple requests being rewritten to different targets."""
|
||||||
server1 = Placeholder()
|
server1 = Placeholder()
|
||||||
server2 = Placeholder()
|
server2 = Placeholder()
|
||||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
|
||||||
def redirect(to: str):
|
def redirect(to: str):
|
||||||
def side_effect(flow: HTTPFlow):
|
def side_effect(flow: HTTPFlow):
|
||||||
@ -164,7 +164,7 @@ def test_http_reply_from_proxy(tctx):
|
|||||||
flow.response = HTTPResponse.make(418)
|
flow.response = HTTPResponse.make(418)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< http.HttpRequestHook(Placeholder())
|
<< http.HttpRequestHook(Placeholder())
|
||||||
>> reply(side_effect=reply_from_proxy)
|
>> reply(side_effect=reply_from_proxy)
|
||||||
@ -176,7 +176,7 @@ def test_response_until_eof(tctx):
|
|||||||
"""Test scenario where the server response body is terminated by EOF."""
|
"""Test scenario where the server response body is terminated by EOF."""
|
||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< OpenConnection(server)
|
<< OpenConnection(server)
|
||||||
>> reply(None)
|
>> reply(None)
|
||||||
@ -197,7 +197,7 @@ def test_disconnect_while_intercept(tctx):
|
|||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
|
||||||
<< http.HttpConnectHook(Placeholder())
|
<< http.HttpConnectHook(Placeholder())
|
||||||
>> reply()
|
>> reply()
|
||||||
@ -206,7 +206,7 @@ def test_disconnect_while_intercept(tctx):
|
|||||||
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
<< SendData(tctx.client, b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||||
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< layer.NextLayerHook(Placeholder())
|
<< layer.NextLayerHook(Placeholder())
|
||||||
>> reply_next_layer(lambda ctx: http.HTTPLayer(ctx, HTTPMode.transparent))
|
>> reply_next_layer(lambda ctx: http.HttpLayer(ctx, HTTPMode.transparent))
|
||||||
<< http.HttpRequestHook(flow)
|
<< http.HttpRequestHook(flow)
|
||||||
>> ConnectionClosed(server1)
|
>> ConnectionClosed(server1)
|
||||||
>> reply(to=-2)
|
>> reply(to=-2)
|
||||||
@ -229,7 +229,7 @@ def test_response_streaming(tctx):
|
|||||||
flow.response.stream = lambda x: x.upper()
|
flow.response.stream = lambda x: x.upper()
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/largefile HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
<< OpenConnection(server)
|
<< OpenConnection(server)
|
||||||
>> reply(None)
|
>> reply(None)
|
||||||
@ -252,7 +252,7 @@ def test_request_streaming(tctx, response):
|
|||||||
"""
|
"""
|
||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
|
|
||||||
def enable_streaming(flow: HTTPFlow):
|
def enable_streaming(flow: HTTPFlow):
|
||||||
flow.request.stream = lambda x: x.upper()
|
flow.request.stream = lambda x: x.upper()
|
||||||
@ -324,7 +324,7 @@ def test_server_aborts(tctx, data):
|
|||||||
server = Placeholder()
|
server = Placeholder()
|
||||||
flow = Placeholder()
|
flow = Placeholder()
|
||||||
err = Placeholder()
|
err = Placeholder()
|
||||||
playbook = Playbook(http.HTTPLayer(tctx, HTTPMode.regular), hooks=False)
|
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
|
||||||
|
from mitmproxy.proxy2.context import ConnectionState
|
||||||
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
|
||||||
from mitmproxy.proxy2.layers import tcp
|
from mitmproxy.proxy2.layers import tcp
|
||||||
from ..tutils import Placeholder, Playbook, reply
|
from ..tutils import Placeholder, Playbook, reply
|
||||||
@ -14,7 +15,7 @@ def test_open_connection(tctx):
|
|||||||
<< OpenConnection(tctx.server)
|
<< OpenConnection(tctx.server)
|
||||||
)
|
)
|
||||||
|
|
||||||
tctx.server.connected = True
|
tctx.server.state = ConnectionState.OPEN
|
||||||
assert (
|
assert (
|
||||||
Playbook(tcp.TCPLayer(tctx, True))
|
Playbook(tcp.TCPLayer(tctx, True))
|
||||||
<< None
|
<< None
|
||||||
|
@ -5,6 +5,7 @@ import pytest
|
|||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from mitmproxy.proxy2 import commands, context, events, layer
|
from mitmproxy.proxy2 import commands, context, events, layer
|
||||||
|
from mitmproxy.proxy2.context import ConnectionState
|
||||||
from mitmproxy.proxy2.layers import tls
|
from mitmproxy.proxy2.layers import tls
|
||||||
from mitmproxy.utils import data
|
from mitmproxy.utils import data
|
||||||
from test.mitmproxy.proxy2 import tutils
|
from test.mitmproxy.proxy2 import tutils
|
||||||
@ -109,7 +110,6 @@ class TlsEchoLayer(tutils.EchoLayer):
|
|||||||
|
|
||||||
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
def _handle_event(self, event: events.Event) -> layer.CommandGenerator[None]:
|
||||||
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
|
if isinstance(event, events.DataReceived) and event.data == b"open-connection":
|
||||||
# noinspection PyTypeChecker
|
|
||||||
err = yield commands.OpenConnection(self.context.server)
|
err = yield commands.OpenConnection(self.context.server)
|
||||||
if err:
|
if err:
|
||||||
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
|
yield commands.SendData(event.connection, f"open-connection failed: {err}".encode())
|
||||||
@ -180,23 +180,20 @@ def reply_tls_start(alpn: typing.Optional[bytes] = None, *args, **kwargs) -> tut
|
|||||||
|
|
||||||
|
|
||||||
class TestServerTLS:
|
class TestServerTLS:
|
||||||
def test_no_tls(self, tctx: context.Context):
|
def test_not_connected(self, tctx: context.Context):
|
||||||
"""Test TLS layer without TLS"""
|
"""Test that we don't do anything if no server connection exists."""
|
||||||
layer = tls.ServerTLSLayer(tctx)
|
layer = tls.ServerTLSLayer(tctx)
|
||||||
layer.child_layer = TlsEchoLayer(tctx)
|
layer.child_layer = TlsEchoLayer(tctx)
|
||||||
|
|
||||||
# Handshake
|
|
||||||
assert (
|
assert (
|
||||||
tutils.Playbook(layer)
|
tutils.Playbook(layer)
|
||||||
>> events.DataReceived(tctx.client, b"Hello World")
|
>> events.DataReceived(tctx.client, b"Hello World")
|
||||||
<< commands.SendData(tctx.client, b"hello world")
|
<< commands.SendData(tctx.client, b"hello world")
|
||||||
>> events.DataReceived(tctx.server, b"Foo")
|
|
||||||
<< commands.SendData(tctx.server, b"foo")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_simple(self, tctx):
|
def test_simple(self, tctx):
|
||||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||||
tctx.server.connected = True
|
tctx.server.state = ConnectionState.OPEN
|
||||||
tctx.server.address = ("example.mitmproxy.org", 443)
|
tctx.server.address = ("example.mitmproxy.org", 443)
|
||||||
tctx.server.sni = b"example.mitmproxy.org"
|
tctx.server.sni = b"example.mitmproxy.org"
|
||||||
|
|
||||||
@ -250,7 +247,6 @@ class TestServerTLS:
|
|||||||
def test_untrusted_cert(self, tctx):
|
def test_untrusted_cert(self, tctx):
|
||||||
"""If the certificate is not trusted, we should fail."""
|
"""If the certificate is not trusted, we should fail."""
|
||||||
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
playbook = tutils.Playbook(tls.ServerTLSLayer(tctx))
|
||||||
tctx.server.connected = True
|
|
||||||
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
|
tctx.server.address = ("wrong.host.mitmproxy.org", 443)
|
||||||
tctx.server.sni = b"wrong.host.mitmproxy.org"
|
tctx.server.sni = b"wrong.host.mitmproxy.org"
|
||||||
|
|
||||||
@ -260,9 +256,11 @@ class TestServerTLS:
|
|||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.client, b"establish-server-tls")
|
>> events.DataReceived(tctx.client, b"open-connection")
|
||||||
<< layer.NextLayerHook(tutils.Placeholder())
|
<< layer.NextLayerHook(tutils.Placeholder())
|
||||||
>> tutils.reply_next_layer(TlsEchoLayer)
|
>> tutils.reply_next_layer(TlsEchoLayer)
|
||||||
|
<< commands.OpenConnection(tctx.server)
|
||||||
|
>> tutils.reply(None)
|
||||||
<< tls.TlsStartHook(tutils.Placeholder())
|
<< tls.TlsStartHook(tutils.Placeholder())
|
||||||
>> reply_tls_start()
|
>> reply_tls_start()
|
||||||
<< commands.SendData(tctx.server, data)
|
<< commands.SendData(tctx.server, data)
|
||||||
@ -278,7 +276,8 @@ class TestServerTLS:
|
|||||||
>> events.DataReceived(tctx.server, tssl.out.read())
|
>> events.DataReceived(tctx.server, tssl.out.read())
|
||||||
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
|
<< commands.Log("Server TLS handshake failed. Certificate verify failed: Hostname mismatch", "warn")
|
||||||
<< commands.CloseConnection(tctx.server)
|
<< commands.CloseConnection(tctx.server)
|
||||||
<< commands.SendData(tctx.client, b"server-tls-failed: Certificate verify failed: Hostname mismatch")
|
<< commands.SendData(tctx.client,
|
||||||
|
b"open-connection failed: Certificate verify failed: Hostname mismatch")
|
||||||
)
|
)
|
||||||
assert not tctx.server.tls_established
|
assert not tctx.server.tls_established
|
||||||
|
|
||||||
@ -334,10 +333,11 @@ class TestClientTLS:
|
|||||||
|
|
||||||
# Echo
|
# Echo
|
||||||
_test_echo(playbook, tssl_client, tctx.client)
|
_test_echo(playbook, tssl_client, tctx.client)
|
||||||
|
other_server = context.Server(None)
|
||||||
assert (
|
assert (
|
||||||
playbook
|
playbook
|
||||||
>> events.DataReceived(tctx.server, b"Plaintext")
|
>> events.DataReceived(other_server, b"Plaintext")
|
||||||
<< commands.SendData(tctx.server, b"plaintext")
|
<< commands.SendData(other_server, b"plaintext")
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_server_required(self, tctx):
|
def test_server_required(self, tctx):
|
||||||
@ -419,6 +419,7 @@ class TestClientTLS:
|
|||||||
def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context):
|
def test_mitmproxy_ca_is_untrusted(self, tctx: context.Context):
|
||||||
"""Test the scenario where the client doesn't trust the mitmproxy CA."""
|
"""Test the scenario where the client doesn't trust the mitmproxy CA."""
|
||||||
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
|
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
|
||||||
|
playbook.logs = True
|
||||||
|
|
||||||
data = tutils.Placeholder()
|
data = tutils.Placeholder()
|
||||||
assert (
|
assert (
|
||||||
@ -440,6 +441,7 @@ class TestClientTLS:
|
|||||||
<< commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate "
|
<< commands.Log("Client TLS handshake failed. The client does not trust the proxy's certificate "
|
||||||
"for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn")
|
"for wrong.host.mitmproxy.org (sslv3 alert bad certificate)", "warn")
|
||||||
<< commands.CloseConnection(tctx.client)
|
<< commands.CloseConnection(tctx.client)
|
||||||
|
>> events.ConnectionClosed(tctx.client)
|
||||||
)
|
)
|
||||||
assert not tctx.client.tls_established
|
assert not tctx.client.tls_established
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import typing
|
import typing
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ class TCommand(commands.Command):
|
|||||||
self.x = x
|
self.x = x
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class TCommandReply(events.CommandReply):
|
class TCommandReply(events.CommandReply):
|
||||||
command: TCommand
|
command: TCommand
|
||||||
|
|
||||||
@ -157,7 +159,7 @@ def test_command_reply(tplaybook):
|
|||||||
tplaybook
|
tplaybook
|
||||||
>> TEvent()
|
>> TEvent()
|
||||||
<< TCommand()
|
<< TCommand()
|
||||||
>> tutils.reply(42)
|
>> tutils.reply()
|
||||||
)
|
)
|
||||||
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
assert tplaybook.actual[1] == tplaybook.actual[2].command
|
||||||
|
|
||||||
|
@ -10,8 +10,7 @@ from mitmproxy.proxy2 import commands, context, layer
|
|||||||
from mitmproxy.proxy2 import events
|
from mitmproxy.proxy2 import events
|
||||||
from mitmproxy.proxy2.context import ConnectionState
|
from mitmproxy.proxy2.context import ConnectionState
|
||||||
from mitmproxy.proxy2.events import command_reply_subclasses
|
from mitmproxy.proxy2.events import command_reply_subclasses
|
||||||
from mitmproxy.proxy2.layer import Layer, NextLayer
|
from mitmproxy.proxy2.layer import Layer
|
||||||
from mitmproxy.proxy2.layers import tls
|
|
||||||
|
|
||||||
PlaybookEntry = typing.Union[commands.Command, events.Event]
|
PlaybookEntry = typing.Union[commands.Command, events.Event]
|
||||||
PlaybookEntryList = typing.List[PlaybookEntry]
|
PlaybookEntryList = typing.List[PlaybookEntry]
|
||||||
@ -101,7 +100,7 @@ class Playbook:
|
|||||||
|
|
||||||
assert playbook(tcp.TCPLayer(tctx)) \
|
assert playbook(tcp.TCPLayer(tctx)) \
|
||||||
<< commands.OpenConnection(tctx.server)
|
<< commands.OpenConnection(tctx.server)
|
||||||
>> events.OpenConnectionReply(-1, "ok") # -1 = reply to command in previous line.
|
>> reply(None)
|
||||||
<< None # this line is optional.
|
<< None # this line is optional.
|
||||||
|
|
||||||
This is syntactic sugar for the following:
|
This is syntactic sugar for the following:
|
||||||
@ -351,15 +350,3 @@ def reply_next_layer(
|
|||||||
next_layer.layer = child_layer(next_layer.context)
|
next_layer.layer = child_layer(next_layer.context)
|
||||||
|
|
||||||
return reply(*args, side_effect=set_layer, **kwargs)
|
return reply(*args, side_effect=set_layer, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def reply_establish_server_tls(**kwargs) -> reply:
|
|
||||||
"""Helper function to simplify the syntax for EstablishServerTls events to this:
|
|
||||||
<< tls.EstablishServerTLS(server)
|
|
||||||
>> tutils.reply_establish_server_tls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
def fake_tls(cmd: tls.EstablishServerTLS) -> None:
|
|
||||||
cmd.connection.tls_established = True
|
|
||||||
|
|
||||||
return reply(None, side_effect=fake_tls, **kwargs)
|
|
||||||
|
@ -84,7 +84,7 @@ class TestConnectionHandler:
|
|||||||
def ask(_, x):
|
def ask(_, x):
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
|
||||||
channel.ask = ask
|
channel._ask = ask
|
||||||
c = ConnectionHandler(
|
c = ConnectionHandler(
|
||||||
mock.MagicMock(),
|
mock.MagicMock(),
|
||||||
("127.0.0.1", 8080),
|
("127.0.0.1", 8080),
|
||||||
|
Loading…
Reference in New Issue
Block a user