Merge pull request #2702 from Kriechi/fix-2640

improve websocket dumps
This commit is contained in:
Thomas Kriechbaumer 2017-12-18 21:31:00 +01:00 committed by GitHub
commit 9faad6bc9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 128 additions and 19 deletions

View File

@ -75,6 +75,15 @@ class Save:
self.stream.add(flow)
self.active_flows.discard(flow)
def websocket_start(self, flow):
if self.stream:
self.active_flows.add(flow)
def websocket_end(self, flow):
if self.stream:
self.stream.add(flow)
self.active_flows.discard(flow)
def response(self, flow):
if self.stream:
self.stream.add(flow)

View File

@ -0,0 +1,13 @@
from . import compat
from . import connection
from . import events
from . import extensions
from . import frame_protocol
__all__ = [
'compat',
'connection',
'events',
'extensions',
'frame_protocol',
]

View File

@ -1,3 +1,5 @@
# type: ignore
# -*- coding: utf-8 -*-
"""
wsproto/extensions

View File

@ -1,3 +1,5 @@
# type: ignore
# -*- coding: utf-8 -*-
"""
wsproto/frame_protocol

View File

@ -322,8 +322,10 @@ class FDomain(_Rex):
flags = re.IGNORECASE
is_binary = False
@only(http.HTTPFlow)
@only(http.HTTPFlow, websocket.WebSocketFlow)
def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
return bool(
self.re.search(f.request.host) or
self.re.search(f.request.pretty_host)
@ -342,9 +344,11 @@ class FUrl(_Rex):
toks = toks[1:]
return klass(*toks)
@only(http.HTTPFlow)
@only(http.HTTPFlow, websocket.WebSocketFlow)
def __call__(self, f):
if not f.request:
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
if not f or not f.request:
return False
return self.re.search(f.request.pretty_url)

View File

@ -9,6 +9,7 @@ from mitmproxy import eventsequence
from mitmproxy import exceptions
from mitmproxy import command
from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log
from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay
@ -41,6 +42,7 @@ class Master:
self.should_exit = threading.Event()
self._server = None
self.first_tick = True
self.waiting_flows = []
@property
def server(self):
@ -117,15 +119,33 @@ class Master:
self.should_exit.set()
self.addons.trigger("done")
def _change_reverse_host(self, f):
"""
When we load flows in reverse proxy mode, we adjust the target host to
the reverse proxy destination for all flows we load. This makes it very
easy to replay saved flows against a different host.
"""
if self.options.mode.startswith("reverse:"):
_, upstream_spec = server_spec.parse_with_mode(self.options.mode)
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
def load_flow(self, f):
"""
Loads a flow
Loads a flow and links websocket & handshake flows
"""
if isinstance(f, http.HTTPFlow):
if self.options.mode.startswith("reverse:"):
_, upstream_spec = server_spec.parse_with_mode(self.options.mode)
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
self._change_reverse_host(f)
if 'websocket' in f.metadata:
self.waiting_flows.append(f)
if isinstance(f, websocket.WebSocketFlow):
hf = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']][0]
f.handshake_flow = hf
self.waiting_flows.remove(hf)
self._change_reverse_host(f.handshake_flow)
f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o)

View File

@ -321,6 +321,7 @@ class HttpLayer(base.Layer):
try:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
f.metadata['websocket'] = True
# We only support RFC6455 with WebSocket version 13
# allow inline scripts to manipulate the client handshake
self.channel.ask("websocket_handshake", f)

View File

@ -1,10 +1,11 @@
import socket
from OpenSSL import SSL
from mitmproxy.contrib import wsproto
from mitmproxy.contrib.wsproto import events
from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
from mitmproxy.contrib.wsproto.frame_protocol import Opcode
from mitmproxy import exceptions
from mitmproxy import flow
@ -93,11 +94,14 @@ class WebSocketLayer(base.Layer):
if event.message_finished:
original_chunk_sizes = [len(f) for f in fb]
message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
if message_type == Opcode.TEXT:
if isinstance(event, events.TextReceived):
message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb)
else:
message_type = wsproto.frame_protocol.Opcode.BINARY
payload = b''.join(fb)
fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload)

