[sans-io] h2++, typeize Placeholder

This commit is contained in:
Maximilian Hils 2020-06-26 01:01:00 +02:00
parent 742efae419
commit 0646a4f1ba
14 changed files with 687 additions and 159 deletions

View File

@ -283,9 +283,7 @@ class HttpStream(layer.Layer):
self.flow.error = flow.Error(event.message) self.flow.error = flow.Error(event.message)
yield HttpErrorHook(self.flow) yield HttpErrorHook(self.flow)
if isinstance(event, RequestProtocolError): if isinstance(event, ResponseProtocolError):
yield SendHttp(event, self.context.server)
else:
yield SendHttp(event, self.context.client) yield SendHttp(event, self.context.client)
def make_server_connection(self) -> layer.CommandGenerator[bool]: def make_server_connection(self) -> layer.CommandGenerator[bool]:

View File

@ -1,41 +1,50 @@
import time import time
from typing import ClassVar
import h2.connection import h2.connection
import h2.config import h2.config
import h2.events import h2.events
import h2.exceptions import h2.exceptions
import h2.settings import h2.settings
import h2.errors
import h2.utilities
from mitmproxy import http from mitmproxy import http
from mitmproxy.net import http as net_http from mitmproxy.net import http as net_http
from mitmproxy.net.http import http2 from mitmproxy.net.http import http2
from . import RequestEndOfMessage, RequestHeaders, ResponseData, ResponseEndOfMessage, ResponseHeaders from . import RequestData, RequestEndOfMessage, RequestHeaders, RequestProtocolError, ResponseData, \
ResponseEndOfMessage, ResponseHeaders, ResponseProtocolError
from ._base import HttpConnection, HttpEvent, ReceiveHttp from ._base import HttpConnection, HttpEvent, ReceiveHttp
from ._http_h2 import BufferedH2Connection, H2ConnectionLogger from ._http_h2 import BufferedH2Connection, H2ConnectionLogger
from ...commands import SendData from ...commands import CloseConnection, Log, SendData
from ...context import Context from ...context import Connection, Context
from ...events import DataReceived, Event, Start from ...events import ConnectionClosed, DataReceived, Event, Start
from ...layer import CommandGenerator from ...layer import CommandGenerator
h2_events_we_dont_care_about = (
h2.events.RemoteSettingsChanged, class Http2Connection(HttpConnection):
h2.events.SettingsAcknowledged h2_conf: ClassVar[h2.config.H2Configuration]
) h2_conn: BufferedH2Connection
def __init__(self, context: Context, conn: Connection):
super().__init__(context, conn)
self.h2_conn = BufferedH2Connection(self.h2_conf)
class Http2Server(HttpConnection): class Http2Server(Http2Connection):
def __init__(self, context: Context):
super().__init__(context, context.client)
# noinspection PyTypeChecker # noinspection PyTypeChecker
self.h2_conf = h2.config.H2Configuration( h2_conf = h2.config.H2Configuration(
client_side=False, client_side=False,
header_encoding=False, header_encoding=False,
validate_outbound_headers=False, validate_outbound_headers=False,
validate_inbound_headers=False, validate_inbound_headers=False,
normalize_inbound_headers=False,
normalize_outbound_headers=False,
logger=H2ConnectionLogger("server") # type: ignore logger=H2ConnectionLogger("server") # type: ignore
) )
self.h2_conn = BufferedH2Connection(self.h2_conf)
def __init__(self, context: Context):
super().__init__(context, context.client)
def _handle_event(self, event: Event) -> CommandGenerator[None]: def _handle_event(self, event: Event) -> CommandGenerator[None]:
if isinstance(event, Start): if isinstance(event, Start):
@ -44,16 +53,30 @@ class Http2Server(HttpConnection):
elif isinstance(event, HttpEvent): elif isinstance(event, HttpEvent):
if isinstance(event, ResponseHeaders): if isinstance(event, ResponseHeaders):
headers = event.response.headers.copy() headers = (
headers.insert(0, ":status", str(event.response.status_code)) (b":status", b"%d" % event.response.status_code),
*event.response.headers.fields
)
if event.response.data.http_version != b"HTTP/2":
# HTTP/1 servers commonly send capitalized headers (Content-Length vs content-length),
# which isn't valid HTTP/2. As such we normalize.
headers = h2.utilities.normalize_outbound_headers(
headers,
h2.utilities.HeaderValidationFlags(False, False, True, False)
)
# make sure that this is not just an iterator but an iterable,
# otherwise hyper-h2 will silently drop headers.
headers = list(headers)
self.h2_conn.send_headers( self.h2_conn.send_headers(
event.stream_id, event.stream_id,
headers.fields, headers,
) )
elif isinstance(event, ResponseData): elif isinstance(event, ResponseData):
self.h2_conn.send_data(event.stream_id, event.data) self.h2_conn.send_data(event.stream_id, event.data)
elif isinstance(event, ResponseEndOfMessage): elif isinstance(event, ResponseEndOfMessage):
self.h2_conn.send_data(event.stream_id, b"", end_stream=True) self.h2_conn.send_data(event.stream_id, b"", end_stream=True)
elif isinstance(event, ResponseProtocolError):
self.h2_conn.reset_stream(event.stream_id, h2.errors.ErrorCodes.PROTOCOL_ERROR)
else: else:
raise NotImplementedError(f"Unknown HTTP event: {event}") raise NotImplementedError(f"Unknown HTTP event: {event}")
yield SendData(self.conn, self.h2_conn.data_to_send()) yield SendData(self.conn, self.h2_conn.data_to_send())
@ -65,9 +88,7 @@ class Http2Server(HttpConnection):
events = [e] events = [e]
for h2_event in events: for h2_event in events:
if isinstance(h2_event, h2_events_we_dont_care_about): if isinstance(h2_event, h2.events.RequestReceived):
pass
elif isinstance(h2_event, h2.events.RequestReceived):
headers = net_http.Headers([(k, v) for k, v in h2_event.headers]) headers = net_http.Headers([(k, v) for k, v in h2_event.headers])
first_line_format, method, scheme, host, port, path = http2.parse_headers(headers) first_line_format, method, scheme, host, port, path = http2.parse_headers(headers)
headers["Host"] = headers.pop(":authority") # FIXME: temporary workaround headers["Host"] = headers.pop(":authority") # FIXME: temporary workaround
@ -78,18 +99,47 @@ class Http2Server(HttpConnection):
host, host,
port, port,
path, path,
b"HTTP/1.1", b"HTTP/1.1", # FIXME: Figure out how to smooth h2 <-> h1.
headers, headers,
None, None,
timestamp_start=time.time(), timestamp_start=time.time(),
) )
yield ReceiveHttp(RequestHeaders(h2_event.stream_id, request)) yield ReceiveHttp(RequestHeaders(h2_event.stream_id, request))
elif isinstance(h2_event, h2.events.DataReceived):
yield ReceiveHttp(RequestData(h2_event.stream_id, h2_event.data))
self.h2_conn.acknowledge_received_data(len(h2_event.data), h2_event.stream_id)
elif isinstance(h2_event, h2.events.StreamEnded): elif isinstance(h2_event, h2.events.StreamEnded):
yield ReceiveHttp(RequestEndOfMessage(h2_event.stream_id)) yield ReceiveHttp(RequestEndOfMessage(h2_event.stream_id))
elif isinstance(h2_event, h2.exceptions.ProtocolError):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 protocol error: {h2_event}")
return
elif isinstance(h2_event, h2.events.ConnectionTerminated):
yield CloseConnection(self.conn)
yield from self._notify_close(f"HTTP/2 connection closed: {h2_event!r}")
return
elif isinstance(h2_event, h2.events.StreamReset):
yield ReceiveHttp(RequestProtocolError(h2_event.stream_id, "EOF"))
elif isinstance(h2_event, h2.events.RemoteSettingsChanged):
pass
elif isinstance(h2_event, h2.events.SettingsAcknowledged):
pass
else: else:
raise NotImplementedError(f"Unknown event: {h2_event!r}") raise NotImplementedError(f"Unknown event: {h2_event!r}")
yield SendData(self.conn, self.h2_conn.data_to_send()) data_to_send = self.h2_conn.data_to_send()
if data_to_send:
yield SendData(self.conn, data_to_send)
elif isinstance(event, ConnectionClosed):
yield CloseConnection(self.conn)
yield from self._notify_close("peer closed connection")
else:
raise NotImplementedError(f"Unexpected event: {event!r}")
def _notify_close(self, err: str) -> CommandGenerator[None]:
for stream_id, stream in self.h2_conn.streams.items():
if stream.open:
yield ReceiveHttp(RequestProtocolError(stream_id, err))
class Http2Client: class Http2Client:

