From 0c52b4e3b9137e22f6c8c649d6b584d44d7f4b75 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 14 Nov 2014 00:26:22 +0100 Subject: [PATCH] handle script hooks in replay, fix tests, fix #402 --- libmproxy/flow.py | 24 +++++++------- libmproxy/protocol/http.py | 64 ++++++++++++++++++++++---------------- test/test_flow.py | 8 +++-- 3 files changed, 54 insertions(+), 42 deletions(-) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 3d5a6a360..007136985 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -169,6 +169,7 @@ class ClientPlaybackState: def __init__(self, flows, exit): self.flows, self.exit = flows, exit self.current = None + self.testing = False # Disables actual replay for testing. def count(self): return len(self.flows) @@ -186,19 +187,16 @@ class ClientPlaybackState: if flow is self.current: self.current = None - def tick(self, master, testing=False): - """ - testing: Disables actual replay for testing. - """ + def tick(self, master): if self.flows and not self.current: - n = self.flows.pop(0).copy() - n.response = None - n.reply = controller.DummyReply() - self.current = master.handle_request(n) - if not testing and not self.current.response: - master.replay_request(self.current) # pragma: no cover - elif self.current.response: - master.handle_response(self.current) + self.current = self.flows.pop(0).copy() + if not self.testing: + master.replay_request(self.current) + else: + self.current.reply = controller.DummyReply() + master.handle_request(self.current) + if self.current.response: + master.handle_response(self.current) class ServerPlaybackState: @@ -371,6 +369,8 @@ class State(object): """ Add a request to the state. Returns the matching flow. """ + if flow in self._flow_list: # catch flow replay + return flow self._flow_list.append(flow) if flow.match(self._limit): self.view.append(flow) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index c8974d251..26a94040c 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1040,7 +1040,7 @@ class HTTPHandler(ProtocolHandler): # inline script to set flow.stream = True flow = self.c.channel.ask("responseheaders", flow) if flow == KILL: - raise KillSignal + raise KillSignal() else: # now get the rest of the request body, if body still needs to be # read but not streaming this response @@ -1085,7 +1085,7 @@ class HTTPHandler(ProtocolHandler): self.process_server_address(flow) # The inline script may have changed request.host if request_reply is None or request_reply == KILL: - return False + raise KillSignal() if isinstance(request_reply, HTTPResponse): flow.response = request_reply @@ -1099,7 +1099,7 @@ class HTTPHandler(ProtocolHandler): self.c.log("response", "debug", [flow.response._assemble_first_line()]) response_reply = self.c.channel.ask("response", flow) if response_reply is None or response_reply == KILL: - return False + raise KillSignal() self.send_response_to_client(flow) @@ -1140,7 +1140,6 @@ class HTTPHandler(ProtocolHandler): self.handle_error(e, flow) except KillSignal: self.c.log("Connection killed", "info") - flow.live = None finally: flow.live = None # Connection is not live anymore. return False @@ -1437,32 +1436,43 @@ class RequestReplayThread(threading.Thread): r = self.flow.request form_out_backup = r.form_out try: - # In all modes, we directly connect to the server displayed - if self.config.mode == "upstream": - server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:] - server = ServerConnection(server_address) - server.connect() - if r.scheme == "https": - send_connect_request(server, r.host, r.port) - server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) - r.form_out = "relative" - else: - r.form_out = "absolute" + self.flow.response = None + request_reply = self.channel.ask("request", self.flow) + if request_reply is None or request_reply == KILL: + raise KillSignal() + elif isinstance(request_reply, HTTPResponse): + self.flow.response = request_reply else: - server_address = (r.host, r.port) - server = ServerConnection(server_address) - server.connect() - if r.scheme == "https": - server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) - r.form_out = "relative" + # In all modes, we directly connect to the server displayed + if self.config.mode == "upstream": + server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:] + server = ServerConnection(server_address) + server.connect() + if r.scheme == "https": + send_connect_request(server, r.host, r.port) + server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) + r.form_out = "relative" + else: + r.form_out = "absolute" + else: + server_address = (r.host, r.port) + server = ServerConnection(server_address) + server.connect() + if r.scheme == "https": + server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni) + r.form_out = "relative" - server.send(r.assemble()) - self.flow.server_conn = server - self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, - body_size_limit=self.config.body_size_limit) - self.channel.ask("response", self.flow) - except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v: + server.send(r.assemble()) + self.flow.server_conn = server + self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, + body_size_limit=self.config.body_size_limit) + response_reply = self.channel.ask("response", self.flow) + if response_reply is None or response_reply == KILL: + raise KillSignal() + except (proxy.ProxyError, http.HttpError, tcp.NetLibError) as v: self.flow.error = Error(repr(v)) self.channel.ask("error", self.flow) + except KillSignal: + self.channel.tell("log", proxy.Log("Connection killed", "info")) finally: r.form_out = form_out_backup diff --git a/test/test_flow.py b/test/test_flow.py index 8c197153d..22abb4d41 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -86,19 +86,20 @@ class TestClientPlaybackState: fm = flow.FlowMaster(None, s) fm.start_client_playback([first, tutils.tflow()], True) c = fm.client_playback + c.testing = True assert not c.done() assert not s.flow_count() assert c.count() == 2 - c.tick(fm, testing=True) + c.tick(fm) assert s.flow_count() assert c.count() == 1 - c.tick(fm, testing=True) + c.tick(fm) assert c.count() == 1 c.clear(c.current) - c.tick(fm, testing=True) + c.tick(fm) assert c.count() == 0 c.clear(c.current) assert c.done() @@ -696,6 +697,7 @@ class TestFlowMaster: fm = flow.FlowMaster(DummyServer(ProxyConfig()), s) assert not fm.start_server_playback(pb, False, [], False, False, None, False) assert not fm.start_client_playback(pb, False) + fm.client_playback.testing = True q = Queue.Queue() assert not fm.state.flow_count()