mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
Merge pull request #2114 from mitmproxy/fix-websocket-serialization
make websocket flows serializable
This commit is contained in:
commit
e9746c5182
@ -78,7 +78,7 @@ class Flow(stateobject.StateObject):
|
||||
self._backup = None # type: typing.Optional[Flow]
|
||||
self.reply = None # type: typing.Optional[controller.Reply]
|
||||
self.marked = False # type: bool
|
||||
self.metadata = dict() # type: typing.Dict[str, str]
|
||||
self.metadata = dict() # type: typing.Dict[str, typing.Any]
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
id=str,
|
||||
|
@ -140,8 +140,8 @@ class WebSocketLayer(base.Layer):
|
||||
|
||||
def __call__(self):
|
||||
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
|
||||
self.flow.metadata['websocket_handshake'] = self.handshake_flow
|
||||
self.handshake_flow.metadata['websocket_flow'] = self.flow
|
||||
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
|
||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
||||
self.channel.ask("websocket_start", self.flow)
|
||||
|
||||
client = self.client_conn.connection
|
||||
|
@ -39,6 +39,14 @@ class StateObject(serializable.Serializable):
|
||||
state[attr] = val.get_state()
|
||||
elif _is_list(cls):
|
||||
state[attr] = [x.get_state() for x in val]
|
||||
elif isinstance(val, dict):
|
||||
s = {}
|
||||
for k, v in val.items():
|
||||
if hasattr(v, "get_state"):
|
||||
s[k] = v.get_state()
|
||||
else:
|
||||
s[k] = v
|
||||
state[attr] = s
|
||||
else:
|
||||
state[attr] = val
|
||||
return state
|
||||
|
@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
|
||||
handshake_flow.response = resp
|
||||
|
||||
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
|
||||
handshake_flow.metadata['websocket_flow'] = f
|
||||
|
||||
if messages is True:
|
||||
messages = [
|
||||
|
@ -2,7 +2,6 @@ import time
|
||||
from typing import List, Optional
|
||||
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.net import websockets
|
||||
from mitmproxy.types import serializable
|
||||
from mitmproxy.utils import strutils
|
||||
@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow):
|
||||
self.close_code = '(status code missing)'
|
||||
self.close_message = '(message missing)'
|
||||
self.close_reason = 'unknown status code'
|
||||
|
||||
if handshake_flow:
|
||||
self.client_key = websockets.get_client_key(handshake_flow.request.headers)
|
||||
self.client_protocol = websockets.get_protocol(handshake_flow.request.headers)
|
||||
self.client_extensions = websockets.get_extensions(handshake_flow.request.headers)
|
||||
self.server_accept = websockets.get_server_accept(handshake_flow.response.headers)
|
||||
self.server_protocol = websockets.get_protocol(handshake_flow.response.headers)
|
||||
self.server_extensions = websockets.get_extensions(handshake_flow.response.headers)
|
||||
else:
|
||||
self.client_key = ''
|
||||
self.client_protocol = ''
|
||||
self.client_extensions = ''
|
||||
self.server_accept = ''
|
||||
self.server_protocol = ''
|
||||
self.server_extensions = ''
|
||||
|
||||
self.handshake_flow = handshake_flow
|
||||
|
||||
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
|
||||
@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow):
|
||||
close_code=str,
|
||||
close_message=str,
|
||||
close_reason=str,
|
||||
handshake_flow=HTTPFlow,
|
||||
client_key=str,
|
||||
client_protocol=str,
|
||||
client_extensions=str,
|
||||
server_accept=str,
|
||||
server_protocol=str,
|
||||
server_extensions=str,
|
||||
# Do not include handshake_flow, to prevent recursive serialization!
|
||||
# Since mitmproxy-console currently only displays HTTPFlows,
|
||||
# dumping the handshake_flow will include the WebSocketFlow too.
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow):
|
||||
def __repr__(self):
|
||||
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
|
||||
|
||||
@property
|
||||
def client_key(self):
|
||||
return websockets.get_client_key(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def client_protocol(self):
|
||||
return websockets.get_protocol(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def client_extensions(self):
|
||||
return websockets.get_extensions(self.handshake_flow.request.headers)
|
||||
|
||||
@property
|
||||
def server_accept(self):
|
||||
return websockets.get_server_accept(self.handshake_flow.response.headers)
|
||||
|
||||
@property
|
||||
def server_protocol(self):
|
||||
return websockets.get_protocol(self.handshake_flow.response.headers)
|
||||
|
||||
@property
|
||||
def server_extensions(self):
|
||||
return websockets.get_extensions(self.handshake_flow.response.headers)
|
||||
|
||||
def message_info(self, message: WebSocketMessage) -> str:
|
||||
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
|
||||
type=message.type,
|
||||
|
@ -26,10 +26,12 @@ class Container(StateObject):
|
||||
def __init__(self):
|
||||
self.child = None
|
||||
self.children = None
|
||||
self.dictionary = None
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
child=Child,
|
||||
children=List[Child],
|
||||
dictionary=dict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -62,12 +64,30 @@ def test_container_list():
|
||||
a.children = [Child(42), Child(44)]
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": [{"x": 42}, {"x": 44}]
|
||||
"children": [{"x": 42}, {"x": 44}],
|
||||
"dictionary": None,
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.children) == 2
|
||||
assert copy.children is not a.children
|
||||
assert copy.children[0] is not a.children[0]
|
||||
assert Container.from_state(a.get_state())
|
||||
|
||||
|
||||
def test_container_dict():
|
||||
a = Container()
|
||||
a.dictionary = dict()
|
||||
a.dictionary['foo'] = 'bar'
|
||||
a.dictionary['bar'] = Child(44)
|
||||
assert a.get_state() == {
|
||||
"child": None,
|
||||
"children": None,
|
||||
"dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
|
||||
}
|
||||
copy = a.copy()
|
||||
assert len(copy.dictionary) == 2
|
||||
assert copy.dictionary is not a.dictionary
|
||||
assert copy.dictionary['bar'] is not a.dictionary['bar']
|
||||
|
||||
|
||||
def test_too_much_state():
|
||||
|
@ -1,5 +1,7 @@
|
||||
import io
|
||||
import pytest
|
||||
|
||||
from mitmproxy.contrib import tnetstring
|
||||
from mitmproxy import flowfilter
|
||||
from mitmproxy.test import tflow
|
||||
|
||||
@ -14,8 +16,6 @@ class TestWebSocketFlow:
|
||||
b = f2.get_state()
|
||||
del a["id"]
|
||||
del b["id"]
|
||||
del a["handshake_flow"]["id"]
|
||||
del b["handshake_flow"]["id"]
|
||||
assert a == b
|
||||
assert not f == f2
|
||||
assert f is not f2
|
||||
@ -60,3 +60,14 @@ class TestWebSocketFlow:
|
||||
assert 'WebSocketFlow' in repr(f)
|
||||
assert 'binary message: ' in repr(f.messages[0])
|
||||
assert 'text message: ' in repr(f.messages[1])
|
||||
|
||||
def test_serialize(self):
|
||||
b = io.BytesIO()
|
||||
d = tflow.twebsocketflow().get_state()
|
||||
tnetstring.dump(d, b)
|
||||
assert b.getvalue()
|
||||
|
||||
b = io.BytesIO()
|
||||
d = tflow.twebsocketflow().handshake_flow.get_state()
|
||||
tnetstring.dump(d, b)
|
||||
assert b.getvalue()
|
||||
|
Loading…
Reference in New Issue
Block a user