View File

@ -28,12 +28,13 @@ class H2ConnectionLogger(h2.config.DummyLogger):
class SendH2Data(NamedTuple): class SendH2Data(NamedTuple):
data: bytes data: bytes
end_stream: bool end_stream: bool
pad_length: Optional[int]
class BufferedH2Connection(h2.connection.H2Connection): class BufferedH2Connection(h2.connection.H2Connection):
""" """
This class wrap's hyper-h2's H2Connection and adds internal send buffers. This class wrap's hyper-h2's H2Connection and adds internal send buffers.
To simplify implementation, padding is unsupported.
""" """
stream_buffers: DefaultDict[int, Deque[SendH2Data]] stream_buffers: DefaultDict[int, Deque[SendH2Data]]
@ -46,37 +47,32 @@ class BufferedH2Connection(h2.connection.H2Connection):
stream_id: int, stream_id: int,
data: bytes, data: bytes,
end_stream: bool = False, end_stream: bool = False,
pad_length: Optional[int] = None pad_length: None = None
) -> None: ) -> None:
""" """
Send data on a given stream. Send data on a given stream.
In contrast to plain h2, this method will not emit In contrast to plain hyper-h2, this method will not raise if the data cannot be sent immediately.
either FlowControlError or FrameTooLargeError. Data is split up and buffered internally.
Instead, data is buffered and split up.
""" """
frame_size = len(data) frame_size = len(data)
if pad_length is not None: assert pad_length is None
frame_size += pad_length + 1
while frame_size > self.max_outbound_frame_size: while frame_size > self.max_outbound_frame_size:
chunk_1 = data[:self.max_outbound_frame_size] chunk_data = data[:self.max_outbound_frame_size]
pad_1 = max(0, self.max_outbound_frame_size - len(data)) self.send_data(stream_id, chunk_data, end_stream=False)
self.send_data(stream_id, chunk_1, end_stream=False, pad_length=pad_1 or None)
data = data[self.max_outbound_frame_size:] data = data[self.max_outbound_frame_size:]
if pad_length: frame_size -= len(chunk_data)
pad_length -= pad_1
frame_size -= len(chunk_1) + pad_1
available_window = self.local_flow_control_window(stream_id) available_window = self.local_flow_control_window(stream_id)
if frame_size > available_window: if frame_size <= available_window:
self.stream_buffers[stream_id].append( super().send_data(stream_id, data, end_stream)
SendH2Data(data, end_stream, pad_length)
)
else: else:
# We can't send right now, so we buffer. # We can't send right now, so we buffer.
super().send_data(stream_id, data, end_stream, pad_length) self.stream_buffers[stream_id].append(
SendH2Data(data, end_stream)
)
def receive_data(self, data: bytes): def receive_data(self, data: bytes):
events = super().receive_data(data) events = super().receive_data(data)
@ -112,16 +108,14 @@ class BufferedH2Connection(h2.connection.H2Connection):
SendH2Data( SendH2Data(
data=chunk.data[available_window:], data=chunk.data[available_window:],
end_stream=chunk.end_stream, end_stream=chunk.end_stream,
pad_length=chunk.pad_length,
) )
) )
chunk = SendH2Data( chunk = SendH2Data(
data=chunk.data[:available_window], data=chunk.data[:available_window],
end_stream=False, end_stream=False,
pad_length=None,
) )
super().send_data(stream_id, data=chunk.data, end_stream=chunk.end_stream, pad_length=chunk.pad_length) self.send_data(stream_id, data=chunk.data, end_stream=chunk.end_stream)
available_window -= len(chunk.data) available_window -= len(chunk.data)
if not self.stream_buffers[stream_id]: if not self.stream_buffers[stream_id]:

