[sans-io] remove unused code

This commit is contained in:
Maximilian Hils 2020-12-12 14:34:27 +01:00
parent efacbca0ca
commit 8e7cbb3991
4 changed files with 11 additions and 430 deletions

View File

@ -2,32 +2,10 @@ import re
import time import time
from typing import Iterable, List, Optional, Tuple from typing import Iterable, List, Optional, Tuple
from mitmproxy.net import check
from mitmproxy.net.http import headers, request, response, url from mitmproxy.net.http import headers, request, response, url
from mitmproxy.net.http.http1 import read from mitmproxy.net.http.http1 import read
def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]:
"""
Returns (host, port) if hostport is a valid authority-form host specification.
http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
Raises:
ValueError, if the input is malformed
"""
try:
host, port_str = hostport.rsplit(b":", 1)
if host.startswith(b"[") and host.endswith(b"]"):
host = host[1:-1]
port = int(port_str)
if not check.is_valid_host(host) or not check.is_valid_port(port):
raise ValueError
except ValueError:
raise ValueError(f"Invalid host specification: {hostport!r}")
return host, port
def raise_if_http_version_unknown(http_version: bytes) -> None: def raise_if_http_version_unknown(http_version: bytes) -> None:
if not re.match(br"^HTTP/\d\.\d$", http_version): if not re.match(br"^HTTP/\d\.\d$", http_version):
raise ValueError(f"Unknown HTTP version: {http_version!r}") raise ValueError(f"Unknown HTTP version: {http_version!r}")

View File

@ -1,5 +1,6 @@
import uuid import uuid
import warnings import warnings
from abc import ABCMeta
from enum import Flag from enum import Flag
from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING
@ -25,7 +26,7 @@ class ConnectionState(Flag):
Address = Tuple[str, int] Address = Tuple[str, int]
class Connection(serializable.Serializable): class Connection(serializable.Serializable, metaclass=ABCMeta):
""" """
Connections exposed to the layers only contain metadata, no socket objects. Connections exposed to the layers only contain metadata, no socket objects.
""" """
@ -87,7 +88,7 @@ class Connection(serializable.Serializable):
return f"{type(self).__name__}({attrs})" return f"{type(self).__name__}({attrs})"
@property @property
def alpn_proto_negotiated(self) -> Optional[bytes]: def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover
warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning) warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning)
return self.alpn return self.alpn
@ -164,22 +165,22 @@ class Client(Connection):
self.cipher_list = state["cipher_list"] self.cipher_list = state["cipher_list"]
@property @property
def address(self): def address(self): # pragma: no cover
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning) warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
return self.peername return self.peername
@address.setter @address.setter
def address(self, x): def address(self, x): # pragma: no cover
warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning) warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning)
self.peername = x self.peername = x
@property @property
def cipher_name(self) -> Optional[str]: def cipher_name(self) -> Optional[str]: # pragma: no cover
warnings.warn("Client.cipher_name is deprecated, use Client.cipher instead.", PendingDeprecationWarning) warnings.warn("Client.cipher_name is deprecated, use Client.cipher instead.", PendingDeprecationWarning)
return self.cipher return self.cipher
@property @property
def clientcert(self) -> Optional[certs.Cert]: def clientcert(self) -> Optional[certs.Cert]: # pragma: no cover
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning) warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
if self.certificate_list: if self.certificate_list:
return self.certificate_list[0] return self.certificate_list[0]
@ -187,7 +188,7 @@ class Client(Connection):
return None return None
@clientcert.setter @clientcert.setter
def clientcert(self, val): def clientcert(self, val): # pragma: no cover
warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning) warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning)
if val: if val:
self.certificate_list = [val] self.certificate_list = [val]
@ -268,12 +269,12 @@ class Server(Connection):
self.via = state["via2"] self.via = state["via2"]
@property @property
def ip_address(self) -> Optional[Address]: def ip_address(self) -> Optional[Address]: # pragma: no cover
warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning) warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning)
return self.peername return self.peername
@property @property
def cert(self) -> Optional[certs.Cert]: def cert(self) -> Optional[certs.Cert]: # pragma: no cover
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning) warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
if self.certificate_list: if self.certificate_list:
return self.certificate_list[0] return self.certificate_list[0]
@ -281,7 +282,7 @@ class Server(Connection):
return None return None
@cert.setter @cert.setter
def cert(self, val): def cert(self, val): # pragma: no cover
warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning) warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning)
if val: if val:
self.certificate_list = [val] self.certificate_list = [val]

View File

