mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
Merge pull request #67 from Kriechi/http2-wip
HTTP/2: preparations for pathod
This commit is contained in:
commit
4fbe406e2e
@ -26,55 +26,80 @@ class HTTP2Protocol(object):
|
||||
)
|
||||
|
||||
# "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
|
||||
CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'
|
||||
CLIENT_CONNECTION_PREFACE =\
|
||||
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')
|
||||
|
||||
ALPN_PROTO_H2 = 'h2'
|
||||
|
||||
def __init__(self, tcp_client):
|
||||
self.tcp_client = tcp_client
|
||||
def __init__(self, tcp_handler, is_server=False):
|
||||
self.tcp_handler = tcp_handler
|
||||
self.is_server = is_server
|
||||
|
||||
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
|
||||
self.current_stream_id = None
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.connection_preface_performed = False
|
||||
|
||||
def check_alpn(self):
|
||||
alp = self.tcp_client.get_alpn_proto_negotiated()
|
||||
alp = self.tcp_handler.get_alpn_proto_negotiated()
|
||||
if alp != self.ALPN_PROTO_H2:
|
||||
raise NotImplementedError(
|
||||
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
|
||||
return True
|
||||
|
||||
def perform_connection_preface(self):
|
||||
self.tcp_client.wfile.write(
|
||||
bytes(self.CLIENT_CONNECTION_PREFACE.decode('hex')))
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
|
||||
# read server settings frame
|
||||
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
|
||||
def _receive_settings(self):
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
assert isinstance(frm, frame.SettingsFrame)
|
||||
self._apply_settings(frm.settings)
|
||||
|
||||
# read setting ACK frame
|
||||
def _read_settings_ack(self):
|
||||
settings_ack_frame = self.read_frame()
|
||||
assert isinstance(settings_ack_frame, frame.SettingsFrame)
|
||||
assert settings_ack_frame.flags & frame.Frame.FLAG_ACK
|
||||
assert len(settings_ack_frame.settings) == 0
|
||||
|
||||
def perform_server_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
|
||||
magic = self.tcp_handler.rfile.safe_read(magic_length)
|
||||
assert magic == self.CLIENT_CONNECTION_PREFACE
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
self._receive_settings()
|
||||
self._read_settings_ack()
|
||||
|
||||
def perform_client_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self))
|
||||
self._receive_settings()
|
||||
self._read_settings_ack()
|
||||
|
||||
def next_stream_id(self):
|
||||
if self.current_stream_id is None:
|
||||
self.current_stream_id = 1
|
||||
if self.is_server:
|
||||
# servers must use even stream ids
|
||||
self.current_stream_id = 2
|
||||
else:
|
||||
# clients must use odd stream ids
|
||||
self.current_stream_id = 1
|
||||
else:
|
||||
self.current_stream_id += 2
|
||||
return self.current_stream_id
|
||||
|
||||
def send_frame(self, frame):
|
||||
raw_bytes = frame.to_bytes()
|
||||
self.tcp_client.wfile.write(raw_bytes)
|
||||
self.tcp_client.wfile.flush()
|
||||
self.tcp_handler.wfile.write(raw_bytes)
|
||||
self.tcp_handler.wfile.flush()
|
||||
|
||||
def read_frame(self):
|
||||
frm = frame.Frame.from_file(self.tcp_client.rfile, self)
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
if isinstance(frm, frame.SettingsFrame):
|
||||
self._apply_settings(frm.settings)
|
||||
|
||||
@ -127,10 +152,13 @@ class HTTP2Protocol(object):
|
||||
if headers is None:
|
||||
headers = []
|
||||
|
||||
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
|
||||
headers = [
|
||||
(b':method', bytes(method)),
|
||||
(b':path', bytes(path)),
|
||||
(b':scheme', b'https')] + headers
|
||||
(b':scheme', b'https'),
|
||||
(b':authority', authority),
|
||||
] + headers
|
||||
|
||||
stream_id = self.next_stream_id()
|
||||
|
||||
@ -139,25 +167,50 @@ class HTTP2Protocol(object):
|
||||
self._create_body(body, stream_id)))
|
||||
|
||||
def read_response(self):
|
||||
headers, body = self._receive_transmission()
|
||||
return headers[':status'], headers, body
|
||||
|
||||
def read_request(self):
|
||||
return self._receive_transmission()
|
||||
|
||||
def _receive_transmission(self):
|
||||
body_expected = True
|
||||
|
||||
header_block_fragment = b''
|
||||
body = b''
|
||||
|
||||
while True:
|
||||
frm = self.read_frame()
|
||||
if isinstance(frm, frame.HeadersFrame):
|
||||
if isinstance(frm, frame.HeadersFrame)\
|
||||
or isinstance(frm, frame.ContinuationFrame):
|
||||
header_block_fragment += frm.header_block_fragment
|
||||
if frm.flags | frame.Frame.FLAG_END_HEADERS:
|
||||
if frm.flags & frame.Frame.FLAG_END_HEADERS:
|
||||
if frm.flags & frame.Frame.FLAG_END_STREAM:
|
||||
body_expected = False
|
||||
break
|
||||
|
||||
while True:
|
||||
while body_expected:
|
||||
frm = self.read_frame()
|
||||
if isinstance(frm, frame.DataFrame):
|
||||
body += frm.payload
|
||||
if frm.flags | frame.Frame.FLAG_END_STREAM:
|
||||
if frm.flags & frame.Frame.FLAG_END_STREAM:
|
||||
break
|
||||
# TODO: implement window update & flow
|
||||
|
||||
headers = {}
|
||||
for header, value in self.decoder.decode(header_block_fragment):
|
||||
headers[header] = value
|
||||
|
||||
return headers[':status'], headers, body
|
||||
return headers, body
|
||||
|
||||
def create_response(self, code, headers=None, body=None):
|
||||
if headers is None:
|
||||
headers = []
|
||||
|
||||
headers = [(b':status', bytes(str(code)))] + headers
|
||||
|
||||
stream_id = self.next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(body is None)),
|
||||
self._create_body(body, stream_id)))
|
||||
|
@ -19,6 +19,9 @@ SSLv2_METHOD = SSL.SSLv2_METHOD
|
||||
SSLv3_METHOD = SSL.SSLv3_METHOD
|
||||
SSLv23_METHOD = SSL.SSLv23_METHOD
|
||||
TLSv1_METHOD = SSL.TLSv1_METHOD
|
||||
TLSv1_1_METHOD = SSL.TLSv1_1_METHOD
|
||||
TLSv1_2_METHOD = SSL.TLSv1_2_METHOD
|
||||
|
||||
OP_NO_SSLv2 = SSL.OP_NO_SSLv2
|
||||
OP_NO_SSLv3 = SSL.OP_NO_SSLv3
|
||||
|
||||
@ -376,7 +379,7 @@ class _Connection(object):
|
||||
alpn_select=None,
|
||||
):
|
||||
"""
|
||||
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD or TLSv1_1_METHOD
|
||||
:param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
|
||||
:param options: A bit field consisting of OpenSSL.SSL.OP_* values
|
||||
:param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
|
||||
:rtype : SSL.Context
|
||||
@ -404,16 +407,17 @@ class _Connection(object):
|
||||
context.set_info_callback(log_ssl_key)
|
||||
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
|
||||
# advertise application layer protocols
|
||||
if alpn_protos is not None:
|
||||
# advertise application layer protocols
|
||||
context.set_alpn_protos(alpn_protos)
|
||||
|
||||
# select application layer protocol
|
||||
if alpn_select is not None:
|
||||
def alpn_select_f(conn, options):
|
||||
return bytes(alpn_select)
|
||||
|
||||
context.set_alpn_select_callback(alpn_select_f)
|
||||
elif alpn_select is not None:
|
||||
# select application layer protocol
|
||||
def alpn_select_callback(conn, options):
|
||||
if alpn_select in options:
|
||||
return bytes(alpn_select)
|
||||
else: # pragma no cover
|
||||
return options[0]
|
||||
context.set_alpn_select_callback(alpn_select_callback)
|
||||
|
||||
return context
|
||||
|
||||
@ -499,9 +503,9 @@ class TCPClient(_Connection):
|
||||
return self.connection.gettimeout()
|
||||
|
||||
def get_alpn_proto_negotiated(self):
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else: # pragma no cover
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -531,7 +535,6 @@ class BaseHandler(_Connection):
|
||||
request_client_cert=None,
|
||||
chain_file=None,
|
||||
dhparams=None,
|
||||
alpn_select=None,
|
||||
**sslctx_kwargs):
|
||||
"""
|
||||
cert: A certutils.SSLCert object.
|
||||
@ -558,9 +561,7 @@ class BaseHandler(_Connection):
|
||||
until then we're conservative.
|
||||
"""
|
||||
|
||||
context = self._create_ssl_context(
|
||||
alpn_select=alpn_select,
|
||||
**sslctx_kwargs)
|
||||
context = self._create_ssl_context(**sslctx_kwargs)
|
||||
|
||||
context.use_privatekey(key)
|
||||
context.use_certificate(cert.x509)
|
||||
@ -585,7 +586,7 @@ class BaseHandler(_Connection):
|
||||
|
||||
return context
|
||||
|
||||
def convert_to_ssl(self, cert, key, alpn_select=None, **sslctx_kwargs):
|
||||
def convert_to_ssl(self, cert, key, **sslctx_kwargs):
|
||||
"""
|
||||
Convert connection to SSL.
|
||||
For a list of parameters, see BaseHandler._create_ssl_context(...)
|
||||
@ -594,7 +595,6 @@ class BaseHandler(_Connection):
|
||||
context = self.create_ssl_context(
|
||||
cert,
|
||||
key,
|
||||
alpn_select=alpn_select,
|
||||
**sslctx_kwargs)
|
||||
self.connection = SSL.Connection(context, self.connection)
|
||||
self.connection.set_accept_state()
|
||||
@ -612,6 +612,12 @@ class BaseHandler(_Connection):
|
||||
def settimeout(self, n):
|
||||
self.connection.settimeout(n)
|
||||
|
||||
def get_alpn_proto_negotiated(self):
|
||||
if OpenSSL._util.lib.Cryptography_HAS_ALPN and self.ssl_established:
|
||||
return self.connection.get_alpn_proto_negotiated()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class TCPServer(object):
|
||||
request_queue_size = 20
|
||||
|
@ -1,4 +1,3 @@
|
||||
|
||||
import OpenSSL
|
||||
|
||||
from netlib import http2
|
||||
@ -50,7 +49,39 @@ class TestCheckALPNMismatch(test.ServerTestBase):
|
||||
tutils.raises(NotImplementedError, protocol.check_alpn)
|
||||
|
||||
|
||||
class TestPerformConnectionPreface(test.ServerTestBase):
|
||||
class TestPerformServerConnectionPreface(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
# send magic
|
||||
self.wfile.write(
|
||||
'505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
# send empty settings frame
|
||||
self.wfile.write('000000040000000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
# check empty settings frame
|
||||
assert self.rfile.read(9) ==\
|
||||
'000000040000000000'.decode('hex')
|
||||
|
||||
# check settings acknowledgement
|
||||
assert self.rfile.read(9) == \
|
||||
'000000040100000000'.decode('hex')
|
||||
|
||||
# send settings acknowledgement
|
||||
self.wfile.write('000000040100000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
def test_perform_server_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.perform_server_connection_preface()
|
||||
|
||||
|
||||
class TestPerformClientConnectionPreface(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
@ -74,21 +105,18 @@ class TestPerformConnectionPreface(test.ServerTestBase):
|
||||
self.wfile.write('000000040100000000'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_perform_connection_preface(self):
|
||||
def test_perform_client_connection_preface(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.perform_connection_preface()
|
||||
protocol.perform_client_connection_preface()
|
||||
|
||||
|
||||
class TestStreamIds():
|
||||
class TestClientStreamIds():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
|
||||
def test_stream_ids(self):
|
||||
def test_client_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
assert self.protocol.next_stream_id() == 1
|
||||
assert self.protocol.current_stream_id == 1
|
||||
@ -98,6 +126,20 @@ class TestStreamIds():
|
||||
assert self.protocol.current_stream_id == 5
|
||||
|
||||
|
||||
class TestServerStreamIds():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
|
||||
def test_server_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
assert self.protocol.next_stream_id() == 2
|
||||
assert self.protocol.current_stream_id == 2
|
||||
assert self.protocol.next_stream_id() == 4
|
||||
assert self.protocol.current_stream_id == 4
|
||||
assert self.protocol.next_stream_id() == 6
|
||||
assert self.protocol.current_stream_id == 6
|
||||
|
||||
|
||||
class TestApplySettings(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
@ -180,14 +222,14 @@ class TestCreateRequest():
|
||||
def test_create_request_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/')
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] == '000003010500000001828487'.decode('hex')
|
||||
assert bytes[0] == '00000c0105000000018284874187089d5c0b8170ff'.decode('hex')
|
||||
|
||||
def test_create_request_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).create_request(
|
||||
'GET', '/', [(b'foo', b'bar')], 'foobar')
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'00000b010400000001828487408294e7838c767f'.decode('hex')
|
||||
'0000140104000000018284874187089d5c0b8170ff408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000001666f6f626172'.decode('hex')
|
||||
|
||||
@ -213,5 +255,71 @@ class TestReadResponse(test.ServerTestBase):
|
||||
status, headers, body = protocol.read_response()
|
||||
|
||||
assert headers == {':status': '200', 'etag': 'foobar'}
|
||||
assert status == '200'
|
||||
assert status == "200"
|
||||
assert body == b'foobar'
|
||||
|
||||
|
||||
class TestReadEmptyResponse(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'00000801050000000188628594e78c767f'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_empty_response(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
|
||||
status, headers, body = protocol.read_response()
|
||||
|
||||
assert headers == {':status': '200', 'etag': 'foobar'}
|
||||
assert status == "200"
|
||||
assert body == b''
|
||||
|
||||
|
||||
class TestReadRequest(test.ServerTestBase):
|
||||
class handler(tcp.BaseHandler):
|
||||
|
||||
def handle(self):
|
||||
self.wfile.write(
|
||||
b'000003010400000001828487'.decode('hex'))
|
||||
self.wfile.write(
|
||||
b'000006000100000001666f6f626172'.decode('hex'))
|
||||
self.wfile.flush()
|
||||
|
||||
ssl = True
|
||||
|
||||
def test_read_request(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
|
||||
headers, body = protocol.read_request()
|
||||
|
||||
assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'}
|
||||
assert body == b'foobar'
|
||||
|
||||
|
||||
class TestCreateResponse():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_create_request_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] ==\
|
||||
'00000101050000000288'.decode('hex')
|
||||
|
||||
def test_create_request_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
|
||||
200, [(b'foo', b'bar')], 'foobar')
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'00000901040000000288408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000002666f6f626172'.decode('hex')
|
||||
|
@ -376,6 +376,11 @@ class TestALPN(test.ServerTestBase):
|
||||
c.convert_to_ssl(alpn_protos=["foobar"])
|
||||
assert c.get_alpn_proto_negotiated() == "foobar"
|
||||
|
||||
def test_no_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
assert c.get_alpn_proto_negotiated() == None
|
||||
|
||||
else:
|
||||
def test_none_alpn(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
|
Loading…
Reference in New Issue
Block a user