diff --git a/mitmproxy/addons/clientplayback.py b/mitmproxy/addons/clientplayback.py index 7bdaeb33b..7adefd7aa 100644 --- a/mitmproxy/addons/clientplayback.py +++ b/mitmproxy/addons/clientplayback.py @@ -1,23 +1,23 @@ import queue import threading -import typing import time +import typing -from mitmproxy import log -from mitmproxy import controller -from mitmproxy import exceptions -from mitmproxy import http -from mitmproxy import flow -from mitmproxy import options +import mitmproxy.types +from mitmproxy import command from mitmproxy import connections +from mitmproxy import controller +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import http +from mitmproxy import io +from mitmproxy import log +from mitmproxy import options +from mitmproxy.coretypes import basethread from mitmproxy.net import server_spec, tls from mitmproxy.net.http import http1 -from mitmproxy.coretypes import basethread from mitmproxy.utils import human -from mitmproxy import ctx -from mitmproxy import io -from mitmproxy import command -import mitmproxy.types class RequestReplayThread(basethread.BaseThread): @@ -117,7 +117,7 @@ class RequestReplayThread(basethread.BaseThread): finally: r.first_line_format = first_line_format_backup f.live = False - if server.connected(): + if server and server.connected(): server.finish() server.close() diff --git a/mitmproxy/addons/serverplayback.py b/mitmproxy/addons/serverplayback.py index 18bc35453..7f642585b 100644 --- a/mitmproxy/addons/serverplayback.py +++ b/mitmproxy/addons/serverplayback.py @@ -1,16 +1,19 @@ import hashlib -import urllib import typing +import urllib -from mitmproxy import ctx -from mitmproxy import flow -from mitmproxy import exceptions -from mitmproxy import io -from mitmproxy import command import mitmproxy.types +from mitmproxy import command +from mitmproxy import ctx, http +from mitmproxy import exceptions +from mitmproxy import flow +from mitmproxy import io class ServerPlayback: + flowmap: typing.Dict[typing.Hashable, typing.List[http.HTTPFlow]] + configured: bool + def __init__(self): self.flowmap = {} self.configured = False @@ -82,10 +85,10 @@ class ServerPlayback: Replay server responses from flows. """ self.flowmap = {} - for i in flows: - if i.response: # type: ignore - l = self.flowmap.setdefault(self._hash(i), []) - l.append(i) + for f in flows: + if isinstance(f, http.HTTPFlow): + lst = self.flowmap.setdefault(self._hash(f), []) + lst.append(f) ctx.master.addons.trigger("update", []) @command.command("replay.server.file") @@ -108,12 +111,11 @@ class ServerPlayback: def count(self) -> int: return sum([len(i) for i in self.flowmap.values()]) - def _hash(self, flow): + def _hash(self, flow: http.HTTPFlow) -> typing.Hashable: """ Calculates a loose hash of the flow request. """ r = flow.request - _, _, path, _, query, _ = urllib.parse.urlparse(r.url) queriesArray = urllib.parse.parse_qsl(query, keep_blank_values=True) @@ -158,20 +160,32 @@ class ServerPlayback: repr(key).encode("utf8", "surrogateescape") ).digest() - def next_flow(self, request): + def next_flow(self, flow: http.HTTPFlow) -> typing.Optional[http.HTTPFlow]: """ Returns the next flow object, or None if no matching flow was found. """ - hsh = self._hash(request) - if hsh in self.flowmap: + hash = self._hash(flow) + if hash in self.flowmap: if ctx.options.server_replay_nopop: - return self.flowmap[hsh][0] + return next(( + flow + for flow in self.flowmap[hash] + if flow.response + ), None) else: - ret = self.flowmap[hsh].pop(0) - if not self.flowmap[hsh]: - del self.flowmap[hsh] + ret = self.flowmap[hash].pop(0) + while not ret.response: + if self.flowmap[hash]: + ret = self.flowmap[hash].pop(0) + else: + del self.flowmap[hash] + return None + if not self.flowmap[hash]: + del self.flowmap[hash] return ret + else: + return None def configure(self, updated): if not self.configured and ctx.options.server_replay: @@ -182,10 +196,11 @@ class ServerPlayback: raise exceptions.OptionsError(str(e)) self.load_flows(flows) - def request(self, f): + def request(self, f: http.HTTPFlow) -> None: if self.flowmap: rflow = self.next_flow(f) if rflow: + assert rflow.response response = rflow.response.copy() response.is_replay = True if ctx.options.server_replay_refresh: @@ -197,4 +212,5 @@ class ServerPlayback: f.request.url ) ) + assert f.reply f.reply.kill() diff --git a/test/mitmproxy/addons/test_serverplayback.py b/test/mitmproxy/addons/test_serverplayback.py index c6a0c1f48..2e42fa030 100644 --- a/test/mitmproxy/addons/test_serverplayback.py +++ b/test/mitmproxy/addons/test_serverplayback.py @@ -1,13 +1,13 @@ import urllib + import pytest -from mitmproxy.test import taddons -from mitmproxy.test import tflow - import mitmproxy.test.tutils -from mitmproxy.addons import serverplayback from mitmproxy import exceptions from mitmproxy import io +from mitmproxy.addons import serverplayback +from mitmproxy.test import taddons +from mitmproxy.test import tflow def tdump(path, flows): @@ -321,7 +321,7 @@ def test_server_playback_full(): with taddons.context(s) as tctx: tctx.configure( s, - server_replay_refresh = True, + server_replay_refresh=True, ) f = tflow.tflow() @@ -345,7 +345,7 @@ def test_server_playback_kill(): with taddons.context(s) as tctx: tctx.configure( s, - server_replay_refresh = True, + server_replay_refresh=True, server_replay_kill_extra=True ) @@ -357,3 +357,25 @@ def test_server_playback_kill(): f.request.host = "nonexistent" tctx.cycle(s, f) assert f.reply.value == exceptions.Kill + + +def test_server_playback_response_deleted(): + """ + The server playback addon holds references to flows that can be modified by the user in the meantime. + One thing that can happen is that users remove the response object. This happens for example when doing a client + replay at the same time. + """ + sp = serverplayback.ServerPlayback() + with taddons.context(sp) as tctx: + tctx.configure(sp) + f1 = tflow.tflow(resp=True) + f2 = tflow.tflow(resp=True) + + assert not sp.flowmap + + sp.load_flows([f1, f2]) + assert sp.flowmap + + f1.response = f2.response = None + assert not sp.next_flow(f1) + assert not sp.flowmap