asyncio: fix channel interface and tests

We now need to synthesize a tick event when changing addons in tests, because
tick is asynchronously called on the event loop.
This commit is contained in:
Aldo Cortesi 2018-04-01 11:37:35 +12:00
parent b6d943cfa3
commit 3cc5d81a4a
4 changed files with 33 additions and 24 deletions

View File

@ -8,8 +8,9 @@ class Channel:
The only way for the proxy server to communicate with the master The only way for the proxy server to communicate with the master
is to use the channel it has been given. is to use the channel it has been given.
""" """
def __init__(self, loop, q): def __init__(self, loop, q, should_exit):
self.loop = loop self.loop = loop
self.should_exit = should_exit
self._q = q self._q = q
def ask(self, mtype, m): def ask(self, mtype, m):

View File

@ -1,7 +1,6 @@
import threading import threading
import contextlib import contextlib
import asyncio import asyncio
import signal
from mitmproxy import addonmanager from mitmproxy import addonmanager
from mitmproxy import options from mitmproxy import options
@ -37,11 +36,16 @@ class Master:
""" """
def __init__(self, opts): def __init__(self, opts):
self.event_queue = asyncio.Queue() self.event_queue = asyncio.Queue()
self.should_exit = threading.Event()
self.channel = controller.Channel(
asyncio.get_event_loop(),
self.event_queue,
self.should_exit,
)
self.options = opts or options.Options() # type: options.Options self.options = opts or options.Options() # type: options.Options
self.commands = command.CommandManager(self) self.commands = command.CommandManager(self)
self.addons = addonmanager.AddonManager(self) self.addons = addonmanager.AddonManager(self)
self.should_exit = threading.Event()
self._server = None self._server = None
self.first_tick = True self.first_tick = True
self.waiting_flows = [] self.waiting_flows = []
@ -52,7 +56,7 @@ class Master:
@server.setter @server.setter
def server(self, server): def server(self, server):
server.set_channel(controller.Channel(asyncio.get_event_loop(), self.event_queue)) server.set_channel(self.channel)
self._server = server self._server = server
@contextlib.contextmanager @contextlib.contextmanager
@ -202,7 +206,7 @@ class Master:
host = f.request.headers.pop(":authority") host = f.request.headers.pop(":authority")
f.request.headers.insert(0, "host", host) f.request.headers.insert(0, "host", host)
rt = http_replay.RequestReplayThread(self.options, f, self.server.channel) rt = http_replay.RequestReplayThread(self.options, f, self.channel)
rt.start() # pragma: no cover rt.start() # pragma: no cover
if block: if block:
rt.join() rt.join()

View File

