Merge pull request #1592 from cortesi/ws

docs and API-related cleanups
This commit is contained in:
Aldo Cortesi 2016-10-04 10:54:15 +11:00 committed by GitHub
commit 3d5b811994
11 changed files with 66 additions and 53 deletions

View File

@ -128,7 +128,7 @@ HTTP Events
WebSockets Events WebSockets Events
^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^
.. py:function:: websockets_handshake(context, flow) .. py:function:: websocket_handshake(context, flow)
Called when a client wants to establish a WebSockets connection. Called when a client wants to establish a WebSockets connection.
The WebSockets-specific headers can be manipulated to manipulate the handshake. The WebSockets-specific headers can be manipulated to manipulate the handshake.

View File

@ -25,10 +25,11 @@ Events = frozenset([
"tcp_close", "tcp_close",
"request", "request",
"requestheaders",
"response", "response",
"responseheaders", "responseheaders",
"websockets_handshake", "websocket_handshake",
"next_layer", "next_layer",
@ -89,10 +90,11 @@ class Master(object):
mitmproxy_ctx.master = None mitmproxy_ctx.master = None
mitmproxy_ctx.log = None mitmproxy_ctx.log = None
def add_log(self, e, level="info"): def add_log(self, e, level):
""" """
level: debug, info, warn, error level: debug, info, warn, error
""" """
pass
def add_server(self, server): def add_server(self, server):
# We give a Channel to the server which can be used to communicate with the master # We give a Channel to the server which can be used to communicate with the master

View File

@ -18,6 +18,7 @@ from mitmproxy.protocol import http_replay
def event_sequence(f): def event_sequence(f):
if isinstance(f, models.HTTPFlow): if isinstance(f, models.HTTPFlow):
if f.request: if f.request:
yield "requestheaders", f
yield "request", f yield "request", f
if f.response: if f.response:
yield "responseheaders", f yield "responseheaders", f
@ -215,6 +216,10 @@ class FlowMaster(controller.Master):
def error(self, f): def error(self, f):
self.state.update_flow(f) self.state.update_flow(f)
@controller.handler
def requestheaders(self, f):
pass
@controller.handler @controller.handler
def request(self, f): def request(self, f):
if f.live: if f.live:
@ -246,7 +251,7 @@ class FlowMaster(controller.Master):
self.state.update_flow(f) self.state.update_flow(f)
@controller.handler @controller.handler
def websockets_handshake(self, f): def websocket_handshake(self, f):
pass pass
def handle_intercept(self, f): def handle_intercept(self, f):

View File

@ -16,23 +16,23 @@ from netlib import websockets
class _HttpTransmissionLayer(base.Layer): class _HttpTransmissionLayer(base.Layer):
def read_request_headers(self):
def read_request(self):
raise NotImplementedError() raise NotImplementedError()
def read_request_body(self, request): def read_request_body(self, request):
raise NotImplementedError() raise NotImplementedError()
def read_request(self):
request = self.read_request_headers()
request.data.content = b"".join(
self.read_request_body(request)
)
request.timestamp_end = time.time()
return request
def send_request(self, request): def send_request(self, request):
raise NotImplementedError() raise NotImplementedError()
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
return response
def read_response_headers(self): def read_response_headers(self):
raise NotImplementedError() raise NotImplementedError()
@ -40,6 +40,13 @@ class _HttpTransmissionLayer(base.Layer):
raise NotImplementedError() raise NotImplementedError()
yield "this is a generator" # pragma: no cover yield "this is a generator" # pragma: no cover
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
return response
def send_response(self, response): def send_response(self, response):
if response.data.content is None: if response.data.content is None:
raise netlib.exceptions.HttpException("Cannot assemble flow with missing content") raise netlib.exceptions.HttpException("Cannot assemble flow with missing content")
@ -140,8 +147,9 @@ class HttpLayer(base.Layer):
self.__initial_server_tls = self.server_tls self.__initial_server_tls = self.server_tls
self.__initial_server_conn = self.server_conn self.__initial_server_conn = self.server_conn
while True: while True:
flow = models.HTTPFlow(self.client_conn, self.server_conn, live=self)
try: try:
request = self.get_request_from_client() request = self.get_request_from_client(flow)
# Make sure that the incoming request matches our expectations # Make sure that the incoming request matches our expectations
self.validate_request(request) self.validate_request(request)
except netlib.exceptions.HttpReadDisconnect: except netlib.exceptions.HttpReadDisconnect:
@ -168,7 +176,6 @@ class HttpLayer(base.Layer):
if not (self.http_authenticated or self.authenticate(request)): if not (self.http_authenticated or self.authenticate(request)):
return return
flow = models.HTTPFlow(self.client_conn, self.server_conn, live=self)
flow.request = request flow.request = request
try: try:
@ -200,7 +207,7 @@ class HttpLayer(base.Layer):
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers): if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
# We only support RFC6455 with WebSockets version 13 # We only support RFC6455 with WebSockets version 13
# allow inline scripts to manipulate the client handshake # allow inline scripts to manipulate the client handshake
self.channel.ask("websockets_handshake", flow) self.channel.ask("websocket_handshake", flow)
if not flow.response: if not flow.response:
self.establish_server_connection( self.establish_server_connection(
@ -212,10 +219,10 @@ class HttpLayer(base.Layer):
else: else:
# response was set by an inline script. # response was set by an inline script.
# we now need to emulate the responseheaders hook. # we now need to emulate the responseheaders hook.
flow = self.channel.ask("responseheaders", flow) self.channel.ask("responseheaders", flow)
self.log("response", "debug", [repr(flow.response)]) self.log("response", "debug", [repr(flow.response)])
flow = self.channel.ask("response", flow) self.channel.ask("response", flow)
self.send_response_to_client(flow) self.send_response_to_client(flow)
if self.check_close_connection(flow): if self.check_close_connection(flow):
@ -243,13 +250,16 @@ class HttpLayer(base.Layer):
if flow: if flow:
flow.live = False flow.live = False
def get_request_from_client(self): def get_request_from_client(self, flow):
request = self.read_request() request = self.read_request()
flow.request = request
self.channel.ask("requestheaders", flow)
if request.headers.get("expect", "").lower() == "100-continue": if request.headers.get("expect", "").lower() == "100-continue":
# TODO: We may have to use send_response_headers for HTTP2 here. # TODO: We may have to use send_response_headers for HTTP2 here.
self.send_response(models.expect_continue_response) self.send_response(models.expect_continue_response)
request.headers.pop("expect") request.headers.pop("expect")
request.body = b"".join(self.read_request_body(request)) request.body = b"".join(self.read_request_body(request))
request.timestamp_end = time.time()
return request return request
def send_error_response(self, code, message, headers=None): def send_error_response(self, code, message, headers=None):
@ -329,7 +339,7 @@ class HttpLayer(base.Layer):
# call the appropriate script hook - this is an opportunity for an # call the appropriate script hook - this is an opportunity for an
# inline script to set flow.stream = True # inline script to set flow.stream = True
flow = self.channel.ask("responseheaders", flow) self.channel.ask("responseheaders", flow)
if flow.response.stream: if flow.response.stream:
flow.response.data.content = None flow.response.data.content = None
@ -362,11 +372,7 @@ class HttpLayer(base.Layer):
if host_header: if host_header:
flow.request.headers["host"] = host_header flow.request.headers["host"] = host_header
flow.request.scheme = "https" if self.__initial_server_tls else "http" flow.request.scheme = "https" if self.__initial_server_tls else "http"
self.channel.ask("request", flow)
request_reply = self.channel.ask("request", flow)
if isinstance(request_reply, models.HTTPResponse):
flow.response = request_reply
return
def establish_server_connection(self, host, port, scheme): def establish_server_connection(self, host, port, scheme):
address = tcp.Address((host, port)) address = tcp.Address((host, port))

View File

@ -11,11 +11,10 @@ class Http1Layer(http._HttpTransmissionLayer):
super(Http1Layer, self).__init__(ctx) super(Http1Layer, self).__init__(ctx)
self.mode = mode self.mode = mode
def read_request(self): def read_request_headers(self):
req = http1.read_request( return models.HTTPRequest.wrap(
self.client_conn.rfile, body_size_limit=self.config.options.body_size_limit http1.read_request_head(self.client_conn.rfile)
) )
return models.HTTPRequest.wrap(req)
def read_request_body(self, request): def read_request_body(self, request):
expected_size = http1.expected_http_body_size(request) expected_size = http1.expected_http_body_size(request)

View File

@ -162,6 +162,7 @@ class Http2Layer(base.Layer):
self.streams[eid].priority_weight = event.priority_updated.weight self.streams[eid].priority_weight = event.priority_updated.weight
self.streams[eid].handled_priority_event = event.priority_updated self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start() self.streams[eid].start()
self.streams[eid].request_arrived.set()
return True return True
def _handle_response_received(self, eid, event): def _handle_response_received(self, eid, event):
@ -248,6 +249,7 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time() self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].request_arrived.set()
self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].start() self.streams[event.pushed_stream_id].start()
return True return True
@ -376,6 +378,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.timestamp_start = None self.timestamp_start = None
self.timestamp_end = None self.timestamp_end = None
self.request_arrived = threading.Event()
self.request_data_queue = queue.Queue() self.request_data_queue = queue.Queue()
self.request_queued_data_length = 0 self.request_queued_data_length = 0
self.request_data_finished = threading.Event() self.request_data_finished = threading.Event()
@ -396,6 +399,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if not self.zombie: if not self.zombie:
self.zombie = time.time() self.zombie = time.time()
self.request_data_finished.set() self.request_data_finished.set()
self.request_arrived.set()
self.response_arrived.set() self.response_arrived.set()
self.response_data_finished.set() self.response_data_finished.set()
@ -446,16 +450,10 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
raise exceptions.Http2ZombieException("Connection already dead") raise exceptions.Http2ZombieException("Connection already dead")
@detect_zombie_stream @detect_zombie_stream
def read_request(self): def read_request_headers(self):
self.request_data_finished.wait() self.request_arrived.wait()
self.raise_zombie()
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
data = b"".join(data)
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers) first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
return models.HTTPRequest( return models.HTTPRequest(
first_line_format, first_line_format,
method, method,
@ -465,13 +463,18 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
path, path,
b"HTTP/2.0", b"HTTP/2.0",
self.request_headers, self.request_headers,
data, None,
timestamp_start=self.timestamp_start, timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end, timestamp_end=self.timestamp_end,
) )
def read_request_body(self, request): # pragma: no cover @detect_zombie_stream
raise NotImplementedError() def read_request_body(self, request):
self.request_data_finished.wait()
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
return data
@detect_zombie_stream @detect_zombie_stream
def send_request(self, message): def send_request(self, message):

