diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 3ae8abad9..ca3639dae 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -98,6 +98,8 @@ class ReplayHandler(server.ConnectionHandler): data.reply = AsyncReply(data) await ctx.master.addons.handle_lifecycle(hook) await data.reply.done.wait() + if isinstance(data, flow.Flow): + await data.wait_for_resume() if isinstance(hook, (layers.http.HttpResponseHook, layers.http.HttpErrorHook)): if self.transports: # close server connections diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index dcf8ec32a..df81b201f 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -51,6 +51,8 @@ class ProxyConnectionHandler(server.StreamConnectionHandler): await self.master.addons.handle_lifecycle(hook) await data.reply.done.wait() data.reply = None + if isinstance(data, flow.Flow): + await data.wait_for_resume() def log(self, message: str, level: str = "info") -> None: x = log.LogEntry(self.log_prefix + message, level) diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index 3fa0e914b..fab447cdc 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -1,3 +1,4 @@ +import asyncio import time import typing # noqa import uuid @@ -118,6 +119,7 @@ class Flow(stateobject.StateObject): self.live = live self.intercepted: bool = False + self._resume_event: typing.Optional[asyncio.Event] = None self._backup: typing.Optional[Flow] = None self.reply: typing.Optional[controller.Reply] = None self.marked: str = "" @@ -213,7 +215,18 @@ class Flow(stateobject.StateObject): if self.intercepted: return self.intercepted = True - self.reply.take() + if self._resume_event is not None: + self._resume_event.clear() + + async def wait_for_resume(self): + """ + Wait until this Flow is resumed. + """ + if not self.intercepted: + return + if self._resume_event is None: + self._resume_event = asyncio.Event() + await self._resume_event.wait() def resume(self): """ @@ -222,9 +235,8 @@ class Flow(stateobject.StateObject): if not self.intercepted: return self.intercepted = False - # If a flow is intercepted and then duplicated, the duplicated one is not taken. - if self.reply.state == "taken": - self.reply.commit() + if self._resume_event is not None: + self._resume_event.set() @property def timestamp_start(self) -> float: diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index f4bcab989..3cfd5e83c 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -1,3 +1,4 @@ +import asyncio import email import time import json @@ -719,16 +720,46 @@ class TestHTTPFlow: def test_intercept(self): f = tflow() f.intercept() - assert f.reply.state == "taken" + assert f.intercepted f.intercept() - assert f.reply.state == "taken" + assert f.intercepted def test_resume(self): f = tflow() - f.intercept() - assert f.reply.state == "taken" f.resume() - assert f.reply.state == "committed" + assert not f.intercepted + f.intercept() + assert f.intercepted + f.resume() + assert not f.intercepted + + @pytest.mark.asyncio + async def test_wait_for_resume(self): + f = tflow() + await f.wait_for_resume() + + f = tflow() + f.intercept() + f.resume() + await f.wait_for_resume() + + f = tflow() + f.intercept() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(f.wait_for_resume(), 0.2) + f.resume() + await f.wait_for_resume() + + f = tflow() + f.intercept() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(f.wait_for_resume(), 0.2) + f.resume() + f.intercept() + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(f.wait_for_resume(), 0.2) + f.resume() + await f.wait_for_resume() def test_resume_duplicated(self): f = tflow()