View File

@ -6,7 +6,7 @@ from mitmproxy.proxy2 import context
@pytest.fixture @pytest.fixture
def tctx(): def tctx() -> context.Context:
opts = options.Options() opts = options.Options()
Proxyserver().load(opts) Proxyserver().load(opts)
return context.Context( return context.Context(

View File

@ -0,0 +1,179 @@
# This file has been copied from https://github.com/python-hyper/hyper-h2/blob/master/test/helpers.py,
# MIT License
# -*- coding: utf-8 -*-
"""
helpers
~~~~~~~
This module contains helpers for the h2 tests.
"""
from hyperframe.frame import (
HeadersFrame, DataFrame, SettingsFrame, WindowUpdateFrame, PingFrame,
GoAwayFrame, RstStreamFrame, PushPromiseFrame, PriorityFrame,
ContinuationFrame, AltSvcFrame
)
from hpack.hpack import Encoder
SAMPLE_SETTINGS = {
SettingsFrame.HEADER_TABLE_SIZE: 4096,
SettingsFrame.ENABLE_PUSH: 1,
SettingsFrame.MAX_CONCURRENT_STREAMS: 2,
}
class FrameFactory(object):
"""
A class containing lots of helper methods and state to build frames. This
allows test cases to easily build correct HTTP/2 frames to feed to
hyper-h2.
"""
def __init__(self):
self.encoder = Encoder()
def refresh_encoder(self):
self.encoder = Encoder()
def preamble(self):
return b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
def build_headers_frame(self,
headers,
flags=[],
stream_id=1,
**priority_kwargs):
"""
Builds a single valid headers frame out of the contained headers.
"""
f = HeadersFrame(stream_id)
f.data = self.encoder.encode(headers)
f.flags.add('END_HEADERS')
for flag in flags:
f.flags.add(flag)
for k, v in priority_kwargs.items():
setattr(f, k, v)
return f
def build_continuation_frame(self, header_block, flags=[], stream_id=1):
"""
Builds a single continuation frame out of the binary header block.
"""
f = ContinuationFrame(stream_id)
f.data = header_block
f.flags = set(flags)
return f
def build_data_frame(self, data, flags=None, stream_id=1, padding_len=0):
"""
Builds a single data frame out of a chunk of data.
"""
flags = set(flags) if flags is not None else set()
f = DataFrame(stream_id)
f.data = data
f.flags = flags
if padding_len:
flags.add('PADDED')
f.pad_length = padding_len
return f
def build_settings_frame(self, settings, ack=False):
"""
Builds a single settings frame.
"""
f = SettingsFrame(0)
if ack:
f.flags.add('ACK')
f.settings = settings
return f
def build_window_update_frame(self, stream_id, increment):
"""
Builds a single WindowUpdate frame.
"""
f = WindowUpdateFrame(stream_id)
f.window_increment = increment
return f
def build_ping_frame(self, ping_data, flags=None):
"""
Builds a single Ping frame.
"""
f = PingFrame(0)
f.opaque_data = ping_data
if flags:
f.flags = set(flags)
return f
def build_goaway_frame(self,
last_stream_id,
error_code=0,
additional_data=b''):
"""
Builds a single GOAWAY frame.
"""
f = GoAwayFrame(0)
f.error_code = error_code
f.last_stream_id = last_stream_id
f.additional_data = additional_data
return f
def build_rst_stream_frame(self, stream_id, error_code=0):
"""
Builds a single RST_STREAM frame.
"""
f = RstStreamFrame(stream_id)
f.error_code = error_code
return f
def build_push_promise_frame(self,
stream_id,
promised_stream_id,
headers,
flags=[]):
"""
Builds a single PUSH_PROMISE frame.
"""
f = PushPromiseFrame(stream_id)
f.promised_stream_id = promised_stream_id
f.data = self.encoder.encode(headers)
f.flags = set(flags)
f.flags.add('END_HEADERS')
return f
def build_priority_frame(self,
stream_id,
weight,
depends_on=0,
exclusive=False):
"""
Builds a single priority frame.
"""
f = PriorityFrame(stream_id)
f.depends_on = depends_on
f.stream_weight = weight
f.exclusive = exclusive
return f
def build_alt_svc_frame(self, stream_id, origin, field):
"""
Builds a single ALTSVC frame.
"""
f = AltSvcFrame(stream_id)
f.origin = origin
f.field = field
return f
def change_table_size(self, new_size):
"""
Causes the encoder to send a dynamic size update in the next header
block it sends.
"""
self.encoder.header_table_size = new_size

View File

@ -1,9 +1,12 @@
from typing import Callable
import pytest import pytest
from mitmproxy.http import HTTPFlow, HTTPResponse from mitmproxy.http import HTTPFlow, HTTPResponse
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import layer from mitmproxy.proxy2 import layer
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import Server
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import TCPLayer, http, tls from mitmproxy.proxy2.layers import TCPLayer, http, tls
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_next_layer
@ -11,8 +14,8 @@ from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply, reply_nex
def test_http_proxy(tctx): def test_http_proxy(tctx):
"""Test a simple HTTP GET / request""" """Test a simple HTTP GET / request"""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular)) Playbook(http.HttpLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/foo?hello=1 HTTP/1.1\r\nHost: example.com\r\n\r\n")
@ -37,8 +40,8 @@ def test_http_proxy(tctx):
@pytest.mark.parametrize("strategy", ["lazy", "eager"]) @pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_https_proxy(strategy, tctx): def test_https_proxy(strategy, tctx):
"""Test a CONNECT request, followed by a HTTP GET /""" """Test a CONNECT request, followed by a HTTP GET /"""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular)) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
@ -79,8 +82,8 @@ def test_https_proxy(strategy, tctx):
@pytest.mark.parametrize("strategy", ["lazy", "eager"]) @pytest.mark.parametrize("strategy", ["lazy", "eager"])
def test_redirect(strategy, https_server, https_client, tctx, monkeypatch): def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
"""Test redirects between http:// and https:// in regular proxy mode.""" """Test redirects between http:// and https:// in regular proxy mode."""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) p = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
@ -121,8 +124,8 @@ def test_redirect(strategy, https_server, https_client, tctx, monkeypatch):
def test_multiple_server_connections(tctx): def test_multiple_server_connections(tctx):
"""Test multiple requests being rewritten to different targets.""" """Test multiple requests being rewritten to different targets."""
server1 = Placeholder() server1 = Placeholder(Server)
server2 = Placeholder() server2 = Placeholder(Server)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def redirect(to: str): def redirect(to: str):
@ -174,7 +177,7 @@ def test_http_reply_from_proxy(tctx):
def test_response_until_eof(tctx): def test_response_until_eof(tctx):
"""Test scenario where the server response body is terminated by EOF.""" """Test scenario where the server response body is terminated by EOF."""
server = Placeholder() server = Placeholder(Server)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
@ -192,14 +195,14 @@ def test_disconnect_while_intercept(tctx):
"""Test a server disconnect while a request is intercepted.""" """Test a server disconnect while a request is intercepted."""
tctx.options.connection_strategy = "eager" tctx.options.connection_strategy = "eager"
server1 = Placeholder() server1 = Placeholder(Server)
server2 = Placeholder() server2 = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n") >> DataReceived(tctx.client, b"CONNECT example.com:80 HTTP/1.1\r\n\r\n")
<< http.HttpConnectHook(Placeholder()) << http.HttpConnectHook(Placeholder(HTTPFlow))
>> reply() >> reply()
<< OpenConnection(server1) << OpenConnection(server1)
>> reply(None) >> reply(None)
@ -222,8 +225,8 @@ def test_disconnect_while_intercept(tctx):
def test_response_streaming(tctx): def test_response_streaming(tctx):
"""Test HTTP response streaming""" """Test HTTP response streaming"""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
def enable_streaming(flow: HTTPFlow): def enable_streaming(flow: HTTPFlow):
flow.response.stream = lambda x: x.upper() flow.response.stream = lambda x: x.upper()
@ -250,8 +253,8 @@ def test_request_streaming(tctx, response):
This is a bit more contrived as we may receive server data while we are still sending the request. This is a bit more contrived as we may receive server data while we are still sending the request.
""" """
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
def enable_streaming(flow: HTTPFlow): def enable_streaming(flow: HTTPFlow):
@ -301,7 +304,7 @@ def test_request_streaming(tctx, response):
<< CloseConnection(tctx.client) << CloseConnection(tctx.client)
) )
elif response == "early kill": elif response == "early kill":
err = Placeholder() err = Placeholder(bytes)
assert ( assert (
playbook playbook
>> ConnectionClosed(server) >> ConnectionClosed(server)
@ -318,9 +321,9 @@ def test_request_streaming(tctx, response):
def test_server_unreachable(tctx, connect): def test_server_unreachable(tctx, connect):
"""Test the scenario where the target server is unreachable.""" """Test the scenario where the target server is unreachable."""
tctx.options.connection_strategy = "eager" tctx.options.connection_strategy = "eager"
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
err = Placeholder() err = Placeholder(bytes)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
if connect: if connect:
playbook >> DataReceived(tctx.client, b"CONNECT example.com:443 HTTP/1.1\r\n\r\n") playbook >> DataReceived(tctx.client, b"CONNECT example.com:443 HTTP/1.1\r\n\r\n")
@ -352,9 +355,9 @@ def test_server_unreachable(tctx, connect):
]) ])
def test_server_aborts(tctx, data): def test_server_aborts(tctx, data):
"""Test the scenario where the server doesn't serve a response""" """Test the scenario where the server doesn't serve a response"""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
err = Placeholder() err = Placeholder(bytes)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
assert ( assert (
playbook playbook
@ -383,9 +386,9 @@ def test_server_aborts(tctx, data):
@pytest.mark.parametrize("strategy", ["eager", "lazy"]) @pytest.mark.parametrize("strategy", ["eager", "lazy"])
def test_upstream_proxy(tctx, redirect, scheme, strategy): def test_upstream_proxy(tctx, redirect, scheme, strategy):
"""Test that an upstream HTTP proxy is used.""" """Test that an upstream HTTP proxy is used."""
server = Placeholder() server = Placeholder(Server)
server2 = Placeholder() server2 = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
tctx.options.mode = "upstream:http://proxy:8080" tctx.options.mode = "upstream:http://proxy:8080"
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.upstream), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.upstream), hooks=False)
@ -471,8 +474,8 @@ def test_upstream_proxy(tctx, redirect, scheme, strategy):
@pytest.mark.parametrize("strategy", ["eager", "lazy"]) @pytest.mark.parametrize("strategy", ["eager", "lazy"])
def test_http_proxy_tcp(tctx, mode, strategy): def test_http_proxy_tcp(tctx, mode, strategy):
"""Test TCP over HTTP CONNECT.""" """Test TCP over HTTP CONNECT."""
server = Placeholder() server = Placeholder(Server)
flow = Placeholder() flow = Placeholder(HTTPFlow)
if mode == "upstream": if mode == "upstream":
tctx.options.mode = "upstream:http://proxy:8080" tctx.options.mode = "upstream:http://proxy:8080"
@ -526,7 +529,7 @@ def test_http_proxy_tcp(tctx, mode, strategy):
@pytest.mark.parametrize("strategy", ["eager", "lazy"]) @pytest.mark.parametrize("strategy", ["eager", "lazy"])
def test_proxy_chain(tctx, strategy): def test_proxy_chain(tctx, strategy):
server = Placeholder() server = Placeholder(Server)
tctx.options.connection_strategy = strategy tctx.options.connection_strategy = strategy
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
@ -551,7 +554,7 @@ def test_proxy_chain(tctx, strategy):
def test_no_headers(tctx): def test_no_headers(tctx):
"""Test that we can correctly reassemble requests/responses with no headers.""" """Test that we can correctly reassemble requests/responses with no headers."""
server = Placeholder() server = Placeholder(Server)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\n\r\n") >> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\n\r\n")
@ -566,7 +569,7 @@ def test_no_headers(tctx):
def test_http_proxy_relative_request(tctx): def test_http_proxy_relative_request(tctx):
"""Test handling of a relative-form "GET /" in regular proxy mode.""" """Test handling of a relative-form "GET /" in regular proxy mode."""
server = Placeholder() server = Placeholder(Server)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") >> DataReceived(tctx.client, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
@ -593,7 +596,7 @@ def test_http_proxy_relative_request_no_host_header(tctx):
def test_http_expect(tctx): def test_http_expect(tctx):
"""Test handling of a 'Expect: 100-continue' header.""" """Test handling of a 'Expect: 100-continue' header."""
server = Placeholder() server = Placeholder(Server)
assert ( assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
>> DataReceived(tctx.client, b"PUT http://example.com/large-file HTTP/1.1\r\n" >> DataReceived(tctx.client, b"PUT http://example.com/large-file HTTP/1.1\r\n"
@ -612,3 +615,47 @@ def test_http_expect(tctx):
<< SendData(tctx.client, b"HTTP/1.1 201 Created\r\nContent-Length: 0\r\n\r\n") << SendData(tctx.client, b"HTTP/1.1 201 Created\r\nContent-Length: 0\r\n\r\n")
) )
assert server().address == ("example.com", 80) assert server().address == ("example.com", 80)
@pytest.mark.parametrize("stream", [True, False])
def test_http_client_aborts(tctx, stream):
"""Test handling of the case where a client aborts during request transmission."""
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=True)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = True
assert (
playbook
>> DataReceived(tctx.client, b"POST http://example.com/ HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
<< http.HttpRequestHeadersHook(flow)
)
if stream:
assert (
playbook
>> reply(side_effect=enable_streaming)
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
)
else:
assert playbook >> reply()
assert (
playbook
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.client)
<< http.HttpErrorHook(flow)
>> reply()
)
flow: Callable[[], HTTPFlow]
assert "peer closed connection" in flow().error.msg

