Merge pull request #4486 from mhils/websocket

Merge WebSocketFlow into HTTPFlow, add WebSocket UI
This commit is contained in:
Maximilian Hils 2021-03-11 11:02:40 +01:00 committed by GitHub
commit 70223163de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 501 additions and 614 deletions

View File

@ -5,17 +5,17 @@ from mitmproxy import ctx
def websocket_message(flow): def websocket_message(flow):
# get the latest message # get the latest message
message = flow.messages[-1] message = flow.websocket.messages[-1]
# was the message sent from the client or server? # was the message sent from the client or server?
if message.from_client: if message.from_client:
ctx.log.info(f"Client sent a message: {message.content}") ctx.log.info(f"Client sent a message: {message.content!r}")
else: else:
ctx.log.info(f"Server sent a message: {message.content}") ctx.log.info(f"Server sent a message: {message.content!r}")
# manipulate the message content # manipulate the message content
message.content = re.sub(r'^Hello', 'HAPPY', message.content) message.content = re.sub(rb'^Hello', b'HAPPY', message.content)
if 'FOOBAR' in message.content: if b'FOOBAR' in message.content:
# kill the message and not send it to the other endpoint # kill the message and not send it to the other endpoint
message.content = "" message.content = ""

View File

@ -145,6 +145,8 @@ class ClientPlayback:
return "Can't replay flow with missing request." return "Can't replay flow with missing request."
if f.request.raw_content is None: if f.request.raw_content is None:
return "Can't replay flow with missing content." return "Can't replay flow with missing content."
if f.websocket is not None:
return "Can't replay WebSocket flows."
else: else:
return "Can only replay HTTP flows." return "Can only replay HTTP flows."
return None return None

View File

