Roll out synchronisation for mitmproxy tests

This extends some of the work I did for pathod and netlib to the mitmproxy test
suite. It also fixes what may be a leak in replays.

Failing on connection leak is disabled on Windows for the moment.

Fixes #1535
This commit is contained in:
Aldo Cortesi 2016-09-10 09:18:11 +12:00
parent ea49b8a2e2
commit 4ff8a72521
11 changed files with 254 additions and 215 deletions

View File

@ -4,7 +4,7 @@ import pprint
def _get_name(itm): def _get_name(itm):
return getattr(itm, "name", itm.__class__.__name__) return getattr(itm, "name", itm.__class__.__name__.lower())
class Addons(object): class Addons(object):
@ -13,6 +13,16 @@ class Addons(object):
self.master = master self.master = master
master.options.changed.connect(self.options_update) master.options.changed.connect(self.options_update)
def get(self, name):
"""
Retrieve an addon by name. Addon names are equal to the .name
attribute on the instance, or the lower case class name if that
does not exist.
"""
for i in self.chain:
if name == _get_name(i):
return i
def options_update(self, options, updated): def options_update(self, options, updated):
for i in self.chain: for i in self.chain:
with self.master.handlecontext(): with self.master.handlecontext():
@ -39,14 +49,6 @@ class Addons(object):
for i in self.chain: for i in self.chain:
self.invoke_with_context(i, "done") self.invoke_with_context(i, "done")
def has_addon(self, name):
"""
Is an addon with this name registered?
"""
for i in self.chain:
if _get_name(i) == name:
return True
def __len__(self): def __len__(self):
return len(self.chain) return len(self.chain)

View File

@ -88,13 +88,14 @@ class ServerPlayback(object):
def configure(self, options, updated): def configure(self, options, updated):
self.options = options self.options = options
if options.server_replay and "server_replay" in updated: if "server_replay" in updated:
try:
flows = flow.read_flows_from_paths(options.server_replay)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e))
self.clear() self.clear()
self.load(flows) if options.server_replay:
try:
flows = flow.read_flows_from_paths(options.server_replay)
except exceptions.FlowReadException as e:
raise exceptions.OptionsError(str(e))
self.load(flows)
# FIXME: These options have to be renamed to something more sensible - # FIXME: These options have to be renamed to something more sensible -
# prefixed with serverplayback_ where appropriate, and playback_ where # prefixed with serverplayback_ where appropriate, and playback_ where

View File

@ -248,9 +248,6 @@ class ConsoleMaster(flow.FlowMaster):
if options.client_replay: if options.client_replay:
self.client_playback_path(options.client_replay) self.client_playback_path(options.client_replay)
if options.server_replay:
self.server_playback_path(options.server_replay)
self.view_stack = [] self.view_stack = []
if options.app: if options.app:
@ -391,21 +388,6 @@ class ConsoleMaster(flow.FlowMaster):
if flows: if flows:
self.start_client_playback(flows, False) self.start_client_playback(flows, False)
def server_playback_path(self, path):
if not isinstance(path, list):
path = [path]
flows = self._readflows(path)
if flows:
self.start_server_playback(
flows,
self.options.kill, self.options.rheaders,
False, self.options.nopop,
self.options.replay_ignore_params,
self.options.replay_ignore_content,
self.options.replay_ignore_payload_params,
self.options.replay_ignore_host
)
def spawn_editor(self, data): def spawn_editor(self, data):
text = not isinstance(data, bytes) text = not isinstance(data, bytes)
fd, name = tempfile.mkstemp('', "mproxy", text=text) fd, name = tempfile.mkstemp('', "mproxy", text=text)

View File

