handle script hooks in replay, fix tests, fix #402

This commit is contained in:
Maximilian Hils 2014-11-14 00:26:22 +01:00
parent 9b5a8af12d
commit 0c52b4e3b9
3 changed files with 54 additions and 42 deletions

View File

@ -169,6 +169,7 @@ class ClientPlaybackState:
def __init__(self, flows, exit):
self.flows, self.exit = flows, exit
self.current = None
self.testing = False # Disables actual replay for testing.
def count(self):
return len(self.flows)
@ -186,19 +187,16 @@ class ClientPlaybackState:
if flow is self.current:
self.current = None
def tick(self, master, testing=False):
"""
testing: Disables actual replay for testing.
"""
def tick(self, master):
if self.flows and not self.current:
n = self.flows.pop(0).copy()
n.response = None
n.reply = controller.DummyReply()
self.current = master.handle_request(n)
if not testing and not self.current.response:
master.replay_request(self.current) # pragma: no cover
elif self.current.response:
master.handle_response(self.current)
self.current = self.flows.pop(0).copy()
if not self.testing:
master.replay_request(self.current)
else:
self.current.reply = controller.DummyReply()
master.handle_request(self.current)
if self.current.response:
master.handle_response(self.current)
class ServerPlaybackState:
@ -371,6 +369,8 @@ class State(object):
"""
Add a request to the state. Returns the matching flow.
"""
if flow in self._flow_list: # catch flow replay
return flow
self._flow_list.append(flow)
if flow.match(self._limit):
self.view.append(flow)

View File

@ -1040,7 +1040,7 @@ class HTTPHandler(ProtocolHandler):
# inline script to set flow.stream = True
flow = self.c.channel.ask("responseheaders", flow)
if flow == KILL:
raise KillSignal
raise KillSignal()
else:
# now get the rest of the request body, if body still needs to be
# read but not streaming this response
@ -1085,7 +1085,7 @@ class HTTPHandler(ProtocolHandler):
self.process_server_address(flow) # The inline script may have changed request.host
if request_reply is None or request_reply == KILL:
return False
raise KillSignal()
if isinstance(request_reply, HTTPResponse):
flow.response = request_reply
@ -1099,7 +1099,7 @@ class HTTPHandler(ProtocolHandler):
self.c.log("response", "debug", [flow.response._assemble_first_line()])
response_reply = self.c.channel.ask("response", flow)
if response_reply is None or response_reply == KILL:
return False
raise KillSignal()
self.send_response_to_client(flow)
@ -1140,7 +1140,6 @@ class HTTPHandler(ProtocolHandler):
self.handle_error(e, flow)
except KillSignal:
self.c.log("Connection killed", "info")
flow.live = None
finally:
flow.live = None # Connection is not live anymore.
return False
@ -1437,32 +1436,43 @@ class RequestReplayThread(threading.Thread):
r = self.flow.request
form_out_backup = r.form_out
try:
# In all modes, we directly connect to the server displayed
if self.config.mode == "upstream":
server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
server = ServerConnection(server_address)
server.connect()
if r.scheme == "https":
send_connect_request(server, r.host, r.port)
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
r.form_out = "relative"
else:
r.form_out = "absolute"
self.flow.response = None
request_reply = self.channel.ask("request", self.flow)
if request_reply is None or request_reply == KILL:
raise KillSignal()
elif isinstance(request_reply, HTTPResponse):
self.flow.response = request_reply
else:
server_address = (r.host, r.port)
server = ServerConnection(server_address)
server.connect()
if r.scheme == "https":
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
r.form_out = "relative"
# In all modes, we directly connect to the server displayed
if self.config.mode == "upstream":
server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
server = ServerConnection(server_address)
server.connect()
if r.scheme == "https":
send_connect_request(server, r.host, r.port)
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
r.form_out = "relative"
else:
r.form_out = "absolute"
else:
server_address = (r.host, r.port)
server = ServerConnection(server_address)
server.connect()
if r.scheme == "https":
server.establish_ssl(self.config.clientcerts, sni=self.flow.server_conn.sni)
r.form_out = "relative"
server.send(r.assemble())
self.flow.server_conn = server
self.flow.response = HTTPResponse.from_stream(server.rfile, r.method,
body_size_limit=self.config.body_size_limit)
self.channel.ask("response", self.flow)
except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v:
server.send(r.assemble())
self.flow.server_conn = server
self.flow.response = HTTPResponse.from_stream(server.rfile, r.method,
body_size_limit=self.config.body_size_limit)
response_reply = self.channel.ask("response", self.flow)
if response_reply is None or response_reply == KILL:
raise KillSignal()
except (proxy.ProxyError, http.HttpError, tcp.NetLibError) as v:
self.flow.error = Error(repr(v))
self.channel.ask("error", self.flow)
except KillSignal:
self.channel.tell("log", proxy.Log("Connection killed", "info"))
finally:
r.form_out = form_out_backup

View File

@ -86,19 +86,20 @@ class TestClientPlaybackState:
fm = flow.FlowMaster(None, s)
fm.start_client_playback([first, tutils.tflow()], True)
c = fm.client_playback
c.testing = True
assert not c.done()
assert not s.flow_count()
assert c.count() == 2
c.tick(fm, testing=True)
c.tick(fm)
assert s.flow_count()
assert c.count() == 1
c.tick(fm, testing=True)
c.tick(fm)
assert c.count() == 1
c.clear(c.current)
c.tick(fm, testing=True)
c.tick(fm)
assert c.count() == 0
c.clear(c.current)
assert c.done()
@ -696,6 +697,7 @@ class TestFlowMaster:
fm = flow.FlowMaster(DummyServer(ProxyConfig()), s)
assert not fm.start_server_playback(pb, False, [], False, False, None, False)
assert not fm.start_client_playback(pb, False)
fm.client_playback.testing = True
q = Queue.Queue()
assert not fm.state.flow_count()