diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py index 329d1215f..2658c0af4 100644 --- a/mitmproxy/addons.py +++ b/mitmproxy/addons.py @@ -4,7 +4,7 @@ import pprint def _get_name(itm): - return getattr(itm, "name", itm.__class__.__name__) + return getattr(itm, "name", itm.__class__.__name__.lower()) class Addons(object): @@ -13,6 +13,16 @@ class Addons(object): self.master = master master.options.changed.connect(self.options_update) + def get(self, name): + """ + Retrieve an addon by name. Addon names are equal to the .name + attribute on the instance, or the lower case class name if that + does not exist. + """ + for i in self.chain: + if name == _get_name(i): + return i + def options_update(self, options, updated): for i in self.chain: with self.master.handlecontext(): @@ -39,14 +49,6 @@ class Addons(object): for i in self.chain: self.invoke_with_context(i, "done") - def has_addon(self, name): - """ - Is an addon with this name registered? - """ - for i in self.chain: - if _get_name(i) == name: - return True - def __len__(self): return len(self.chain) diff --git a/mitmproxy/builtins/serverplayback.py b/mitmproxy/builtins/serverplayback.py index fe56d68b2..be82cad95 100644 --- a/mitmproxy/builtins/serverplayback.py +++ b/mitmproxy/builtins/serverplayback.py @@ -88,13 +88,14 @@ class ServerPlayback(object): def configure(self, options, updated): self.options = options - if options.server_replay and "server_replay" in updated: - try: - flows = flow.read_flows_from_paths(options.server_replay) - except exceptions.FlowReadException as e: - raise exceptions.OptionsError(str(e)) + if "server_replay" in updated: self.clear() - self.load(flows) + if options.server_replay: + try: + flows = flow.read_flows_from_paths(options.server_replay) + except exceptions.FlowReadException as e: + raise exceptions.OptionsError(str(e)) + self.load(flows) # FIXME: These options have to be renamed to something more sensible - # prefixed with serverplayback_ where appropriate, and playback_ where diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index a6942ca40..1cb3a32b5 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -248,9 +248,6 @@ class ConsoleMaster(flow.FlowMaster): if options.client_replay: self.client_playback_path(options.client_replay) - if options.server_replay: - self.server_playback_path(options.server_replay) - self.view_stack = [] if options.app: @@ -391,21 +388,6 @@ class ConsoleMaster(flow.FlowMaster): if flows: self.start_client_playback(flows, False) - def server_playback_path(self, path): - if not isinstance(path, list): - path = [path] - flows = self._readflows(path) - if flows: - self.start_server_playback( - flows, - self.options.kill, self.options.rheaders, - False, self.options.nopop, - self.options.replay_ignore_params, - self.options.replay_ignore_content, - self.options.replay_ignore_payload_params, - self.options.replay_ignore_host - ) - def spawn_editor(self, data): text = not isinstance(data, bytes) fd, name = tempfile.mkstemp('', "mproxy", text=text) diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index 43d68d51a..6c4cc8b52 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -147,14 +147,12 @@ class StatusBar(urwid.WidgetWrap): if self.master.client_playback: r.append("[") r.append(("heading_key", "cplayback")) - r.append(":%s to go]" % self.master.client_playback.count()) - if self.master.server_playback: + r.append(":%s]" % self.master.client_playback.count()) + if self.master.options.server_replay: r.append("[") r.append(("heading_key", "splayback")) - if self.master.options.nopop: - r.append(":%s in file]" % self.master.server_playback.count()) - else: - r.append(":%s to go]" % self.master.server_playback.count()) + a = self.master.addons.get("serverplayback") + r.append(":%s]" % a.count()) if self.master.options.ignore_hosts: r.append("[") r.append(("heading_key", "I")) diff --git a/mitmproxy/console/window.py b/mitmproxy/console/window.py index 355936431..159f68ed6 100644 --- a/mitmproxy/console/window.py +++ b/mitmproxy/console/window.py @@ -57,13 +57,11 @@ class Window(urwid.Frame): callback = self.master.stop_client_playback_prompt, ) elif k == "s": - if not self.master.server_playback: - signals.status_prompt_path.send( - self, - prompt = "Server replay path", - callback = self.master.server_playback_path - ) - else: + a = self.master.addons.get("serverplayback") + if a.count(): + def stop_server_playback(response): + if response == "y": + self.master.options.server_replay = [] signals.status_prompt_onekey.send( self, prompt = "Stop current server replay?", @@ -71,7 +69,13 @@ class Window(urwid.Frame): ("yes", "y"), ("no", "n"), ), - callback = self.master.stop_server_playback_prompt, + callback = stop_server_playback + ) + else: + signals.status_prompt_path.send( + self, + prompt = "Server playback path", + callback = lambda x: self.master.options.setter("server_replay")([x]) ) def keypress(self, size, k): diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index bfde06c5c..877eaa22c 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -33,6 +33,7 @@ class RequestReplayThread(basethread.BaseThread): def run(self): r = self.flow.request first_line_format_backup = r.first_line_format + server = None try: self.flow.response = None @@ -103,3 +104,5 @@ class RequestReplayThread(basethread.BaseThread): self.channel.tell("log", Log(traceback.format_exc(), "error")) finally: r.first_line_format = first_line_format_backup + if server: + server.finish() diff --git a/test/mitmproxy/protocol/test_http1.py b/test/mitmproxy/protocol/test_http1.py index 7d04c56b7..2fc4ac635 100644 --- a/test/mitmproxy/protocol/test_http1.py +++ b/test/mitmproxy/protocol/test_http1.py @@ -18,14 +18,15 @@ class TestInvalidRequests(tservers.HTTPProxyTest): def test_double_connect(self): p = self.pathoc() - r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) + with p.connect(): + r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) assert r.status_code == 400 assert b"Invalid HTTP request form" in r.content def test_relative_request(self): p = self.pathoc_raw() - p.connect() - r = p.request("get:/p/200") + with p.connect(): + r = p.request("get:/p/200") assert r.status_code == 400 assert b"Invalid HTTP request form" in r.content @@ -61,5 +62,8 @@ class TestHeadContentLength(tservers.HTTPProxyTest): def test_head_content_length(self): p = self.pathoc() - resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) + with p.connect(): + resp = p.request( + """head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase + ) assert resp.headers["Content-Length"] == "42" diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py index a5085ea0b..52d7f07f3 100644 --- a/test/mitmproxy/test_addons.py +++ b/test/mitmproxy/test_addons.py @@ -17,5 +17,5 @@ def test_simple(): m = controller.Master(o) a = addons.Addons(m) a.add(o, TAddon("one")) - assert a.has_addon("one") - assert not a.has_addon("two") + assert a.get("one") + assert not a.get("two") diff --git a/test/mitmproxy/test_fuzzing.py b/test/mitmproxy/test_fuzzing.py index 27ea36a6e..905ba1cde 100644 --- a/test/mitmproxy/test_fuzzing.py +++ b/test/mitmproxy/test_fuzzing.py @@ -11,17 +11,20 @@ class TestFuzzy(tservers.HTTPProxyTest): def test_idna_err(self): req = r'get:"http://localhost:%s":i10,"\xc6"' p = self.pathoc() - assert p.request(req % self.server.port).status_code == 400 + with p.connect(): + assert p.request(req % self.server.port).status_code == 400 def test_nullbytes(self): req = r'get:"http://localhost:%s":i19,"\x00"' p = self.pathoc() - assert p.request(req % self.server.port).status_code == 400 + with p.connect(): + assert p.request(req % self.server.port).status_code == 400 def test_invalid_ipv6_url(self): req = 'get:"http://localhost:%s":i13,"["' p = self.pathoc() - resp = p.request(req % self.server.port) + with p.connect(): + resp = p.request(req % self.server.port) assert resp.status_code == 400 # def test_invalid_upstream(self): diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index e0a8da471..321bb11fc 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -91,11 +91,11 @@ class CommonMixin: def test_invalid_http(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - t.connect() - t.wfile.write(b"invalid\r\n\r\n") - t.wfile.flush() - line = t.rfile.readline() - assert (b"Bad Request" in line) or (b"Bad Gateway" in line) + with t.connect(): + t.wfile.write(b"invalid\r\n\r\n") + t.wfile.flush() + line = t.rfile.readline() + assert (b"Bad Request" in line) or (b"Bad Gateway" in line) def test_sni(self): if not self.ssl: @@ -208,20 +208,22 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): def test_app_err(self): p = self.pathoc() - ret = p.request("get:'http://errapp/'") + with p.connect(): + ret = p.request("get:'http://errapp/'") assert ret.status_code == 500 assert b"ValueError" in ret.content def test_invalid_connect(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - t.connect() - t.wfile.write(b"CONNECT invalid\n\n") - t.wfile.flush() - assert b"Bad Request" in t.rfile.readline() + with t.connect(): + t.wfile.write(b"CONNECT invalid\n\n") + t.wfile.flush() + assert b"Bad Request" in t.rfile.readline() def test_upstream_ssl_error(self): p = self.pathoc() - ret = p.request("get:'https://localhost:%s/'" % self.server.port) + with p.connect(): + ret = p.request("get:'https://localhost:%s/'" % self.server.port) assert ret.status_code == 400 def test_connection_close(self): @@ -232,25 +234,28 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): # Lets sanity check that the connection does indeed stay open by # issuing two requests over the same connection p = self.pathoc() - assert p.request("get:'%s'" % response) - assert p.request("get:'%s'" % response) + with p.connect(): + assert p.request("get:'%s'" % response) + assert p.request("get:'%s'" % response) # Now check that the connection is closed as the client specifies p = self.pathoc() - assert p.request("get:'%s':h'Connection'='close'" % response) - # There's a race here, which means we can get any of a number of errors. - # Rather than introduce yet another sleep into the test suite, we just - # relax the Exception specification. - with raises(Exception): - p.request("get:'%s'" % response) + with p.connect(): + assert p.request("get:'%s':h'Connection'='close'" % response) + # There's a race here, which means we can get any of a number of errors. + # Rather than introduce yet another sleep into the test suite, we just + # relax the Exception specification. + with raises(Exception): + p.request("get:'%s'" % response) def test_reconnect(self): req = "get:'%s/p/200:b@1:da'" % self.server.urlbase p = self.pathoc() - assert p.request(req) - # Server has disconnected. Mitmproxy should detect this, and reconnect. - assert p.request(req) - assert p.request(req) + with p.connect(): + assert p.request(req) + # Server has disconnected. Mitmproxy should detect this, and reconnect. + assert p.request(req) + assert p.request(req) def test_get_connection_switching(self): def switched(l): @@ -260,18 +265,21 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): req = "get:'%s/p/200:b@1'" p = self.pathoc() - assert p.request(req % self.server.urlbase) - assert p.request(req % self.server2.urlbase) + with p.connect(): + assert p.request(req % self.server.urlbase) + assert p.request(req % self.server2.urlbase) assert switched(self.proxy.tlog) def test_blank_leading_line(self): p = self.pathoc() - req = "get:'%s/p/201':i0,'\r\n'" - assert p.request(req % self.server.urlbase).status_code == 201 + with p.connect(): + req = "get:'%s/p/201':i0,'\r\n'" + assert p.request(req % self.server.urlbase).status_code == 201 def test_invalid_headers(self): p = self.pathoc() - resp = p.request("get:'http://foo':h':foo'='bar'") + with p.connect(): + resp = p.request("get:'http://foo':h':foo'='bar'") assert resp.status_code == 400 def test_stream(self): @@ -301,15 +309,16 @@ class TestHTTPAuth(tservers.HTTPProxyTest): self.master.options.auth_singleuser = "test:test" assert self.pathod("202").status_code == 407 p = self.pathoc() - ret = p.request(""" - get - 'http://localhost:%s/p/202' - h'%s'='%s' - """ % ( - self.server.port, - http.authentication.BasicProxyAuth.AUTH_HEADER, - authentication.assemble_http_basic_auth("basic", "test", "test") - )) + with p.connect(): + ret = p.request(""" + get + 'http://localhost:%s/p/202' + h'%s'='%s' + """ % ( + self.server.port, + http.authentication.BasicProxyAuth.AUTH_HEADER, + authentication.assemble_http_basic_auth("basic", "test", "test") + )) assert ret.status_code == 202 @@ -318,14 +327,15 @@ class TestHTTPReverseAuth(tservers.ReverseProxyTest): self.master.options.auth_singleuser = "test:test" assert self.pathod("202").status_code == 401 p = self.pathoc() - ret = p.request(""" - get - '/p/202' - h'%s'='%s' - """ % ( - http.authentication.BasicWebsiteAuth.AUTH_HEADER, - authentication.assemble_http_basic_auth("basic", "test", "test") - )) + with p.connect(): + ret = p.request(""" + get + '/p/202' + h'%s'='%s' + """ % ( + http.authentication.BasicWebsiteAuth.AUTH_HEADER, + authentication.assemble_http_basic_auth("basic", "test", "test") + )) assert ret.status_code == 202 @@ -354,7 +364,8 @@ class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin): def test_error_post_connect(self): p = self.pathoc() - assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 + with p.connect(): + assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin): @@ -389,7 +400,8 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest): def _request(self): p = self.pathoc(sni="example.mitmproxy.org") - return p.request("get:/p/242") + with p.connect(): + return p.request("get:/p/242") def test_verification_w_cadir(self): self.config.options.update( @@ -426,7 +438,8 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest): def _request(self): p = self.pathoc(sni="example.mitmproxy.org") - return p.request("get:/p/242") + with p.connect(): + return p.request("get:/p/242") @classmethod def get_options(cls): @@ -481,13 +494,15 @@ class TestSocks5(tservers.SocksModeTest): def test_simple(self): p = self.pathoc() - p.socks_connect(("localhost", self.server.port)) - f = p.request("get:/p/200") + with p.connect(): + p.socks_connect(("localhost", self.server.port)) + f = p.request("get:/p/200") assert f.status_code == 200 def test_with_authentication_only(self): p = self.pathoc() - f = p.request("get:/p/200") + with p.connect(): + f = p.request("get:/p/200") assert f.status_code == 502 assert b"SOCKS5 mode failure" in f.content @@ -496,21 +511,21 @@ class TestSocks5(tservers.SocksModeTest): mitmproxy doesn't support UDP or BIND SOCKS CMDs """ p = self.pathoc() + with p.connect(): + socks.ClientGreeting( + socks.VERSION.SOCKS5, + [socks.METHOD.NO_AUTHENTICATION_REQUIRED] + ).to_file(p.wfile) + socks.Message( + socks.VERSION.SOCKS5, + socks.CMD.BIND, + socks.ATYP.DOMAINNAME, + ("example.com", 8080) + ).to_file(p.wfile) - socks.ClientGreeting( - socks.VERSION.SOCKS5, - [socks.METHOD.NO_AUTHENTICATION_REQUIRED] - ).to_file(p.wfile) - socks.Message( - socks.VERSION.SOCKS5, - socks.CMD.BIND, - socks.ATYP.DOMAINNAME, - ("example.com", 8080) - ).to_file(p.wfile) - - p.wfile.flush() - p.rfile.read(2) # read server greeting - f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway. + p.wfile.flush() + p.rfile.read(2) # read server greeting + f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway. assert f.status_code == 502 assert b"SOCKS5 mode failure" in f.content @@ -531,21 +546,23 @@ class TestHttps2Http(tservers.ReverseProxyTest): p = pathoc.Pathoc( ("localhost", self.proxy.port), ssl=True, sni=sni, fp=None ) - p.connect() return p def test_all(self): p = self.pathoc(ssl=True) - assert p.request("get:'/p/200'").status_code == 200 + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 def test_sni(self): p = self.pathoc(ssl=True, sni="example.com") - assert p.request("get:'/p/200'").status_code == 200 - assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 + assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) def test_http(self): p = self.pathoc(ssl=False) - assert p.request("get:'/p/200'").status_code == 200 + with p.connect(): + assert p.request("get:'/p/200'").status_code == 200 class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): @@ -703,29 +720,29 @@ class TestRedirectRequest(tservers.HTTPProxyTest): self.master.redirect_port = self.server2.port p = self.pathoc() + with p.connect(): + self.server.clear_log() + self.server2.clear_log() + r1 = p.request("get:'/p/200'") + assert r1.status_code == 200 + assert self.server.last_log() + assert not self.server2.last_log() - self.server.clear_log() - self.server2.clear_log() - r1 = p.request("get:'/p/200'") - assert r1.status_code == 200 - assert self.server.last_log() - assert not self.server2.last_log() + self.server.clear_log() + self.server2.clear_log() + r2 = p.request("get:'/p/201'") + assert r2.status_code == 201 + assert not self.server.last_log() + assert self.server2.last_log() - self.server.clear_log() - self.server2.clear_log() - r2 = p.request("get:'/p/201'") - assert r2.status_code == 201 - assert not self.server.last_log() - assert self.server2.last_log() + self.server.clear_log() + self.server2.clear_log() + r3 = p.request("get:'/p/202'") + assert r3.status_code == 202 + assert self.server.last_log() + assert not self.server2.last_log() - self.server.clear_log() - self.server2.clear_log() - r3 = p.request("get:'/p/202'") - assert r3.status_code == 202 - assert self.server.last_log() - assert not self.server2.last_log() - - assert r1.content == r2.content == r3.content + assert r1.content == r2.content == r3.content class MasterStreamRequest(tservers.TestMaster): @@ -743,22 +760,22 @@ class TestStreamRequest(tservers.HTTPProxyTest): def test_stream_simple(self): p = self.pathoc() - - # a request with 100k of data but without content-length - r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) - assert r1.status_code == 200 - assert len(r1.content) > 100000 + with p.connect(): + # a request with 100k of data but without content-length + r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) + assert r1.status_code == 200 + assert len(r1.content) > 100000 def test_stream_multiple(self): p = self.pathoc() + with p.connect(): + # simple request with streaming turned on + r1 = p.request("get:'%s/p/200'" % self.server.urlbase) + assert r1.status_code == 200 - # simple request with streaming turned on - r1 = p.request("get:'%s/p/200'" % self.server.urlbase) - assert r1.status_code == 200 - - # now send back 100k of data, streamed but not chunked - r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) - assert r1.status_code == 201 + # now send back 100k of data, streamed but not chunked + r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase) + assert r1.status_code == 201 def test_stream_chunked(self): connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -887,7 +904,8 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin): ("~s", "baz", "ORLY") ] p = self.pathoc() - req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) + with p.connect(): + req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) assert req.content == b"ORLY" assert req.status_code == 418 @@ -948,7 +966,8 @@ class TestUpstreamProxySSL( def test_simple(self): p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") + with p.connect(): + req = p.request("get:'/p/418:b\"content\"'") assert req.content == b"content" assert req.status_code == 418 @@ -1006,48 +1025,49 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): ]) p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == b"content" - assert req.status_code == 418 + with p.connect(): + req = p.request("get:'/p/418:b\"content\"'") + assert req.content == b"content" + assert req.status_code == 418 - assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request - # CONNECT, failing request, - assert self.chain[0].tmaster.state.flow_count() == 4 - # reCONNECT, request - # failing request, request - assert self.chain[1].tmaster.state.flow_count() == 2 - # (doesn't store (repeated) CONNECTs from chain[0] - # as it is a regular proxy) + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request + # CONNECT, failing request, + assert self.chain[0].tmaster.state.flow_count() == 4 + # reCONNECT, request + # failing request, request + assert self.chain[1].tmaster.state.flow_count() == 2 + # (doesn't store (repeated) CONNECTs from chain[0] + # as it is a regular proxy) - assert not self.chain[1].tmaster.state.flows[0].response # killed - assert self.chain[1].tmaster.state.flows[1].response + assert not self.chain[1].tmaster.state.flows[0].response # killed + assert self.chain[1].tmaster.state.flows[1].response - assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" - assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" + assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" + assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" - assert self.chain[0].tmaster.state.flows[ - 0].request.first_line_format == "authority" - assert self.chain[0].tmaster.state.flows[ - 1].request.first_line_format == "relative" - assert self.chain[0].tmaster.state.flows[ - 2].request.first_line_format == "authority" - assert self.chain[0].tmaster.state.flows[ - 3].request.first_line_format == "relative" + assert self.chain[0].tmaster.state.flows[ + 0].request.first_line_format == "authority" + assert self.chain[0].tmaster.state.flows[ + 1].request.first_line_format == "relative" + assert self.chain[0].tmaster.state.flows[ + 2].request.first_line_format == "authority" + assert self.chain[0].tmaster.state.flows[ + 3].request.first_line_format == "relative" - assert self.chain[1].tmaster.state.flows[ - 0].request.first_line_format == "relative" - assert self.chain[1].tmaster.state.flows[ - 1].request.first_line_format == "relative" + assert self.chain[1].tmaster.state.flows[ + 0].request.first_line_format == "relative" + assert self.chain[1].tmaster.state.flows[ + 1].request.first_line_format == "relative" - req = p.request("get:'/p/418:b\"content2\"'") + req = p.request("get:'/p/418:b\"content2\"'") - assert req.status_code == 502 - assert self.proxy.tmaster.state.flow_count() == 3 # + new request - # + new request, repeated CONNECT from chain[1] - assert self.chain[0].tmaster.state.flow_count() == 6 - # (both terminated) - # nothing happened here - assert self.chain[1].tmaster.state.flow_count() == 2 + assert req.status_code == 502 + assert self.proxy.tmaster.state.flow_count() == 3 # + new request + # + new request, repeated CONNECT from chain[1] + assert self.chain[0].tmaster.state.flow_count() == 6 + # (both terminated) + # nothing happened here + assert self.chain[1].tmaster.state.flow_count() == 2 class AddUpstreamCertsToClientChainMixin: @@ -1066,12 +1086,13 @@ class AddUpstreamCertsToClientChainMixin: d = f.read() upstreamCert = SSLCert.from_pem(d) p = self.pathoc() - upstream_cert_found_in_client_chain = False - for receivedCert in p.server_certs: - if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): - upstream_cert_found_in_client_chain = True - break - assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) + with p.connect(): + upstream_cert_found_in_client_chain = False + for receivedCert in p.server_certs: + if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): + upstream_cert_found_in_client_chain = True + break + assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) class TestHTTPSAddUpstreamCertsToClientChainTrue( diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 1597f59cf..4291f7435 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -3,6 +3,7 @@ import threading import tempfile import flask import mock +import sys from mitmproxy.proxy.config import ProxyConfig from mitmproxy.proxy.server import ProxyServer @@ -10,6 +11,7 @@ import pathod.test import pathod.pathoc from mitmproxy import flow, controller, options from mitmproxy import builtins +import netlib.exceptions testapp = flask.Flask(__name__) @@ -104,6 +106,14 @@ class ProxyTestBase(object): cls.server.shutdown() cls.server2.shutdown() + def teardown(self): + try: + self.server.wait_for_silence() + except netlib.exceptions.Timeout: + # FIXME: Track down the Windows sync issues + if sys.platform != "win32": + raise + def setup(self): self.master.clear_log() self.master.state.clear() @@ -125,6 +135,15 @@ class ProxyTestBase(object): ) +class LazyPathoc(pathod.pathoc.Pathoc): + def __init__(self, lazy_connect, *args, **kwargs): + self.lazy_connect = lazy_connect + pathod.pathoc.Pathoc.__init__(self, *args, **kwargs) + + def connect(self): + return pathod.pathoc.Pathoc.connect(self, self.lazy_connect) + + class HTTPProxyTest(ProxyTestBase): def pathoc_raw(self): @@ -134,14 +153,14 @@ class HTTPProxyTest(ProxyTestBase): """ Returns a connected Pathoc instance. """ - p = pathod.pathoc.Pathoc( + if self.ssl: + conn = ("127.0.0.1", self.server.port) + else: + conn = None + return LazyPathoc( + conn, ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ) - if self.ssl: - p.connect(("127.0.0.1", self.server.port)) - else: - p.connect() - return p def pathod(self, spec, sni=None): """ @@ -152,18 +171,20 @@ class HTTPProxyTest(ProxyTestBase): q = "get:'/p/%s'" % spec else: q = "get:'%s/p/%s'" % (self.server.urlbase, spec) - return p.request(q) + with p.connect(): + return p.request(q) def app(self, page): if self.ssl: p = pathod.pathoc.Pathoc( ("127.0.0.1", self.proxy.port), True, fp=None ) - p.connect((options.APP_HOST, options.APP_PORT)) - return p.request("get:'%s'" % page) + with p.connect((options.APP_HOST, options.APP_PORT)): + return p.request("get:'%s'" % page) else: p = self.pathoc() - return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) + with p.connect(): + return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) class TResolver: @@ -210,7 +231,8 @@ class TransparentProxyTest(ProxyTestBase): else: p = self.pathoc() q = "get:'/p/%s'" % spec - return p.request(q) + with p.connect(): + return p.request(q) def pathoc(self, sni=None): """ @@ -219,7 +241,6 @@ class TransparentProxyTest(ProxyTestBase): p = pathod.pathoc.Pathoc( ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ) - p.connect() return p @@ -247,7 +268,6 @@ class ReverseProxyTest(ProxyTestBase): p = pathod.pathoc.Pathoc( ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ) - p.connect() return p def pathod(self, spec, sni=None): @@ -260,7 +280,8 @@ class ReverseProxyTest(ProxyTestBase): else: p = self.pathoc() q = "get:'/p/%s'" % spec - return p.request(q) + with p.connect(): + return p.request(q) class SocksModeTest(HTTPProxyTest):