mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] remove unused code
This commit is contained in:
parent
efacbca0ca
commit
8e7cbb3991
@ -2,32 +2,10 @@ import re
|
|||||||
import time
|
import time
|
||||||
from typing import Iterable, List, Optional, Tuple
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
from mitmproxy.net import check
|
|
||||||
from mitmproxy.net.http import headers, request, response, url
|
from mitmproxy.net.http import headers, request, response, url
|
||||||
from mitmproxy.net.http.http1 import read
|
from mitmproxy.net.http.http1 import read
|
||||||
|
|
||||||
|
|
||||||
def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
|
|
||||||
"""
|
|
||||||
Returns (host, port) if hostport is a valid authority-form host specification.
|
|
||||||
http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError, if the input is malformed
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
host, port_str = hostport.rsplit(b":", 1)
|
|
||||||
if host.startswith(b"[") and host.endswith(b"]"):
|
|
||||||
host = host[1:-1]
|
|
||||||
port = int(port_str)
|
|
||||||
if not check.is_valid_host(host) or not check.is_valid_port(port):
|
|
||||||
raise ValueError
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(f"Invalid host specification: {hostport!r}")
|
|
||||||
|
|
||||||
return host, port
|
|
||||||
|
|
||||||
|
|
||||||
def raise_if_http_version_unknown(http_version: bytes) -> None:
|
def raise_if_http_version_unknown(http_version: bytes) -> None:
|
||||||
if not re.match(br"^HTTP/\d\.\d$", http_version):
|
if not re.match(br"^HTTP/\d\.\d$", http_version):
|
||||||
raise ValueError(f"Unknown HTTP version: {http_version!r}")
|
raise ValueError(f"Unknown HTTP version: {http_version!r}")
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
|
from abc import ABCMeta
|
||||||
from enum import Flag
|
from enum import Flag
|
||||||
from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING
|
from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ class ConnectionState(Flag):
|
|||||||
Address = Tuple[str, int]
|
Address = Tuple[str, int]
|
||||||
|
|
||||||
|
|
||||||
class Connection(serializable.Serializable):
|
class Connection(serializable.Serializable, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Connections exposed to the layers only contain metadata, no socket objects.
|
Connections exposed to the layers only contain metadata, no socket objects.
|
||||||
"""
|
"""
|
||||||
@ -87,7 +88,7 @@ class Connection(serializable.Serializable):
|
|||||||
return f"{type(self).__name__}({attrs})"
|
return f"{type(self).__name__}({attrs})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alpn_proto_negotiated(self) -> Optional[bytes]:
|
def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
|
||||||
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning)
|
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning)
|
||||||
return self.alpn
|
return self.alpn
|
||||||
|
|
||||||
@ -164,22 +165,22 @@ class Client(Connection):
|
|||||||
self.cipher_list = state["cipher_list"]
|
self.cipher_list = state["cipher_list"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def address(self):
|
def address(self): # pragma: no cover
|
||||||
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
|
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
|
||||||
return self.peername
|
return self.peername
|
||||||
|
|
||||||
@address.setter
|
@address.setter
|
||||||
def address(self, x):
|
def address(self, x): # pragma: no cover
|
||||||
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
|
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
|
||||||
self.peername = x
|
self.peername = x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cipher_name(self) -> Optional[str]:
|
def cipher_name(self) -> Optional[str]: # pragma: no cover
|
||||||
warnings.warn("Client.cipher_name is deprecated, use Client.cipher instead.", PendingDeprecationWarning)
|
warnings.warn("Client.cipher_name is deprecated, use Client.cipher instead.", PendingDeprecationWarning)
|
||||||
return self.cipher
|
return self.cipher
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clientcert(self) -> Optional[certs.Cert]:
|
def clientcert(self) -> Optional[certs.Cert]: # pragma: no cover
|
||||||
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
|
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
|
||||||
if self.certificate_list:
|
if self.certificate_list:
|
||||||
return self.certificate_list[0]
|
return self.certificate_list[0]
|
||||||
@ -187,7 +188,7 @@ class Client(Connection):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@clientcert.setter
|
@clientcert.setter
|
||||||
def clientcert(self, val):
|
def clientcert(self, val): # pragma: no cover
|
||||||
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
|
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
|
||||||
if val:
|
if val:
|
||||||
self.certificate_list = [val]
|
self.certificate_list = [val]
|
||||||
@ -268,12 +269,12 @@ class Server(Connection):
|
|||||||
self.via = state["via2"]
|
self.via = state["via2"]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ip_address(self) -> Optional[Address]:
|
def ip_address(self) -> Optional[Address]: # pragma: no cover
|
||||||
warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning)
|
warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning)
|
||||||
return self.peername
|
return self.peername
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cert(self) -> Optional[certs.Cert]:
|
def cert(self) -> Optional[certs.Cert]: # pragma: no cover
|
||||||
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
|
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
|
||||||
if self.certificate_list:
|
if self.certificate_list:
|
||||||
return self.certificate_list[0]
|
return self.certificate_list[0]
|
||||||
@ -281,7 +282,7 @@ class Server(Connection):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@cert.setter
|
@cert.setter
|
||||||
def cert(self, val):
|
def cert(self, val): # pragma: no cover
|
||||||
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
|
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
|
||||||
if val:
|
if val:
|
||||||
self.certificate_list = [val]
|
self.certificate_list = [val]
|
||||||
|
@ -1,201 +0,0 @@
|
|||||||
import wsproto
|
|
||||||
from wsproto import events as wsevents
|
|
||||||
from wsproto import ConnectionType, WSConnection
|
|
||||||
from wsproto.extensions import PerMessageDeflate
|
|
||||||
|
|
||||||
from mitmproxy import websocket, http, flow
|
|
||||||
from mitmproxy.proxy2 import events, commands, layer
|
|
||||||
from mitmproxy.proxy2.context import Context
|
|
||||||
from mitmproxy.proxy2.utils import expect
|
|
||||||
|
|
||||||
|
|
||||||
class WebsocketLayer(layer.Layer):
|
|
||||||
"""
|
|
||||||
WebSocket layer that intercepts and relays messages.
|
|
||||||
"""
|
|
||||||
context: Context = None
|
|
||||||
flow: websocket.WebSocketFlow
|
|
||||||
|
|
||||||
def __init__(self, context: Context, handshake_flow: http.HTTPFlow):
|
|
||||||
super().__init__(context)
|
|
||||||
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
|
|
||||||
self.flow.metadata['websocket_handshake'] = handshake_flow.id
|
|
||||||
self.handshake_flow = handshake_flow
|
|
||||||
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
|
|
||||||
self.client_frame_buffer = []
|
|
||||||
self.server_frame_buffer = []
|
|
||||||
|
|
||||||
assert context.server.connected
|
|
||||||
|
|
||||||
@expect(events.Start)
|
|
||||||
def start(self, _) -> layer.CommandGenerator[None]:
|
|
||||||
extensions = []
|
|
||||||
if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers:
|
|
||||||
if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']:
|
|
||||||
extensions = [PerMessageDeflate()]
|
|
||||||
self.client_conn = WSConnection(ConnectionType.SERVER,
|
|
||||||
extensions=extensions)
|
|
||||||
self.server_conn = WSConnection(ConnectionType.CLIENT,
|
|
||||||
host=self.handshake_flow.request.host,
|
|
||||||
resource=self.handshake_flow.request.path,
|
|
||||||
extensions=extensions)
|
|
||||||
if extensions:
|
|
||||||
self.client_conn.extensions[0].finalize(self.client_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
|
|
||||||
self.server_conn.extensions[0].finalize(self.server_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
|
|
||||||
|
|
||||||
data = self.server_conn.bytes_to_send()
|
|
||||||
self.client_conn.receive_bytes(data)
|
|
||||||
|
|
||||||
event = next(self.client_conn.events())
|
|
||||||
assert isinstance(event, wsevents.ConnectionRequested)
|
|
||||||
|
|
||||||
self.client_conn.accept(event)
|
|
||||||
self.server_conn.receive_bytes(self.client_conn.bytes_to_send())
|
|
||||||
assert isinstance(next(self.server_conn.events()), wsevents.ConnectionEstablished)
|
|
||||||
|
|
||||||
yield commands.Hook("websocket_start", self.flow)
|
|
||||||
self._handle_event = self.process_data
|
|
||||||
|
|
||||||
_handle_event = start
|
|
||||||
|
|
||||||
@expect(events.DataReceived, events.ConnectionClosed)
|
|
||||||
def process_data(self, event: events.Event) -> layer.CommandGenerator[None]:
|
|
||||||
if isinstance(event, events.DataReceived):
|
|
||||||
from_client = event.connection == self.context.client
|
|
||||||
if from_client:
|
|
||||||
source = self.client_conn
|
|
||||||
other = self.server_conn
|
|
||||||
fb = self.client_frame_buffer
|
|
||||||
send_to = self.context.server
|
|
||||||
else:
|
|
||||||
source = self.server_conn
|
|
||||||
other = self.client_conn
|
|
||||||
fb = self.server_frame_buffer
|
|
||||||
send_to = self.context.client
|
|
||||||
|
|
||||||
source.receive_bytes(event.data)
|
|
||||||
|
|
||||||
closing = False
|
|
||||||
received_ws_events = list(source.events())
|
|
||||||
for ws_event in received_ws_events:
|
|
||||||
if isinstance(ws_event, wsevents.DataReceived):
|
|
||||||
yield from self._handle_data_received(ws_event, source, other, send_to, from_client, fb)
|
|
||||||
elif isinstance(ws_event, wsevents.PingReceived):
|
|
||||||
yield from self._handle_ping_received(ws_event, source, other, send_to, from_client)
|
|
||||||
elif isinstance(ws_event, wsevents.PongReceived):
|
|
||||||
yield from self._handle_pong_received(ws_event, source, other, send_to, from_client)
|
|
||||||
elif isinstance(ws_event, wsevents.ConnectionClosed):
|
|
||||||
yield from self._handle_connection_closed(ws_event, source, other, send_to, from_client)
|
|
||||||
closing = True
|
|
||||||
else:
|
|
||||||
yield commands.Log(
|
|
||||||
"WebSocket unhandled event: from {}: {}".format("client" if from_client else "server", ws_event)
|
|
||||||
)
|
|
||||||
|
|
||||||
if closing:
|
|
||||||
yield commands.Hook("websocket_end", self.flow)
|
|
||||||
if not from_client:
|
|
||||||
yield commands.CloseConnection(self.context.client)
|
|
||||||
self._handle_event = self.done
|
|
||||||
|
|
||||||
# TODO: elif isinstance(event, events.InjectMessage):
|
|
||||||
# TODO: come up with a solid API to inject messages
|
|
||||||
|
|
||||||
elif isinstance(event, events.ConnectionClosed):
|
|
||||||
yield commands.Log("Connection closed abnormally", "error")
|
|
||||||
self.flow.error = flow.Error(
|
|
||||||
"WebSocket connection closed unexpectedly by {}".format(
|
|
||||||
"client" if event.connection == self.context.client else "server"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if event.connection == self.context.server:
|
|
||||||
yield commands.CloseConnection(self.context.client)
|
|
||||||
yield commands.Hook("websocket_error", self.flow)
|
|
||||||
yield commands.Hook("websocket_end", self.flow)
|
|
||||||
self._handle_event = self.done
|
|
||||||
|
|
||||||
@expect(events.DataReceived, events.ConnectionClosed)
|
|
||||||
def done(self, _):
|
|
||||||
yield from ()
|
|
||||||
|
|
||||||
def _handle_data_received(self, ws_event, source, other, send_to, from_client, fb):
|
|
||||||
fb.append(ws_event.data)
|
|
||||||
|
|
||||||
if ws_event.message_finished:
|
|
||||||
original_chunk_sizes = [len(f) for f in fb]
|
|
||||||
|
|
||||||
if isinstance(ws_event, wsevents.TextReceived):
|
|
||||||
message_type = wsproto.frame_protocol.Opcode.TEXT
|
|
||||||
payload = ''.join(fb)
|
|
||||||
else:
|
|
||||||
message_type = wsproto.frame_protocol.Opcode.BINARY
|
|
||||||
payload = b''.join(fb)
|
|
||||||
|
|
||||||
fb.clear()
|
|
||||||
|
|
||||||
websocket_message = websocket.WebSocketMessage(message_type, from_client, payload)
|
|
||||||
length = len(websocket_message.content)
|
|
||||||
self.flow.messages.append(websocket_message)
|
|
||||||
yield commands.Hook("websocket_message", self.flow)
|
|
||||||
|
|
||||||
if not self.flow.stream and not websocket_message.killed:
|
|
||||||
def get_chunk(payload):
|
|
||||||
if len(payload) == length:
|
|
||||||
# message has the same length, we can reuse the same sizes
|
|
||||||
pos = 0
|
|
||||||
for s in original_chunk_sizes:
|
|
||||||
yield (payload[pos:pos + s], True if pos + s == length else False)
|
|
||||||
pos += s
|
|
||||||
else:
|
|
||||||
# just re-chunk everything into 4kB frames
|
|
||||||
# header len = 4 bytes without masking key and 8 bytes with masking key
|
|
||||||
chunk_size = 4088 if from_client else 4092
|
|
||||||
chunks = range(0, len(payload), chunk_size)
|
|
||||||
for i in chunks:
|
|
||||||
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
|
|
||||||
|
|
||||||
for chunk, final in get_chunk(websocket_message.content):
|
|
||||||
other.send_data(chunk, final)
|
|
||||||
yield commands.SendData(send_to, other.bytes_to_send())
|
|
||||||
|
|
||||||
if self.flow.stream:
|
|
||||||
other.send_data(ws_event.data, ws_event.message_finished)
|
|
||||||
yield commands.SendData(send_to, other.bytes_to_send())
|
|
||||||
|
|
||||||
def _handle_ping_received(self, ws_event, source, other, send_to, from_client):
|
|
||||||
yield commands.Log(
|
|
||||||
"WebSocket PING received from {}: {}".format("client" if from_client else "server",
|
|
||||||
ws_event.payload.decode() or "<no payload>")
|
|
||||||
)
|
|
||||||
# We do not forward the PING payload, as this might leak information!
|
|
||||||
other.ping()
|
|
||||||
yield commands.SendData(send_to, other.bytes_to_send())
|
|
||||||
# PING is automatically answered with a PONG by wsproto
|
|
||||||
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
|
|
||||||
|
|
||||||
def _handle_pong_received(self, ws_event, source, other, send_to, from_client):
|
|
||||||
yield commands.Log(
|
|
||||||
"WebSocket PONG received from {}: {}".format("client" if from_client else "server",
|
|
||||||
ws_event.payload.decode() or "<no payload>")
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_connection_closed(self, ws_event, source, other, send_to, from_client):
|
|
||||||
self.flow.close_sender = "client" if from_client else "server"
|
|
||||||
self.flow.close_code = ws_event.code
|
|
||||||
self.flow.close_reason = ws_event.reason
|
|
||||||
|
|
||||||
other.close(ws_event.code, ws_event.reason)
|
|
||||||
yield commands.SendData(send_to, other.bytes_to_send())
|
|
||||||
|
|
||||||
# FIXME: Wait for other end to actually send the closing frame
|
|
||||||
# FIXME: https://github.com/python-hyper/wsproto/pull/50
|
|
||||||
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
|
|
||||||
|
|
||||||
if ws_event.code != 1000:
|
|
||||||
self.flow.error = flow.Error(
|
|
||||||
"WebSocket connection closed unexpectedly by {}: {}".format(
|
|
||||||
"client" if from_client else "server",
|
|
||||||
ws_event.reason
|
|
||||||
)
|
|
||||||
)
|
|
||||||
yield commands.Hook("websocket_error", self.flow)
|
|
@ -1,197 +0,0 @@
|
|||||||
import struct
|
|
||||||
from unittest import mock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from mitmproxy.proxy2.layers.old import websocket
|
|
||||||
|
|
||||||
from mitmproxy.net.websockets import Frame, OPCODE
|
|
||||||
from mitmproxy.proxy2 import commands, events
|
|
||||||
from mitmproxy.proxy2.context import ConnectionState
|
|
||||||
from mitmproxy.test import tflow
|
|
||||||
from .. import tutils
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def ws_playbook(tctx):
|
|
||||||
tctx.server.state = ConnectionState.OPEN
|
|
||||||
playbook = tutils.Playbook(
|
|
||||||
websocket.WebsocketLayer(
|
|
||||||
tctx,
|
|
||||||
tflow.twebsocketflow().handshake_flow
|
|
||||||
),
|
|
||||||
ignore_log=False,
|
|
||||||
)
|
|
||||||
with mock.patch("os.urandom") as m:
|
|
||||||
m.return_value = b"\x10\x11\x12\x13"
|
|
||||||
yield playbook
|
|
||||||
|
|
||||||
|
|
||||||
def test_simple(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, payload=b'client-foobar')),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')),
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.TEXT, payload=b'fail')),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.client, frames[0])
|
|
||||||
<< commands.Hook("websocket_message", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.SendData(tctx.server, frames[0])
|
|
||||||
>> events.DataReceived(tctx.server, frames[1])
|
|
||||||
<< commands.Hook("websocket_message", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.SendData(tctx.client, frames[1])
|
|
||||||
>> events.DataReceived(tctx.client, frames[2])
|
|
||||||
<< commands.SendData(tctx.server, frames[2])
|
|
||||||
<< commands.SendData(tctx.client, frames[3])
|
|
||||||
<< commands.Hook("websocket_end", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.server, frames[4])
|
|
||||||
<< None
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(f().messages) == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_server_close(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.server, frames[0])
|
|
||||||
<< commands.SendData(tctx.client, frames[0])
|
|
||||||
<< commands.SendData(tctx.server, frames[1])
|
|
||||||
<< commands.Hook("websocket_end", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.CloseConnection(tctx.client)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_connection_closed(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.ConnectionClosed(tctx.server)
|
|
||||||
<< commands.Log("error", "Connection closed abnormally")
|
|
||||||
<< commands.CloseConnection(tctx.client)
|
|
||||||
<< commands.Hook("websocket_error", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.Hook("websocket_end", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert f().error
|
|
||||||
|
|
||||||
|
|
||||||
def test_connection_failed(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
b'Not a valid frame',
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.client, frames[0])
|
|
||||||
<< commands.SendData(tctx.server, frames[1])
|
|
||||||
<< commands.SendData(tctx.client, frames[2])
|
|
||||||
<< commands.Hook("websocket_error", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.Hook("websocket_end", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ping_pong(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PING)),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.PONG)),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.client, frames[0])
|
|
||||||
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
|
|
||||||
<< commands.SendData(tctx.server, frames[0])
|
|
||||||
<< commands.SendData(tctx.client, frames[1])
|
|
||||||
>> events.DataReceived(tctx.server, frames[1])
|
|
||||||
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_ping_pong_hidden_payload(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'foobar')),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'')),
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'foobar')),
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'')),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.server, frames[0])
|
|
||||||
<< commands.Log("info", "WebSocket PING received from server: foobar")
|
|
||||||
<< commands.SendData(tctx.client, frames[1])
|
|
||||||
<< commands.SendData(tctx.server, frames[2])
|
|
||||||
>> events.DataReceived(tctx.client, frames[3])
|
|
||||||
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_extension(tctx, ws_playbook):
|
|
||||||
f = tutils.Placeholder()
|
|
||||||
|
|
||||||
ws_playbook.layer.handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
|
|
||||||
ws_playbook.layer.handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
|
|
||||||
|
|
||||||
frames = [
|
|
||||||
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
|
|
||||||
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x00\x11\x00\x00')),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert (
|
|
||||||
ws_playbook
|
|
||||||
<< commands.Hook("websocket_start", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
>> events.DataReceived(tctx.client, frames[0])
|
|
||||||
<< commands.Hook("websocket_message", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.SendData(tctx.server, frames[0])
|
|
||||||
>> events.DataReceived(tctx.server, frames[1])
|
|
||||||
<< commands.Hook("websocket_message", f)
|
|
||||||
>> events.HookReply(-1)
|
|
||||||
<< commands.SendData(tctx.client, frames[2])
|
|
||||||
)
|
|
||||||
assert len(f().messages) == 2
|
|
||||||
assert f().messages[0].content == "Hello"
|
|
||||||
assert f().messages[1].content == "Hello"
|
|
Loading…
Reference in New Issue
Block a user