@ -276,10 +276,9 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin):
s = script.Script( s = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/stream_modify.py") tutils.test_data.path("mitmproxy/data/addonscripts/stream_modify.py")
) )
self.master.addons.add(s) self.set_addons(s)
d = self.pathod('200:b"foo"') d = self.pathod('200:b"foo"')
assert d.content == b"bar" assert d.content == b"bar"
self.master.addons.remove(s)
def test_first_line_rewrite(self): def test_first_line_rewrite(self):
""" """
@ -583,12 +582,11 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
s = script.Script( s = script.Script(
tutils.test_data.path("mitmproxy/data/addonscripts/tcp_stream_modify.py") tutils.test_data.path("mitmproxy/data/addonscripts/tcp_stream_modify.py")
) )
self.master.addons.add(s) self.set_addons(s)
self._tcpproxy_on() self._tcpproxy_on()
d = self.pathod('200:b"foo"') d = self.pathod('200:b"foo"')
self._tcpproxy_off() self._tcpproxy_off()
assert d.content == b"bar" assert d.content == b"bar"
self.master.addons.remove(s)
class TestTransparentSSL(tservers.TransparentProxyTest, CommonMixin, TcpMixin): class TestTransparentSSL(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
@ -731,7 +729,7 @@ class TestRedirectRequest(tservers.HTTPProxyTest):
This test verifies that the original destination is restored for the third request. This test verifies that the original destination is restored for the third request.
""" """
self.proxy.tmaster.addons.add(ARedirectRequest(self.server2.port)) self.set_addons(ARedirectRequest(self.server2.port))
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
@ -770,7 +768,7 @@ class AStreamRequest:
class TestStreamRequest(tservers.HTTPProxyTest): class TestStreamRequest(tservers.HTTPProxyTest):
def test_stream_simple(self): def test_stream_simple(self):
self.proxy.tmaster.addons.add(AStreamRequest()) self.set_addons(AStreamRequest())
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
# a request with 100k of data but without content-length # a request with 100k of data but without content-length
@ -779,7 +777,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
assert len(r1.content) > 100000 assert len(r1.content) > 100000
def test_stream_multiple(self): def test_stream_multiple(self):
self.proxy.tmaster.addons.add(AStreamRequest()) self.set_addons(AStreamRequest())
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
# simple request with streaming turned on # simple request with streaming turned on
@ -791,7 +789,7 @@ class TestStreamRequest(tservers.HTTPProxyTest):
assert r1.status_code == 201 assert r1.status_code == 201
def test_stream_chunked(self): def test_stream_chunked(self):
self.proxy.tmaster.addons.add(AStreamRequest()) self.set_addons(AStreamRequest())
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection.connect(("127.0.0.1", self.proxy.port)) connection.connect(("127.0.0.1", self.proxy.port))
fconn = connection.makefile("rb") fconn = connection.makefile("rb")
@ -820,7 +818,7 @@ class AFakeResponse:
class TestFakeResponse(tservers.HTTPProxyTest): class TestFakeResponse(tservers.HTTPProxyTest):
def test_fake(self): def test_fake(self):
self.proxy.tmaster.addons.add(AFakeResponse()) self.set_addons(AFakeResponse())
f = self.pathod("200") f = self.pathod("200")
assert "header-response" in f.headers assert "header-response" in f.headers
@ -836,7 +834,7 @@ class TestServerConnect(tservers.HTTPProxyTest):
def test_unnecessary_serverconnect(self): def test_unnecessary_serverconnect(self):
"""A replayed/fake response with no upstream_cert should not connect to an upstream server""" """A replayed/fake response with no upstream_cert should not connect to an upstream server"""
self.proxy.tmaster.addons.add(AFakeResponse()) self.set_addons(AFakeResponse())
assert self.pathod("200").status_code == 200 assert self.pathod("200").status_code == 200
assert not self.proxy.tmaster.has_log("serverconnect") assert not self.proxy.tmaster.has_log("serverconnect")
@ -849,7 +847,7 @@ class AKillRequest:
class TestKillRequest(tservers.HTTPProxyTest): class TestKillRequest(tservers.HTTPProxyTest):
def test_kill(self): def test_kill(self):
self.proxy.tmaster.addons.add(AKillRequest()) self.set_addons(AKillRequest())
with pytest.raises(exceptions.HttpReadDisconnect): with pytest.raises(exceptions.HttpReadDisconnect):
self.pathod("200") self.pathod("200")
# Nothing should have hit the server # Nothing should have hit the server
@ -863,7 +861,7 @@ class AKillResponse:
class TestKillResponse(tservers.HTTPProxyTest): class TestKillResponse(tservers.HTTPProxyTest):
def test_kill(self): def test_kill(self):
self.proxy.tmaster.addons.add(AKillResponse()) self.set_addons(AKillResponse())
with pytest.raises(exceptions.HttpReadDisconnect): with pytest.raises(exceptions.HttpReadDisconnect):
self.pathod("200") self.pathod("200")
# The server should have seen a request # The server should have seen a request
@ -886,7 +884,7 @@ class AIncomplete:
class TestIncompleteResponse(tservers.HTTPProxyTest): class TestIncompleteResponse(tservers.HTTPProxyTest):
def test_incomplete(self): def test_incomplete(self):
self.proxy.tmaster.addons.add(AIncomplete()) self.set_addons(AIncomplete())
assert self.pathod("200").status_code == 502 assert self.pathod("200").status_code == 502
@ -969,7 +967,7 @@ class TestUpstreamProxySSL(
def test_change_upstream_proxy_connect(self): def test_change_upstream_proxy_connect(self):
# skip chain[0]. # skip chain[0].
self.proxy.tmaster.addons.add( self.set_addons(
UpstreamProxyChanger( UpstreamProxyChanger(
("127.0.0.1", self.chain[1].port) ("127.0.0.1", self.chain[1].port)
) )
@ -988,8 +986,8 @@ class TestUpstreamProxySSL(
Client <- HTTPS -> Proxy <- HTTP -> Proxy <- HTTPS -> Server Client <- HTTPS -> Proxy <- HTTP -> Proxy <- HTTPS -> Server
""" """
self.proxy.tmaster.addons.add(RewriteToHttp()) self.set_addons(RewriteToHttp())
self.chain[1].tmaster.addons.add(RewriteToHttps()) self.set_addons(RewriteToHttps())
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():
resp = p.request("get:'/p/418'") resp = p.request("get:'/p/418'")
@ -1063,8 +1061,8 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest):
http1obj.server_conn.wfile.write(headers) http1obj.server_conn.wfile.write(headers)
http1obj.server_conn.wfile.flush() http1obj.server_conn.wfile.flush()
self.chain[0].tmaster.addons.add(RequestKiller([1, 2])) self.chain[0].set_addons(RequestKiller([1, 2]))
self.chain[1].tmaster.addons.add(RequestKiller([1])) self.chain[1].set_addons(RequestKiller([1]))
p = self.pathoc() p = self.pathoc()
with p.connect(): with p.connect():

View File

@ -92,6 +92,7 @@ class ProxyThread(threading.Thread):
self.masterclass = masterclass self.masterclass = masterclass
self.options = options self.options = options
self.tmaster = None self.tmaster = None
self.event_loop = None
controller.should_exit = False controller.should_exit = False
@property @property
@ -106,7 +107,8 @@ class ProxyThread(threading.Thread):
self.tmaster.shutdown() self.tmaster.shutdown()
def run(self): def run(self):
asyncio.set_event_loop(asyncio.new_event_loop()) self.event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.event_loop)
self.tmaster = self.masterclass(self.options) self.tmaster = self.masterclass(self.options)
self.tmaster.addons.add(core.Core()) self.tmaster.addons.add(core.Core())
self.name = "ProxyThread (%s:%s)" % ( self.name = "ProxyThread (%s:%s)" % (
@ -177,6 +179,10 @@ class ProxyTestBase:
ssl_insecure=True, ssl_insecure=True,
) )
def set_addons(self, *addons):
self.proxy.tmaster.reset(addons)
self.proxy.tmaster.addons.trigger("tick")
def addons(self): def addons(self):
""" """
Can be over-ridden to add a standard set of addons to tests. Can be over-ridden to add a standard set of addons to tests.