mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 10:16:27 +00:00
Implement state loading that doesn't change object identity.
We need this to let us load state from copied Flows returned from scripts.
This commit is contained in:
parent
58fc0041fa
commit
9c5c3c2b1a
@ -99,11 +99,26 @@ class Flow:
|
||||
|
||||
def load_state(self, state):
|
||||
self._backup = state["backup"]
|
||||
if self.request:
|
||||
self.request.load_state(state["request"])
|
||||
else:
|
||||
self.request = proxy.Request.from_state(state["request"])
|
||||
|
||||
if state["response"]:
|
||||
if self.response:
|
||||
self.response.load_state(state["response"])
|
||||
else:
|
||||
self.response = proxy.Response.from_state(self.request, state["response"])
|
||||
else:
|
||||
self.response = None
|
||||
|
||||
if state["error"]:
|
||||
if self.error:
|
||||
self.error.load_state(state["error"])
|
||||
else:
|
||||
self.error = proxy.Error.from_state(state["error"])
|
||||
else:
|
||||
self.error = None
|
||||
|
||||
@classmethod
|
||||
def from_state(klass, state):
|
||||
|
@ -148,6 +148,23 @@ class Request(controller.Msg):
|
||||
def is_cached(self):
|
||||
return False
|
||||
|
||||
def load_state(self, state):
|
||||
if state["client_conn"]:
|
||||
if self.client_conn:
|
||||
self.client_conn.load_state(state["client_conn"])
|
||||
else:
|
||||
self.client_conn = ClientConnect.from_state(state["client_conn"])
|
||||
else:
|
||||
self.client_conn = None
|
||||
self.host = state["host"]
|
||||
self.port = state["port"]
|
||||
self.scheme = state["scheme"]
|
||||
self.method = state["method"]
|
||||
self.path = state["path"]
|
||||
self.headers = utils.Headers.from_state(state["headers"])
|
||||
self.content = base64.decodestring(state["content"])
|
||||
self.timestamp = state["timestamp"]
|
||||
|
||||
def get_state(self):
|
||||
return dict(
|
||||
client_conn = self.client_conn.get_state() if self.client_conn else None,
|
||||
@ -164,7 +181,7 @@ class Request(controller.Msg):
|
||||
@classmethod
|
||||
def from_state(klass, state):
|
||||
return klass(
|
||||
ClientConnect.from_state(state["client_conn"]) if state["client_conn"] else None,
|
||||
ClientConnect.from_state(state["client_conn"]),
|
||||
state["host"],
|
||||
state["port"],
|
||||
state["scheme"],
|
||||
@ -249,6 +266,13 @@ class Response(controller.Msg):
|
||||
self.cached = False
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
def load_state(self, state):
|
||||
self.code = state["code"]
|
||||
self.msg = state["msg"]
|
||||
self.headers = utils.Headers.from_state(state["headers"])
|
||||
self.content = base64.decodestring(state["content"])
|
||||
self.timestamp = state["timestamp"]
|
||||
|
||||
def get_state(self):
|
||||
return dict(
|
||||
code = self.code,
|
||||
@ -325,12 +349,21 @@ class ClientConnect(controller.Msg):
|
||||
self.close = False
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.get_state() == other.get_state()
|
||||
|
||||
def load_state(self, state):
|
||||
self.address = state
|
||||
|
||||
def get_state(self):
|
||||
return list(self.address) if self.address else None
|
||||
|
||||
@classmethod
|
||||
def from_state(klass, state):
|
||||
if state:
|
||||
return klass(state)
|
||||
else:
|
||||
return None
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self)
|
||||
@ -342,6 +375,10 @@ class Error(controller.Msg):
|
||||
self.timestamp = timestamp or time.time()
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
def load_state(self, state):
|
||||
self.msg = state["msg"]
|
||||
self.timestamp = state["timestamp"]
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self)
|
||||
|
||||
|
@ -234,6 +234,11 @@ class uRequest(libpry.AutoTree):
|
||||
state = r.get_state()
|
||||
assert proxy.Request.from_state(state) == r
|
||||
|
||||
r2 = proxy.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test")
|
||||
assert not r == r2
|
||||
r.load_state(r2.get_state())
|
||||
assert r == r2
|
||||
|
||||
|
||||
class uResponse(libpry.AutoTree):
|
||||
def test_simple(self):
|
||||
@ -256,6 +261,11 @@ class uResponse(libpry.AutoTree):
|
||||
state = resp.get_state()
|
||||
assert proxy.Response.from_state(req, state) == resp
|
||||
|
||||
resp2 = proxy.Response(req, 220, "foo", h.copy(), "test")
|
||||
assert not resp == resp2
|
||||
resp.load_state(resp2.get_state())
|
||||
assert resp == resp2
|
||||
|
||||
|
||||
class uError(libpry.AutoTree):
|
||||
def test_getset_state(self):
|
||||
@ -265,6 +275,12 @@ class uError(libpry.AutoTree):
|
||||
|
||||
assert e.copy()
|
||||
|
||||
e2 = proxy.Error(None, "bar")
|
||||
assert not e == e2
|
||||
e.load_state(e2.get_state())
|
||||
assert e == e2
|
||||
|
||||
|
||||
|
||||
class uProxyError(libpry.AutoTree):
|
||||
def test_simple(self):
|
||||
@ -272,6 +288,18 @@ class uProxyError(libpry.AutoTree):
|
||||
assert repr(p)
|
||||
|
||||
|
||||
class uClientConnect(libpry.AutoTree):
|
||||
def test_state(self):
|
||||
c = proxy.ClientConnect(("a", 22))
|
||||
assert proxy.ClientConnect.from_state(c.get_state()) == c
|
||||
|
||||
c2 = proxy.ClientConnect(("a", 25))
|
||||
assert not c == c2
|
||||
|
||||
c.load_state(c2.get_state())
|
||||
assert c == c2
|
||||
|
||||
|
||||
|
||||
tests = [
|
||||
uProxyError(),
|
||||
@ -281,8 +309,9 @@ tests = [
|
||||
u_parse_request_line(),
|
||||
u_parse_url(),
|
||||
uError(),
|
||||
uClientConnect(),
|
||||
_TestServers(), [
|
||||
uSanity(),
|
||||
uProxy(),
|
||||
]
|
||||
],
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user