@ -13,7 +13,7 @@ from mitmproxy import http
from mitmproxy.tcp import TCPFlow, TCPMessage from mitmproxy.tcp import TCPFlow, TCPMessage
from mitmproxy.utils import human from mitmproxy.utils import human
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
from mitmproxy.websocket import WebSocketFlow, WebSocketMessage from mitmproxy.websocket import WebSocketMessage
def indent(n: int, text: str) -> str: def indent(n: int, text: str) -> str:
@ -98,7 +98,7 @@ class Dumper:
def _echo_message( def _echo_message(
self, self,
message: Union[http.Message, TCPMessage, WebSocketMessage], message: Union[http.Message, TCPMessage, WebSocketMessage],
flow: Union[http.HTTPFlow, TCPFlow, WebSocketFlow] flow: Union[http.HTTPFlow, TCPFlow]
): ):
_, lines, error = contentviews.get_message_content_view( _, lines, error = contentviews.get_message_content_view(
ctx.options.dumper_default_contentview, ctx.options.dumper_default_contentview,
@ -277,37 +277,36 @@ class Dumper:
if self.match(f): if self.match(f):
self.echo_flow(f) self.echo_flow(f)
def websocket_error(self, f): def websocket_error(self, f: http.HTTPFlow):
self.echo_error( self.echo_error(
"Error in WebSocket connection to {}: {}".format( f"Error in WebSocket connection to {human.format_address(f.server_conn.address)}: {f.error}",
human.format_address(f.server_conn.address), f.error
),
fg="red" fg="red"
) )
def websocket_message(self, f): def websocket_message(self, f: http.HTTPFlow):
assert f.websocket is not None # satisfy type checker
if self.match(f): if self.match(f):
message = f.messages[-1] message = f.websocket.messages[-1]
self.echo(f.message_info(message))
direction = "->" if message.from_client else "<-"
self.echo(
f"{human.format_address(f.client_conn.peername)} "
f"{direction} WebSocket {message.type.name.lower()} message "
f"{direction} {human.format_address(f.server_conn.address)}{f.request.path}"
)
if ctx.options.flow_detail >= 3: if ctx.options.flow_detail >= 3:
message = message.from_state(message.get_state())
message.content = message.content.encode() if isinstance(message.content, str) else message.content
self._echo_message(message, f) self._echo_message(message, f)
def websocket_end(self, f): def websocket_end(self, f: http.HTTPFlow):
assert f.websocket is not None # satisfy type checker
if self.match(f): if self.match(f):
self.echo("WebSocket connection closed by {}: {} {}, {}".format( c = 'client' if f.websocket.closed_by_client else 'server'
f.close_sender, self.echo(f"WebSocket connection closed by {c}: {f.websocket.close_code} {f.websocket.close_reason}")
f.close_code,
f.close_message,
f.close_reason))
def tcp_error(self, f): def tcp_error(self, f):
if self.match(f): if self.match(f):
self.echo_error( self.echo_error(
"Error in TCP connection to {}: {}".format( f"Error in TCP connection to {human.format_address(f.server_conn.address)}: {f.error}",
human.format_address(f.server_conn.address), f.error
),
fg="red" fg="red"
) )

View File

@ -7,6 +7,7 @@ from mitmproxy import flowfilter
from mitmproxy import io from mitmproxy import io
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import http
import mitmproxy.types import mitmproxy.types
@ -88,28 +89,26 @@ class Save:
def tcp_error(self, flow): def tcp_error(self, flow):
self.tcp_end(flow) self.tcp_end(flow)
def websocket_start(self, flow): def websocket_end(self, flow: http.HTTPFlow):
if self.stream:
self.active_flows.add(flow)
def websocket_end(self, flow):
if self.stream: if self.stream:
self.stream.add(flow) self.stream.add(flow)
self.active_flows.discard(flow) self.active_flows.discard(flow)
def websocket_error(self, flow): def websocket_error(self, flow: http.HTTPFlow):
self.websocket_end(flow) self.websocket_end(flow)
def request(self, flow): def request(self, flow: http.HTTPFlow):
if self.stream: if self.stream:
self.active_flows.add(flow) self.active_flows.add(flow)
def response(self, flow): def response(self, flow: http.HTTPFlow):
if self.stream: # websocket flows will receive either websocket_end or websocket_error,
# we don't want to persist them here already
if self.stream and flow.websocket is None:
self.stream.add(flow) self.stream.add(flow)
self.active_flows.discard(flow) self.active_flows.discard(flow)
def error(self, flow): def error(self, flow: http.HTTPFlow):
self.response(flow) self.response(flow)
def done(self): def done(self):

View File

@ -25,7 +25,7 @@ from . import (
from .base import View, KEY_MAX, format_text, format_dict, TViewResult from .base import View, KEY_MAX, format_text, format_dict, TViewResult
from ..http import HTTPFlow from ..http import HTTPFlow
from ..tcp import TCPMessage, TCPFlow from ..tcp import TCPMessage, TCPFlow
from ..websocket import WebSocketMessage, WebSocketFlow from ..websocket import WebSocketMessage
views: List[View] = [] views: List[View] = []
@ -67,7 +67,7 @@ def safe_to_print(lines, encoding="utf8"):
def get_message_content_view( def get_message_content_view(
viewname: str, viewname: str,
message: Union[http.Message, TCPMessage, WebSocketMessage], message: Union[http.Message, TCPMessage, WebSocketMessage],
flow: Union[HTTPFlow, TCPFlow, WebSocketFlow], flow: Union[HTTPFlow, TCPFlow],
): ):
""" """
Like get_content_view, but also handles message encoding. Like get_content_view, but also handles message encoding.
@ -79,7 +79,7 @@ def get_message_content_view(
content: Optional[bytes] content: Optional[bytes]
try: try:
content = message.content # type: ignore content = message.content
except ValueError: except ValueError:
assert isinstance(message, http.Message) assert isinstance(message, http.Message)
content = message.raw_content content = message.raw_content

View File

@ -1,11 +1,10 @@
from typing import Iterator, Any, Dict, Type, Callable from typing import Any, Callable, Dict, Iterator, Type
from mitmproxy import controller from mitmproxy import controller
from mitmproxy import hooks
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import hooks
from mitmproxy import http from mitmproxy import http
from mitmproxy import tcp from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.proxy import layers from mitmproxy.proxy import layers
TEventGenerator = Iterator[hooks.Hook] TEventGenerator = Iterator[hooks.Hook]
@ -18,24 +17,21 @@ def _iterate_http(f: http.HTTPFlow) -> TEventGenerator:
if f.response: if f.response:
yield layers.http.HttpResponseHeadersHook(f) yield layers.http.HttpResponseHeadersHook(f)
yield layers.http.HttpResponseHook(f) yield layers.http.HttpResponseHook(f)
if f.error: if f.websocket:
message_queue = f.websocket.messages
f.websocket.messages = []
yield layers.websocket.WebsocketStartHook(f)
for m in message_queue:
f.websocket.messages.append(m)
yield layers.websocket.WebsocketMessageHook(f)
if f.error:
yield layers.websocket.WebsocketErrorHook(f)
else:
yield layers.websocket.WebsocketEndHook(f)
elif f.error:
yield layers.http.HttpErrorHook(f) yield layers.http.HttpErrorHook(f)
def _iterate_websocket(f: websocket.WebSocketFlow) -> TEventGenerator:
messages = f.messages
f.messages = []
f.reply = controller.DummyReply()
yield layers.websocket.WebsocketStartHook(f)
while messages:
f.messages.append(messages.pop(0))
yield layers.websocket.WebsocketMessageHook(f)
if f.error:
yield layers.websocket.WebsocketErrorHook(f)
else:
yield layers.websocket.WebsocketEndHook(f)
def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator: def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator:
messages = f.messages messages = f.messages
f.messages = [] f.messages = []
@ -52,7 +48,6 @@ def _iterate_tcp(f: tcp.TCPFlow) -> TEventGenerator:
_iterate_map: Dict[Type[flow.Flow], Callable[[Any], TEventGenerator]] = { _iterate_map: Dict[Type[flow.Flow], Callable[[Any], TEventGenerator]] = {
http.HTTPFlow: _iterate_http, http.HTTPFlow: _iterate_http,
websocket.WebSocketFlow: _iterate_websocket,
tcp.TCPFlow: _iterate_tcp, tcp.TCPFlow: _iterate_tcp,
} }

View File

@ -39,8 +39,7 @@ from typing import Callable, ClassVar, Optional, Sequence, Type
import pyparsing as pp import pyparsing as pp
from mitmproxy import flow, http, tcp, websocket from mitmproxy import flow, http, tcp
from mitmproxy.net.websocket import check_handshake
def only(*types): def only(*types):
@ -102,15 +101,11 @@ class FHTTP(_Action):
class FWebSocket(_Action): class FWebSocket(_Action):
code = "websocket" code = "websocket"
help = "Match WebSocket flows (and HTTP-WebSocket handshake flows)" help = "Match WebSocket flows"
@only(http.HTTPFlow, websocket.WebSocketFlow) @only(http.HTTPFlow)
def __call__(self, f): def __call__(self, f: http.HTTPFlow):
m = ( return f.websocket is not None
(isinstance(f, http.HTTPFlow) and f.request and check_handshake(f.request.headers))
or isinstance(f, websocket.WebSocketFlow)
)
return m
class FTCP(_Action): class FTCP(_Action):
@ -258,7 +253,7 @@ class FBod(_Rex):
help = "Body" help = "Body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) @only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content: if f.request and f.request.raw_content:
@ -267,7 +262,11 @@ class FBod(_Rex):
if f.response and f.response.raw_content: if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)): if self.re.search(f.response.get_content(strict=False)):
return True return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): if f.websocket:
for msg in f.websocket.messages:
if self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if self.re.search(msg.content): if self.re.search(msg.content):
return True return True
@ -279,13 +278,17 @@ class FBodRequest(_Rex):
help = "Request body" help = "Request body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) @only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.request and f.request.raw_content: if f.request and f.request.raw_content:
if self.re.search(f.request.get_content(strict=False)): if self.re.search(f.request.get_content(strict=False)):
return True return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): if f.websocket:
for msg in f.websocket.messages:
if msg.from_client and self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if msg.from_client and self.re.search(msg.content): if msg.from_client and self.re.search(msg.content):
return True return True
@ -296,13 +299,17 @@ class FBodResponse(_Rex):
help = "Response body" help = "Response body"
flags = re.DOTALL flags = re.DOTALL
@only(http.HTTPFlow, websocket.WebSocketFlow, tcp.TCPFlow) @only(http.HTTPFlow, tcp.TCPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.response and f.response.raw_content: if f.response and f.response.raw_content:
if self.re.search(f.response.get_content(strict=False)): if self.re.search(f.response.get_content(strict=False)):
return True return True
elif isinstance(f, websocket.WebSocketFlow) or isinstance(f, tcp.TCPFlow): if f.websocket:
for msg in f.websocket.messages:
if not msg.from_client and self.re.search(msg.content):
return True
elif isinstance(f, tcp.TCPFlow):
for msg in f.messages: for msg in f.messages:
if not msg.from_client and self.re.search(msg.content): if not msg.from_client and self.re.search(msg.content):
return True return True
@ -324,10 +331,8 @@ class FDomain(_Rex):
flags = re.IGNORECASE flags = re.IGNORECASE
is_binary = False is_binary = False
@only(http.HTTPFlow, websocket.WebSocketFlow) @only(http.HTTPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
return bool( return bool(
self.re.search(f.request.host) or self.re.search(f.request.host) or
self.re.search(f.request.pretty_host) self.re.search(f.request.pretty_host)
@ -347,10 +352,8 @@ class FUrl(_Rex):
toks = toks[1:] toks = toks[1:]
return klass(*toks) return klass(*toks)
@only(http.HTTPFlow, websocket.WebSocketFlow) @only(http.HTTPFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
if not f or not f.request: if not f or not f.request:
return False return False
return self.re.search(f.request.pretty_url) return self.re.search(f.request.pretty_url)
@ -482,9 +485,9 @@ def _make():
unicode_words = pp.CharsNotIn("()~'\"" + pp.ParserElement.DEFAULT_WHITE_CHARS) unicode_words = pp.CharsNotIn("()~'\"" + pp.ParserElement.DEFAULT_WHITE_CHARS)
unicode_words.skipWhitespace = True unicode_words.skipWhitespace = True
regex = ( regex = (
unicode_words unicode_words
| pp.QuotedString('"', escChar='\\') | pp.QuotedString('"', escChar='\\')
| pp.QuotedString("'", escChar='\\') | pp.QuotedString("'", escChar='\\')
) )
for cls in filter_rex: for cls in filter_rex:
f = pp.Literal(f"~{cls.code}") + pp.WordEnd() + regex.copy() f = pp.Literal(f"~{cls.code}") + pp.WordEnd() + regex.copy()

View File

@ -18,6 +18,7 @@ from typing import Union
from typing import cast from typing import cast
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.websocket import WebSocketData
from mitmproxy.coretypes import multidict from mitmproxy.coretypes import multidict
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
from mitmproxy.net import encoding from mitmproxy.net import encoding
@ -1169,6 +1170,11 @@ class HTTPFlow(flow.Flow):
from the server, but there was an error sending it back to the client. from the server, but there was an error sending it back to the client.
""" """
websocket: Optional[WebSocketData] = None
"""
If this HTTP flow initiated a WebSocket connection, this attribute contains all associated WebSocket data.
"""
def __init__(self, client_conn, server_conn, live=None, mode="regular"): def __init__(self, client_conn, server_conn, live=None, mode="regular"):
super().__init__("http", client_conn, server_conn, live) super().__init__("http", client_conn, server_conn, live)
self.mode = mode self.mode = mode
@ -1178,12 +1184,13 @@ class HTTPFlow(flow.Flow):
_stateobject_attributes.update(dict( _stateobject_attributes.update(dict(
request=Request, request=Request,
response=Response, response=Response,
websocket=WebSocketData,
mode=str mode=str
)) ))
def __repr__(self): def __repr__(self):
s = "<HTTPFlow" s = "<HTTPFlow"
for a in ("request", "response", "error", "client_conn", "server_conn"): for a in ("request", "response", "websocket", "error", "client_conn", "server_conn"):
if getattr(self, a, False): if getattr(self, a, False):
s += f"\r\n {a} = {{flow.{a}}}" s += f"\r\n {a} = {{flow.{a}}}"
s += ">" s += ">"

View File

@ -250,6 +250,55 @@ def convert_10_11(data):
return data return data
_websocket_handshakes = {}
def convert_11_12(data):
data["version"] = 12
if "websocket" in data["metadata"]:
_websocket_handshakes[data["id"]] = data
if "websocket_handshake" in data["metadata"]:
ws_flow = data
try:
data = _websocket_handshakes.pop(data["metadata"]["websocket_handshake"])
except KeyError:
# The handshake flow is missing, which should never really happen. We make up a dummy.
data = {
'client_conn': data["client_conn"],
'error': data["error"],
'id': data["id"],
'intercepted': data["intercepted"],
'is_replay': data["is_replay"],
'marked': data["marked"],
'metadata': {},
'mode': 'transparent',
'request': {'authority': b'', 'content': None, 'headers': [], 'host': b'unknown',
'http_version': b'HTTP/1.1', 'method': b'GET', 'path': b'/', 'port': 80, 'scheme': b'http',
'timestamp_end': 0, 'timestamp_start': 0, 'trailers': None, },
'response': None,
'server_conn': data["server_conn"],
'type': 'http',
'version': 12
}
data["metadata"]["duplicated"] = (
"This WebSocket flow has been migrated from an old file format version "
"and may appear duplicated."
)
data["websocket"] = {
"messages": ws_flow["messages"],
"closed_by_client": ws_flow["close_sender"] == "client",
"close_code": ws_flow["close_code"],
"close_reason": ws_flow["close_reason"],
}
else:
data["websocket"] = None
return data
def _convert_dict_keys(o: Any) -> Any: def _convert_dict_keys(o: Any) -> Any:
if isinstance(o, dict): if isinstance(o, dict):
return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()} return {strutils.always_str(k): _convert_dict_keys(v) for k, v in o.items()}
@ -308,6 +357,7 @@ converters = {
8: convert_8_9, 8: convert_8_9,
9: convert_9_10, 9: convert_9_10,
10: convert_10_11, 10: convert_10_11,
11: convert_11_12,
} }
@ -325,8 +375,8 @@ def migrate_flow(flow_data: Dict[Union[bytes, str], Any]) -> Dict[Union[bytes, s
flow_data = converters[flow_version](flow_data) flow_data = converters[flow_version](flow_data)
else: else:
should_upgrade = ( should_upgrade = (
isinstance(flow_version, int) isinstance(flow_version, int)
and flow_version > version.FLOW_FORMAT_VERSION and flow_version > version.FLOW_FORMAT_VERSION
) )
raise ValueError( raise ValueError(
"{} cannot read files with flow format version {}{}.".format( "{} cannot read files with flow format version {}{}.".format(

View File

@ -1,19 +1,16 @@
import os import os
from typing import Type, Iterable, Dict, Union, Any, cast # noqa from typing import Any, Dict, Iterable, Type, Union, cast # noqa
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flow from mitmproxy import flow
from mitmproxy import flowfilter from mitmproxy import flowfilter
from mitmproxy import http from mitmproxy import http
from mitmproxy import tcp from mitmproxy import tcp
from mitmproxy import websocket
from mitmproxy.io import compat from mitmproxy.io import compat
from mitmproxy.io import tnetstring from mitmproxy.io import tnetstring
FLOW_TYPES: Dict[str, Type[flow.Flow]] = dict( FLOW_TYPES: Dict[str, Type[flow.Flow]] = dict(
http=http.HTTPFlow, http=http.HTTPFlow,
websocket=websocket.WebSocketFlow,
tcp=tcp.TCPFlow, tcp=tcp.TCPFlow,
) )

View File

@ -11,7 +11,6 @@ from mitmproxy import eventsequence
from mitmproxy import http from mitmproxy import http
from mitmproxy import log from mitmproxy import log
from mitmproxy import options from mitmproxy import options
from mitmproxy import websocket
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from . import ctx as mitmproxy_ctx from . import ctx as mitmproxy_ctx
@ -34,7 +33,6 @@ class Master:
self.commands = command.CommandManager(self) self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self) self.addons = addonmanager.AddonManager(self)
self._server = None self._server = None
self.waiting_flows = []
self.log = log.Log(self) self.log = log.Log(self)
mitmproxy_ctx.master = self mitmproxy_ctx.master = self
@ -111,24 +109,11 @@ class Master:
async def load_flow(self, f): async def load_flow(self, f):
""" """
Loads a flow and links websocket & handshake flows Loads a flow
""" """
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
self._change_reverse_host(f) self._change_reverse_host(f)
if 'websocket' in f.metadata:
self.waiting_flows.append(f)
if isinstance(f, websocket.WebSocketFlow):
hfs = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']]
if hfs:
hf = hfs[0]
f.handshake_flow = hf
self.waiting_flows.remove(hf)
self._change_reverse_host(f.handshake_flow)
else:
# this will fail - but at least it will load the remaining flows
f.handshake_flow = http.HTTPFlow(None, None)
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e in eventsequence.iterate(f): for e in eventsequence.iterate(f):

View File

@ -1,28 +0,0 @@
"""
Collection of WebSocket protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""
def check_handshake(headers):
return (
"upgrade" in headers.get("connection", "").lower() and
headers.get("upgrade", "").lower() == "websocket" and
(headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
)
def get_extensions(headers):
return headers.get("sec-websocket-extensions", None)
def get_protocol(headers):
return headers.get("sec-websocket-protocol", None)
def get_client_key(headers):
return headers.get("sec-websocket-key", None)
def get_server_accept(headers):
return headers.get("sec-websocket-accept", None)

View File

@ -4,6 +4,8 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import DefaultDict, Dict, List, Optional, Tuple, Union from typing import DefaultDict, Dict, List, Optional, Tuple, Union
import wsproto.handshake
from mitmproxy import flow, http from mitmproxy import flow, http
from mitmproxy.connection import Connection, Server from mitmproxy.connection import Connection, Server
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
@ -13,6 +15,7 @@ from mitmproxy.proxy.layers import tcp, tls, websocket
from mitmproxy.proxy.layers.http import _upstream_proxy from mitmproxy.proxy.layers.http import _upstream_proxy
from mitmproxy.proxy.utils import expect from mitmproxy.proxy.utils import expect
from mitmproxy.utils import human from mitmproxy.utils import human
from mitmproxy.websocket import WebSocketData
from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId from ._base import HttpCommand, HttpConnection, ReceiveHttp, StreamId
from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \ from ._events import HttpEvent, RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
@ -308,6 +311,21 @@ class HttpStream(layer.Layer):
"""We have either consumed the entire response from the server or the response was set by an addon.""" """We have either consumed the entire response from the server or the response was set by an addon."""
assert self.flow.response assert self.flow.response
self.flow.response.timestamp_end = time.time() self.flow.response.timestamp_end = time.time()
is_websocket = (
self.flow.response.status_code == 101
and
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
and
self.flow.request.headers.get("Sec-WebSocket-Version", "").encode() == wsproto.handshake.WEBSOCKET_VERSION
and
self.context.options.websocket
)
if is_websocket:
# We need to set this before calling the response hook
# so that addons can determine if a WebSocket connection is following up.
self.flow.websocket = WebSocketData()
yield HttpResponseHook(self.flow) yield HttpResponseHook(self.flow)
self.server_state = self.state_done self.server_state = self.state_done
if (yield from self.check_killed(False)): if (yield from self.check_killed(False)):
@ -322,12 +340,7 @@ class HttpStream(layer.Layer):
yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client) yield SendHttp(ResponseEndOfMessage(self.stream_id), self.context.client)
if self.flow.response.status_code == 101: if self.flow.response.status_code == 101:
is_websocket = ( if is_websocket:
self.flow.response.headers.get("upgrade", "").lower() == "websocket"
and
self.flow.request.headers.get("Sec-WebSocket-Version", "") == "13"
)
if is_websocket and self.context.options.websocket:
self.child_layer = websocket.WebsocketLayer(self.context, self.flow) self.child_layer = websocket.WebsocketLayer(self.context, self.flow)
elif self.context.options.rawtcp: elif self.context.options.rawtcp:
self.child_layer = tcp.TCPLayer(self.context) self.child_layer = tcp.TCPLayer(self.context)

View File

@ -1,11 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, List, Iterator from typing import Iterator, List
import wsproto import wsproto
import wsproto.extensions import wsproto.extensions
import wsproto.frame_protocol import wsproto.frame_protocol
import wsproto.utilities import wsproto.utilities
from mitmproxy import flow, websocket, http, connection from mitmproxy import connection, flow, http, websocket
from mitmproxy.proxy import commands, events, layer from mitmproxy.proxy import commands, events, layer
from mitmproxy.proxy.commands import StartHook from mitmproxy.proxy.commands import StartHook
from mitmproxy.proxy.context import Context from mitmproxy.proxy.context import Context
@ -19,7 +19,7 @@ class WebsocketStartHook(StartHook):
""" """
A WebSocket connection has commenced. A WebSocket connection has commenced.
""" """
flow: websocket.WebSocketFlow flow: http.HTTPFlow
@dataclass @dataclass
@ -30,7 +30,7 @@ class WebsocketMessageHook(StartHook):
message is user-modifiable. Currently there are two types of message is user-modifiable. Currently there are two types of
messages, corresponding to the BINARY and TEXT frame types. messages, corresponding to the BINARY and TEXT frame types.
""" """
flow: websocket.WebSocketFlow flow: http.HTTPFlow
@dataclass @dataclass
@ -39,7 +39,7 @@ class WebsocketEndHook(StartHook):
A WebSocket connection has ended. A WebSocket connection has ended.
""" """
flow: websocket.WebSocketFlow flow: http.HTTPFlow
@dataclass @dataclass
@ -49,7 +49,7 @@ class WebsocketErrorHook(StartHook):
Every WebSocket flow will receive either a websocket_error or a websocket_end event, but not both. Every WebSocket flow will receive either a websocket_error or a websocket_end event, but not both.
""" """
flow: websocket.WebSocketFlow flow: http.HTTPFlow
class WebsocketConnection(wsproto.Connection): class WebsocketConnection(wsproto.Connection):
@ -61,7 +61,7 @@ class WebsocketConnection(wsproto.Connection):
- we wrap .send() so that we can directly yield it. - we wrap .send() so that we can directly yield it.
""" """
conn: connection.Connection conn: connection.Connection
frame_buf: List[Union[str, bytes]] frame_buf: List[bytes]
def __init__(self, *args, conn: connection.Connection, **kwargs): def __init__(self, *args, conn: connection.Connection, **kwargs):
super(WebsocketConnection, self).__init__(*args, **kwargs) super(WebsocketConnection, self).__init__(*args, **kwargs)
@ -80,13 +80,13 @@ class WebsocketLayer(layer.Layer):
""" """
WebSocket layer that intercepts and relays messages. WebSocket layer that intercepts and relays messages.
""" """
flow: websocket.WebSocketFlow flow: http.HTTPFlow
client_ws: WebsocketConnection client_ws: WebsocketConnection
server_ws: WebsocketConnection server_ws: WebsocketConnection
def __init__(self, context: Context, handshake_flow: http.HTTPFlow): def __init__(self, context: Context, flow: http.HTTPFlow):
super().__init__(context) super().__init__(context)
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow) self.flow = flow
assert context.server.connected assert context.server.connected
@expect(events.Start) @expect(events.Start)
@ -96,7 +96,8 @@ class WebsocketLayer(layer.Layer):
server_extensions = [] server_extensions = []
# Parse extension headers. We only support deflate at the moment and ignore everything else. # Parse extension headers. We only support deflate at the moment and ignore everything else.
ext_header = self.flow.handshake_flow.response.headers.get("Sec-WebSocket-Extensions", "") assert self.flow.response # satisfy type checker
ext_header = self.flow.response.headers.get("Sec-WebSocket-Extensions", "")
if ext_header: if ext_header:
for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")): for ext in wsproto.utilities.split_comma_header(ext_header.encode("ascii", "replace")):
ext_name = ext.split(";", 1)[0].strip() ext_name = ext.split(";", 1)[0].strip()
@ -115,15 +116,14 @@ class WebsocketLayer(layer.Layer):
yield WebsocketStartHook(self.flow) yield WebsocketStartHook(self.flow)
if self.flow.stream: # pragma: no cover
raise NotImplementedError("WebSocket streaming is not supported at the moment.")
self._handle_event = self.relay_messages self._handle_event = self.relay_messages
_handle_event = start _handle_event = start
@expect(events.DataReceived, events.ConnectionClosed) @expect(events.DataReceived, events.ConnectionClosed)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]:
assert self.flow.websocket # satisfy type checker
from_client = event.connection == self.context.client from_client = event.connection == self.context.client
from_str = 'client' if from_client else 'server' from_str = 'client' if from_client else 'server'
if from_client: if from_client:
@ -142,27 +142,27 @@ class WebsocketLayer(layer.Layer):
for ws_event in src_ws.events(): for ws_event in src_ws.events():
if isinstance(ws_event, wsproto.events.Message): if isinstance(ws_event, wsproto.events.Message):
src_ws.frame_buf.append(ws_event.data) is_text = isinstance(ws_event.data, str)
if is_text:
typ = Opcode.TEXT
src_ws.frame_buf.append(ws_event.data.encode())
else:
typ = Opcode.BINARY
src_ws.frame_buf.append(ws_event.data)
if ws_event.message_finished: if ws_event.message_finished:
if isinstance(ws_event, wsproto.events.TextMessage): content = b"".join(src_ws.frame_buf)
frame_type = Opcode.TEXT
content = "".join(src_ws.frame_buf) # type: ignore
else:
frame_type = Opcode.BINARY
content = b"".join(src_ws.frame_buf) # type: ignore
fragmentizer = Fragmentizer(src_ws.frame_buf) fragmentizer = Fragmentizer(src_ws.frame_buf, is_text)
src_ws.frame_buf.clear() src_ws.frame_buf.clear()
message = websocket.WebSocketMessage(frame_type, from_client, content) message = websocket.WebSocketMessage(typ, from_client, content)
self.flow.messages.append(message) self.flow.websocket.messages.append(message)
yield WebsocketMessageHook(self.flow) yield WebsocketMessageHook(self.flow)
assert not message.killed # this is deprecated, instead we should have .content set to emptystr. if not message.killed:
for msg in fragmentizer(message.content):
for msg in fragmentizer(message.content): yield dst_ws.send2(msg)
yield dst_ws.send2(msg)
elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)): elif isinstance(ws_event, (wsproto.events.Ping, wsproto.events.Pong)):
yield commands.Log( yield commands.Log(
@ -171,9 +171,9 @@ class WebsocketLayer(layer.Layer):
) )
yield dst_ws.send2(ws_event) yield dst_ws.send2(ws_event)
elif isinstance(ws_event, wsproto.events.CloseConnection): elif isinstance(ws_event, wsproto.events.CloseConnection):
self.flow.close_sender = from_str self.flow.websocket.closed_by_client = from_client
self.flow.close_code = ws_event.code self.flow.websocket.close_code = ws_event.code
self.flow.close_reason = ws_event.reason self.flow.websocket.close_reason = ws_event.reason
for ws in [self.server_ws, self.client_ws]: for ws in [self.server_ws, self.client_ws]:
if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}: if ws.state in {ConnectionState.OPEN, ConnectionState.REMOTE_CLOSING}:
@ -215,27 +215,35 @@ class Fragmentizer:
As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks. As a workaround, we either retain the original chunking or, if the payload has been modified, use ~4kB chunks.
""" """
# A bit less than 4kb to accomodate for headers. # A bit less than 4kb to accommodate for headers.
FRAGMENT_SIZE = 4000 FRAGMENT_SIZE = 4000
def __init__(self, fragments: List[Union[str, bytes]]): def __init__(self, fragments: List[bytes], is_text: bool):
assert fragments assert fragments
self.fragment_lengths = [len(x) for x in fragments] self.fragment_lengths = [len(x) for x in fragments]
self.is_text = is_text
def __call__(self, content: Union[str, bytes]) -> Iterator[wsproto.events.Message]: def msg(self, data: bytes, message_finished: bool):
if self.is_text:
data_str = data.decode(errors="replace")
return wsproto.events.TextMessage(data_str, message_finished=message_finished)
else:
return wsproto.events.BytesMessage(data, message_finished=message_finished)
def __call__(self, content: bytes) -> Iterator[wsproto.events.Message]:
if not content: if not content:
return return
if len(content) == sum(self.fragment_lengths): if len(content) == sum(self.fragment_lengths):
# message has the same length, we can reuse the same sizes # message has the same length, we can reuse the same sizes
offset = 0 offset = 0
for fl in self.fragment_lengths[:-1]: for fl in self.fragment_lengths[:-1]:
yield wsproto.events.Message(content[offset:offset + fl], message_finished=False) yield self.msg(content[offset:offset + fl], False)
offset += fl offset += fl
yield wsproto.events.Message(content[offset:], message_finished=True) yield self.msg(content[offset:], True)
else: else:
offset = 0 offset = 0
total = len(content) - self.FRAGMENT_SIZE total = len(content) - self.FRAGMENT_SIZE
while offset < total: while offset < total:
yield wsproto.events.Message(content[offset:offset + self.FRAGMENT_SIZE], message_finished=False) yield self.msg(content[offset:offset + self.FRAGMENT_SIZE], False)
offset += self.FRAGMENT_SIZE offset += self.FRAGMENT_SIZE
yield wsproto.events.Message(content[offset:], message_finished=True) yield self.msg(content[offset:], True)

View File

@ -6,7 +6,6 @@ from mitmproxy import flow
from mitmproxy import http from mitmproxy import http
from mitmproxy import tcp from mitmproxy import tcp
from mitmproxy import websocket from mitmproxy import websocket
from mitmproxy.net.http import status_codes
from mitmproxy.test import tutils from mitmproxy.test import tutils
from wsproto.frame_protocol import Opcode from wsproto.frame_protocol import Opcode
@ -31,68 +30,55 @@ def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None):
return f return f
def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None, handshake_flow=True): def twebsocketflow(messages=True, err=None) -> http.HTTPFlow:
if client_conn is True: flow = http.HTTPFlow(tclient_conn(), tserver_conn())
client_conn = tclient_conn() flow.request = http.Request(
if server_conn is True: "example.com",
server_conn = tserver_conn() 80,
if handshake_flow is True: b"GET",
req = http.Request( b"http",
"example.com", b"example.com",
80, b"/ws",
b"GET", b"HTTP/1.1",
b"http", headers=http.Headers(
b"example.com", connection="upgrade",
b"/ws", upgrade="websocket",
b"HTTP/1.1", sec_websocket_version="13",
headers=http.Headers( sec_websocket_key="1234",
connection="upgrade", ),
upgrade="websocket", content=b'',
sec_websocket_version="13", trailers=None,
sec_websocket_key="1234", timestamp_start=946681200,
), timestamp_end=946681201,
content=b'',
trailers=None,
timestamp_start=946681200,
timestamp_end=946681201,
) )
resp = http.Response( flow.response = http.Response(
b"HTTP/1.1", b"HTTP/1.1",
101, 101,
reason=status_codes.RESPONSES.get(101), reason=b"Switching Protocols",
headers=http.Headers( headers=http.Headers(
connection='upgrade', connection='upgrade',
upgrade='websocket', upgrade='websocket',
sec_websocket_accept=b'', sec_websocket_accept=b'',
), ),
content=b'', content=b'',
trailers=None, trailers=None,
timestamp_start=946681202, timestamp_start=946681202,
timestamp_end=946681203, timestamp_end=946681203,
) )
handshake_flow = http.HTTPFlow(client_conn, server_conn) flow.websocket = websocket.WebSocketData()
handshake_flow.request = req
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
f.metadata['websocket_handshake'] = handshake_flow.id
handshake_flow.metadata['websocket_flow'] = f.id
handshake_flow.metadata['websocket'] = True
if messages is True: if messages is True:
messages = [ flow.websocket.messages = [
websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary"), websocket.WebSocketMessage(Opcode.BINARY, True, b"hello binary", 946681203),
websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text"), websocket.WebSocketMessage(Opcode.TEXT, True, b"hello text", 946681204),
websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me"), websocket.WebSocketMessage(Opcode.TEXT, False, b"it's me", 946681205),
] ]
if err is True: if err is True:
err = terr() flow.error = terr()
f.messages = messages flow.reply = controller.DummyReply()
f.error = err return flow
f.reply = controller.DummyReply()
return f
def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None): def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None):

View File

@ -119,6 +119,8 @@ else:
SCHEME_STYLES = { SCHEME_STYLES = {
'http': 'scheme_http', 'http': 'scheme_http',
'https': 'scheme_https', 'https': 'scheme_https',
'ws': 'scheme_ws',
'wss': 'scheme_wss',
'tcp': 'scheme_tcp', 'tcp': 'scheme_tcp',
} }
HTTP_REQUEST_METHOD_STYLES = { HTTP_REQUEST_METHOD_STYLES = {
@ -297,12 +299,8 @@ def colorize_url(url):
parts = url.split('/', 3) parts = url.split('/', 3)
if len(parts) < 4 or len(parts[1]) > 0 or parts[0][-1:] != ':': if len(parts) < 4 or len(parts[1]) > 0 or parts[0][-1:] != ':':
return [('error', len(url))] # bad URL return [('error', len(url))] # bad URL
schemes = {
'http:': 'scheme_http',
'https:': 'scheme_https',
}
return [ return [
(schemes.get(parts[0], "scheme_other"), len(parts[0]) - 1), (SCHEME_STYLES.get(parts[0], "scheme_other"), len(parts[0]) - 1),
('url_punctuation', 3), # :// ('url_punctuation', 3), # ://
] + colorize_host(parts[2]) + colorize_req('/' + parts[3]) ] + colorize_host(parts[2]) + colorize_req('/' + parts[3])
@ -699,6 +697,13 @@ def format_flow(
response_content_type = None response_content_type = None
duration = None duration = None
scheme = f.request.scheme
if f.websocket is not None:
if scheme == "https":
scheme = "wss"
elif scheme == "http":
scheme = "ws"
if render_mode in (RenderMode.LIST, RenderMode.DETAILVIEW): if render_mode in (RenderMode.LIST, RenderMode.DETAILVIEW):
render_func = format_http_flow_list render_func = format_http_flow_list
else: else:
@ -709,7 +714,7 @@ def format_flow(
marked=f.marked, marked=f.marked,
is_replay=f.is_replay, is_replay=f.is_replay,
request_method=f.request.method, request_method=f.request.method,
request_scheme=f.request.scheme, request_scheme=scheme,
request_host=f.request.pretty_host if hostheader else f.request.host, request_host=f.request.pretty_host if hostheader else f.request.host,
request_path=f.request.path, request_path=f.request.path,
request_url=f.request.pretty_url if hostheader else f.request.url, request_url=f.request.pretty_url if hostheader else f.request.url,

View File

@ -42,25 +42,6 @@ console_flowlist_layout = [
] ]
class UnsupportedLog:
"""
A small addon to dump info on flow types we don't support yet.
"""
def websocket_message(self, f):
message = f.messages[-1]
ctx.log.info(f.message_info(message))
ctx.log.debug(
message.content if isinstance(message.content, str) else strutils.bytes_to_escaped_str(message.content))
def websocket_end(self, f):
ctx.log.info("WebSocket connection closed by {}: {} {}, {}".format(
f.close_sender,
f.close_code,
f.close_message,
f.close_reason))
class ConsoleAddon: class ConsoleAddon:
""" """
An addon that exposes console-specific commands, and hooks into required An addon that exposes console-specific commands, and hooks into required
@ -226,11 +207,11 @@ class ConsoleAddon:
@command.command("console.choose") @command.command("console.choose")
def console_choose( def console_choose(
self, self,
prompt: str, prompt: str,
choices: typing.Sequence[str], choices: typing.Sequence[str],
cmd: mitmproxy.types.Cmd, cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs *args: mitmproxy.types.CmdArgs
) -> None: ) -> None:
""" """
Prompt the user to choose from a specified list of strings, then Prompt the user to choose from a specified list of strings, then
@ -252,11 +233,11 @@ class ConsoleAddon:
@command.command("console.choose.cmd") @command.command("console.choose.cmd")
def console_choose_cmd( def console_choose_cmd(
self, self,
prompt: str, prompt: str,
choicecmd: mitmproxy.types.Cmd, choicecmd: mitmproxy.types.Cmd,
subcmd: mitmproxy.types.Cmd, subcmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs *args: mitmproxy.types.CmdArgs
) -> None: ) -> None:
""" """
Prompt the user to choose from a list of strings returned by a Prompt the user to choose from a list of strings returned by a
@ -415,8 +396,8 @@ class ConsoleAddon:
flow.backup() flow.backup()
require_dummy_response = ( require_dummy_response = (
flow_part in ("response-headers", "response-body", "set-cookies") and flow_part in ("response-headers", "response-body", "set-cookies") and
flow.response is None flow.response is None
) )
if require_dummy_response: if require_dummy_response:
flow.response = http.Response.make() flow.response = http.Response.make()
@ -584,11 +565,11 @@ class ConsoleAddon:
@command.command("console.key.bind") @command.command("console.key.bind")
def key_bind( def key_bind(
self, self,
contexts: typing.Sequence[str], contexts: typing.Sequence[str],
key: str, key: str,
cmd: mitmproxy.types.Cmd, cmd: mitmproxy.types.Cmd,
*args: mitmproxy.types.CmdArgs *args: mitmproxy.types.CmdArgs
) -> None: ) -> None:
""" """
Bind a shortcut key. Bind a shortcut key.

View File

@ -60,14 +60,23 @@ class FlowDetails(tabs.Tabs):
return self.master.view.focus.flow return self.master.view.focus.flow
def focus_changed(self): def focus_changed(self):
if self.flow: f = self.flow
if isinstance(self.flow, http.HTTPFlow): if f:
self.tabs = [ if isinstance(f, http.HTTPFlow):
(self.tab_http_request, self.view_request), if f.websocket:
(self.tab_http_response, self.view_response), self.tabs = [
(self.tab_details, self.view_details), (self.tab_http_request, self.view_request),
] (self.tab_http_response, self.view_response),
elif isinstance(self.flow, tcp.TCPFlow): (self.tab_websocket_messages, self.view_websocket_messages),
(self.tab_details, self.view_details),
]
else:
self.tabs = [
(self.tab_http_request, self.view_request),
(self.tab_http_response, self.view_response),
(self.tab_details, self.view_details),
]
elif isinstance(f, tcp.TCPFlow):
self.tabs = [ self.tabs = [
(self.tab_tcp_stream, self.view_tcp_stream), (self.tab_tcp_stream, self.view_tcp_stream),
(self.tab_details, self.view_details), (self.tab_details, self.view_details),
@ -95,6 +104,9 @@ class FlowDetails(tabs.Tabs):
def tab_tcp_stream(self): def tab_tcp_stream(self):
return "TCP Stream" return "TCP Stream"
def tab_websocket_messages(self):
return "WebSocket Messages"
def tab_details(self): def tab_details(self):
return "Detail" return "Detail"
@ -128,6 +140,36 @@ class FlowDetails(tabs.Tabs):
contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading") contentview_status_bar = urwid.AttrWrap(urwid.Columns(cols), "heading")
return contentview_status_bar return contentview_status_bar
def view_websocket_messages(self):
flow = self.flow
assert isinstance(flow, http.HTTPFlow)
assert flow.websocket is not None
if not flow.websocket.messages:
return searchable.Searchable([urwid.Text(("highlight", "No messages."))])
viewmode = self.master.commands.call("console.flowview.mode")
widget_lines = []
for m in flow.websocket.messages:
_, lines, _ = contentviews.get_message_content_view(viewmode, m, flow)
for line in lines:
if m.from_client:
line.insert(0, ("from_client", f"{common.SYMBOL_FROM_CLIENT} "))
else:
line.insert(0, ("to_client", f"{common.SYMBOL_TO_CLIENT} "))
widget_lines.append(urwid.Text(line))
if flow.intercepted:
markup = widget_lines[-1].get_text()[0]
widget_lines[-1].set_text(("intercept", markup))
widget_lines.insert(0, self._contentview_status_bar(viewmode.capitalize(), viewmode))
return searchable.Searchable(widget_lines)
def view_tcp_stream(self) -> urwid.Widget: def view_tcp_stream(self) -> urwid.Widget:
flow = self.flow flow = self.flow
assert isinstance(flow, tcp.TCPFlow) assert isinstance(flow, tcp.TCPFlow)

View File

@ -53,7 +53,6 @@ class ConsoleMaster(master.Master):
intercept.Intercept(), intercept.Intercept(),
self.view, self.view,
self.events, self.events,
consoleaddons.UnsupportedLog(),
readfile.ReadFile(), readfile.ReadFile(),
consoleaddons.ConsoleAddon(self), consoleaddons.ConsoleAddon(self),
keymap.KeymapConfig(), keymap.KeymapConfig(),

View File

@ -23,7 +23,7 @@ class Palette:
# List and Connections # List and Connections
'method_get', 'method_post', 'method_delete', 'method_other', 'method_head', 'method_put', 'method_http2_push', 'method_get', 'method_post', 'method_delete', 'method_other', 'method_head', 'method_put', 'method_http2_push',
'scheme_http', 'scheme_https', 'scheme_tcp', 'scheme_other', 'scheme_http', 'scheme_https', 'scheme_ws', 'scheme_wss', 'scheme_tcp', 'scheme_other',
'url_punctuation', 'url_domain', 'url_filename', 'url_extension', 'url_query_key', 'url_query_value', 'url_punctuation', 'url_domain', 'url_filename', 'url_extension', 'url_query_key', 'url_query_value',
'content_none', 'content_text', 'content_script', 'content_media', 'content_data', 'content_raw', 'content_other', 'content_none', 'content_text', 'content_script', 'content_media', 'content_data', 'content_raw', 'content_other',
'focus', 'focus',
@ -136,6 +136,8 @@ class LowDark(Palette):
scheme_http = ('dark cyan', 'default'), scheme_http = ('dark cyan', 'default'),
scheme_https = ('dark green', 'default'), scheme_https = ('dark green', 'default'),
scheme_ws=('brown', 'default'),
scheme_wss=('dark magenta', 'default'),
scheme_tcp=('dark magenta', 'default'), scheme_tcp=('dark magenta', 'default'),
scheme_other = ('dark magenta', 'default'), scheme_other = ('dark magenta', 'default'),
@ -245,6 +247,8 @@ class LowLight(Palette):
scheme_http = ('dark cyan', 'default'), scheme_http = ('dark cyan', 'default'),
scheme_https = ('light green', 'default'), scheme_https = ('light green', 'default'),
scheme_ws=('brown', 'default'),
scheme_wss=('light magenta', 'default'),
scheme_tcp=('light magenta', 'default'), scheme_tcp=('light magenta', 'default'),
scheme_other = ('light magenta', 'default'), scheme_other = ('light magenta', 'default'),
@ -373,6 +377,8 @@ class SolarizedLight(LowLight):
scheme_http = (sol_cyan, 'default'), scheme_http = (sol_cyan, 'default'),
scheme_https = ('light green', 'default'), scheme_https = ('light green', 'default'),
scheme_ws=(sol_orange, 'default'),
scheme_wss=('light magenta', 'default'),
scheme_tcp=('light magenta', 'default'), scheme_tcp=('light magenta', 'default'),
scheme_other = ('light magenta', 'default'), scheme_other = ('light magenta', 'default'),

View File

@ -7,7 +7,7 @@ MITMPROXY = "mitmproxy " + VERSION
# Serialization format version. This is displayed nowhere, it just needs to be incremented by one # Serialization format version. This is displayed nowhere, it just needs to be incremented by one
# for each change in the file format. # for each change in the file format.
FLOW_FORMAT_VERSION = 11 FLOW_FORMAT_VERSION = 12
def get_dev_version() -> str: def get_dev_version() -> str:

View File

@ -1,168 +1,126 @@
""" """
*Deprecation Notice:* Mitmproxy's WebSocket API is going to change soon, Mitmproxy used to have its own WebSocketFlow type until mitmproxy 6, but now WebSocket connections now are represented
see <https://github.com/mitmproxy/mitmproxy/issues/4425>. as HTTP flows as well. They can be distinguished from regular HTTP requests by having the
`mitmproxy.http.HTTPFlow.websocket` attribute set.
This module only defines the classes for individual `WebSocketMessage`s and the `WebSocketData` container.
""" """
import queue
import time import time
import warnings from typing import List, Tuple, Union
from typing import List
from typing import Optional from typing import Optional
from typing import Union
from mitmproxy import flow from mitmproxy import stateobject
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
from mitmproxy.net import websocket
from mitmproxy.utils import human
from mitmproxy.utils import strutils
from wsproto.frame_protocol import CloseReason
from wsproto.frame_protocol import Opcode from wsproto.frame_protocol import Opcode
WebSocketMessageState = Tuple[int, bool, bytes, float, bool]
class WebSocketMessage(serializable.Serializable): class WebSocketMessage(serializable.Serializable):
""" """
A WebSocket message sent from one endpoint to the other. A single WebSocket message sent from one peer to the other.
Fragmented WebSocket messages are reassembled by mitmproxy and the
represented as a single instance of this class.
The [WebSocket RFC](https://tools.ietf.org/html/rfc6455) specifies both
text and binary messages. To avoid a whole class of nasty type confusion bugs,
mitmproxy stores all message contents as binary. If you need text, you can decode the `content` property:
>>> from wsproto.frame_protocol import Opcode
>>> if message.type == Opcode.TEXT:
>>> text = message.content.decode()
Per the WebSocket spec, text messages always use UTF-8 encoding.
""" """
type: Opcode
"""indicates either TEXT or BINARY (from wsproto.frame_protocol.Opcode)."""
from_client: bool from_client: bool
"""True if this messages was sent by the client.""" """True if this messages was sent by the client."""
content: Union[bytes, str] type: Opcode
"""
The message type, as per RFC 6455's [opcode](https://tools.ietf.org/html/rfc6455#section-5.2).
Note that mitmproxy will always store the message contents as *bytes*.
A dedicated `.text` property for text messages is planned, see https://github.com/mitmproxy/mitmproxy/pull/4486.
"""
content: bytes
"""A byte-string representing the content of this message.""" """A byte-string representing the content of this message."""
timestamp: float timestamp: float
"""Timestamp of when this message was received or created.""" """Timestamp of when this message was received or created."""
killed: bool killed: bool
"""True if this messages was killed and should not be sent to the other endpoint.""" """True if the message has not been forwarded by mitmproxy, False otherwise."""
def __init__( def __init__(
self, self,
type: int, type: Union[int, Opcode],
from_client: bool, from_client: bool,
content: Union[bytes, str], content: bytes,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
killed: bool = False killed: bool = False,
) -> None: ) -> None:
self.type = Opcode(type) # type: ignore
self.from_client = from_client self.from_client = from_client
self.type = Opcode(type)
self.content = content self.content = content
self.timestamp: float = timestamp or time.time() self.timestamp: float = timestamp or time.time()
self.killed = killed self.killed = killed
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state: WebSocketMessageState):
return cls(*state) return cls(*state)
def get_state(self): def get_state(self) -> WebSocketMessageState:
return int(self.type), self.from_client, self.content, self.timestamp, self.killed return int(self.type), self.from_client, self.content, self.timestamp, self.killed
def set_state(self, state): def set_state(self, state: WebSocketMessageState) -> None:
self.type, self.from_client, self.content, self.timestamp, self.killed = state typ, self.from_client, self.content, self.timestamp, self.killed = state
self.type = Opcode(self.type) # replace enum with bare int self.type = Opcode(typ)
def __repr__(self): def __repr__(self):
if self.type == Opcode.TEXT: if self.type == Opcode.TEXT:
return "text message: {}".format(repr(self.content)) return repr(self.content.decode(errors="replace"))
else: else:
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) return repr(self.content)
def kill(self): # pragma: no cover def kill(self):
""" # Likely to be replaced with .drop() in the future, see https://github.com/mitmproxy/mitmproxy/pull/4486
Kill this message. self.killed = True
It will not be sent to the other endpoint. This has no effect in streaming mode.
"""
warnings.warn(
"WebSocketMessage.kill is deprecated, set an empty content instead.",
DeprecationWarning,
stacklevel=2,
)
# empty str or empty bytes.
self.content = type(self.content)()
class WebSocketFlow(flow.Flow): class WebSocketData(stateobject.StateObject):
""" """
A WebSocketFlow is a simplified representation of a WebSocket connection. A data container for everything related to a single WebSocket connection.
This is typically accessed as `mitmproxy.http.HTTPFlow.websocket`.
""" """
def __init__(self, client_conn, server_conn, handshake_flow, live=None): messages: List[WebSocketMessage]
super().__init__("websocket", client_conn, server_conn, live) """All `WebSocketMessage`s transferred over this connection."""
self.messages: List[WebSocketMessage] = [] closed_by_client: Optional[bool] = None
"""A list containing all WebSocketMessage's.""" """
self.close_sender = 'client' True if the client closed the connection,
"""'client' if the client initiated connection closing.""" False if the server closed the connection,
self.close_code = CloseReason.NORMAL_CLOSURE None if the connection is active.
"""WebSocket close code.""" """
self.close_message = '(message missing)' close_code: Optional[int] = None
"""WebSocket close message.""" """[Close Code](https://tools.ietf.org/html/rfc6455#section-7.1.5)"""
self.close_reason = 'unknown status code' close_reason: Optional[str] = None
"""WebSocket close reason.""" """[Close Reason](https://tools.ietf.org/html/rfc6455#section-7.1.6)"""
self.stream = False
"""True of this connection is streaming directly to the other endpoint."""
self.handshake_flow = handshake_flow
"""The HTTP flow containing the initial WebSocket handshake."""
self.ended = False
"""True when the WebSocket connection has been closed."""
self._inject_messages_client = queue.Queue(maxsize=1) _stateobject_attributes = dict(
self._inject_messages_server = queue.Queue(maxsize=1)
if handshake_flow:
self.client_key = websocket.get_client_key(handshake_flow.request.headers)
self.client_protocol = websocket.get_protocol(handshake_flow.request.headers)
self.client_extensions = websocket.get_extensions(handshake_flow.request.headers)
self.server_accept = websocket.get_server_accept(handshake_flow.response.headers)
self.server_protocol = websocket.get_protocol(handshake_flow.response.headers)
self.server_extensions = websocket.get_extensions(handshake_flow.response.headers)
else:
self.client_key = ''
self.client_protocol = ''
self.client_extensions = ''
self.server_accept = ''
self.server_protocol = ''
self.server_extensions = ''
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
# mypy doesn't support update with kwargs
_stateobject_attributes.update(dict(
messages=List[WebSocketMessage], messages=List[WebSocketMessage],
close_sender=str, closed_by_client=bool,
close_code=int, close_code=int,
close_message=str,
close_reason=str, close_reason=str,
client_key=str, )
client_protocol=str,
client_extensions=str,
server_accept=str,
server_protocol=str,
server_extensions=str,
# Do not include handshake_flow, to prevent recursive serialization!
# Since mitmproxy-console currently only displays HTTPFlows,
# dumping the handshake_flow will include the WebSocketFlow too.
))
def get_state(self): def __init__(self):
d = super().get_state() self.messages = []
d['close_code'] = int(d['close_code']) # replace enum with bare int
return d def __repr__(self):
return f"<WebSocketData ({len(self.messages)} messages)>"
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, None, None) d = WebSocketData()
f.set_state(state) d.set_state(state)
return f return d
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
def message_info(self, message: WebSocketMessage) -> str:
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=message.type,
client=human.format_address(self.client_conn.peername),
server=human.format_address(self.server_conn.address),
direction="->" if message.from_client else "<-",
endpoint=self.handshake_flow.request.path,
)

View File

@ -110,7 +110,7 @@ async def test_start_stop(tdata):
assert cp.count() == 1 assert cp.count() == 1
cp.start_replay([tflow.twebsocketflow()]) cp.start_replay([tflow.twebsocketflow()])
await tctx.master.await_log("Can only replay HTTP flows.", level="warn") await tctx.master.await_log("Can't replay WebSocket flows.", level="warn")
assert cp.count() == 1 assert cp.count() == 1
cp.stop_replay() cp.stop_replay()

View File

@ -233,7 +233,7 @@ def test_websocket():
d.websocket_end(f) d.websocket_end(f)
assert "WebSocket connection closed by" in sio.getvalue() assert "WebSocket connection closed by" in sio.getvalue()
f = tflow.twebsocketflow(client_conn=True, err=True) f = tflow.twebsocketflow(err=True)
d.websocket_error(f) d.websocket_error(f)
assert "Error in WebSocket" in sio_err.getvalue() assert "Error in WebSocket" in sio_err.getvalue()

View File

@ -55,11 +55,11 @@ def test_websocket(tmpdir):
tctx.configure(sa, save_stream_file=p) tctx.configure(sa, save_stream_file=p)
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
sa.websocket_start(f) sa.request(f)
sa.websocket_end(f) sa.websocket_end(f)
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
sa.websocket_start(f) sa.request(f)
sa.websocket_error(f) sa.websocket_error(f)
tctx.configure(sa, save_stream_file=None) tctx.configure(sa, save_stream_file=None)

File diff suppressed because one or more lines are too long

View File

@ -8,7 +8,7 @@ from mitmproxy import exceptions
["dumpfile-011.mitm", "https://example.com/", 1], ["dumpfile-011.mitm", "https://example.com/", 1],
["dumpfile-018.mitm", "https://www.example.com/", 1], ["dumpfile-018.mitm", "https://www.example.com/", 1],
["dumpfile-019.mitm", "https://webrv.rtb-seller.com/", 1], ["dumpfile-019.mitm", "https://webrv.rtb-seller.com/", 1],
["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 5], ["dumpfile-7-websocket.mitm", "https://echo.websocket.org/", 6],
]) ])
def test_load(tdata, dumpfile, url, count): def test_load(tdata, dumpfile, url, count):
with open(tdata.path("mitmproxy/data/" + dumpfile), "rb") as f: with open(tdata.path("mitmproxy/data/" + dumpfile), "rb") as f:

View File

@ -1,39 +0,0 @@
from mitmproxy.net import websocket
def test_check_handshake():
assert not websocket.check_handshake({
"connection": "upgrade",
"upgrade": "webFOOsocket",
"sec-websocket-key": "foo",
})
assert websocket.check_handshake({
"connection": "upgrade",
"upgrade": "websocket",
"sec-websocket-key": "foo",
})
assert websocket.check_handshake({
"connection": "upgrade",
"upgrade": "websocket",
"sec-websocket-accept": "bar",
})
def test_get_extensions():
assert websocket.get_extensions({}) is None
assert websocket.get_extensions({"sec-websocket-extensions": "foo"}) == "foo"
def test_get_protocol():
assert websocket.get_protocol({}) is None
assert websocket.get_protocol({"sec-websocket-protocol": "foo"}) == "foo"
def test_get_client_key():
assert websocket.get_client_key({}) is None
assert websocket.get_client_key({"sec-websocket-key": "foo"}) == "foo"
def test_get_server_accept():
assert websocket.get_server_accept({}) is None
assert websocket.get_server_accept({"sec-websocket-accept": "foo"}) == "foo"

View File

@ -12,7 +12,6 @@ from mitmproxy.proxy.layers import TCPLayer, http, tls
from mitmproxy.proxy.layers.tcp import TcpStartHook from mitmproxy.proxy.layers.tcp import TcpStartHook
from mitmproxy.proxy.layers.websocket import WebsocketStartHook from mitmproxy.proxy.layers.websocket import WebsocketStartHook
from mitmproxy.tcp import TCPFlow from mitmproxy.tcp import TCPFlow
from mitmproxy.websocket import WebSocketFlow
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply, reply_next_layer
@ -960,7 +959,7 @@ def test_upgrade(tctx, proto):
b"\r\n") b"\r\n")
) )
if proto == "websocket": if proto == "websocket":
assert playbook << WebsocketStartHook(Placeholder(WebSocketFlow)) assert playbook << WebsocketStartHook(http_flow)
elif proto == "tcp": elif proto == "tcp":
assert playbook << TcpStartHook(Placeholder(TCPFlow)) assert playbook << TcpStartHook(Placeholder(TCPFlow))
else: else:

View File

@ -11,8 +11,9 @@ from mitmproxy.proxy.commands import SendData, CloseConnection, Log
from mitmproxy.connection import ConnectionState from mitmproxy.connection import ConnectionState
from mitmproxy.proxy.events import DataReceived, ConnectionClosed from mitmproxy.proxy.events import DataReceived, ConnectionClosed
from mitmproxy.proxy.layers import http, websocket from mitmproxy.proxy.layers import http, websocket
from mitmproxy.websocket import WebSocketFlow from mitmproxy.websocket import WebSocketData
from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply from test.mitmproxy.proxy.tutils import Placeholder, Playbook, reply
from wsproto.frame_protocol import Opcode
@dataclass @dataclass
@ -53,8 +54,7 @@ def test_upgrade(tctx):
"""Test a HTTP -> WebSocket upgrade""" """Test a HTTP -> WebSocket upgrade"""
tctx.server.address = ("example.com", 80) tctx.server.address = ("example.com", 80)
tctx.server.state = ConnectionState.OPEN tctx.server.state = ConnectionState.OPEN
http_flow = Placeholder(HTTPFlow) flow = Placeholder(HTTPFlow)
flow = Placeholder(WebSocketFlow)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.transparent)) Playbook(http.HttpLayer(tctx, HTTPMode.transparent))
>> DataReceived(tctx.client, >> DataReceived(tctx.client,
@ -63,9 +63,9 @@ def test_upgrade(tctx):
b"Upgrade: websocket\r\n" b"Upgrade: websocket\r\n"
b"Sec-WebSocket-Version: 13\r\n" b"Sec-WebSocket-Version: 13\r\n"
b"\r\n") b"\r\n")
<< http.HttpRequestHeadersHook(http_flow) << http.HttpRequestHeadersHook(flow)
>> reply() >> reply()
<< http.HttpRequestHook(http_flow) << http.HttpRequestHook(flow)
>> reply() >> reply()
<< SendData(tctx.server, b"GET / HTTP/1.1\r\n" << SendData(tctx.server, b"GET / HTTP/1.1\r\n"
b"Connection: upgrade\r\n" b"Connection: upgrade\r\n"
@ -76,9 +76,9 @@ def test_upgrade(tctx):
b"Upgrade: websocket\r\n" b"Upgrade: websocket\r\n"
b"Connection: Upgrade\r\n" b"Connection: Upgrade\r\n"
b"\r\n") b"\r\n")
<< http.HttpResponseHeadersHook(http_flow) << http.HttpResponseHeadersHook(flow)
>> reply() >> reply()
<< http.HttpResponseHook(http_flow) << http.HttpResponseHook(flow)
>> reply() >> reply()
<< SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n" << SendData(tctx.client, b"HTTP/1.1 101 Switching Protocols\r\n"
b"Upgrade: websocket\r\n" b"Upgrade: websocket\r\n"
@ -95,12 +95,13 @@ def test_upgrade(tctx):
>> reply() >> reply()
<< SendData(tctx.client, b"\x82\nhello back") << SendData(tctx.client, b"\x82\nhello back")
) )
assert flow().handshake_flow == http_flow() assert len(flow().websocket.messages) == 2
assert len(flow().messages) == 2 assert flow().websocket.messages[0].content == b"hello world"
assert flow().messages[0].content == "hello world" assert flow().websocket.messages[0].from_client
assert flow().messages[0].from_client assert flow().websocket.messages[0].type == Opcode.TEXT
assert flow().messages[1].content == b"hello back" assert flow().websocket.messages[1].content == b"hello back"
assert flow().messages[1].from_client is False assert flow().websocket.messages[1].from_client is False
assert flow().websocket.messages[1].type == Opcode.BINARY
@pytest.fixture() @pytest.fixture()
@ -120,12 +121,12 @@ def ws_testdata(tctx):
"Connection": "upgrade", "Connection": "upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
}) })
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow)) flow.websocket = WebSocketData()
return tctx, Playbook(websocket.WebsocketLayer(tctx, flow)), flow
def test_modify_message(ws_testdata): def test_modify_message(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -133,7 +134,7 @@ def test_modify_message(ws_testdata):
>> DataReceived(tctx.server, b"\x81\x03foo") >> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow) << websocket.WebsocketMessageHook(flow)
) )
flow().messages[-1].content = flow().messages[-1].content.replace("foo", "foobar") flow.websocket.messages[-1].content = flow.websocket.messages[-1].content.replace(b"foo", b"foobar")
assert ( assert (
playbook playbook
>> reply() >> reply()
@ -142,8 +143,7 @@ def test_modify_message(ws_testdata):
def test_drop_message(ws_testdata): def test_drop_message(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -151,7 +151,7 @@ def test_drop_message(ws_testdata):
>> DataReceived(tctx.server, b"\x81\x03foo") >> DataReceived(tctx.server, b"\x81\x03foo")
<< websocket.WebsocketMessageHook(flow) << websocket.WebsocketMessageHook(flow)
) )
flow().messages[-1].content = "" flow.websocket.messages[-1].kill()
assert ( assert (
playbook playbook
>> reply() >> reply()
@ -160,8 +160,7 @@ def test_drop_message(ws_testdata):
def test_fragmented(ws_testdata): def test_fragmented(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -173,12 +172,11 @@ def test_fragmented(ws_testdata):
<< SendData(tctx.client, b"\x01\x03foo") << SendData(tctx.client, b"\x01\x03foo")
<< SendData(tctx.client, b"\x80\x03bar") << SendData(tctx.client, b"\x80\x03bar")
) )
assert flow().messages[-1].content == "foobar" assert flow.websocket.messages[-1].content == b"foobar"
def test_protocol_error(ws_testdata): def test_protocol_error(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -193,12 +191,11 @@ def test_protocol_error(ws_testdata):
>> reply() >> reply()
) )
assert not flow().messages assert not flow.websocket.messages
def test_ping(ws_testdata): def test_ping(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -210,12 +207,11 @@ def test_ping(ws_testdata):
<< Log("Received WebSocket pong from server (payload: b'pong-with-payload')") << Log("Received WebSocket pong from server (payload: b'pong-with-payload')")
<< SendData(tctx.client, b"\x8a\x11pong-with-payload") << SendData(tctx.client, b"\x8a\x11pong-with-payload")
) )
assert not flow().messages assert not flow.websocket.messages
def test_close_normal(ws_testdata): def test_close_normal(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
masked_close = Placeholder(bytes) masked_close = Placeholder(bytes)
close = Placeholder(bytes) close = Placeholder(bytes)
assert ( assert (
@ -235,12 +231,11 @@ def test_close_normal(ws_testdata):
assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00") assert masked_close() == masked(b"\x88\x02\x03\xe8") or masked_close() == masked(b"\x88\x00")
assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00" assert close() == b"\x88\x02\x03\xe8" or close() == b"\x88\x00"
assert flow().close_code == 1005 assert flow.websocket.close_code == 1005
def test_close_disconnect(ws_testdata): def test_close_disconnect(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -253,12 +248,11 @@ def test_close_disconnect(ws_testdata):
>> reply() >> reply()
>> ConnectionClosed(tctx.client) >> ConnectionClosed(tctx.client)
) )
assert "ABNORMAL_CLOSURE" in flow().error.msg assert "ABNORMAL_CLOSURE" in flow.error.msg
def test_close_error(ws_testdata): def test_close_error(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow)
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -271,15 +265,12 @@ def test_close_error(ws_testdata):
<< websocket.WebsocketErrorHook(flow) << websocket.WebsocketErrorHook(flow)
>> reply() >> reply()
) )
assert "UNKNOWN_ERROR=4000" in flow().error.msg assert "UNKNOWN_ERROR=4000" in flow.error.msg
def test_deflate(ws_testdata): def test_deflate(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow) flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "permessage-deflate; server_max_window_bits=10"
assert ( assert (
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
@ -290,15 +281,12 @@ def test_deflate(ws_testdata):
>> reply() >> reply()
<< SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00")) << SendData(tctx.client, bytes.fromhex("c1 07 f2 48 cd c9 c9 07 00"))
) )
assert flow().messages[0].content == "Hello" assert flow.websocket.messages[0].content == b"Hello"
def test_unknown_ext(ws_testdata): def test_unknown_ext(ws_testdata):
tctx, playbook = ws_testdata tctx, playbook, flow = ws_testdata
flow = Placeholder(WebSocketFlow) flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
# noinspection PyUnresolvedReferences
http_flow: HTTPFlow = playbook.layer.flow.handshake_flow
http_flow.response.headers["Sec-WebSocket-Extensions"] = "funky-bits; param=42"
assert ( assert (
playbook playbook
<< Log("Ignoring unknown WebSocket extension 'funky-bits'.") << Log("Ignoring unknown WebSocket extension 'funky-bits'.")
@ -314,20 +302,20 @@ def test_websocket_connection_repr(tctx):
class TestFragmentizer: class TestFragmentizer:
def test_empty(self): def test_empty(self):
f = websocket.Fragmentizer([b"foo"]) f = websocket.Fragmentizer([b"foo"], False)
assert list(f(b"")) == [] assert list(f(b"")) == []
def test_keep_sizes(self): def test_keep_sizes(self):
f = websocket.Fragmentizer([b"foo", b"bar"]) f = websocket.Fragmentizer([b"foo", b"bar"], True)
assert list(f(b"foobaz")) == [ assert list(f(b"foobaz")) == [
wsproto.events.Message(b"foo", message_finished=False), wsproto.events.TextMessage("foo", message_finished=False),
wsproto.events.Message(b"baz", message_finished=True), wsproto.events.TextMessage("baz", message_finished=True),
] ]
def test_rechunk(self): def test_rechunk(self):
f = websocket.Fragmentizer([b"foo"]) f = websocket.Fragmentizer([b"foo"], False)
f.FRAGMENT_SIZE = 4 f.FRAGMENT_SIZE = 4
assert list(f(b"foobar")) == [ assert list(f(b"foobar")) == [
wsproto.events.Message(b"foob", message_finished=False), wsproto.events.BytesMessage(b"foob", message_finished=False),
wsproto.events.Message(b"ar", message_finished=True), wsproto.events.BytesMessage(b"ar", message_finished=True),
] ]

View File

@ -27,14 +27,20 @@ def test_http_flow(resp, err):
def test_websocket_flow(err): def test_websocket_flow(err):
f = tflow.twebsocketflow(err=err) f = tflow.twebsocketflow(err=err)
i = eventsequence.iterate(f) i = eventsequence.iterate(f)
assert isinstance(next(i), layers.http.HttpRequestHeadersHook)
assert isinstance(next(i), layers.http.HttpRequestHook)
assert isinstance(next(i), layers.http.HttpResponseHeadersHook)
assert isinstance(next(i), layers.http.HttpResponseHook)
assert isinstance(next(i), layers.websocket.WebsocketStartHook) assert isinstance(next(i), layers.websocket.WebsocketStartHook)
assert len(f.messages) == 0 assert len(f.websocket.messages) == 0
assert isinstance(next(i), layers.websocket.WebsocketMessageHook) assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 1 assert len(f.websocket.messages) == 1
assert isinstance(next(i), layers.websocket.WebsocketMessageHook) assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 2 assert len(f.websocket.messages) == 2
assert isinstance(next(i), layers.websocket.WebsocketMessageHook) assert isinstance(next(i), layers.websocket.WebsocketMessageHook)
assert len(f.messages) == 3 assert len(f.websocket.messages) == 3
if err: if err:
assert isinstance(next(i), layers.websocket.WebsocketErrorHook) assert isinstance(next(i), layers.websocket.WebsocketErrorHook)
else: else:

View File

@ -122,20 +122,6 @@ class TestFlowMaster:
await ctx.master.load_flow(f) await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
@pytest.mark.asyncio
async def test_load_websocket_flow(self):
opts = options.Options(
mode="reverse:https://use-this-domain"
)
s = State()
with taddons.context(s, options=opts) as ctx:
f = tflow.twebsocketflow()
await ctx.master.load_flow(f.handshake_flow)
await ctx.master.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all(self): async def test_all(self):
opts = options.Options( opts = options.Options(

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
from mitmproxy.test import tflow from mitmproxy.test import tflow
from mitmproxy import flowfilter from mitmproxy import flowfilter, http
class TestParsing: class TestParsing:
@ -424,10 +424,10 @@ class TestMatchingTCPFlow:
class TestMatchingWebSocketFlow: class TestMatchingWebSocketFlow:
def flow(self): def flow(self) -> http.HTTPFlow:
return tflow.twebsocketflow() return tflow.twebsocketflow()
def err(self): def err(self) -> http.HTTPFlow:
return tflow.twebsocketflow(err=True) return tflow.twebsocketflow(err=True)
def q(self, q, o): def q(self, q, o):
@ -437,10 +437,10 @@ class TestMatchingWebSocketFlow:
f = self.flow() f = self.flow()
assert self.q("~websocket", f) assert self.q("~websocket", f)
assert not self.q("~tcp", f) assert not self.q("~tcp", f)
assert not self.q("~http", f) assert self.q("~http", f)
def test_handshake(self): def test_handshake(self):
f = self.flow().handshake_flow f = self.flow()
assert self.q("~websocket", f) assert self.q("~websocket", f)
assert not self.q("~tcp", f) assert not self.q("~tcp", f)
assert self.q("~http", f) assert self.q("~http", f)
@ -465,9 +465,6 @@ class TestMatchingWebSocketFlow:
assert self.q("~u example.com/ws", q) assert self.q("~u example.com/ws", q)
assert not self.q("~u moo/path", q) assert not self.q("~u moo/path", q)
q.handshake_flow = None
assert not self.q("~u example.com", q)
def test_body(self): def test_body(self):
f = self.flow() f = self.flow()

View File

@ -1,88 +1,28 @@
import io from mitmproxy import http
from mitmproxy import websocket
import pytest
from mitmproxy import flowfilter
from mitmproxy.exceptions import ControlException
from mitmproxy.io import tnetstring
from mitmproxy.test import tflow from mitmproxy.test import tflow
from wsproto.frame_protocol import Opcode
class TestWebSocketFlow: class TestWebSocketData:
def test_copy(self):
f = tflow.twebsocketflow()
f.get_state()
f2 = f.copy()
a = f.get_state()
b = f2.get_state()
del a["id"]
del b["id"]
assert a == b
assert not f == f2
assert f is not f2
assert f.client_key == f2.client_key
assert f.client_protocol == f2.client_protocol
assert f.client_extensions == f2.client_extensions
assert f.server_accept == f2.server_accept
assert f.server_protocol == f2.server_protocol
assert f.server_extensions == f2.server_extensions
assert f.messages is not f2.messages
assert f.handshake_flow is not f2.handshake_flow
for m in f.messages:
m2 = m.copy()
m2.set_state(m2.get_state())
assert m is not m2
assert m.get_state() == m2.get_state()
f = tflow.twebsocketflow(err=True)
f2 = f.copy()
assert f is not f2
assert f.handshake_flow is not f2.handshake_flow
assert f.error.get_state() == f2.error.get_state()
assert f.error is not f2.error
def test_kill(self):
f = tflow.twebsocketflow()
with pytest.raises(ControlException):
f.intercept()
f.resume()
f.kill()
f = tflow.twebsocketflow()
f.intercept()
assert f.killable
f.kill()
assert not f.killable
def test_match(self):
f = tflow.twebsocketflow()
assert not flowfilter.match("~b nonexistent", f)
assert flowfilter.match(None, f)
assert not flowfilter.match("~b nonexistent", f)
f = tflow.twebsocketflow(err=True)
assert flowfilter.match("~e", f)
with pytest.raises(ValueError):
flowfilter.match("~", f)
def test_repr(self): def test_repr(self):
assert repr(tflow.twebsocketflow().websocket) == "<WebSocketData (3 messages)>"
def test_state(self):
f = tflow.twebsocketflow() f = tflow.twebsocketflow()
assert f.message_info(f.messages[0]) f2 = http.HTTPFlow.from_state(f.get_state())
assert 'WebSocketFlow' in repr(f) f2.set_state(f.get_state())
assert 'binary message: ' in repr(f.messages[0])
assert 'text message: ' in repr(f.messages[1])
def test_serialize(self):
b = io.BytesIO()
d = tflow.twebsocketflow().get_state()
tnetstring.dump(d, b)
assert b.getvalue()
b = io.BytesIO() class TestWebSocketMessage:
d = tflow.twebsocketflow().handshake_flow.get_state() def test_basic(self):
tnetstring.dump(d, b) m = websocket.WebSocketMessage(Opcode.TEXT, True, b"foo")
assert b.getvalue() m.set_state(m.get_state())
assert m.content == b"foo"
assert repr(m) == "'foo'"
m.type = Opcode.BINARY
assert repr(m) == "b'foo'"
assert not m.killed
m.kill()
assert m.killed