View File

@ -0,0 +1,105 @@
from typing import Callable, List
import hyperframe.frame
import pytest
from mitmproxy.http import HTTPFlow
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import Context, Server
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import http
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
@pytest.fixture
def frame_factory() -> FrameFactory:
return FrameFactory()
example_request_headers = (
(b':authority', b'example.com'),
(b':path', b'/'),
(b':scheme', b'https'),
(b':method', b'GET'),
)
example_response_headers = (
(b':status', b'200'),
(b'content-length', b'12'),
)
def decode_frames(data: bytes) -> List[hyperframe.frame.Frame]:
frames = []
while data:
f, length = hyperframe.frame.Frame.parse_frame_header(data[:9])
f.parse_body(memoryview(data[9:9 + length]))
frames.append(f)
data = data[9 + length:]
return frames
def start_h2(tctx: Context, frame_factory: FrameFactory) -> Playbook:
tctx.client.alpn = b"h2"
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
assert (
playbook
<< SendData(tctx.client, Placeholder()) # initial settings frame
>> DataReceived(tctx.client, frame_factory.preamble())
>> DataReceived(tctx.client, frame_factory.build_settings_frame({}, ack=True).serialize())
)
return playbook
def make_h2(open_connection: OpenConnection) -> None:
open_connection.connection.alpn = b"h2"
@pytest.mark.parametrize("stream", [True, False])
def test_http2_client_aborts(tctx, frame_factory, stream):
"""Test handling of the case where a client aborts during request transmission."""
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook = start_h2(tctx, frame_factory)
def enable_streaming(flow: HTTPFlow):
flow.request.stream = True
assert (
playbook
>> DataReceived(tctx.client, frame_factory.build_headers_frame(example_request_headers).serialize())
<< http.HttpRequestHeadersHook(flow)
)
if stream:
pytest.xfail("h2 client not implemented yet")
assert (
playbook
>> reply(side_effect=enable_streaming)
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 6\r\n\r\n"
b"abc")
)
else:
assert playbook >> reply()
assert (
playbook
>> ConnectionClosed(tctx.client)
<< CloseConnection(tctx.client)
<< http.HttpErrorHook(flow)
>> reply()
)
assert "peer closed connection" in flow().error.msg
@pytest.mark.xfail
def test_no_normalization():
"""Test that we don't normalize headers when we just pass them through."""
raise NotImplementedError

