diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 0e42d619a..ee1057813 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -96,15 +96,17 @@ class Http2Layer(base.Layer): self.server_to_client_stream_ids = dict([(0, 0)]) self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False, header_encoding=False) - # make sure that we only pass actual SSL.Connection objects in here, - # because otherwise ssl_read_select fails! - self.active_conns = [self.client_conn.connection] - def _initiate_server_conn(self): - self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) - self.server_conn.h2.initiate_connection() - self.server_conn.send(self.server_conn.h2.data_to_send()) - self.active_conns.append(self.server_conn.connection) + if self.server_conn: + self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) + self.server_conn.h2.initiate_connection() + self.server_conn.send(self.server_conn.h2.data_to_send()) + + def _complete_handshake(self): + preamble = self.client_conn.rfile.read(24) + self.client_conn.h2.initiate_connection() + self.client_conn.h2.receive_data(preamble) + self.client_conn.send(self.client_conn.h2.data_to_send()) def next_layer(self): # pragma: no cover # WebSockets over HTTP/2? @@ -301,20 +303,19 @@ class Http2Layer(base.Layer): stream.data_finished.set() def __call__(self): - if self.server_conn: - self._initiate_server_conn() + self._initiate_server_conn() + self._complete_handshake() - preamble = self.client_conn.rfile.read(24) - self.client_conn.h2.initiate_connection() - self.client_conn.h2.receive_data(preamble) - self.client_conn.send(self.client_conn.h2.data_to_send()) + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] try: while True: - r = tcp.ssl_read_select(self.active_conns, 1) + r = tcp.ssl_read_select(conns, 1) for conn in r: - source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn - other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn is_server = (conn == self.server_conn.connection) with source_conn.h2.lock: diff --git a/test/mitmproxy/protocol/test_http2.py b/test/mitmproxy/protocol/test_http2.py index 1eabebf11..4dce698d4 100644 --- a/test/mitmproxy/protocol/test_http2.py +++ b/test/mitmproxy/protocol/test_http2.py @@ -130,6 +130,7 @@ class _Http2TestBase(object): b"\r\n" % (self.server.server.address.port, self.server.server.address.port) ) client.wfile.flush() + # TODO: rewrite as http.Request object with http.assemble_request # read CONNECT response while client.rfile.readline() != b"\r\n":