diff --git a/mitmproxy/addonmanager.py b/mitmproxy/addonmanager.py index 70cfda309..37c501ee2 100644 --- a/mitmproxy/addonmanager.py +++ b/mitmproxy/addonmanager.py @@ -230,7 +230,7 @@ class AddonManager: self.trigger(name, message) - if message.reply.state != "taken": + if message.reply.state == "start": message.reply.take() if not message.reply.has_message: message.reply.ack() diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index 63117ef03..f39c1b24f 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -105,16 +105,16 @@ class Reply: self.q.put(self.value) def ack(self, force=False): - if self.state not in {"start", "taken"}: - raise exceptions.ControlException( - "Reply is {}, but expected it to be start or taken.".format(self.state) - ) self.send(self.obj, force) def kill(self, force=False): self.send(exceptions.Kill, force) def send(self, msg, force=False): + if self.state not in {"start", "taken"}: + raise exceptions.ControlException( + "Reply is {}, but expected it to be start or taken.".format(self.state) + ) if self.has_message and not force: raise exceptions.ControlException("There is already a reply message.") self.value = msg diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 111566b8d..944c032d8 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -1,13 +1,12 @@ import time +import typing # noqa import uuid -from mitmproxy import controller # noqa -from mitmproxy import stateobject from mitmproxy import connections +from mitmproxy import controller, exceptions # noqa +from mitmproxy import stateobject from mitmproxy import version -import typing # noqa - class Error(stateobject.StateObject): @@ -145,7 +144,11 @@ class Flow(stateobject.StateObject): @property def killable(self): - return self.reply and self.reply.state == "taken" + return ( + self.reply and + self.reply.state in {"start", "taken"} and + self.reply.value != exceptions.Kill + ) def kill(self): """ @@ -153,13 +156,7 @@ class Flow(stateobject.StateObject): """ self.error = Error("Connection killed") self.intercepted = False - - # reply.state should be "taken" here, or .take() will raise an - # exception. - if self.reply.state != "taken": - self.reply.take() self.reply.kill(force=True) - self.reply.commit() self.live = False def intercept(self): diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index a7acdc4db..d9389faf8 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -221,6 +221,25 @@ class TestSimple(_WebSocketTest): assert frame.payload == b'foo' +class TestKillFlow(_WebSocketTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + def test_kill(self): + class KillFlow: + def websocket_message(self, f): + f.kill() + + self.master.addons.add(KillFlow()) + self.setup_connection() + + with pytest.raises(exceptions.TcpDisconnect): + websockets.Frame.from_file(self.client.rfile) + + class TestSimpleTLS(_WebSocketTest): ssl = True