injection: ConnectionEvent -> Event

This commit is contained in:
Maximilian Hils 2021-03-12 16:42:55 +01:00
parent 07f1bcf543
commit 5921c590e3
6 changed files with 16 additions and 10 deletions

View File

@ -157,7 +157,7 @@ class Proxyserver:
if isinstance(f, http.HTTPFlow): if isinstance(f, http.HTTPFlow):
if f.websocket: if f.websocket:
event = WebSocketMessageInjected( event = WebSocketMessageInjected(
f.client_conn if from_client else f.server_conn, f,
websocket.WebSocketMessage( websocket.WebSocketMessage(
Opcode.TEXT, from_client, message.encode() Opcode.TEXT, from_client, message.encode()
) )

View File

@ -8,6 +8,7 @@ import typing
import warnings import warnings
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from mitmproxy import flow
from mitmproxy.proxy import commands from mitmproxy.proxy import commands
from mitmproxy.connection import Connection from mitmproxy.connection import Connection
@ -112,8 +113,9 @@ T = typing.TypeVar('T')
@dataclass @dataclass
class MessageInjected(ConnectionEvent, typing.Generic[T]): class MessageInjected(Event, typing.Generic[T]):
""" """
The user has injected a custom WebSocket/TCP/... message. The user has injected a custom WebSocket/TCP/... message.
""" """
flow: flow.Flow
message: T message: T

View File

@ -589,10 +589,8 @@ class HttpLayer(layer.Layer):
yield from self.event_to_child(stream, event) yield from self.event_to_child(stream, event)
elif isinstance(event, events.MessageInjected): elif isinstance(event, events.MessageInjected):
# For injected messages we pass the HTTP stacks entirely and directly address the stream. # For injected messages we pass the HTTP stacks entirely and directly address the stream.
conn = self.connections[event.connection] conn = self.connections[event.flow.server_conn]
if isinstance(conn, Http1Server): if isinstance(conn, HttpStream):
stream_id = conn.stream_id
elif isinstance(conn, HttpStream):
stream_id = conn.stream_id stream_id = conn.stream_id
else: else:
# We reach to the end of the connection's child stack to get the HTTP/1 client layer, # We reach to the end of the connection's child stack to get the HTTP/1 client layer,

View File

@ -128,10 +128,16 @@ class WebsocketLayer(layer.Layer):
_handle_event = start _handle_event = start
@expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected) @expect(events.DataReceived, events.ConnectionClosed, WebSocketMessageInjected)
def relay_messages(self, event: events.ConnectionEvent) -> layer.CommandGenerator[None]: def relay_messages(self, event: events.Event) -> layer.CommandGenerator[None]:
assert self.flow.websocket # satisfy type checker assert self.flow.websocket # satisfy type checker
from_client = event.connection == self.context.client if isinstance(event, events.ConnectionEvent):
from_client = event.connection == self.context.client
elif isinstance(event, WebSocketMessageInjected):
from_client = event.message.from_client
else:
raise AssertionError(f"Unexpected event: {event}")
from_str = 'client' if from_client else 'server' from_str = 'client' if from_client else 'server'
if from_client: if from_client:
src_ws = self.client_ws src_ws = self.client_ws

View File

@ -56,10 +56,10 @@ async def test_start_stop():
proxy_addr = ps.server.sockets[0].getsockname()[:2] proxy_addr = ps.server.sockets[0].getsockname()[:2]
reader, writer = await asyncio.open_connection(*proxy_addr) reader, writer = await asyncio.open_connection(*proxy_addr)
assert repr(ps) == "ProxyServer(running, 1 active conns)"
req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n" req = f"GET http://{addr[0]}:{addr[1]}/hello HTTP/1.1\r\n\r\n"
writer.write(req.encode()) writer.write(req.encode())
assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n" assert await reader.readuntil(b"\r\n\r\n") == b"HTTP/1.1 204 No Content\r\n\r\n"
assert repr(ps) == "ProxyServer(running, 1 active conns)"
tctx.configure(ps, server=False) tctx.configure(ps, server=False)
await tctx.master.await_log("Stopping server", level="info") await tctx.master.await_log("Stopping server", level="info")

View File

@ -328,7 +328,7 @@ def test_inject_message(ws_testdata):
playbook playbook
<< websocket.WebsocketStartHook(flow) << websocket.WebsocketStartHook(flow)
>> reply() >> reply()
>> WebSocketMessageInjected(tctx.server, WebSocketMessage(Opcode.TEXT, False, b"hello")) >> WebSocketMessageInjected(flow, WebSocketMessage(Opcode.TEXT, False, b"hello"))
<< websocket.WebsocketMessageHook(flow) << websocket.WebsocketMessageHook(flow)
) )
assert flow.websocket.messages[-1].content == b"hello" assert flow.websocket.messages[-1].content == b"hello"