Merge pull request #1592 from cortesi/ws

docs and API-related cleanups
This commit is contained in:
Aldo Cortesi 2016-10-04 10:54:15 +11:00 committed by GitHub
commit 3d5b811994
11 changed files with 66 additions and 53 deletions

View File

@ -128,7 +128,7 @@ HTTP Events
WebSockets Events
^^^^^^^^^^^^^^^^^
.. py:function:: websockets_handshake(context, flow)
.. py:function:: websocket_handshake(context, flow)
Called when a client wants to establish a WebSockets connection.
The WebSockets-specific headers can be manipulated to manipulate the handshake.

View File

@ -25,10 +25,11 @@ Events = frozenset([
"tcp_close",
"request",
"requestheaders",
"response",
"responseheaders",
"websockets_handshake",
"websocket_handshake",
"next_layer",
@ -89,10 +90,11 @@ class Master(object):
mitmproxy_ctx.master = None
mitmproxy_ctx.log = None
def add_log(self, e, level="info"):
def add_log(self, e, level):
"""
level: debug, info, warn, error
"""
pass
def add_server(self, server):
# We give a Channel to the server which can be used to communicate with the master

View File

@ -18,6 +18,7 @@ from mitmproxy.protocol import http_replay
def event_sequence(f):
if isinstance(f, models.HTTPFlow):
if f.request:
yield "requestheaders", f
yield "request", f
if f.response:
yield "responseheaders", f
@ -215,6 +216,10 @@ class FlowMaster(controller.Master):
def error(self, f):
self.state.update_flow(f)
@controller.handler
def requestheaders(self, f):
pass
@controller.handler
def request(self, f):
if f.live:
@ -246,7 +251,7 @@ class FlowMaster(controller.Master):
self.state.update_flow(f)
@controller.handler
def websockets_handshake(self, f):
def websocket_handshake(self, f):
pass
def handle_intercept(self, f):

View File

@ -16,23 +16,23 @@ from netlib import websockets
class _HttpTransmissionLayer(base.Layer):
def read_request(self):
def read_request_headers(self):
raise NotImplementedError()
def read_request_body(self, request):
raise NotImplementedError()
def read_request(self):
request = self.read_request_headers()
request.data.content = b"".join(
self.read_request_body(request)
)
request.timestamp_end = time.time()
return request
def send_request(self, request):
raise NotImplementedError()
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
return response
def read_response_headers(self):
raise NotImplementedError()
@ -40,6 +40,13 @@ class _HttpTransmissionLayer(base.Layer):
raise NotImplementedError()
yield "this is a generator" # pragma: no cover
def read_response(self, request):
response = self.read_response_headers()
response.data.content = b"".join(
self.read_response_body(request, response)
)
return response
def send_response(self, response):
if response.data.content is None:
raise netlib.exceptions.HttpException("Cannot assemble flow with missing content")
@ -140,8 +147,9 @@ class HttpLayer(base.Layer):
self.__initial_server_tls = self.server_tls
self.__initial_server_conn = self.server_conn
while True:
flow = models.HTTPFlow(self.client_conn, self.server_conn, live=self)
try:
request = self.get_request_from_client()
request = self.get_request_from_client(flow)
# Make sure that the incoming request matches our expectations
self.validate_request(request)
except netlib.exceptions.HttpReadDisconnect:
@ -168,7 +176,6 @@ class HttpLayer(base.Layer):
if not (self.http_authenticated or self.authenticate(request)):
return
flow = models.HTTPFlow(self.client_conn, self.server_conn, live=self)
flow.request = request
try:
@ -200,7 +207,7 @@ class HttpLayer(base.Layer):
if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
# We only support RFC6455 with WebSockets version 13
# allow inline scripts to manipulate the client handshake
self.channel.ask("websockets_handshake", flow)
self.channel.ask("websocket_handshake", flow)
if not flow.response:
self.establish_server_connection(
@ -212,10 +219,10 @@ class HttpLayer(base.Layer):
else:
# response was set by an inline script.
# we now need to emulate the responseheaders hook.
flow = self.channel.ask("responseheaders", flow)
self.channel.ask("responseheaders", flow)
self.log("response", "debug", [repr(flow.response)])
flow = self.channel.ask("response", flow)
self.channel.ask("response", flow)
self.send_response_to_client(flow)
if self.check_close_connection(flow):
@ -243,13 +250,16 @@ class HttpLayer(base.Layer):
if flow:
flow.live = False
def get_request_from_client(self):
def get_request_from_client(self, flow):
request = self.read_request()
flow.request = request
self.channel.ask("requestheaders", flow)
if request.headers.get("expect", "").lower() == "100-continue":
# TODO: We may have to use send_response_headers for HTTP2 here.
self.send_response(models.expect_continue_response)
request.headers.pop("expect")
request.body = b"".join(self.read_request_body(request))
request.timestamp_end = time.time()
return request
def send_error_response(self, code, message, headers=None):
@ -329,7 +339,7 @@ class HttpLayer(base.Layer):
# call the appropriate script hook - this is an opportunity for an
# inline script to set flow.stream = True
flow = self.channel.ask("responseheaders", flow)
self.channel.ask("responseheaders", flow)
if flow.response.stream:
flow.response.data.content = None
@ -362,11 +372,7 @@ class HttpLayer(base.Layer):
if host_header:
flow.request.headers["host"] = host_header
flow.request.scheme = "https" if self.__initial_server_tls else "http"
request_reply = self.channel.ask("request", flow)
if isinstance(request_reply, models.HTTPResponse):
flow.response = request_reply
return
self.channel.ask("request", flow)
def establish_server_connection(self, host, port, scheme):
address = tcp.Address((host, port))

View File

@ -11,11 +11,10 @@ class Http1Layer(http._HttpTransmissionLayer):
super(Http1Layer, self).__init__(ctx)
self.mode = mode
def read_request(self):
req = http1.read_request(
self.client_conn.rfile, body_size_limit=self.config.options.body_size_limit
def read_request_headers(self):
return models.HTTPRequest.wrap(
http1.read_request_head(self.client_conn.rfile)
)
return models.HTTPRequest.wrap(req)
def read_request_body(self, request):
expected_size = http1.expected_http_body_size(request)

View File

@ -162,6 +162,7 @@ class Http2Layer(base.Layer):
self.streams[eid].priority_weight = event.priority_updated.weight
self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start()
self.streams[eid].request_arrived.set()
return True
def _handle_response_received(self, eid, event):
@ -248,6 +249,7 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time()
self.streams[event.pushed_stream_id].request_arrived.set()
self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].start()
return True
@ -376,6 +378,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.timestamp_start = None
self.timestamp_end = None
self.request_arrived = threading.Event()
self.request_data_queue = queue.Queue()
self.request_queued_data_length = 0
self.request_data_finished = threading.Event()
@ -396,6 +399,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if not self.zombie:
self.zombie = time.time()
self.request_data_finished.set()
self.request_arrived.set()
self.response_arrived.set()
self.response_data_finished.set()
@ -446,16 +450,10 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
raise exceptions.Http2ZombieException("Connection already dead")
@detect_zombie_stream
def read_request(self):
self.request_data_finished.wait()
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
data = b"".join(data)
def read_request_headers(self):
self.request_arrived.wait()
self.raise_zombie()
first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
return models.HTTPRequest(
first_line_format,
method,
@ -465,13 +463,18 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
path,
b"HTTP/2.0",
self.request_headers,
data,
None,
timestamp_start=self.timestamp_start,
timestamp_end=self.timestamp_end,
)
def read_request_body(self, request): # pragma: no cover
raise NotImplementedError()
@detect_zombie_stream
def read_request_body(self, request):
self.request_data_finished.wait()
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
return data
@detect_zombie_stream
def send_request(self, message):

View File

@ -112,7 +112,6 @@ class RootContext(object):
"""
Send a log message to the master.
"""
full_msg = [
"{}: {}".format(repr(self.client_conn.address), msg)
]

View File

@ -174,12 +174,12 @@ class PathodHandler(tcp.BaseHandler):
m = utils.MemBool()
valid_websockets_handshake = websockets.check_handshake(headers)
valid_websocket_handshake = websockets.check_handshake(headers)
self.settings.websocket_key = websockets.get_client_key(headers)
# If this is a websocket initiation, we respond with a proper
# server response, unless over-ridden.
if valid_websockets_handshake:
if valid_websocket_handshake:
anchor_gen = language.parse_pathod("ws")
else:
anchor_gen = None
@ -226,7 +226,7 @@ class PathodHandler(tcp.BaseHandler):
spec,
lg
)
if nexthandler and valid_websockets_handshake:
if nexthandler and valid_websocket_handshake:
self.protocol = protocols.websockets.WebsocketsProtocol(self)
return self.protocol.handle_websocket, retlog
else:

View File

@ -250,13 +250,13 @@ class HTTP2StateProtocol(object):
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
print(">> " + repr(frm))
def read_frame(self, hide=False):
while True:
frm = http2.parse_frame(*http2.read_raw_frame(self.tcp_handler.rfile))
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
print("<< " + repr(frm))
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
@ -341,7 +341,7 @@ class HTTP2StateProtocol(object):
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
print(">> ", repr(frm))
return [frm.serialize() for frm in frms]
@ -359,7 +359,7 @@ class HTTP2StateProtocol(object):
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
print(">> ", repr(frm))
return [frm.serialize() for frm in frms]

View File

@ -175,7 +175,7 @@ class TestScriptLoader(mastertest.MasterTest):
), [f]
)
evts = [i[1] for i in sc.ns.call_log]
assert evts == ['start', 'request', 'responseheaders', 'response', 'done']
assert evts == ['start', 'requestheaders', 'request', 'responseheaders', 'response', 'done']
with m.handlecontext():
tutils.raises(

View File

@ -807,8 +807,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
class MasterFakeResponse(tservers.TestMaster):
@controller.handler
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
f.reply.send(resp)
f.response = HTTPResponse.wrap(netlib.tutils.tresp())
class TestFakeResponse(tservers.HTTPProxyTest):
@ -889,7 +888,7 @@ class MasterIncomplete(tservers.TestMaster):
def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None
f.reply.send(resp)
f.response = resp
class TestIncompleteResponse(tservers.HTTPProxyTest):