@ -1,201 +0,0 @@
import wsproto
from wsproto import events as wsevents
from wsproto import ConnectionType, WSConnection
from wsproto.extensions import PerMessageDeflate
from mitmproxy import websocket, http, flow
from mitmproxy.proxy2 import events, commands, layer
from mitmproxy.proxy2.context import Context
from mitmproxy.proxy2.utils import expect
class WebsocketLayer(layer.Layer):
"""
WebSocket layer that intercepts and relays messages.
"""
context: Context = None
flow: websocket.WebSocketFlow
def __init__(self, context: Context, handshake_flow: http.HTTPFlow):
super().__init__(context)
self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow)
self.flow.metadata['websocket_handshake'] = handshake_flow.id
self.handshake_flow = handshake_flow
self.handshake_flow.metadata['websocket_flow'] = self.flow.id
self.client_frame_buffer = []
self.server_frame_buffer = []
assert context.server.connected
@expect(events.Start)
def start(self, _) -> layer.CommandGenerator[None]:
extensions = []
if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers:
if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']:
extensions = [PerMessageDeflate()]
self.client_conn = WSConnection(ConnectionType.SERVER,
extensions=extensions)
self.server_conn = WSConnection(ConnectionType.CLIENT,
host=self.handshake_flow.request.host,
resource=self.handshake_flow.request.path,
extensions=extensions)
if extensions:
self.client_conn.extensions[0].finalize(self.client_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
self.server_conn.extensions[0].finalize(self.server_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions'])
data = self.server_conn.bytes_to_send()
self.client_conn.receive_bytes(data)
event = next(self.client_conn.events())
assert isinstance(event, wsevents.ConnectionRequested)
self.client_conn.accept(event)
self.server_conn.receive_bytes(self.client_conn.bytes_to_send())
assert isinstance(next(self.server_conn.events()), wsevents.ConnectionEstablished)
yield commands.Hook("websocket_start", self.flow)
self._handle_event = self.process_data
_handle_event = start
@expect(events.DataReceived, events.ConnectionClosed)
def process_data(self, event: events.Event) -> layer.CommandGenerator[None]:
if isinstance(event, events.DataReceived):
from_client = event.connection == self.context.client
if from_client:
source = self.client_conn
other = self.server_conn
fb = self.client_frame_buffer
send_to = self.context.server
else:
source = self.server_conn
other = self.client_conn
fb = self.server_frame_buffer
send_to = self.context.client
source.receive_bytes(event.data)
closing = False
received_ws_events = list(source.events())
for ws_event in received_ws_events:
if isinstance(ws_event, wsevents.DataReceived):
yield from self._handle_data_received(ws_event, source, other, send_to, from_client, fb)
elif isinstance(ws_event, wsevents.PingReceived):
yield from self._handle_ping_received(ws_event, source, other, send_to, from_client)
elif isinstance(ws_event, wsevents.PongReceived):
yield from self._handle_pong_received(ws_event, source, other, send_to, from_client)
elif isinstance(ws_event, wsevents.ConnectionClosed):
yield from self._handle_connection_closed(ws_event, source, other, send_to, from_client)
closing = True
else:
yield commands.Log(
"WebSocket unhandled event: from {}: {}".format("client" if from_client else "server", ws_event)
)
if closing:
yield commands.Hook("websocket_end", self.flow)
if not from_client:
yield commands.CloseConnection(self.context.client)
self._handle_event = self.done
# TODO: elif isinstance(event, events.InjectMessage):
# TODO: come up with a solid API to inject messages
elif isinstance(event, events.ConnectionClosed):
yield commands.Log("Connection closed abnormally", "error")
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}".format(
"client" if event.connection == self.context.client else "server"
)
)
if event.connection == self.context.server:
yield commands.CloseConnection(self.context.client)
yield commands.Hook("websocket_error", self.flow)
yield commands.Hook("websocket_end", self.flow)
self._handle_event = self.done
@expect(events.DataReceived, events.ConnectionClosed)
def done(self, _):
yield from ()
def _handle_data_received(self, ws_event, source, other, send_to, from_client, fb):
fb.append(ws_event.data)
if ws_event.message_finished:
original_chunk_sizes = [len(f) for f in fb]
if isinstance(ws_event, wsevents.TextReceived):
message_type = wsproto.frame_protocol.Opcode.TEXT
payload = ''.join(fb)
else:
message_type = wsproto.frame_protocol.Opcode.BINARY
payload = b''.join(fb)
fb.clear()
websocket_message = websocket.WebSocketMessage(message_type, from_client, payload)
length = len(websocket_message.content)
self.flow.messages.append(websocket_message)
yield commands.Hook("websocket_message", self.flow)
if not self.flow.stream and not websocket_message.killed:
def get_chunk(payload):
if len(payload) == length:
# message has the same length, we can reuse the same sizes
pos = 0
for s in original_chunk_sizes:
yield (payload[pos:pos + s], True if pos + s == length else False)
pos += s
else:
# just re-chunk everything into 4kB frames
# header len = 4 bytes without masking key and 8 bytes with masking key
chunk_size = 4088 if from_client else 4092
chunks = range(0, len(payload), chunk_size)
for i in chunks:
yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False)
for chunk, final in get_chunk(websocket_message.content):
other.send_data(chunk, final)
yield commands.SendData(send_to, other.bytes_to_send())
if self.flow.stream:
other.send_data(ws_event.data, ws_event.message_finished)
yield commands.SendData(send_to, other.bytes_to_send())
def _handle_ping_received(self, ws_event, source, other, send_to, from_client):
yield commands.Log(
"WebSocket PING received from {}: {}".format("client" if from_client else "server",
ws_event.payload.decode() or "<no payload>")
)
# We do not forward the PING payload, as this might leak information!
other.ping()
yield commands.SendData(send_to, other.bytes_to_send())
# PING is automatically answered with a PONG by wsproto
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
def _handle_pong_received(self, ws_event, source, other, send_to, from_client):
yield commands.Log(
"WebSocket PONG received from {}: {}".format("client" if from_client else "server",
ws_event.payload.decode() or "<no payload>")
)
def _handle_connection_closed(self, ws_event, source, other, send_to, from_client):
self.flow.close_sender = "client" if from_client else "server"
self.flow.close_code = ws_event.code
self.flow.close_reason = ws_event.reason
other.close(ws_event.code, ws_event.reason)
yield commands.SendData(send_to, other.bytes_to_send())
# FIXME: Wait for other end to actually send the closing frame
# FIXME: https://github.com/python-hyper/wsproto/pull/50
yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send())
if ws_event.code != 1000:
self.flow.error = flow.Error(
"WebSocket connection closed unexpectedly by {}: {}".format(
"client" if from_client else "server",
ws_event.reason
)
)
yield commands.Hook("websocket_error", self.flow)

