diff --git a/mitmproxy/net/http/http1/read_sansio.py b/mitmproxy/net/http/http1/read_sansio.py index a4cc5b671..fe9253382 100644 --- a/mitmproxy/net/http/http1/read_sansio.py +++ b/mitmproxy/net/http/http1/read_sansio.py @@ -2,32 +2,10 @@ import re import time from typing import Iterable, List, Optional, Tuple -from mitmproxy.net import check from mitmproxy.net.http import headers, request, response, url from mitmproxy.net.http.http1 import read -def _parse_authority_form(hostport: bytes) -> Tuple[bytes, int]: - """ - Returns (host, port) if hostport is a valid authority-form host specification. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - - Raises: - ValueError, if the input is malformed - """ - try: - host, port_str = hostport.rsplit(b":", 1) - if host.startswith(b"[") and host.endswith(b"]"): - host = host[1:-1] - port = int(port_str) - if not check.is_valid_host(host) or not check.is_valid_port(port): - raise ValueError - except ValueError: - raise ValueError(f"Invalid host specification: {hostport!r}") - - return host, port - - def raise_if_http_version_unknown(http_version: bytes) -> None: if not re.match(br"^HTTP/\d\.\d$", http_version): raise ValueError(f"Unknown HTTP version: {http_version!r}") diff --git a/mitmproxy/proxy2/context.py b/mitmproxy/proxy2/context.py index dcb4ce73b..c12793924 100644 --- a/mitmproxy/proxy2/context.py +++ b/mitmproxy/proxy2/context.py @@ -1,5 +1,6 @@ import uuid import warnings +from abc import ABCMeta from enum import Flag from typing import List, Literal, Optional, Sequence, Tuple, Union, TYPE_CHECKING @@ -25,7 +26,7 @@ class ConnectionState(Flag): Address = Tuple[str, int] -class Connection(serializable.Serializable): +class Connection(serializable.Serializable, metaclass=ABCMeta): """ Connections exposed to the layers only contain metadata, no socket objects. """ @@ -87,7 +88,7 @@ class Connection(serializable.Serializable): return f"{type(self).__name__}({attrs})" @property - def alpn_proto_negotiated(self) -> Optional[bytes]: + def alpn_proto_negotiated(self) -> Optional[bytes]: # pragma: no cover warnings.warn("Server.alpn_proto_negotiated is deprecated, use Server.alpn instead.", PendingDeprecationWarning) return self.alpn @@ -164,22 +165,22 @@ class Client(Connection): self.cipher_list = state["cipher_list"] @property - def address(self): + def address(self): # pragma: no cover warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning) return self.peername @address.setter - def address(self, x): + def address(self, x): # pragma: no cover warnings.warn("Client.address is deprecated, use Client.peername instead.", PendingDeprecationWarning) self.peername = x @property - def cipher_name(self) -> Optional[str]: + def cipher_name(self) -> Optional[str]: # pragma: no cover warnings.warn("Client.cipher_name is deprecated, use Client.cipher instead.", PendingDeprecationWarning) return self.cipher @property - def clientcert(self) -> Optional[certs.Cert]: + def clientcert(self) -> Optional[certs.Cert]: # pragma: no cover warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning) if self.certificate_list: return self.certificate_list[0] @@ -187,7 +188,7 @@ class Client(Connection): return None @clientcert.setter - def clientcert(self, val): + def clientcert(self, val): # pragma: no cover warnings.warn("Client.clientcert is deprecated, use Client.certificate_list instead.", PendingDeprecationWarning) if val: self.certificate_list = [val] @@ -268,12 +269,12 @@ class Server(Connection): self.via = state["via2"] @property - def ip_address(self) -> Optional[Address]: + def ip_address(self) -> Optional[Address]: # pragma: no cover warnings.warn("Server.ip_address is deprecated, use Server.peername instead.", PendingDeprecationWarning) return self.peername @property - def cert(self) -> Optional[certs.Cert]: + def cert(self) -> Optional[certs.Cert]: # pragma: no cover warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning) if self.certificate_list: return self.certificate_list[0] @@ -281,7 +282,7 @@ class Server(Connection): return None @cert.setter - def cert(self, val): + def cert(self, val): # pragma: no cover warnings.warn("Server.cert is deprecated, use Server.certificate_list instead.", PendingDeprecationWarning) if val: self.certificate_list = [val] diff --git a/mitmproxy/proxy2/layers/outdated/websocket.py b/mitmproxy/proxy2/layers/outdated/websocket.py deleted file mode 100644 index e9f506147..000000000 --- a/mitmproxy/proxy2/layers/outdated/websocket.py +++ /dev/null @@ -1,201 +0,0 @@ -import wsproto -from wsproto import events as wsevents -from wsproto import ConnectionType, WSConnection -from wsproto.extensions import PerMessageDeflate - -from mitmproxy import websocket, http, flow -from mitmproxy.proxy2 import events, commands, layer -from mitmproxy.proxy2.context import Context -from mitmproxy.proxy2.utils import expect - - -class WebsocketLayer(layer.Layer): - """ - WebSocket layer that intercepts and relays messages. - """ - context: Context = None - flow: websocket.WebSocketFlow - - def __init__(self, context: Context, handshake_flow: http.HTTPFlow): - super().__init__(context) - self.flow = websocket.WebSocketFlow(context.client, context.server, handshake_flow) - self.flow.metadata['websocket_handshake'] = handshake_flow.id - self.handshake_flow = handshake_flow - self.handshake_flow.metadata['websocket_flow'] = self.flow.id - self.client_frame_buffer = [] - self.server_frame_buffer = [] - - assert context.server.connected - - @expect(events.Start) - def start(self, _) -> layer.CommandGenerator[None]: - extensions = [] - if 'Sec-WebSocket-Extensions' in self.handshake_flow.response.headers: - if PerMessageDeflate.name in self.handshake_flow.response.headers['Sec-WebSocket-Extensions']: - extensions = [PerMessageDeflate()] - self.client_conn = WSConnection(ConnectionType.SERVER, - extensions=extensions) - self.server_conn = WSConnection(ConnectionType.CLIENT, - host=self.handshake_flow.request.host, - resource=self.handshake_flow.request.path, - extensions=extensions) - if extensions: - self.client_conn.extensions[0].finalize(self.client_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions']) - self.server_conn.extensions[0].finalize(self.server_conn, self.handshake_flow.response.headers['Sec-WebSocket-Extensions']) - - data = self.server_conn.bytes_to_send() - self.client_conn.receive_bytes(data) - - event = next(self.client_conn.events()) - assert isinstance(event, wsevents.ConnectionRequested) - - self.client_conn.accept(event) - self.server_conn.receive_bytes(self.client_conn.bytes_to_send()) - assert isinstance(next(self.server_conn.events()), wsevents.ConnectionEstablished) - - yield commands.Hook("websocket_start", self.flow) - self._handle_event = self.process_data - - _handle_event = start - - @expect(events.DataReceived, events.ConnectionClosed) - def process_data(self, event: events.Event) -> layer.CommandGenerator[None]: - if isinstance(event, events.DataReceived): - from_client = event.connection == self.context.client - if from_client: - source = self.client_conn - other = self.server_conn - fb = self.client_frame_buffer - send_to = self.context.server - else: - source = self.server_conn - other = self.client_conn - fb = self.server_frame_buffer - send_to = self.context.client - - source.receive_bytes(event.data) - - closing = False - received_ws_events = list(source.events()) - for ws_event in received_ws_events: - if isinstance(ws_event, wsevents.DataReceived): - yield from self._handle_data_received(ws_event, source, other, send_to, from_client, fb) - elif isinstance(ws_event, wsevents.PingReceived): - yield from self._handle_ping_received(ws_event, source, other, send_to, from_client) - elif isinstance(ws_event, wsevents.PongReceived): - yield from self._handle_pong_received(ws_event, source, other, send_to, from_client) - elif isinstance(ws_event, wsevents.ConnectionClosed): - yield from self._handle_connection_closed(ws_event, source, other, send_to, from_client) - closing = True - else: - yield commands.Log( - "WebSocket unhandled event: from {}: {}".format("client" if from_client else "server", ws_event) - ) - - if closing: - yield commands.Hook("websocket_end", self.flow) - if not from_client: - yield commands.CloseConnection(self.context.client) - self._handle_event = self.done - - # TODO: elif isinstance(event, events.InjectMessage): - # TODO: come up with a solid API to inject messages - - elif isinstance(event, events.ConnectionClosed): - yield commands.Log("Connection closed abnormally", "error") - self.flow.error = flow.Error( - "WebSocket connection closed unexpectedly by {}".format( - "client" if event.connection == self.context.client else "server" - ) - ) - if event.connection == self.context.server: - yield commands.CloseConnection(self.context.client) - yield commands.Hook("websocket_error", self.flow) - yield commands.Hook("websocket_end", self.flow) - self._handle_event = self.done - - @expect(events.DataReceived, events.ConnectionClosed) - def done(self, _): - yield from () - - def _handle_data_received(self, ws_event, source, other, send_to, from_client, fb): - fb.append(ws_event.data) - - if ws_event.message_finished: - original_chunk_sizes = [len(f) for f in fb] - - if isinstance(ws_event, wsevents.TextReceived): - message_type = wsproto.frame_protocol.Opcode.TEXT - payload = ''.join(fb) - else: - message_type = wsproto.frame_protocol.Opcode.BINARY - payload = b''.join(fb) - - fb.clear() - - websocket_message = websocket.WebSocketMessage(message_type, from_client, payload) - length = len(websocket_message.content) - self.flow.messages.append(websocket_message) - yield commands.Hook("websocket_message", self.flow) - - if not self.flow.stream and not websocket_message.killed: - def get_chunk(payload): - if len(payload) == length: - # message has the same length, we can reuse the same sizes - pos = 0 - for s in original_chunk_sizes: - yield (payload[pos:pos + s], True if pos + s == length else False) - pos += s - else: - # just re-chunk everything into 4kB frames - # header len = 4 bytes without masking key and 8 bytes with masking key - chunk_size = 4088 if from_client else 4092 - chunks = range(0, len(payload), chunk_size) - for i in chunks: - yield (payload[i:i + chunk_size], True if i + chunk_size >= len(payload) else False) - - for chunk, final in get_chunk(websocket_message.content): - other.send_data(chunk, final) - yield commands.SendData(send_to, other.bytes_to_send()) - - if self.flow.stream: - other.send_data(ws_event.data, ws_event.message_finished) - yield commands.SendData(send_to, other.bytes_to_send()) - - def _handle_ping_received(self, ws_event, source, other, send_to, from_client): - yield commands.Log( - "WebSocket PING received from {}: {}".format("client" if from_client else "server", - ws_event.payload.decode() or "") - ) - # We do not forward the PING payload, as this might leak information! - other.ping() - yield commands.SendData(send_to, other.bytes_to_send()) - # PING is automatically answered with a PONG by wsproto - yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send()) - - def _handle_pong_received(self, ws_event, source, other, send_to, from_client): - yield commands.Log( - "WebSocket PONG received from {}: {}".format("client" if from_client else "server", - ws_event.payload.decode() or "") - ) - - def _handle_connection_closed(self, ws_event, source, other, send_to, from_client): - self.flow.close_sender = "client" if from_client else "server" - self.flow.close_code = ws_event.code - self.flow.close_reason = ws_event.reason - - other.close(ws_event.code, ws_event.reason) - yield commands.SendData(send_to, other.bytes_to_send()) - - # FIXME: Wait for other end to actually send the closing frame - # FIXME: https://github.com/python-hyper/wsproto/pull/50 - yield commands.SendData(self.context.client if from_client else self.context.server, source.bytes_to_send()) - - if ws_event.code != 1000: - self.flow.error = flow.Error( - "WebSocket connection closed unexpectedly by {}: {}".format( - "client" if from_client else "server", - ws_event.reason - ) - ) - yield commands.Hook("websocket_error", self.flow) diff --git a/test/mitmproxy/proxy2/layers/_test_websocket.py b/test/mitmproxy/proxy2/layers/_test_websocket.py deleted file mode 100644 index ffc5f1bdf..000000000 --- a/test/mitmproxy/proxy2/layers/_test_websocket.py +++ /dev/null @@ -1,197 +0,0 @@ -import struct -from unittest import mock - -import pytest -from mitmproxy.proxy2.layers.old import websocket - -from mitmproxy.net.websockets import Frame, OPCODE -from mitmproxy.proxy2 import commands, events -from mitmproxy.proxy2.context import ConnectionState -from mitmproxy.test import tflow -from .. import tutils - - -@pytest.fixture -def ws_playbook(tctx): - tctx.server.state = ConnectionState.OPEN - playbook = tutils.Playbook( - websocket.WebsocketLayer( - tctx, - tflow.twebsocketflow().handshake_flow - ), - ignore_log=False, - ) - with mock.patch("os.urandom") as m: - m.return_value = b"\x10\x11\x12\x13" - yield playbook - - -def test_simple(tctx, ws_playbook): - f = tutils.Placeholder() - - frames = [ - bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, payload=b'client-foobar')), - bytes(Frame(fin=1, opcode=OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')), - bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))), - bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))), - bytes(Frame(fin=1, opcode=OPCODE.TEXT, payload=b'fail')), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.client, frames[0]) - << commands.Hook("websocket_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.server, frames[0]) - >> events.DataReceived(tctx.server, frames[1]) - << commands.Hook("websocket_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.client, frames[1]) - >> events.DataReceived(tctx.client, frames[2]) - << commands.SendData(tctx.server, frames[2]) - << commands.SendData(tctx.client, frames[3]) - << commands.Hook("websocket_end", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.server, frames[4]) - << None - ) - - assert len(f().messages) == 2 - - -def test_server_close(tctx, ws_playbook): - f = tutils.Placeholder() - - frames = [ - bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))), - bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1000))), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.server, frames[0]) - << commands.SendData(tctx.client, frames[0]) - << commands.SendData(tctx.server, frames[1]) - << commands.Hook("websocket_end", f) - >> events.HookReply(-1) - << commands.CloseConnection(tctx.client) - ) - - -def test_connection_closed(tctx, ws_playbook): - f = tutils.Placeholder() - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.ConnectionClosed(tctx.server) - << commands.Log("error", "Connection closed abnormally") - << commands.CloseConnection(tctx.client) - << commands.Hook("websocket_error", f) - >> events.HookReply(-1) - << commands.Hook("websocket_end", f) - >> events.HookReply(-1) - ) - - assert f().error - - -def test_connection_failed(tctx, ws_playbook): - f = tutils.Placeholder() - - frames = [ - b'Not a valid frame', - bytes(Frame(fin=1, mask=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')), - bytes(Frame(fin=1, opcode=OPCODE.CLOSE, payload=struct.pack('>H', 1002) + b'Invalid opcode 0xe')), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.client, frames[0]) - << commands.SendData(tctx.server, frames[1]) - << commands.SendData(tctx.client, frames[2]) - << commands.Hook("websocket_error", f) - >> events.HookReply(-1) - << commands.Hook("websocket_end", f) - >> events.HookReply(-1) - ) - - -def test_ping_pong(tctx, ws_playbook): - f = tutils.Placeholder() - - frames = [ - bytes(Frame(fin=1, mask=1, opcode=OPCODE.PING)), - bytes(Frame(fin=1, opcode=OPCODE.PONG)), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.client, frames[0]) - << commands.Log("info", "WebSocket PING received from client: ") - << commands.SendData(tctx.server, frames[0]) - << commands.SendData(tctx.client, frames[1]) - >> events.DataReceived(tctx.server, frames[1]) - << commands.Log("info", "WebSocket PONG received from server: ") - ) - - -def test_ping_pong_hidden_payload(tctx, ws_playbook): - f = tutils.Placeholder() - - frames = [ - bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'foobar')), - bytes(Frame(fin=1, opcode=OPCODE.PING, payload=b'')), - bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'foobar')), - bytes(Frame(fin=1, mask=1, opcode=OPCODE.PONG, payload=b'')), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.server, frames[0]) - << commands.Log("info", "WebSocket PING received from server: foobar") - << commands.SendData(tctx.client, frames[1]) - << commands.SendData(tctx.server, frames[2]) - >> events.DataReceived(tctx.client, frames[3]) - << commands.Log("info", "WebSocket PONG received from client: ") - ) - - -def test_extension(tctx, ws_playbook): - f = tutils.Placeholder() - - ws_playbook.layer.handshake_flow.request.headers["sec-websocket-extensions"] = "permessage-deflate;" - ws_playbook.layer.handshake_flow.response.headers["sec-websocket-extensions"] = "permessage-deflate;" - - frames = [ - bytes(Frame(fin=1, mask=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')), - bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x48\xcd\xc9\xc9\x07\x00')), - bytes(Frame(fin=1, opcode=OPCODE.TEXT, rsv1=True, payload=b'\xf2\x00\x11\x00\x00')), - ] - - assert ( - ws_playbook - << commands.Hook("websocket_start", f) - >> events.HookReply(-1) - >> events.DataReceived(tctx.client, frames[0]) - << commands.Hook("websocket_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.server, frames[0]) - >> events.DataReceived(tctx.server, frames[1]) - << commands.Hook("websocket_message", f) - >> events.HookReply(-1) - << commands.SendData(tctx.client, frames[2]) - ) - assert len(f().messages) == 2 - assert f().messages[0].content == "Hello" - assert f().messages[1].content == "Hello"