View File

@ -111,7 +111,7 @@ def test_fuzz_request(opts, data):
@example([b'0 OK\r\n\r\n', b'\r\n', b'5\r\n12345\r\n0\r\n\r\n']) @example([b'0 OK\r\n\r\n', b'\r\n', b'5\r\n12345\r\n0\r\n\r\n'])
def test_fuzz_response(opts, data): def test_fuzz_response(opts, data):
tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080)), opts) tctx = context.Context(context.Client(("client", 1234), ("127.0.0.1", 8080)), opts)
server = Placeholder() server = Placeholder(context.Server)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False) playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular), hooks=False)
assert ( assert (
playbook playbook

View File

@ -0,0 +1,125 @@
from typing import List, Tuple
import h2.connection
import h2.events
import h2.config
import hyperframe.frame
import pytest
from mitmproxy.http import HTTPFlow
from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2.commands import OpenConnection, SendData
from mitmproxy.proxy2.context import Context, Server
from mitmproxy.proxy2.events import DataReceived
from mitmproxy.proxy2.layers import http
from test.mitmproxy.proxy2.layers.http.hyper_h2_test_helpers import FrameFactory
from test.mitmproxy.proxy2.layers.http.test_http2 import example_request_headers, example_response_headers, make_h2
from test.mitmproxy.proxy2.tutils import Placeholder, Playbook, reply
h2f = FrameFactory()
def event_types(events):
return [type(x) for x in events]
def h2_client(tctx: Context) -> Tuple[h2.connection.H2Connection, Playbook]:
tctx.client.alpn = b"h2"
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
conn = h2.connection.H2Connection()
conn.initiate_connection()
server_preamble = Placeholder(bytes)
assert (
playbook
<< SendData(tctx.client, server_preamble)
)
assert event_types(conn.receive_data(server_preamble())) == [h2.events.RemoteSettingsChanged]
settings_ack = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client, conn.data_to_send())
<< SendData(tctx.client, settings_ack)
)
assert event_types(conn.receive_data(settings_ack())) == [h2.events.SettingsAcknowledged]
return conn, playbook
def test_h2_to_h1(tctx):
"""Test HTTP/2 -> HTTP/1 request translation"""
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
conn, playbook = h2_client(tctx)
conn.send_headers(1, example_request_headers, end_stream=True)
response = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client, conn.data_to_send())
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None)
<< SendData(server, b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
>> DataReceived(server, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
<< http.HttpResponseHeadersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client, response)
)
events = conn.receive_data(response())
assert event_types(events) == [
h2.events.ResponseReceived, h2.events.DataReceived, h2.events.DataReceived, h2.events.StreamEnded
]
resp: h2.events.ResponseReceived = events[0]
body: h2.events.DataReceived = events[1]
assert resp.headers == [(b':status', b'200'), (b'content-length', b'12')]
assert body.data == b"Hello World!"
@pytest.mark.xfail
def test_h1_to_h2(tctx):
"""Test HTTP/1 -> HTTP/2 request translation"""
server = Placeholder(Server)
flow = Placeholder(HTTPFlow)
playbook = Playbook(http.HttpLayer(tctx, HTTPMode.regular))
conf = h2.config.H2Configuration(client_side=False)
conn = h2.connection.H2Connection(conf)
conn.initiate_connection()
h2_preamble = Placeholder(bytes)
assert (
playbook
>> DataReceived(tctx.client, b"GET http://example.com/ HTTP/1.1\r\nHost: example.com\r\n\r\n")
<< http.HttpRequestHeadersHook(flow)
>> reply()
<< http.HttpRequestHook(flow)
>> reply()
<< OpenConnection(server)
>> reply(None, side_effect=make_h2)
<< SendData(server, h2_preamble)
)
events = conn.receive_data(h2_preamble())
y = h2_preamble()
assert not events # FIXME
request = Placeholder(bytes)
assert (
playbook
>> DataReceived(server, conn.data_to_send())
<< http.HttpResponseHeadersHook(flow)
>> reply()
<< http.HttpResponseHook(flow)
>> reply()
<< SendData(tctx.client, b"HTTP/1.1 200 OK\r\nContent-Length: 12\r\n\r\nHello World!")
)