@ -147,14 +147,12 @@ class StatusBar(urwid.WidgetWrap):
if self.master.client_playback: if self.master.client_playback:
r.append("[") r.append("[")
r.append(("heading_key", "cplayback")) r.append(("heading_key", "cplayback"))
r.append(":%s to go]" % self.master.client_playback.count()) r.append(":%s]" % self.master.client_playback.count())
if self.master.server_playback: if self.master.options.server_replay:
r.append("[") r.append("[")
r.append(("heading_key", "splayback")) r.append(("heading_key", "splayback"))
if self.master.options.nopop: a = self.master.addons.get("serverplayback")
r.append(":%s in file]" % self.master.server_playback.count()) r.append(":%s]" % a.count())
else:
r.append(":%s to go]" % self.master.server_playback.count())
if self.master.options.ignore_hosts: if self.master.options.ignore_hosts:
r.append("[") r.append("[")
r.append(("heading_key", "I")) r.append(("heading_key", "I"))

View File

@ -57,13 +57,11 @@ class Window(urwid.Frame):
callback = self.master.stop_client_playback_prompt, callback = self.master.stop_client_playback_prompt,
) )
elif k == "s": elif k == "s":
if not self.master.server_playback: a = self.master.addons.get("serverplayback")
signals.status_prompt_path.send( if a.count():
self, def stop_server_playback(response):
prompt = "Server replay path", if response == "y":
callback = self.master.server_playback_path self.master.options.server_replay = []
)
else:
signals.status_prompt_onekey.send( signals.status_prompt_onekey.send(
self, self,
prompt = "Stop current server replay?", prompt = "Stop current server replay?",
@ -71,7 +69,13 @@ class Window(urwid.Frame):
("yes", "y"), ("yes", "y"),
("no", "n"), ("no", "n"),
), ),
callback = self.master.stop_server_playback_prompt, callback = stop_server_playback
)
else:
signals.status_prompt_path.send(
self,
prompt = "Server playback path",
callback = lambda x: self.master.options.setter("server_replay")([x])
) )
def keypress(self, size, k): def keypress(self, size, k):

View File

@ -33,6 +33,7 @@ class RequestReplayThread(basethread.BaseThread):
def run(self): def run(self):
r = self.flow.request r = self.flow.request
first_line_format_backup = r.first_line_format first_line_format_backup = r.first_line_format
server = None
try: try:
self.flow.response = None self.flow.response = None
@ -103,3 +104,5 @@ class RequestReplayThread(basethread.BaseThread):
self.channel.tell("log", Log(traceback.format_exc(), "error")) self.channel.tell("log", Log(traceback.format_exc(), "error"))
finally: finally:
r.first_line_format = first_line_format_backup r.first_line_format = first_line_format_backup
if server:
server.finish()

View File