View File

@ -44,7 +44,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
"GET",
"http",
"example.com",
"80",
80,
"/ws",
"HTTP/1.1",
headers=net_http.Headers(
@ -75,7 +75,9 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
handshake_flow.metadata['websocket_flow'] = f
f.metadata['websocket_handshake'] = handshake_flow.id
handshake_flow.metadata['websocket_flow'] = f.id
handshake_flow.metadata['websocket'] = True
if messages is True:
messages = [

View File

@ -1,6 +1,8 @@
import time
from typing import List, Optional
from mitmproxy.contrib import wsproto
from mitmproxy import flow
from mitmproxy.net import websockets
from mitmproxy.coretypes import serializable
@ -11,7 +13,7 @@ class WebSocketMessage(serializable.Serializable):
def __init__(
self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None
) -> None:
self.type = type
self.type = wsproto.frame_protocol.Opcode(type) # type: ignore
self.from_client = from_client
self.content = content
self.timestamp = timestamp or int(time.time()) # type: int
@ -21,13 +23,14 @@ class WebSocketMessage(serializable.Serializable):
return cls(*state)
def get_state(self):
return self.type, self.from_client, self.content, self.timestamp
return int(self.type), self.from_client, self.content, self.timestamp
def set_state(self, state):
self.type, self.from_client, self.content, self.timestamp = state
self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int
def __repr__(self):
if self.type == websockets.OPCODE.TEXT:
if self.type == wsproto.frame_protocol.Opcode.TEXT:
return "text message: {}".format(repr(self.content))
else:
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
@ -42,7 +45,7 @@ class WebSocketFlow(flow.Flow):
super().__init__("websocket", client_conn, server_conn, live)
self.messages = [] # type: List[WebSocketMessage]
self.close_sender = 'client'
self.close_code = '(status code missing)'
self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE
self.close_message = '(message missing)'
self.close_reason = 'unknown status code'
self.stream = False
@ -69,7 +72,7 @@ class WebSocketFlow(flow.Flow):
_stateobject_attributes.update(dict(
messages=List[WebSocketMessage],
close_sender=str,
close_code=str,
close_code=int,
close_message=str,
close_reason=str,
client_key=str,
@ -83,6 +86,11 @@ class WebSocketFlow(flow.Flow):
# dumping the handshake_flow will include the WebSocketFlow too.
))
def get_state(self):
d = super().get_state()
d['close_code'] = int(d['close_code']) # replace enum with bare int
return d
@classmethod
def from_state(cls, state):
f = cls(None, None, None)

View File

@ -19,6 +19,9 @@ exclude_lines =
pragma: no cover
raise NotImplementedError()
[mypy-mitmproxy.contrib.*]
ignore_errors = True
[tool:full_coverage]
exclude =
mitmproxy/proxy/protocol/base.py

View File

@ -44,6 +44,19 @@ def test_tcp(tmpdir):
assert rd(p)
def test_websocket(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
f = tflow.twebsocketflow()
sa.websocket_start(f)
sa.websocket_end(f)
tctx.configure(sa, save_stream_file=None)
assert rd(p)
def test_save_command(tmpdir):
sa = save.Save()
with taddons.context() as tctx:

View File

@ -97,7 +97,7 @@ class TestSerialize:
class TestFlowMaster:
def test_load_flow_reverse(self):
def test_load_http_flow_reverse(self):
s = tservers.TestState()
opts = options.Options(
mode="reverse:https://use-this-domain"
@ -108,6 +108,20 @@ class TestFlowMaster:
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self):
s = tservers.TestState()
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow)
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
def test_replay(self):
opts = options.Options()
fm = master.Master(opts)

View File

@ -420,6 +420,20 @@ class TestMatchingWebSocketFlow:
e = self.err()
assert self.q("~e", e)
def test_domain(self):
q = self.flow()
assert self.q("~d example.com", q)
assert not self.q("~d none", q)
def test_url(self):
q = self.flow()
assert self.q("~u example.com", q)
assert self.q("~u example.com/ws", q)
assert not self.q("~u moo/path", q)
q.handshake_flow = None
assert not self.q("~u example.com", q)
def test_body(self):
f = self.flow()