View File

@ -2,6 +2,7 @@ from mitmproxy.proxy2.commands import CloseConnection, OpenConnection, SendData
from mitmproxy.proxy2.context import ConnectionState from mitmproxy.proxy2.context import ConnectionState
from mitmproxy.proxy2.events import ConnectionClosed, DataReceived from mitmproxy.proxy2.events import ConnectionClosed, DataReceived
from mitmproxy.proxy2.layers import tcp from mitmproxy.proxy2.layers import tcp
from mitmproxy.tcp import TCPFlow
from ..tutils import Placeholder, Playbook, reply from ..tutils import Placeholder, Playbook, reply
@ -23,7 +24,7 @@ def test_open_connection(tctx):
def test_open_connection_err(tctx): def test_open_connection_err(tctx):
f = Placeholder() f = Placeholder(TCPFlow)
assert ( assert (
Playbook(tcp.TCPLayer(tctx)) Playbook(tcp.TCPLayer(tctx))
<< tcp.TcpStartHook(f) << tcp.TcpStartHook(f)
@ -38,7 +39,7 @@ def test_open_connection_err(tctx):
def test_simple(tctx): def test_simple(tctx):
"""open connection, receive data, send it to peer""" """open connection, receive data, send it to peer"""
f = Placeholder() f = Placeholder(TCPFlow)
assert ( assert (
Playbook(tcp.TCPLayer(tctx)) Playbook(tcp.TCPLayer(tctx))

View File

@ -95,7 +95,7 @@ class SSLTest:
def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connection) -> None: def _test_echo(playbook: tutils.Playbook, tssl: SSLTest, conn: context.Connection) -> None:
tssl.obj.write(b"Hello World") tssl.obj.write(b"Hello World")
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(conn, tssl.out.read()) >> events.DataReceived(conn, tssl.out.read())
@ -118,7 +118,7 @@ class TlsEchoLayer(tutils.EchoLayer):
def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest): def interact(playbook: tutils.Playbook, conn: context.Connection, tssl: SSLTest):
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(conn, tssl.out.read()) >> events.DataReceived(conn, tssl.out.read())
@ -200,7 +200,7 @@ class TestServerTLS:
tssl = SSLTest(server_side=True) tssl = SSLTest(server_side=True)
# send ClientHello # send ClientHello
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
<< tls.TlsStartHook(tutils.Placeholder()) << tls.TlsStartHook(tutils.Placeholder())
@ -253,7 +253,7 @@ class TestServerTLS:
tssl = SSLTest(server_side=True) tssl = SSLTest(server_side=True)
# send ClientHello # send ClientHello
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, b"open-connection") >> events.DataReceived(tctx.client, b"open-connection")
@ -313,7 +313,7 @@ class TestClientTLS:
assert not tctx.client.tls_established assert not tctx.client.tls_established
# Send ClientHello, receive ServerHello # Send ClientHello, receive ServerHello
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.out.read()) >> events.DataReceived(tctx.client, tssl_client.out.read())
@ -349,8 +349,7 @@ class TestClientTLS:
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"]) playbook, client_layer, tssl_client = make_client_tls_layer(tctx, alpn=["quux"])
# We should now get instructed to open a server connection. # We should now get instructed to open a server connection.
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
tls_clienthello = tutils.Placeholder()
def require_server_conn(client_hello: tls.ClientHelloData) -> None: def require_server_conn(client_hello: tls.ClientHelloData) -> None:
client_hello.establish_server_tls_first = True client_hello.establish_server_tls_first = True
@ -358,7 +357,7 @@ class TestClientTLS:
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.out.read()) >> events.DataReceived(tctx.client, tssl_client.out.read())
<< tls.TlsClienthelloHook(tls_clienthello) << tls.TlsClienthelloHook(tutils.Placeholder())
>> tutils.reply(side_effect=require_server_conn) >> tutils.reply(side_effect=require_server_conn)
<< commands.OpenConnection(tctx.server) << commands.OpenConnection(tctx.server)
>> tutils.reply(None) >> tutils.reply(None)
@ -372,7 +371,7 @@ class TestClientTLS:
with pytest.raises(ssl.SSLWantReadError): with pytest.raises(ssl.SSLWantReadError):
tssl_server.obj.do_handshake() tssl_server.obj.do_handshake()
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.server, tssl_server.out.read()) >> events.DataReceived(tctx.server, tssl_server.out.read())
@ -383,7 +382,7 @@ class TestClientTLS:
assert tctx.server.tls_established assert tctx.server.tls_established
# Server TLS is established, we can now reply to the client handshake... # Server TLS is established, we can now reply to the client handshake...
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> reply_tls_start(alpn=b"quux") >> reply_tls_start(alpn=b"quux")
@ -421,7 +420,7 @@ class TestClientTLS:
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org") playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
playbook.logs = True playbook.logs = True
data = tutils.Placeholder() data = tutils.Placeholder(bytes)
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.out.read()) >> events.DataReceived(tctx.client, tssl_client.out.read())
@ -449,7 +448,6 @@ class TestClientTLS:
"""Test the scenario where the client doesn't trust the mitmproxy CA.""" """Test the scenario where the client doesn't trust the mitmproxy CA."""
playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org") playbook, client_layer, tssl_client = make_client_tls_layer(tctx, sni=b"wrong.host.mitmproxy.org")
data = tutils.Placeholder()
assert ( assert (
playbook playbook
>> events.DataReceived(tctx.client, tssl_client.out.read()) >> events.DataReceived(tctx.client, tssl_client.out.read())
@ -457,7 +455,7 @@ class TestClientTLS:
>> tutils.reply() >> tutils.reply()
<< tls.TlsStartHook(tutils.Placeholder()) << tls.TlsStartHook(tutils.Placeholder())
>> reply_tls_start() >> reply_tls_start()
<< commands.SendData(tctx.client, data) << commands.SendData(tctx.client, tutils.Placeholder())
>> events.ConnectionClosed(tctx.client) >> events.ConnectionClosed(tctx.client)
<< commands.Log("Client TLS handshake failed. The client may not trust the proxy's certificate " << commands.Log("Client TLS handshake failed. The client may not trust the proxy's certificate "
"for wrong.host.mitmproxy.org (connection closed without notice)", "warn") "for wrong.host.mitmproxy.org (connection closed without notice)", "warn")

View File

@ -76,8 +76,12 @@ def test_partial_assert(tplaybook):
assert len(tplaybook.actual) == len(tplaybook.expected) == 4 assert len(tplaybook.actual) == len(tplaybook.expected) == 4
def test_placeholder(tplaybook): @pytest.mark.parametrize("typed", [True, False])
def test_placeholder(tplaybook, typed):
"""Developers can specify placeholders for yet unknown attributes.""" """Developers can specify placeholders for yet unknown attributes."""
if typed:
f = tutils.Placeholder(int)
else:
f = tutils.Placeholder() f = tutils.Placeholder()
assert ( assert (
tplaybook tplaybook
@ -87,6 +91,17 @@ def test_placeholder(tplaybook):
assert f() == 42 assert f() == 42
def test_placeholder_type_mismatch(tplaybook):
"""Developers can specify placeholders for yet unknown attributes."""
f = tutils.Placeholder(str)
with pytest.raises(TypeError, match="Placeholder type error for TCommand.x: expected str, got int."):
assert (
tplaybook
>> TEvent([42])
<< TCommand(f)
)
def test_fork(tplaybook): def test_fork(tplaybook):
"""Playbooks can be forked to test multiple execution streams.""" """Playbooks can be forked to test multiple execution streams."""
assert ( assert (
@ -192,5 +207,5 @@ def test_eq_placeholder():
assert a.foo == b.foo() == 42 assert a.foo == b.foo() == 42
assert a.bar() == b.bar == 43 assert a.bar() == b.bar == 43
b.foo.obj = 44 b.foo._obj = 44
assert not tutils._eq(a, b) assert not tutils._eq(a, b)

View File

@ -24,21 +24,23 @@ def _eq(
if type(a) != type(b): if type(a) != type(b):
return False return False
a = a.__dict__ a_dict = a.__dict__
b = b.__dict__ b_dict = b.__dict__
# we can assume a.keys() == b.keys() # we can assume a.keys() == b.keys()
for k in a: for k in a_dict:
if k == "blocking": if k == "blocking":
continue continue
x, y = a[k], b[k] x = a_dict[k]
y = b_dict[k]
# if there's a placeholder, make it x. # if there's a placeholder, make it x.
if isinstance(y, _Placeholder): if isinstance(y, _Placeholder):
x, y = y, x x, y = y, x
if isinstance(x, _Placeholder): if isinstance(x, _Placeholder):
if x.obj is None: try:
x.obj = y x = x.setdefault(y)
x = x.obj except TypeError as e:
raise TypeError(f"Placeholder type error for {type(a).__name__}.{k}: {e}")
if x != y: if x != y:
return False return False
@ -194,6 +196,7 @@ class Playbook:
pos = i + 1 + offset pos = i + 1 + offset
need_to_emulate_log = ( need_to_emulate_log = (
isinstance(cmd, commands.Log) and isinstance(cmd, commands.Log) and
cmd.level in ("debug", "info") and
( (
pos >= len(self.expected) pos >= len(self.expected)
or not isinstance(self.expected[pos], commands.Log) or not isinstance(self.expected[pos], commands.Log)
@ -297,33 +300,46 @@ class _Placeholder:
Placeholder value in playbooks, so that objects (flows in particular) can be referenced before Placeholder value in playbooks, so that objects (flows in particular) can be referenced before
they are known. Example: they are known. Example:
f = Placeholder() f = Placeholder(TCPFlow)
assert ( assert (
playbook(tcp.TCPLayer(tctx)) playbook(tcp.TCPLayer(tctx))
<< commands.Hook("tcp_start", f) # the flow object returned here is generated by the layer. << TcpStartHook(f) # the flow object returned here is generated by the layer.
) )
# We can obtain the flow object now using f(): # We can obtain the flow object now using f():
assert f().messages == 0 assert f().messages == 0
""" """
def __init__(self): def __init__(self, cls: typing.Type):
self.obj = None self._obj = None
self._cls = cls
def __call__(self): def __call__(self):
"""Get the actual object""" """Get the actual object"""
return self.obj return self._obj
def setdefault(self, value):
if self._obj is None:
if self._cls is not typing.Any and not isinstance(value, self._cls):
raise TypeError(f"expected {self._cls.__name__}, got {type(value).__name__}.")
self._obj = value
return self._obj
def __repr__(self): def __repr__(self):
return f"Placeholder:{repr(self.obj)}" return f"Placeholder:{repr(self._obj)}"
def __str__(self): def __str__(self):
return f"Placeholder:{str(self.obj)}" return f"Placeholder:{str(self._obj)}"
T = typing.TypeVar("T")
# noinspection PyPep8Naming # noinspection PyPep8Naming
def Placeholder() -> typing.Any: def Placeholder(cls: typing.Type[T] = typing.Any) -> typing.Union[
return _Placeholder() T, typing.Callable[[], T]
]:
return _Placeholder(cls)
class EchoLayer(Layer): class EchoLayer(Layer):