@ -18,14 +18,15 @@ class TestInvalidRequests(tservers.HTTPProxyTest):
def test_double_connect(self): def test_double_connect(self):
p = self.pathoc() p = self.pathoc()
r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port)) with p.connect():
r = p.request("connect:'%s:%s'" % ("127.0.0.1", self.server2.port))
assert r.status_code == 400 assert r.status_code == 400
assert b"Invalid HTTP request form" in r.content assert b"Invalid HTTP request form" in r.content
def test_relative_request(self): def test_relative_request(self):
p = self.pathoc_raw() p = self.pathoc_raw()
p.connect() with p.connect():
r = p.request("get:/p/200") r = p.request("get:/p/200")
assert r.status_code == 400 assert r.status_code == 400
assert b"Invalid HTTP request form" in r.content assert b"Invalid HTTP request form" in r.content
@ -61,5 +62,8 @@ class TestHeadContentLength(tservers.HTTPProxyTest):
def test_head_content_length(self): def test_head_content_length(self):
p = self.pathoc() p = self.pathoc()
resp = p.request("""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase) with p.connect():
resp = p.request(
"""head:'%s/p/200:h"Content-Length"="42"'""" % self.server.urlbase
)
assert resp.headers["Content-Length"] == "42" assert resp.headers["Content-Length"] == "42"

View File

@ -17,5 +17,5 @@ def test_simple():
m = controller.Master(o) m = controller.Master(o)
a = addons.Addons(m) a = addons.Addons(m)
a.add(o, TAddon("one")) a.add(o, TAddon("one"))
assert a.has_addon("one") assert a.get("one")
assert not a.has_addon("two") assert not a.get("two")

View File

@ -11,17 +11,20 @@ class TestFuzzy(tservers.HTTPProxyTest):
def test_idna_err(self): def test_idna_err(self):
req = r'get:"http://localhost:%s":i10,"\xc6"' req = r'get:"http://localhost:%s":i10,"\xc6"'
p = self.pathoc() p = self.pathoc()
assert p.request(req % self.server.port).status_code == 400 with p.connect():
assert p.request(req % self.server.port).status_code == 400
def test_nullbytes(self): def test_nullbytes(self):
req = r'get:"http://localhost:%s":i19,"\x00"' req = r'get:"http://localhost:%s":i19,"\x00"'
p = self.pathoc() p = self.pathoc()
assert p.request(req % self.server.port).status_code == 400 with p.connect():
assert p.request(req % self.server.port).status_code == 400
def test_invalid_ipv6_url(self): def test_invalid_ipv6_url(self):
req = 'get:"http://localhost:%s":i13,"["' req = 'get:"http://localhost:%s":i13,"["'
p = self.pathoc() p = self.pathoc()
resp = p.request(req % self.server.port) with p.connect():
resp = p.request(req % self.server.port)
assert resp.status_code == 400 assert resp.status_code == 400
# def test_invalid_upstream(self): # def test_invalid_upstream(self):

View File

@ -91,11 +91,11 @@ class CommonMixin:
def test_invalid_http(self): def test_invalid_http(self):
t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t = tcp.TCPClient(("127.0.0.1", self.proxy.port))
t.connect() with t.connect():
t.wfile.write(b"invalid\r\n\r\n") t.wfile.write(b"invalid\r\n\r\n")
t.wfile.flush() t.wfile.flush()
line = t.rfile.readline() line = t.rfile.readline()
assert (b"Bad Request" in line) or (b"Bad Gateway" in line) assert (b"Bad Request" in line) or (b"Bad Gateway" in line)
def test_sni(self): def test_sni(self):
if not self.ssl: if not self.ssl:
@ -208,20 +208,22 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
def test_app_err(self): def test_app_err(self):
p = self.pathoc() p = self.pathoc()
ret = p.request("get:'http://errapp/'") with p.connect():
ret = p.request("get:'http://errapp/'")
assert ret.status_code == 500 assert ret.status_code == 500
assert b"ValueError" in ret.content assert b"ValueError" in ret.content
def test_invalid_connect(self): def test_invalid_connect(self):
t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t = tcp.TCPClient(("127.0.0.1", self.proxy.port))
t.connect() with t.connect():
t.wfile.write(b"CONNECT invalid\n\n") t.wfile.write(b"CONNECT invalid\n\n")
t.wfile.flush() t.wfile.flush()
assert b"Bad Request" in t.rfile.readline() assert b"Bad Request" in t.rfile.readline()
def test_upstream_ssl_error(self): def test_upstream_ssl_error(self):
p = self.pathoc() p = self.pathoc()
ret = p.request("get:'https://localhost:%s/'" % self.server.port) with p.connect():
ret = p.request("get:'https://localhost:%s/'" % self.server.port)
assert ret.status_code == 400 assert ret.status_code == 400
def test_connection_close(self): def test_connection_close(self):
@ -232,25 +234,28 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
# Lets sanity check that the connection does indeed stay open by # Lets sanity check that the connection does indeed stay open by
# issuing two requests over the same connection # issuing two requests over the same connection
p = self.pathoc() p = self.pathoc()
assert p.request("get:'%s'" % response) with p.connect():
assert p.request("get:'%s'" % response) assert p.request("get:'%s'" % response)
assert p.request("get:'%s'" % response)
# Now check that the connection is closed as the client specifies # Now check that the connection is closed as the client specifies
p = self.pathoc() p = self.pathoc()
assert p.request("get:'%s':h'Connection'='close'" % response) with p.connect():
# There's a race here, which means we can get any of a number of errors. assert p.request("get:'%s':h'Connection'='close'" % response)
# Rather than introduce yet another sleep into the test suite, we just # There's a race here, which means we can get any of a number of errors.
# relax the Exception specification. # Rather than introduce yet another sleep into the test suite, we just
with raises(Exception): # relax the Exception specification.
p.request("get:'%s'" % response) with raises(Exception):
p.request("get:'%s'" % response)
def test_reconnect(self): def test_reconnect(self):
req = "get:'%s/p/200:b@1:da'" % self.server.urlbase req = "get:'%s/p/200:b@1:da'" % self.server.urlbase
p = self.pathoc() p = self.pathoc()
assert p.request(req) with p.connect():
# Server has disconnected. Mitmproxy should detect this, and reconnect. assert p.request(req)
assert p.request(req) # Server has disconnected. Mitmproxy should detect this, and reconnect.
assert p.request(req) assert p.request(req)
assert p.request(req)
def test_get_connection_switching(self): def test_get_connection_switching(self):
def switched(l): def switched(l):
@ -260,18 +265,21 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
req = "get:'%s/p/200:b@1'" req = "get:'%s/p/200:b@1'"
p = self.pathoc() p = self.pathoc()
assert p.request(req % self.server.urlbase) with p.connect():
assert p.request(req % self.server2.urlbase) assert p.request(req % self.server.urlbase)
assert p.request(req % self.server2.urlbase)
assert switched(self.proxy.tlog) assert switched(self.proxy.tlog)
def test_blank_leading_line(self): def test_blank_leading_line(self):
p = self.pathoc() p = self.pathoc()
req = "get:'%s/p/201':i0,'\r\n'" with p.connect():
assert p.request(req % self.server.urlbase).status_code == 201 req = "get:'%s/p/201':i0,'\r\n'"
assert p.request(req % self.server.urlbase).status_code == 201
def test_invalid_headers(self): def test_invalid_headers(self):
p = self.pathoc() p = self.pathoc()
resp = p.request("get:'http://foo':h':foo'='bar'") with p.connect():
resp = p.request("get:'http://foo':h':foo'='bar'")
assert resp.status_code == 400 assert resp.status_code == 400
def test_stream(self): def test_stream(self):
@ -301,15 +309,16 @@ class TestHTTPAuth(tservers.HTTPProxyTest):
self.master.options.auth_singleuser = "test:test" self.master.options.auth_singleuser = "test:test"
assert self.pathod("202").status_code == 407 assert self.pathod("202").status_code == 407
p = self.pathoc() p = self.pathoc()
ret = p.request(""" with p.connect():
get ret = p.request("""
'http://localhost:%s/p/202' get
h'%s'='%s' 'http://localhost:%s/p/202'
""" % ( h'%s'='%s'
self.server.port, """ % (
http.authentication.BasicProxyAuth.AUTH_HEADER, self.server.port,
authentication.assemble_http_basic_auth("basic", "test", "test") http.authentication.BasicProxyAuth.AUTH_HEADER,
)) authentication.assemble_http_basic_auth("basic", "test", "test")
))
assert ret.status_code == 202 assert ret.status_code == 202
@ -318,14 +327,15 @@ class TestHTTPReverseAuth(tservers.ReverseProxyTest):
self.master.options.auth_singleuser = "test:test" self.master.options.auth_singleuser = "test:test"
assert self.pathod("202").status_code == 401 assert self.pathod("202").status_code == 401
p = self.pathoc() p = self.pathoc()
ret = p.request(""" with p.connect():
get ret = p.request("""
'/p/202' get
h'%s'='%s' '/p/202'
""" % ( h'%s'='%s'
http.authentication.BasicWebsiteAuth.AUTH_HEADER, """ % (
authentication.assemble_http_basic_auth("basic", "test", "test") http.authentication.BasicWebsiteAuth.AUTH_HEADER,
)) authentication.assemble_http_basic_auth("basic", "test", "test")
))
assert ret.status_code == 202 assert ret.status_code == 202
@ -354,7 +364,8 @@ class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin):
def test_error_post_connect(self): def test_error_post_connect(self):
p = self.pathoc() p = self.pathoc()
assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 with p.connect():
assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400
class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin): class TestHTTPSCertfile(tservers.HTTPProxyTest, CommonMixin):
@ -389,7 +400,8 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest):
def _request(self): def _request(self):
p = self.pathoc(sni="example.mitmproxy.org") p = self.pathoc(sni="example.mitmproxy.org")
return p.request("get:/p/242") with p.connect():
return p.request("get:/p/242")
def test_verification_w_cadir(self): def test_verification_w_cadir(self):
self.config.options.update( self.config.options.update(
@ -426,7 +438,8 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest):
def _request(self): def _request(self):
p = self.pathoc(sni="example.mitmproxy.org") p = self.pathoc(sni="example.mitmproxy.org")
return p.request("get:/p/242") with p.connect():
return p.request("get:/p/242")
@classmethod @classmethod
def get_options(cls): def get_options(cls):
@ -481,13 +494,15 @@ class TestSocks5(tservers.SocksModeTest):
def test_simple(self): def test_simple(self):
p = self.pathoc() p = self.pathoc()
p.socks_connect(("localhost", self.server.port)) with p.connect():
f = p.request("get:/p/200") p.socks_connect(("localhost", self.server.port))
f = p.request("get:/p/200")
assert f.status_code == 200 assert f.status_code == 200
def test_with_authentication_only(self): def test_with_authentication_only(self):
p = self.pathoc() p = self.pathoc()
f = p.request("get:/p/200") with p.connect():
f = p.request("get:/p/200")
assert f.status_code == 502 assert f.status_code == 502
assert b"SOCKS5 mode failure" in f.content assert b"SOCKS5 mode failure" in f.content
@ -496,21 +511,21 @@ class TestSocks5(tservers.SocksModeTest):
mitmproxy doesn't support UDP or BIND SOCKS CMDs mitmproxy doesn't support UDP or BIND SOCKS CMDs
""" """
p = self.pathoc() p = self.pathoc()
with p.connect():
socks.ClientGreeting(
socks.VERSION.SOCKS5,
[socks.METHOD.NO_AUTHENTICATION_REQUIRED]
).to_file(p.wfile)
socks.Message(
socks.VERSION.SOCKS5,
socks.CMD.BIND,
socks.ATYP.DOMAINNAME,
("example.com", 8080)
).to_file(p.wfile)
socks.ClientGreeting( p.wfile.flush()
socks.VERSION.SOCKS5, p.rfile.read(2) # read server greeting
[socks.METHOD.NO_AUTHENTICATION_REQUIRED] f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway.
).to_file(p.wfile)
socks.Message(
socks.VERSION.SOCKS5,
socks.CMD.BIND,
socks.ATYP.DOMAINNAME,
("example.com", 8080)
).to_file(p.wfile)
p.wfile.flush()
p.rfile.read(2) # read server greeting
f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway.
assert f.status_code == 502 assert f.status_code == 502
assert b"SOCKS5 mode failure" in f.content assert b"SOCKS5 mode failure" in f.content
@ -531,21 +546,23 @@ class TestHttps2Http(tservers.ReverseProxyTest):
p = pathoc.Pathoc( p = pathoc.Pathoc(
("localhost", self.proxy.port), ssl=True, sni=sni, fp=None ("localhost", self.proxy.port), ssl=True, sni=sni, fp=None
) )
p.connect()
return p return p
def test_all(self): def test_all(self):
p = self.pathoc(ssl=True) p = self.pathoc(ssl=True)
assert p.request("get:'/p/200'").status_code == 200 with p.connect():
assert p.request("get:'/p/200'").status_code == 200
def test_sni(self): def test_sni(self):
p = self.pathoc(ssl=True, sni="example.com") p = self.pathoc(ssl=True, sni="example.com")
assert p.request("get:'/p/200'").status_code == 200 with p.connect():
assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) assert p.request("get:'/p/200'").status_code == 200
assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
def test_http(self): def test_http(self):
p = self.pathoc(ssl=False) p = self.pathoc(ssl=False)
assert p.request("get:'/p/200'").status_code == 200 with p.connect():
assert p.request("get:'/p/200'").status_code == 200
class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
@ -703,29 +720,29 @@ class TestRedirectRequest(tservers.HTTPProxyTest):
self.master.redirect_port = self.server2.port self.master.redirect_port = self.server2.port
p = self.pathoc() p = self.pathoc()
with p.connect():
self.server.clear_log()
self.server2.clear_log()
r1 = p.request("get:'/p/200'")
assert r1.status_code == 200
assert self.server.last_log()
assert not self.server2.last_log()
self.server.clear_log() self.server.clear_log()
self.server2.clear_log() self.server2.clear_log()
r1 = p.request("get:'/p/200'") r2 = p.request("get:'/p/201'")
assert r1.status_code == 200 assert r2.status_code == 201
assert self.server.last_log() assert not self.server.last_log()
assert not self.server2.last_log() assert self.server2.last_log()
self.server.clear_log() self.server.clear_log()
self.server2.clear_log() self.server2.clear_log()
r2 = p.request("get:'/p/201'") r3 = p.request("get:'/p/202'")
assert r2.status_code == 201 assert r3.status_code == 202
assert not self.server.last_log() assert self.server.last_log()
assert self.server2.last_log() assert not self.server2.last_log()
self.server.clear_log() assert r1.content == r2.content == r3.content
self.server2.clear_log()
r3 = p.request("get:'/p/202'")
assert r3.status_code == 202
assert self.server.last_log()
assert not self.server2.last_log()
assert r1.content == r2.content == r3.content
class MasterStreamRequest(tservers.TestMaster): class MasterStreamRequest(tservers.TestMaster):
@ -743,22 +760,22 @@ class TestStreamRequest(tservers.HTTPProxyTest):
def test_stream_simple(self): def test_stream_simple(self):
p = self.pathoc() p = self.pathoc()
with p.connect():
# a request with 100k of data but without content-length # a request with 100k of data but without content-length
r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase) r1 = p.request("get:'%s/p/200:r:b@100k:d102400'" % self.server.urlbase)
assert r1.status_code == 200 assert r1.status_code == 200
assert len(r1.content) > 100000 assert len(r1.content) > 100000
def test_stream_multiple(self): def test_stream_multiple(self):
p = self.pathoc() p = self.pathoc()
with p.connect():
# simple request with streaming turned on
r1 = p.request("get:'%s/p/200'" % self.server.urlbase)
assert r1.status_code == 200
# simple request with streaming turned on # now send back 100k of data, streamed but not chunked
r1 = p.request("get:'%s/p/200'" % self.server.urlbase) r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase)
assert r1.status_code == 200 assert r1.status_code == 201
# now send back 100k of data, streamed but not chunked
r1 = p.request("get:'%s/p/201:b@100k'" % self.server.urlbase)
assert r1.status_code == 201
def test_stream_chunked(self): def test_stream_chunked(self):
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -887,7 +904,8 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin):
("~s", "baz", "ORLY") ("~s", "baz", "ORLY")
] ]
p = self.pathoc() p = self.pathoc()
req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) with p.connect():
req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase)
assert req.content == b"ORLY" assert req.content == b"ORLY"
assert req.status_code == 418 assert req.status_code == 418
@ -948,7 +966,8 @@ class TestUpstreamProxySSL(
def test_simple(self): def test_simple(self):
p = self.pathoc() p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'") with p.connect():
req = p.request("get:'/p/418:b\"content\"'")
assert req.content == b"content" assert req.content == b"content"
assert req.status_code == 418 assert req.status_code == 418
@ -1006,48 +1025,49 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
]) ])
p = self.pathoc() p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'") with p.connect():
assert req.content == b"content" req = p.request("get:'/p/418:b\"content\"'")
assert req.status_code == 418 assert req.content == b"content"
assert req.status_code == 418
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
# CONNECT, failing request, # CONNECT, failing request,
assert self.chain[0].tmaster.state.flow_count() == 4 assert self.chain[0].tmaster.state.flow_count() == 4
# reCONNECT, request # reCONNECT, request
# failing request, request # failing request, request
assert self.chain[1].tmaster.state.flow_count() == 2 assert self.chain[1].tmaster.state.flow_count() == 2
# (doesn't store (repeated) CONNECTs from chain[0] # (doesn't store (repeated) CONNECTs from chain[0]
# as it is a regular proxy) # as it is a regular proxy)
assert not self.chain[1].tmaster.state.flows[0].response # killed assert not self.chain[1].tmaster.state.flows[0].response # killed
assert self.chain[1].tmaster.state.flows[1].response assert self.chain[1].tmaster.state.flows[1].response
assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority" assert self.proxy.tmaster.state.flows[0].request.first_line_format == "authority"
assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative" assert self.proxy.tmaster.state.flows[1].request.first_line_format == "relative"
assert self.chain[0].tmaster.state.flows[ assert self.chain[0].tmaster.state.flows[
0].request.first_line_format == "authority" 0].request.first_line_format == "authority"
assert self.chain[0].tmaster.state.flows[ assert self.chain[0].tmaster.state.flows[
1].request.first_line_format == "relative" 1].request.first_line_format == "relative"
assert self.chain[0].tmaster.state.flows[ assert self.chain[0].tmaster.state.flows[
2].request.first_line_format == "authority" 2].request.first_line_format == "authority"
assert self.chain[0].tmaster.state.flows[ assert self.chain[0].tmaster.state.flows[
3].request.first_line_format == "relative" 3].request.first_line_format == "relative"
assert self.chain[1].tmaster.state.flows[ assert self.chain[1].tmaster.state.flows[
0].request.first_line_format == "relative" 0].request.first_line_format == "relative"
assert self.chain[1].tmaster.state.flows[ assert self.chain[1].tmaster.state.flows[
1].request.first_line_format == "relative" 1].request.first_line_format == "relative"
req = p.request("get:'/p/418:b\"content2\"'") req = p.request("get:'/p/418:b\"content2\"'")
assert req.status_code == 502 assert req.status_code == 502
assert self.proxy.tmaster.state.flow_count() == 3 # + new request assert self.proxy.tmaster.state.flow_count() == 3 # + new request
# + new request, repeated CONNECT from chain[1] # + new request, repeated CONNECT from chain[1]
assert self.chain[0].tmaster.state.flow_count() == 6 assert self.chain[0].tmaster.state.flow_count() == 6
# (both terminated) # (both terminated)
# nothing happened here # nothing happened here
assert self.chain[1].tmaster.state.flow_count() == 2 assert self.chain[1].tmaster.state.flow_count() == 2
class AddUpstreamCertsToClientChainMixin: class AddUpstreamCertsToClientChainMixin:
@ -1066,12 +1086,13 @@ class AddUpstreamCertsToClientChainMixin:
d = f.read() d = f.read()
upstreamCert = SSLCert.from_pem(d) upstreamCert = SSLCert.from_pem(d)
p = self.pathoc() p = self.pathoc()
upstream_cert_found_in_client_chain = False with p.connect():
for receivedCert in p.server_certs: upstream_cert_found_in_client_chain = False
if receivedCert.digest('sha256') == upstreamCert.digest('sha256'): for receivedCert in p.server_certs:
upstream_cert_found_in_client_chain = True if receivedCert.digest('sha256') == upstreamCert.digest('sha256'):
break upstream_cert_found_in_client_chain = True
assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain) break
assert(upstream_cert_found_in_client_chain == self.master.options.add_upstream_certs_to_client_chain)
class TestHTTPSAddUpstreamCertsToClientChainTrue( class TestHTTPSAddUpstreamCertsToClientChainTrue(

View File

@ -3,6 +3,7 @@ import threading
import tempfile import tempfile
import flask import flask
import mock import mock
import sys
from mitmproxy.proxy.config import ProxyConfig from mitmproxy.proxy.config import ProxyConfig
from mitmproxy.proxy.server import ProxyServer from mitmproxy.proxy.server import ProxyServer
@ -10,6 +11,7 @@ import pathod.test
import pathod.pathoc import pathod.pathoc
from mitmproxy import flow, controller, options from mitmproxy import flow, controller, options
from mitmproxy import builtins from mitmproxy import builtins
import netlib.exceptions
testapp = flask.Flask(__name__) testapp = flask.Flask(__name__)
@ -104,6 +106,14 @@ class ProxyTestBase(object):
cls.server.shutdown() cls.server.shutdown()
cls.server2.shutdown() cls.server2.shutdown()
def teardown(self):
try:
self.server.wait_for_silence()
except netlib.exceptions.Timeout:
# FIXME: Track down the Windows sync issues
if sys.platform != "win32":
raise
def setup(self): def setup(self):
self.master.clear_log() self.master.clear_log()
self.master.state.clear() self.master.state.clear()
@ -125,6 +135,15 @@ class ProxyTestBase(object):
) )
class LazyPathoc(pathod.pathoc.Pathoc):
def __init__(self, lazy_connect, *args, **kwargs):
self.lazy_connect = lazy_connect
pathod.pathoc.Pathoc.__init__(self, *args, **kwargs)
def connect(self):
return pathod.pathoc.Pathoc.connect(self, self.lazy_connect)
class HTTPProxyTest(ProxyTestBase): class HTTPProxyTest(ProxyTestBase):
def pathoc_raw(self): def pathoc_raw(self):
@ -134,14 +153,14 @@ class HTTPProxyTest(ProxyTestBase):
""" """
Returns a connected Pathoc instance. Returns a connected Pathoc instance.
""" """
p = pathod.pathoc.Pathoc( if self.ssl:
conn = ("127.0.0.1", self.server.port)
else:
conn = None
return LazyPathoc(
conn,
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
) )
if self.ssl:
p.connect(("127.0.0.1", self.server.port))
else:
p.connect()
return p
def pathod(self, spec, sni=None): def pathod(self, spec, sni=None):
""" """
@ -152,18 +171,20 @@ class HTTPProxyTest(ProxyTestBase):
q = "get:'/p/%s'" % spec q = "get:'/p/%s'" % spec
else: else:
q = "get:'%s/p/%s'" % (self.server.urlbase, spec) q = "get:'%s/p/%s'" % (self.server.urlbase, spec)
return p.request(q) with p.connect():
return p.request(q)
def app(self, page): def app(self, page):
if self.ssl: if self.ssl:
p = pathod.pathoc.Pathoc( p = pathod.pathoc.Pathoc(
("127.0.0.1", self.proxy.port), True, fp=None ("127.0.0.1", self.proxy.port), True, fp=None
) )
p.connect((options.APP_HOST, options.APP_PORT)) with p.connect((options.APP_HOST, options.APP_PORT)):
return p.request("get:'%s'" % page) return p.request("get:'%s'" % page)
else: else:
p = self.pathoc() p = self.pathoc()
return p.request("get:'http://%s%s'" % (options.APP_HOST, page)) with p.connect():
return p.request("get:'http://%s%s'" % (options.APP_HOST, page))
class TResolver: class TResolver:
@ -210,7 +231,8 @@ class TransparentProxyTest(ProxyTestBase):
else: else:
p = self.pathoc() p = self.pathoc()
q = "get:'/p/%s'" % spec q = "get:'/p/%s'" % spec
return p.request(q) with p.connect():
return p.request(q)
def pathoc(self, sni=None): def pathoc(self, sni=None):
""" """
@ -219,7 +241,6 @@ class TransparentProxyTest(ProxyTestBase):
p = pathod.pathoc.Pathoc( p = pathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
) )
p.connect()
return p return p
@ -247,7 +268,6 @@ class ReverseProxyTest(ProxyTestBase):
p = pathod.pathoc.Pathoc( p = pathod.pathoc.Pathoc(
("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
) )
p.connect()
return p return p
def pathod(self, spec, sni=None): def pathod(self, spec, sni=None):
@ -260,7 +280,8 @@ class ReverseProxyTest(ProxyTestBase):
else: else:
p = self.pathoc() p = self.pathoc()
q = "get:'/p/%s'" % spec q = "get:'/p/%s'" % spec
return p.request(q) with p.connect():
return p.request(q)
class SocksModeTest(HTTPProxyTest): class SocksModeTest(HTTPProxyTest):