diff --git a/libmproxy/flow.py b/libmproxy/flow.py index e88b8f168..ffcbed63a 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -92,21 +92,28 @@ class Flow: ) return bson.dumps(data) - def get_state(self): - return dict( + def get_state(self, nobackup=False): + d = dict( request = self.request.get_state() if self.request else None, response = self.response.get_state() if self.response else None, error = self.error.get_state() if self.error else None, + client_conn = self.client_conn.get_state() ) + if nobackup: + d["backup"] = None + else: + d["backup"] = self._backup + return d def load_state(self, state): + self.client_conn = proxy.ClientConnection.from_state(state["client_conn"]) + self._backup = state["backup"] if state["request"]: - self.request = proxy.Request.from_state(state["request"]) + self.request = proxy.Request.from_state(self.client_conn, state["request"]) if state["response"]: self.response = proxy.Response.from_state(self.request, state["response"]) if state["error"]: self.error = proxy.Error.from_state(state["error"]) - self.client_conn = self.request.client_conn @classmethod def from_state(klass, state): @@ -126,18 +133,11 @@ class Flow: return False def backup(self): - if not self._backup: - self._backup = [ - self.client_conn.copy(), - self.request.copy() if self.request else None, - self.response.copy() if self.response else None, - self.error.copy() if self.error else None, - ] + self._backup = self.get_state(nobackup=True) def revert(self): if self._backup: - restore = [i.copy() if i else None for i in self._backup] - self.client_conn, self.request, self.response, self.error = restore + self.load_state(self._backup) self._backup = None def match(self, pattern): diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 938c4d217..4c29d7474 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -104,9 +104,9 @@ class Request(controller.Msg): ) @classmethod - def from_state(klass, state): + def from_state(klass, client_conn, state): return klass( - ClientConnection(None), + client_conn, state["host"], state["port"], state["scheme"], @@ -230,6 +230,13 @@ class ClientConnection(controller.Msg): self.address = address controller.Msg.__init__(self) + def get_state(self): + return self.address + + @classmethod + def from_state(klass, state): + return klass(state) + def set_replay(self): self.address = None diff --git a/test/test_flow.py b/test/test_flow.py index 9629934fb..b71ce6afb 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -32,10 +32,15 @@ class uFlow(libpry.AutoTree): def test_backup(self): f = utils.tflow() + f.response = utils.tresp() + f.request = f.response.request + f.request.content = "foo" assert not f.modified() f.backup() + f.request.content = "bar" assert f.modified() f.revert() + assert f.request.content == "foo" def test_getset_state(self): f = utils.tflow() diff --git a/test/test_proxy.py b/test/test_proxy.py index 340b6697d..e343e6934 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -236,7 +236,7 @@ class uRequest(libpry.AutoTree): c = proxy.ClientConnection(("addr", 2222)) r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") state = r.get_state() - assert proxy.Request.from_state(state) == r + assert proxy.Request.from_state(c, state) == r class uResponse(libpry.AutoTree):