mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 16:17:49 +00:00
366 lines
14 KiB
Python
366 lines
14 KiB
Python
|
from __future__ import (absolute_import, print_function, division)
|
||
|
|
||
|
import struct
|
||
|
import threading
|
||
|
import time
|
||
|
import Queue
|
||
|
|
||
|
from netlib.tcp import ssl_read_select
|
||
|
from netlib.exceptions import HttpException
|
||
|
from netlib.http import Headers
|
||
|
|
||
|
import h2
|
||
|
from h2.connection import H2Connection
|
||
|
from h2.events import *
|
||
|
|
||
|
from .base import Layer
|
||
|
from .http import _HttpTransmissionLayer, HttpLayer
|
||
|
from .. import utils
|
||
|
from ..models import HTTPRequest, HTTPResponse
|
||
|
|
||
|
|
||
|
class SafeH2Connection(H2Connection):
|
||
|
def __init__(self, conn, *args, **kwargs):
|
||
|
super(SafeH2Connection, self).__init__(*args, **kwargs)
|
||
|
self.conn = conn
|
||
|
self.lock = threading.RLock()
|
||
|
|
||
|
def safe_close_connection(self, error_code):
|
||
|
with self.lock:
|
||
|
self.close_connection(error_code)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
def safe_increment_flow_control(self, stream_id, length):
|
||
|
if length == 0:
|
||
|
return
|
||
|
|
||
|
with self.lock:
|
||
|
self.increment_flow_control_window(length)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
with self.lock:
|
||
|
if stream_id in self.streams and not self.streams[stream_id].closed:
|
||
|
self.increment_flow_control_window(length, stream_id=stream_id)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
def safe_reset_stream(self, stream_id, error_code):
|
||
|
with self.lock:
|
||
|
try:
|
||
|
self.reset_stream(stream_id, error_code)
|
||
|
except h2.exceptions.StreamClosedError:
|
||
|
# stream is already closed - good
|
||
|
pass
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
def safe_update_settings(self, new_settings):
|
||
|
with self.lock:
|
||
|
self.update_settings(new_settings)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
def safe_send_headers(self, is_zombie, stream_id, headers):
|
||
|
with self.lock:
|
||
|
if is_zombie(self, stream_id):
|
||
|
return
|
||
|
self.send_headers(stream_id, headers)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
def safe_send_body(self, is_zombie, stream_id, chunks):
|
||
|
for chunk in chunks:
|
||
|
position = 0
|
||
|
while position < len(chunk):
|
||
|
self.lock.acquire()
|
||
|
if is_zombie(self, stream_id):
|
||
|
return
|
||
|
max_outbound_frame_size = self.max_outbound_frame_size
|
||
|
frame_chunk = chunk[position:position+max_outbound_frame_size]
|
||
|
if self.local_flow_control_window(stream_id) < len(frame_chunk):
|
||
|
self.lock.release()
|
||
|
time.sleep(0)
|
||
|
continue
|
||
|
self.send_data(stream_id, frame_chunk)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
self.lock.release()
|
||
|
position += max_outbound_frame_size
|
||
|
with self.lock:
|
||
|
if is_zombie(self, stream_id):
|
||
|
return
|
||
|
self.end_stream(stream_id)
|
||
|
self.conn.send(self.data_to_send())
|
||
|
|
||
|
|
||
|
class Http2Layer(Layer):
|
||
|
def __init__(self, ctx, mode):
|
||
|
super(Http2Layer, self).__init__(ctx)
|
||
|
self.mode = mode
|
||
|
self.streams = dict()
|
||
|
self.server_to_client_stream_ids = dict([(0, 0)])
|
||
|
self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False)
|
||
|
|
||
|
# make sure that we only pass actual SSL.Connection objects in here,
|
||
|
# because otherwise ssl_read_select fails!
|
||
|
self.active_conns = [self.client_conn.connection]
|
||
|
|
||
|
def _initiate_server_conn(self):
|
||
|
self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True)
|
||
|
self.server_conn.h2.initiate_connection()
|
||
|
self.server_conn.send(self.server_conn.h2.data_to_send())
|
||
|
self.active_conns.append(self.server_conn.connection)
|
||
|
|
||
|
def connect(self):
|
||
|
self.ctx.connect()
|
||
|
self.server_conn.connect()
|
||
|
self._initiate_server_conn()
|
||
|
|
||
|
def set_server(self):
|
||
|
raise NotImplementedError("Cannot change server for HTTP2 connections.")
|
||
|
|
||
|
def disconnect(self):
|
||
|
raise NotImplementedError("Cannot dis- or reconnect in HTTP2 connections.")
|
||
|
|
||
|
def next_layer(self):
|
||
|
# WebSockets over HTTP/2?
|
||
|
# CONNECT for proxying?
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
def _handle_event(self, event, source_conn, other_conn, is_server):
|
||
|
if hasattr(event, 'stream_id'):
|
||
|
if is_server and event.stream_id % 2 == 1:
|
||
|
eid = self.server_to_client_stream_ids[event.stream_id]
|
||
|
else:
|
||
|
eid = event.stream_id
|
||
|
|
||
|
if isinstance(event, RequestReceived):
|
||
|
headers = Headers([[str(k), str(v)] for k, v in event.headers])
|
||
|
self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
|
||
|
self.streams[eid].timestamp_start = time.time()
|
||
|
self.streams[eid].start()
|
||
|
elif isinstance(event, ResponseReceived):
|
||
|
headers = Headers([[str(k), str(v)] for k, v in event.headers])
|
||
|
self.streams[eid].queued_data_length = 0
|
||
|
self.streams[eid].timestamp_start = time.time()
|
||
|
self.streams[eid].response_headers = headers
|
||
|
self.streams[eid].response_arrived.set()
|
||
|
elif isinstance(event, DataReceived):
|
||
|
if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit:
|
||
|
raise HttpException("HTTP body too large. Limit is {}.".format(self.config.body_size_limit))
|
||
|
self.streams[eid].data_queue.put(event.data)
|
||
|
self.streams[eid].queued_data_length += len(event.data)
|
||
|
source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data))
|
||
|
elif isinstance(event, StreamEnded):
|
||
|
self.streams[eid].timestamp_end = time.time()
|
||
|
self.streams[eid].data_finished.set()
|
||
|
elif isinstance(event, StreamReset):
|
||
|
self.streams[eid].zombie = time.time()
|
||
|
if eid in self.streams and event.error_code == 0x8:
|
||
|
if is_server:
|
||
|
other_stream_id = self.streams[eid].client_stream_id
|
||
|
else:
|
||
|
other_stream_id = self.streams[eid].server_stream_id
|
||
|
other_conn.h2.safe_reset_stream(other_stream_id, event.error_code)
|
||
|
elif isinstance(event, RemoteSettingsChanged):
|
||
|
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
|
||
|
other_conn.h2.safe_update_settings(new_settings)
|
||
|
elif isinstance(event, ConnectionTerminated):
|
||
|
other_conn.h2.safe_close_connection(event.error_code)
|
||
|
return False
|
||
|
elif isinstance(event, PushedStreamReceived):
|
||
|
# pushed stream ids should be uniq and not dependent on race conditions
|
||
|
# only the parent stream id must be looked up first
|
||
|
parent_eid = self.server_to_client_stream_ids[event.parent_stream_id]
|
||
|
with self.client_conn.h2.lock:
|
||
|
self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers)
|
||
|
|
||
|
headers = Headers([[str(k), str(v)] for k, v in event.headers])
|
||
|
headers['x-mitmproxy-pushed'] = 'true'
|
||
|
self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers)
|
||
|
self.streams[event.pushed_stream_id].timestamp_start = time.time()
|
||
|
self.streams[event.pushed_stream_id].pushed = True
|
||
|
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
|
||
|
self.streams[event.pushed_stream_id].timestamp_end = time.time()
|
||
|
self.streams[event.pushed_stream_id].data_finished.set()
|
||
|
self.streams[event.pushed_stream_id].start()
|
||
|
elif isinstance(event, TrailersReceived):
|
||
|
raise NotImplementedError()
|
||
|
|
||
|
return True
|
||
|
|
||
|
def _cleanup_streams(self):
|
||
|
death_time = time.time() - 10
|
||
|
for stream_id in self.streams.keys():
|
||
|
zombie = self.streams[stream_id].zombie
|
||
|
if zombie and zombie <= death_time:
|
||
|
self.streams.pop(stream_id, None)
|
||
|
|
||
|
def __call__(self):
|
||
|
if self.server_conn:
|
||
|
self._initiate_server_conn()
|
||
|
|
||
|
preamble = self.client_conn.rfile.read(24)
|
||
|
self.client_conn.h2.initiate_connection()
|
||
|
self.client_conn.h2.receive_data(preamble)
|
||
|
self.client_conn.send(self.client_conn.h2.data_to_send())
|
||
|
|
||
|
while True:
|
||
|
r = ssl_read_select(self.active_conns, 1)
|
||
|
for conn in r:
|
||
|
source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
|
||
|
other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
|
||
|
is_server = (conn == self.server_conn.connection)
|
||
|
|
||
|
field = source_conn.rfile.peek(3)
|
||
|
length = int(field.encode('hex'), 16)
|
||
|
raw_frame = source_conn.rfile.safe_read(9 + length)
|
||
|
|
||
|
with source_conn.h2.lock:
|
||
|
events = source_conn.h2.receive_data(raw_frame)
|
||
|
source_conn.send(source_conn.h2.data_to_send())
|
||
|
|
||
|
for event in events:
|
||
|
if not self._handle_event(event, source_conn, other_conn, is_server):
|
||
|
return
|
||
|
|
||
|
self._cleanup_streams()
|
||
|
|
||
|
|
||
|
class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
|
||
|
def __init__(self, ctx, stream_id, request_headers):
|
||
|
super(Http2SingleStreamLayer, self).__init__(ctx)
|
||
|
self.zombie = None
|
||
|
self.client_stream_id = stream_id
|
||
|
self.server_stream_id = None
|
||
|
self.request_headers = request_headers
|
||
|
self.response_headers = None
|
||
|
self.data_queue = Queue.Queue()
|
||
|
self.queued_data_length = 0
|
||
|
|
||
|
self.response_arrived = threading.Event()
|
||
|
self.data_finished = threading.Event()
|
||
|
|
||
|
def is_zombie(self, h2_conn, stream_id):
|
||
|
if self.zombie:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def read_request(self):
|
||
|
self.data_finished.wait()
|
||
|
self.data_finished.clear()
|
||
|
|
||
|
authority = self.request_headers.get(':authority', '')
|
||
|
method = self.request_headers.get(':method', 'GET')
|
||
|
scheme = self.request_headers.get(':scheme', 'https')
|
||
|
path = self.request_headers.get(':path', '/')
|
||
|
host = None
|
||
|
port = None
|
||
|
|
||
|
if path == '*' or path.startswith("/"):
|
||
|
form_in = "relative"
|
||
|
elif method == 'CONNECT':
|
||
|
form_in = "authority"
|
||
|
if ":" in authority:
|
||
|
host, port = authority.split(":", 1)
|
||
|
else:
|
||
|
host = authority
|
||
|
else:
|
||
|
form_in = "absolute"
|
||
|
# FIXME: verify if path or :host contains what we need
|
||
|
scheme, host, port, _ = utils.parse_url(path)
|
||
|
|
||
|
if host is None:
|
||
|
host = 'localhost'
|
||
|
if port is None:
|
||
|
port = 80 if scheme == 'http' else 443
|
||
|
port = int(port)
|
||
|
|
||
|
data = []
|
||
|
while self.data_queue.qsize() > 0:
|
||
|
data.append(self.data_queue.get())
|
||
|
data = b"".join(data)
|
||
|
|
||
|
return HTTPRequest(
|
||
|
form_in,
|
||
|
method,
|
||
|
scheme,
|
||
|
host,
|
||
|
port,
|
||
|
path,
|
||
|
(2, 0),
|
||
|
self.request_headers,
|
||
|
data,
|
||
|
timestamp_start=self.timestamp_start,
|
||
|
timestamp_end=self.timestamp_end,
|
||
|
)
|
||
|
|
||
|
def send_request(self, message):
|
||
|
with self.server_conn.h2.lock:
|
||
|
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
|
||
|
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
|
||
|
|
||
|
self.server_conn.h2.safe_send_headers(
|
||
|
self.is_zombie,
|
||
|
self.server_stream_id,
|
||
|
message.headers
|
||
|
)
|
||
|
self.server_conn.h2.safe_send_body(
|
||
|
self.is_zombie,
|
||
|
self.server_stream_id,
|
||
|
message.body
|
||
|
)
|
||
|
|
||
|
def read_response_headers(self):
|
||
|
self.response_arrived.wait()
|
||
|
|
||
|
status_code = int(self.response_headers.get(':status', 502))
|
||
|
|
||
|
return HTTPResponse(
|
||
|
http_version=(2, 0),
|
||
|
status_code=status_code,
|
||
|
reason='',
|
||
|
headers=self.response_headers,
|
||
|
content=None,
|
||
|
timestamp_start=self.timestamp_start,
|
||
|
timestamp_end=self.timestamp_end,
|
||
|
)
|
||
|
|
||
|
def read_response_body(self, request, response):
|
||
|
while True:
|
||
|
try:
|
||
|
yield self.data_queue.get(timeout=1)
|
||
|
except Queue.Empty:
|
||
|
pass
|
||
|
if self.data_finished.is_set():
|
||
|
while self.data_queue.qsize() > 0:
|
||
|
yield self.data_queue.get()
|
||
|
return
|
||
|
if self.zombie:
|
||
|
return
|
||
|
|
||
|
def send_response_headers(self, response):
|
||
|
self.client_conn.h2.safe_send_headers(
|
||
|
self.is_zombie,
|
||
|
self.client_stream_id,
|
||
|
response.headers
|
||
|
)
|
||
|
|
||
|
def send_response_body(self, _response, chunks):
|
||
|
self.client_conn.h2.safe_send_body(
|
||
|
self.is_zombie,
|
||
|
self.client_stream_id,
|
||
|
chunks
|
||
|
)
|
||
|
|
||
|
def check_close_connection(self, flow):
|
||
|
# This layer only handles a single stream.
|
||
|
# RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream.
|
||
|
return True
|
||
|
|
||
|
def connect(self):
|
||
|
raise ValueError("CONNECT inside an HTTP2 stream is not supported.")
|
||
|
|
||
|
def set_server(self, *args, **kwargs):
|
||
|
# do not mess with the server connection - all streams share it.
|
||
|
pass
|
||
|
|
||
|
def run(self):
|
||
|
layer = HttpLayer(self, self.mode)
|
||
|
layer()
|
||
|
self.zombie = time.time()
|