View File

@ -112,7 +112,6 @@ class RootContext(object):
""" """
Send a log message to the master. Send a log message to the master.
""" """
full_msg = [ full_msg = [
"{}: {}".format(repr(self.client_conn.address), msg) "{}: {}".format(repr(self.client_conn.address), msg)
] ]

View File

@ -174,12 +174,12 @@ class PathodHandler(tcp.BaseHandler):
m = utils.MemBool() m = utils.MemBool()
valid_websockets_handshake = websockets.check_handshake(headers) valid_websocket_handshake = websockets.check_handshake(headers)
self.settings.websocket_key = websockets.get_client_key(headers) self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper # If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden. # server response, unless over-ridden.
if valid_websockets_handshake: if valid_websocket_handshake:
anchor_gen = language.parse_pathod("ws") anchor_gen = language.parse_pathod("ws")
else: else:
anchor_gen = None anchor_gen = None
@ -226,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec, spec,
lg lg
) )
if nexthandler and valid_websockets_handshake: if nexthandler and valid_websocket_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self) self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog return self.protocol.handle_websocket, retlog
else: else:

View File

@ -250,13 +250,13 @@ class HTTP2StateProtocol(object):
self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush() self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>")) print(">> " + repr(frm))
def read_frame(self, hide=False): def read_frame(self, hide=False):
while True: while True:
frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile)) frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile))
if not hide and self.dump_frames: # pragma no cover if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<")) print("<< " + repr(frm))
if isinstance(frm, hyperframe.frame.PingFrame): if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
@ -341,7 +341,7 @@ class HTTP2StateProtocol(object):
if self.dump_frames: # pragma no cover if self.dump_frames: # pragma no cover
for frm in frms: for frm in frms:
print(frm.human_readable(">>")) print(">> ", repr(frm))
return [frm.serialize() for frm in frms] return [frm.serialize() for frm in frms]
@ -359,7 +359,7 @@ class HTTP2StateProtocol(object):
if self.dump_frames: # pragma no cover if self.dump_frames: # pragma no cover
for frm in frms: for frm in frms:
print(frm.human_readable(">>")) print(">> ", repr(frm))
return [frm.serialize() for frm in frms] return [frm.serialize() for frm in frms]

View File

@ -175,7 +175,7 @@ class TestScriptLoader(mastertest.MasterTest):
), [f] ), [f]
) )
evts = [i[1] for i in sc.ns.call_log] evts = [i[1] for i in sc.ns.call_log]
assert evts == ['start', 'request', 'responseheaders', 'response', 'done'] assert evts == ['start', 'requestheaders', 'request', 'responseheaders', 'response', 'done']
with m.handlecontext(): with m.handlecontext():
tutils.raises( tutils.raises(

View File

@ -807,8 +807,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
class MasterFakeResponse(tservers.TestMaster): class MasterFakeResponse(tservers.TestMaster):
@controller.handler @controller.handler
def request(self, f): def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp()) f.response = HTTPResponse.wrap(netlib.tutils.tresp())
f.reply.send(resp)
class TestFakeResponse(tservers.HTTPProxyTest): class TestFakeResponse(tservers.HTTPProxyTest):
@ -889,7 +888,7 @@ class MasterIncomplete(tservers.TestMaster):
def request(self, f): def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None resp.content = None
f.reply.send(resp) f.response = resp
class TestIncompleteResponse(tservers.HTTPProxyTest): class TestIncompleteResponse(tservers.HTTPProxyTest):