View File

@ -1,197 +0,0 @@
import struct
from unittest import mock
import pytest
from mitmproxy.proxy2.layers.old import websocket
from mitmproxy.net.websockets import Frame, OPCODE
from mitmproxy.proxy2 import commands, events
from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.test import tflow
from .. import tutils
@pytest.fixture
def ws_playbook(tctx):
tctx.server.state = ConnectionState.OPEN
playbook = tutils.Playbook(
websocket.WebsocketLayer(
tctx,
tflow.twebsocketflow().handshake_flow
),
ignore_log=False,
)
with mock.patch("os.urandom") as m:
m.return_value = b"\x10\x11\x12\x13"
yield playbook
def test_simple(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, payload=b'client-foobar')),
bytes(Frame(fin=1, opcode=OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, payload=b'fail')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.client, frames[2])
<< commands.SendData(tctx.server, frames[2])
<< commands.SendData(tctx.client, frames[3])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[4])
<< None
)
assert len(f().messages) == 2
def test_server_close(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
<< commands.CloseConnection(tctx.client)
)
def test_connection_closed(tctx, ws_playbook):
f = tutils.Placeholder()
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.ConnectionClosed(tctx.server)
<< commands.Log("error", "Connection closed abnormally")
<< commands.CloseConnection(tctx.client)
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
assert f().error
def test_connection_failed(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
b'Not a valid frame',
bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.SendData(tctx.server, frames[1])
<< commands.SendData(tctx.client, frames[2])
<< commands.Hook("websocket_error", f)
>> events.HookReply(-1)
<< commands.Hook("websocket_end", f)
>> events.HookReply(-1)
)
def test_ping_pong(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PING)),
bytes(Frame(fin=1, opcode=OPCODE.PONG)),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Log("info", "WebSocket PING received from client: <no payload>")
<< commands.SendData(tctx.server, frames[0])
<< commands.SendData(tctx.client, frames[1])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Log("info", "WebSocket PONG received from server: <no payload>")
)
def test_ping_pong_hidden_payload(tctx, ws_playbook):
f = tutils.Placeholder()
frames = [
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'foobar')),
bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'foobar')),
bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.server, frames[0])
<< commands.Log("info", "WebSocket PING received from server: foobar")
<< commands.SendData(tctx.client, frames[1])
<< commands.SendData(tctx.server, frames[2])
>> events.DataReceived(tctx.client, frames[3])
<< commands.Log("info", "WebSocket PONG received from client: <no payload>")
)
def test_extension(tctx, ws_playbook):
f = tutils.Placeholder()
ws_playbook.layer.handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;"
ws_playbook.layer.handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;"
frames = [
bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')),
bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x00\x11\x00\x00')),
]
assert (
ws_playbook
<< commands.Hook("websocket_start", f)
>> events.HookReply(-1)
>> events.DataReceived(tctx.client, frames[0])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.server, frames[0])
>> events.DataReceived(tctx.server, frames[1])
<< commands.Hook("websocket_message", f)
>> events.HookReply(-1)
<< commands.SendData(tctx.client, frames[2])
)
assert len(f().messages) == 2
assert f().messages[0].content == "Hello"
assert f().messages[1].content == "Hello"