Merge pull request #2702 from Kriechi/fix-2640

improve websocket dumps
This commit is contained in:
Thomas Kriechbaumer 2017-12-18 21:31:00 +01:00 committed by GitHub
commit 9faad6bc9c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 128 additions and 19 deletions

View File

@ -75,6 +75,15 @@ class Save:
self.stream.add(flow) self.stream.add(flow)
self.active_flows.discard(flow) self.active_flows.discard(flow)
def websocket_start(self, flow):
if self.stream:
self.active_flows.add(flow)
def websocket_end(self, flow):
if self.stream:
self.stream.add(flow)
self.active_flows.discard(flow)
def response(self, flow): def response(self, flow):
if self.stream: if self.stream:
self.stream.add(flow) self.stream.add(flow)

View File

@ -0,0 +1,13 @@
from . import compat
from . import connection
from . import events
from . import extensions
from . import frame_protocol
__all__ = [
'compat',
'connection',
'events',
'extensions',
'frame_protocol',
]

View File

@ -1,3 +1,5 @@
# type: ignore
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
wsproto/extensions wsproto/extensions

View File

@ -1,3 +1,5 @@
# type: ignore
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
wsproto/frame_protocol wsproto/frame_protocol

View File

@ -322,8 +322,10 @@ class FDomain(_Rex):
flags = re.IGNORECASE flags = re.IGNORECASE
is_binary = False is_binary = False
@only(http.HTTPFlow) @only(http.HTTPFlow, websocket.WebSocketFlow)
def __call__(self, f): def __call__(self, f):
if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
return bool( return bool(
self.re.search(f.request.host) or self.re.search(f.request.host) or
self.re.search(f.request.pretty_host) self.re.search(f.request.pretty_host)
@ -342,9 +344,11 @@ class FUrl(_Rex):
toks = toks[1:] toks = toks[1:]
return klass(*toks) return klass(*toks)
@only(http.HTTPFlow) @only(http.HTTPFlow, websocket.WebSocketFlow)
def __call__(self, f): def __call__(self, f):
if not f.request: if isinstance(f, websocket.WebSocketFlow):
f = f.handshake_flow
if not f or not f.request:
return False return False
return self.re.search(f.request.pretty_url) return self.re.search(f.request.pretty_url)

View File

@ -9,6 +9,7 @@ from mitmproxy import eventsequence
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import command from mitmproxy import command
from mitmproxy import http from mitmproxy import http
from mitmproxy import websocket
from mitmproxy import log from mitmproxy import log
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.proxy.protocol import http_replay from mitmproxy.proxy.protocol import http_replay
@ -41,6 +42,7 @@ class Master:
self.should_exit = threading.Event() self.should_exit = threading.Event()
self._server = None self._server = None
self.first_tick = True self.first_tick = True
self.waiting_flows = []
@property @property
def server(self): def server(self):
@ -117,15 +119,33 @@ class Master:
self.should_exit.set() self.should_exit.set()
self.addons.trigger("done") self.addons.trigger("done")
def _change_reverse_host(self, f):
"""
When we load flows in reverse proxy mode, we adjust the target host to
the reverse proxy destination for all flows we load. This makes it very
easy to replay saved flows against a different host.
"""
if self.options.mode.startswith("reverse:"):
_, upstream_spec = server_spec.parse_with_mode(self.options.mode)
f.request.host, f.request.port = upstream_spec.address
f.request.scheme = upstream_spec.scheme
def load_flow(self, f): def load_flow(self, f):
""" """
Loads a flow Loads a flow and links websocket & handshake flows
""" """
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if self.options.mode.startswith("reverse:"): self._change_reverse_host(f)
_, upstream_spec = server_spec.parse_with_mode(self.options.mode) if 'websocket' in f.metadata:
f.request.host, f.request.port = upstream_spec.address self.waiting_flows.append(f)
f.request.scheme = upstream_spec.scheme
if isinstance(f, websocket.WebSocketFlow):
hf = [hf for hf in self.waiting_flows if hf.id == f.metadata['websocket_handshake']][0]
f.handshake_flow = hf
self.waiting_flows.remove(hf)
self._change_reverse_host(f.handshake_flow)
f.reply = controller.DummyReply() f.reply = controller.DummyReply()
for e, o in eventsequence.iterate(f): for e, o in eventsequence.iterate(f):
self.addons.handle_lifecycle(e, o) self.addons.handle_lifecycle(e, o)

View File

@ -321,6 +321,7 @@ class HttpLayer(base.Layer):
try: try:
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
f.metadata['websocket'] = True
# We only support RFC6455 with WebSocket version 13 # We only support RFC6455 with WebSocket version 13
# allow inline scripts to manipulate the client handshake # allow inline scripts to manipulate the client handshake
self.channel.ask("websocket_handshake", f) self.channel.ask("websocket_handshake", f)

View File

@ -1,10 +1,11 @@
import socket import socket
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy.contrib import wsproto
from mitmproxy.contrib.wsproto import events from mitmproxy.contrib.wsproto import events
from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection from mitmproxy.contrib.wsproto.connection import ConnectionType, WSConnection
from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate from mitmproxy.contrib.wsproto.extensions import PerMessageDeflate
from mitmproxy.contrib.wsproto.frame_protocol import Opcode
from mitmproxy import exceptions from mitmproxy import exceptions
from mitmproxy import flow from mitmproxy import flow
@ -93,11 +94,14 @@ class WebSocketLayer(base.Layer):
if event.message_finished: if event.message_finished:
original_chunk_sizes = [len(f) for f in fb] original_chunk_sizes = [len(f) for f in fb]
message_type = Opcode.TEXT if isinstance(event, events.TextReceived) else Opcode.BINARY
if message_type == Opcode.TEXT: if isinstance(event, events.TextReceived):
message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb) payload = ''.join(fb)
else: else:
message_type = wsproto.frame_protocol.Opcode.BINARY
payload = b''.join(fb) payload = b''.join(fb)
fb.clear() fb.clear()
websocket_message = WebSocketMessage(message_type, not is_server, payload) websocket_message = WebSocketMessage(message_type, not is_server, payload)

View File

@ -44,7 +44,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
"GET", "GET",
"http", "http",
"example.com", "example.com",
"80", 80,
"/ws", "/ws",
"HTTP/1.1", "HTTP/1.1",
headers=net_http.Headers( headers=net_http.Headers(
@ -75,7 +75,9 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
handshake_flow.response = resp handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow) f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
handshake_flow.metadata['websocket_flow'] = f f.metadata['websocket_handshake'] = handshake_flow.id
handshake_flow.metadata['websocket_flow'] = f.id
handshake_flow.metadata['websocket'] = True
if messages is True: if messages is True:
messages = [ messages = [

View File

@ -1,6 +1,8 @@
import time import time
from typing import List, Optional from typing import List, Optional
from mitmproxy.contrib import wsproto
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.net import websockets from mitmproxy.net import websockets
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
@ -11,7 +13,7 @@ class WebSocketMessage(serializable.Serializable):
def __init__( def __init__(
self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None self, type: int, from_client: bool, content: bytes, timestamp: Optional[int]=None
) -> None: ) -> None:
self.type = type self.type = wsproto.frame_protocol.Opcode(type) # type: ignore
self.from_client = from_client self.from_client = from_client
self.content = content self.content = content
self.timestamp = timestamp or int(time.time()) # type: int self.timestamp = timestamp or int(time.time()) # type: int
@ -21,13 +23,14 @@ class WebSocketMessage(serializable.Serializable):
return cls(*state) return cls(*state)
def get_state(self): def get_state(self):
return self.type, self.from_client, self.content, self.timestamp return int(self.type), self.from_client, self.content, self.timestamp
def set_state(self, state): def set_state(self, state):
self.type, self.from_client, self.content, self.timestamp = state self.type, self.from_client, self.content, self.timestamp = state
self.type = wsproto.frame_protocol.Opcode(self.type) # replace enum with bare int
def __repr__(self): def __repr__(self):
if self.type == websockets.OPCODE.TEXT: if self.type == wsproto.frame_protocol.Opcode.TEXT:
return "text message: {}".format(repr(self.content)) return "text message: {}".format(repr(self.content))
else: else:
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content)) return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
@ -42,7 +45,7 @@ class WebSocketFlow(flow.Flow):
super().__init__("websocket", client_conn, server_conn, live) super().__init__("websocket", client_conn, server_conn, live)
self.messages = [] # type: List[WebSocketMessage] self.messages = [] # type: List[WebSocketMessage]
self.close_sender = 'client' self.close_sender = 'client'
self.close_code = '(status code missing)' self.close_code = wsproto.frame_protocol.CloseReason.NORMAL_CLOSURE
self.close_message = '(message missing)' self.close_message = '(message missing)'
self.close_reason = 'unknown status code' self.close_reason = 'unknown status code'
self.stream = False self.stream = False
@ -69,7 +72,7 @@ class WebSocketFlow(flow.Flow):
_stateobject_attributes.update(dict( _stateobject_attributes.update(dict(
messages=List[WebSocketMessage], messages=List[WebSocketMessage],
close_sender=str, close_sender=str,
close_code=str, close_code=int,
close_message=str, close_message=str,
close_reason=str, close_reason=str,
client_key=str, client_key=str,
@ -83,6 +86,11 @@ class WebSocketFlow(flow.Flow):
# dumping the handshake_flow will include the WebSocketFlow too. # dumping the handshake_flow will include the WebSocketFlow too.
)) ))
def get_state(self):
d = super().get_state()
d['close_code'] = int(d['close_code']) # replace enum with bare int
return d
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, None, None) f = cls(None, None, None)

View File

@ -19,6 +19,9 @@ exclude_lines =
pragma: no cover pragma: no cover
raise NotImplementedError() raise NotImplementedError()
[mypy-mitmproxy.contrib.*]
ignore_errors = True
[tool:full_coverage] [tool:full_coverage]
exclude = exclude =
mitmproxy/proxy/protocol/base.py mitmproxy/proxy/protocol/base.py

View File

@ -44,6 +44,19 @@ def test_tcp(tmpdir):
assert rd(p) assert rd(p)
def test_websocket(tmpdir):
sa = save.Save()
with taddons.context() as tctx:
p = str(tmpdir.join("foo"))
tctx.configure(sa, save_stream_file=p)
f = tflow.twebsocketflow()
sa.websocket_start(f)
sa.websocket_end(f)
tctx.configure(sa, save_stream_file=None)
assert rd(p)
def test_save_command(tmpdir): def test_save_command(tmpdir):
sa = save.Save() sa = save.Save()
with taddons.context() as tctx: with taddons.context() as tctx:

View File

@ -97,7 +97,7 @@ class TestSerialize:
class TestFlowMaster: class TestFlowMaster:
def test_load_flow_reverse(self): def test_load_http_flow_reverse(self):
s = tservers.TestState() s = tservers.TestState()
opts = options.Options( opts = options.Options(
mode="reverse:https://use-this-domain" mode="reverse:https://use-this-domain"
@ -108,6 +108,20 @@ class TestFlowMaster:
fm.load_flow(f) fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain" assert s.flows[0].request.host == "use-this-domain"
def test_load_websocket_flow(self):
s = tservers.TestState()
opts = options.Options(
mode="reverse:https://use-this-domain"
)
fm = master.Master(opts)
fm.addons.add(s)
f = tflow.twebsocketflow()
fm.load_flow(f.handshake_flow)
fm.load_flow(f)
assert s.flows[0].request.host == "use-this-domain"
assert s.flows[1].handshake_flow == f.handshake_flow
assert len(s.flows[1].messages) == len(f.messages)
def test_replay(self): def test_replay(self):
opts = options.Options() opts = options.Options()
fm = master.Master(opts) fm = master.Master(opts)

View File

@ -420,6 +420,20 @@ class TestMatchingWebSocketFlow:
e = self.err() e = self.err()
assert self.q("~e", e) assert self.q("~e", e)
def test_domain(self):
q = self.flow()
assert self.q("~d example.com", q)
assert not self.q("~d none", q)
def test_url(self):
q = self.flow()
assert self.q("~u example.com", q)
assert self.q("~u example.com/ws", q)
assert not self.q("~u moo/path", q)
q.handshake_flow = None
assert not self.q("~u example.com", q)
def test_body(self): def test_body(self):
f = self.flow() f = self.flow()