Merge pull request #2114 from mitmproxy/fix-websocket-serialization

make websocket flows serializable
This commit is contained in:
Thomas Kriechbaumer 2017-03-10 21:15:46 +01:00 committed by GitHub
commit e9746c5182
7 changed files with 71 additions and 32 deletions

View File

@ -78,7 +78,7 @@ class Flow(stateobject.StateObject):
self._backup = None # type: typing.Optional[Flow]
self.reply = None # type: typing.Optional[controller.Reply]
self.marked = False # type: bool
self.metadata = dict() # type: typing.Dict[str, str]
self.metadata = dict() # type: typing.Dict[str, typing.Any]
_stateobject_attributes = dict(
id=str,

View File

@ -140,8 +140,8 @@ class WebSocketLayer(base.Layer):
def __call__(self):
self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
self.flow.metadata['websocket_handshake'] = self.handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow
self.flow.metadata['websocket_handshake'] = self.handshake_flow.id
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.channel.ask("websocket_start", self.flow)
client = self.client_conn.connection

View File

@ -39,6 +39,14 @@ class StateObject(serializable.Serializable):
state[attr] = val.get_state()
elif _is_list(cls):
state[attr] = [x.get_state() for x in val]
elif isinstance(val, dict):
s = {}
for k, v in val.items():
if hasattr(v, "get_state"):
s[k] = v.get_state()
else:
s[k] = v
state[attr] = s
else:
state[attr] = val
return state

View File

@ -70,6 +70,7 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,
handshake_flow.response = resp
f = websocket.WebSocketFlow(client_conn, server_conn, handshake_flow)
handshake_flow.metadata['websocket_flow'] = f
if messages is True:
messages = [

View File

@ -2,7 +2,6 @@ import time
from typing import List, Optional
from mitmproxy import flow
from mitmproxy.http import HTTPFlow
from mitmproxy.net import websockets
from mitmproxy.types import serializable
from mitmproxy.utils import strutils
@ -44,6 +43,22 @@ class WebSocketFlow(flow.Flow):
self.close_code = '(status code missing)'
self.close_message = '(message missing)'
self.close_reason = 'unknown status code'
if handshake_flow:
self.client_key = websockets.get_client_key(handshake_flow.request.headers)
self.client_protocol = websockets.get_protocol(handshake_flow.request.headers)
self.client_extensions = websockets.get_extensions(handshake_flow.request.headers)
self.server_accept = websockets.get_server_accept(handshake_flow.response.headers)
self.server_protocol = websockets.get_protocol(handshake_flow.response.headers)
self.server_extensions = websockets.get_extensions(handshake_flow.response.headers)
else:
self.client_key = ''
self.client_protocol = ''
self.client_extensions = ''
self.server_accept = ''
self.server_protocol = ''
self.server_extensions = ''
self.handshake_flow = handshake_flow
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
@ -53,7 +68,15 @@ class WebSocketFlow(flow.Flow):
close_code=str,
close_message=str,
close_reason=str,
handshake_flow=HTTPFlow,
client_key=str,
client_protocol=str,
client_extensions=str,
server_accept=str,
server_protocol=str,
server_extensions=str,
# Do not include handshake_flow, to prevent recursive serialization!
# Since mitmproxy-console currently only displays HTTPFlows,
# dumping the handshake_flow will include the WebSocketFlow too.
)
@classmethod
@ -65,30 +88,6 @@ class WebSocketFlow(flow.Flow):
def __repr__(self):
return "<WebSocketFlow ({} messages)>".format(len(self.messages))
@property
def client_key(self):
return websockets.get_client_key(self.handshake_flow.request.headers)
@property
def client_protocol(self):
return websockets.get_protocol(self.handshake_flow.request.headers)
@property
def client_extensions(self):
return websockets.get_extensions(self.handshake_flow.request.headers)
@property
def server_accept(self):
return websockets.get_server_accept(self.handshake_flow.response.headers)
@property
def server_protocol(self):
return websockets.get_protocol(self.handshake_flow.response.headers)
@property
def server_extensions(self):
return websockets.get_extensions(self.handshake_flow.response.headers)
def message_info(self, message: WebSocketMessage) -> str:
return "{client} {direction} WebSocket {type} message {direction} {server}{endpoint}".format(
type=message.type,

View File

@ -26,10 +26,12 @@ class Container(StateObject):
def __init__(self):
self.child = None
self.children = None
self.dictionary = None
_stateobject_attributes = dict(
child=Child,
children=List[Child],
dictionary=dict,
)
@classmethod
@ -62,12 +64,30 @@ def test_container_list():
a.children = [Child(42), Child(44)]
assert a.get_state() == {
"child": None,
"children": [{"x": 42}, {"x": 44}]
"children": [{"x": 42}, {"x": 44}],
"dictionary": None,
}
copy = a.copy()
assert len(copy.children) == 2
assert copy.children is not a.children
assert copy.children[0] is not a.children[0]
assert Container.from_state(a.get_state())
def test_container_dict():
a = Container()
a.dictionary = dict()
a.dictionary['foo'] = 'bar'
a.dictionary['bar'] = Child(44)
assert a.get_state() == {
"child": None,
"children": None,
"dictionary": {'bar': {'x': 44}, 'foo': 'bar'},
}
copy = a.copy()
assert len(copy.dictionary) == 2
assert copy.dictionary is not a.dictionary
assert copy.dictionary['bar'] is not a.dictionary['bar']
def test_too_much_state():

View File

@ -1,5 +1,7 @@
import io
import pytest
from mitmproxy.contrib import tnetstring
from mitmproxy import flowfilter
from mitmproxy.test import tflow
@ -14,8 +16,6 @@ class TestWebSocketFlow:
b = f2.get_state()
del a["id"]
del b["id"]
del a["handshake_flow"]["id"]
del b["handshake_flow"]["id"]
assert a == b
assert not f == f2
assert f is not f2
@ -60,3 +60,14 @@ class TestWebSocketFlow:
assert 'WebSocketFlow' in repr(f)
assert 'binary message: ' in repr(f.messages[0])
assert 'text message: ' in repr(f.messages[1])
def test_serialize(self):
b = io.BytesIO()
d = tflow.twebsocketflow().get_state()
tnetstring.dump(d, b)
assert b.getvalue()
b = io.BytesIO()
d = tflow.twebsocketflow().handshake_flow.get_state()
tnetstring.dump(d, b)
assert b.getvalue()