diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 244df343e..e89f9c5a1 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -157,7 +157,7 @@ class Proxyserver: if isinstance(f, http.HTTPFlow): if f.websocket: event = WebSocketMessageInjected( - f.client_conn if from_client else f.server_conn, + f, websocket.WebSocketMessage( Opcode.TEXT, from_client, message.encode() ) diff --git a/mitmproxy/proxy/events.py b/mitmproxy/proxy/events.py index 4593d9d49..1b6b6ea02 100644 --- a/mitmproxy/proxy/events.py +++ b/mitmproxy/proxy/events.py @@ -8,6 +8,7 @@ import typing import warnings from dataclasses import dataclass, is_dataclass +from mitmproxy import flow from mitmproxy.proxy import commands from mitmproxy.connection import Connection @@ -112,8 +113,9 @@ T = typing.TypeVar('T') @dataclass -class MessageInjected(ConnectionEvent, typing.Generic[T]): +class MessageInjected(Event, typing.Generic[T]): """ The user has injected a custom WebSocket/TCP/... message. """ + flow: flow.Flow message: T diff --git a/mitmproxy/proxy/layers/http/__init__.py b/mitmproxy/proxy/layers/http/__init__.py index a0a7682ea..78606b21e 100644 --- a/mitmproxy/proxy/layers/http/__init__.py +++ b/mitmproxy/proxy/layers/http/__init__.py @@ -589,10 +589,8 @@ class HttpLayer(layer.Layer): yield from self.event_to_child(stream, event) elif isinstance(event, events.MessageInjected): # For injected messages we pass the HTTP stacks entirely and directly address the stream. - conn = self.connections[event.connection] - if isinstance(conn, Http1Server): - stream_id = conn.stream_id - elif isinstance(conn, HttpStream): + conn = self.connections[event.flow.server_conn] + if isinstance(conn, HttpStream): stream_id = conn.stream_id else: # We reach to the end of the connection's child stack to get the HTTP/1 client layer, diff --git a/mitmproxy/proxy/layers/websocket.py b/mitmproxy/proxy/layers/websocket.py index 328ad34ee..f55ab07d6 100644 --- a/mitmproxy/proxy/layers/websocket.py +++ b/mitmproxy/proxy/layers/websocket.py @@ -128,10 +128,16 @@ class WebsocketLayer(layer.Layer): _handle_event = start @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected) - def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: + def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]: assert self.flow.websocket # satisfy type checker - from_client = event.connection == self.context.client + if isinstance(event, events.ConnectionEvent): + from_client = event.connection == self.context.client + elif isinstance(event, WebSocketMessageInjected): + from_client = event.message.from_client + else: + raise AssertionError(f"Unexpected event: {event}") + from_str = 'client' if from_client else 'server' if from_client: src_ws = self.client_ws diff --git a/test/mitmproxy/addons/test_proxyserver.py b/test/mitmproxy/addons/test_proxyserver.py index 049008211..d71b23615 100644 --- a/test/mitmproxy/addons/test_proxyserver.py +++ b/test/mitmproxy/addons/test_proxyserver.py @@ -56,10 +56,10 @@ async def test_start_stop(): proxy_addr = ps.server.sockets[0].getsockname()[:2] reader, writer = await asyncio.open_connection(*proxy_addr) - assert repr(ps) == "ProxyServer(running, 1 active conns)" req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n" writer.write(req.encode()) assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n" + assert repr(ps) == "ProxyServer(running, 1 active conns)" tctx.configure(ps, server=False) await tctx.master.await_log("Stopping server", level="info") diff --git a/test/mitmproxy/proxy/layers/test_websocket.py b/test/mitmproxy/proxy/layers/test_websocket.py index 9b4ceb660..7e3ef5cd8 100644 --- a/test/mitmproxy/proxy/layers/test_websocket.py +++ b/test/mitmproxy/proxy/layers/test_websocket.py @@ -328,7 +328,7 @@ def test_inject_message(ws_testdata): playbook << websocket.WebsocketStartHook(flow) >> reply() - >> WebSocketMessageInjected(tctx.server, WebSocketMessage(Opcode.TEXT, False, b"hello")) + >> WebSocketMessageInjected(flow, WebSocketMessage(Opcode.TEXT, False, b"hello")) << websocket.WebsocketMessageHook(flow) ) assert flow.websocket.messages[-1].content == b"hello"