mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
websocket: add tests
This commit is contained in:
parent
4beb693c9c
commit
5dfc199086
@ -9,7 +9,6 @@ from mitmproxy import http
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.proxy.protocol import base
|
||||
from mitmproxy.proxy.protocol.websocket import WebSocketLayer
|
||||
import mitmproxy.net.http
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.net import websockets
|
||||
|
||||
|
@ -3,11 +3,9 @@ import socket
|
||||
import struct
|
||||
from OpenSSL import SSL
|
||||
|
||||
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy import flow
|
||||
from mitmproxy.proxy.protocol import base
|
||||
from mitmproxy.utils import strutils
|
||||
from mitmproxy.net import tcp
|
||||
from mitmproxy.net import websockets
|
||||
from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage
|
||||
|
@ -47,6 +47,7 @@ class WebSocketBinaryMessage(WebSocketMessage):
|
||||
def __repr__(self):
|
||||
return "binary message: {}".format(strutils.bytes_to_escaped_str(self.content))
|
||||
|
||||
|
||||
class WebSocketTextMessage(WebSocketMessage):
|
||||
|
||||
type = 'text'
|
||||
@ -72,7 +73,6 @@ class WebSocketFlow(flow.Flow):
|
||||
self.server_protocol = websockets.get_protocol(self.handshake_flow.response.headers)
|
||||
self.server_extensions = websockets.get_extensions(self.handshake_flow.response.headers)
|
||||
|
||||
|
||||
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
|
||||
_stateobject_attributes.update(
|
||||
messages=List[WebSocketMessage],
|
||||
|
@ -5,6 +5,8 @@ import traceback
|
||||
|
||||
from mitmproxy import options
|
||||
from mitmproxy import exceptions
|
||||
from mitmproxy.http import HTTPFlow
|
||||
from mitmproxy.websocket import WebSocketFlow
|
||||
from mitmproxy.proxy.config import ProxyConfig
|
||||
|
||||
import mitmproxy.net
|
||||
@ -147,6 +149,10 @@ class TestSimple(_WebSocketTest):
|
||||
wfile.write(bytes(frame))
|
||||
wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(rfile)
|
||||
wfile.write(bytes(frame))
|
||||
wfile.flush()
|
||||
|
||||
def test_simple(self):
|
||||
client = self._setup_connection()
|
||||
|
||||
@ -159,9 +165,31 @@ class TestSimple(_WebSocketTest):
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'client-foobar'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
|
||||
client.wfile.flush()
|
||||
|
||||
frame = websockets.Frame.from_file(client.rfile)
|
||||
assert frame.payload == b'\xde\xad\xbe\xef'
|
||||
|
||||
client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
|
||||
client.wfile.flush()
|
||||
|
||||
assert len(self.master.state.flows) == 2
|
||||
assert isinstance(self.master.state.flows[0], HTTPFlow)
|
||||
assert isinstance(self.master.state.flows[1], WebSocketFlow)
|
||||
assert len(self.master.state.flows[1].messages) == 5
|
||||
assert self.master.state.flows[1].messages[0].content == b'server-foobar'
|
||||
assert self.master.state.flows[1].messages[0].type == 'text'
|
||||
assert self.master.state.flows[1].messages[1].content == b'client-foobar'
|
||||
assert self.master.state.flows[1].messages[1].type == 'text'
|
||||
assert self.master.state.flows[1].messages[2].content == b'client-foobar'
|
||||
assert self.master.state.flows[1].messages[2].type == 'text'
|
||||
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
|
||||
assert self.master.state.flows[1].messages[3].type == 'binary'
|
||||
assert self.master.state.flows[1].messages[4].content == b'\xde\xad\xbe\xef'
|
||||
assert self.master.state.flows[1].messages[4].type == 'binary'
|
||||
assert [m.info for m in self.master.state.flows[1].messages]
|
||||
|
||||
|
||||
class TestSimpleTLS(_WebSocketTest):
|
||||
ssl = True
|
||||
|
@ -26,6 +26,14 @@ class TestState:
|
||||
if f not in self.flows:
|
||||
self.flows.append(f)
|
||||
|
||||
def websocket_start(self, f):
|
||||
if f not in self.flows:
|
||||
self.flows.append(f)
|
||||
|
||||
def tcp_start(self, f):
|
||||
if f not in self.flows:
|
||||
self.flows.append(f)
|
||||
|
||||
# FIXME: compat with old state - remove in favor of len(state.flows)
|
||||
def flow_count(self):
|
||||
return len(self.flows)
|
||||
|
Loading…
Reference in New Issue
Block a user