mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-30 03:14:22 +00:00
Implement state loading that doesn't change object identity.
We need this to let us load state from copied Flows returned from scripts.
This commit is contained in:
parent
58fc0041fa
commit
9c5c3c2b1a
@ -99,11 +99,26 @@ class Flow:
|
|||||||
|
|
||||||
def load_state(self, state):
|
def load_state(self, state):
|
||||||
self._backup = state["backup"]
|
self._backup = state["backup"]
|
||||||
|
if self.request:
|
||||||
|
self.request.load_state(state["request"])
|
||||||
|
else:
|
||||||
self.request = proxy.Request.from_state(state["request"])
|
self.request = proxy.Request.from_state(state["request"])
|
||||||
|
|
||||||
if state["response"]:
|
if state["response"]:
|
||||||
|
if self.response:
|
||||||
|
self.response.load_state(state["response"])
|
||||||
|
else:
|
||||||
self.response = proxy.Response.from_state(self.request, state["response"])
|
self.response = proxy.Response.from_state(self.request, state["response"])
|
||||||
|
else:
|
||||||
|
self.response = None
|
||||||
|
|
||||||
if state["error"]:
|
if state["error"]:
|
||||||
|
if self.error:
|
||||||
|
self.error.load_state(state["error"])
|
||||||
|
else:
|
||||||
self.error = proxy.Error.from_state(state["error"])
|
self.error = proxy.Error.from_state(state["error"])
|
||||||
|
else:
|
||||||
|
self.error = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(klass, state):
|
def from_state(klass, state):
|
||||||
|
@ -148,6 +148,23 @@ class Request(controller.Msg):
|
|||||||
def is_cached(self):
|
def is_cached(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
if state["client_conn"]:
|
||||||
|
if self.client_conn:
|
||||||
|
self.client_conn.load_state(state["client_conn"])
|
||||||
|
else:
|
||||||
|
self.client_conn = ClientConnect.from_state(state["client_conn"])
|
||||||
|
else:
|
||||||
|
self.client_conn = None
|
||||||
|
self.host = state["host"]
|
||||||
|
self.port = state["port"]
|
||||||
|
self.scheme = state["scheme"]
|
||||||
|
self.method = state["method"]
|
||||||
|
self.path = state["path"]
|
||||||
|
self.headers = utils.Headers.from_state(state["headers"])
|
||||||
|
self.content = base64.decodestring(state["content"])
|
||||||
|
self.timestamp = state["timestamp"]
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return dict(
|
return dict(
|
||||||
client_conn = self.client_conn.get_state() if self.client_conn else None,
|
client_conn = self.client_conn.get_state() if self.client_conn else None,
|
||||||
@ -164,7 +181,7 @@ class Request(controller.Msg):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_state(klass, state):
|
def from_state(klass, state):
|
||||||
return klass(
|
return klass(
|
||||||
ClientConnect.from_state(state["client_conn"]) if state["client_conn"] else None,
|
ClientConnect.from_state(state["client_conn"]),
|
||||||
state["host"],
|
state["host"],
|
||||||
state["port"],
|
state["port"],
|
||||||
state["scheme"],
|
state["scheme"],
|
||||||
@ -249,6 +266,13 @@ class Response(controller.Msg):
|
|||||||
self.cached = False
|
self.cached = False
|
||||||
controller.Msg.__init__(self)
|
controller.Msg.__init__(self)
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
self.code = state["code"]
|
||||||
|
self.msg = state["msg"]
|
||||||
|
self.headers = utils.Headers.from_state(state["headers"])
|
||||||
|
self.content = base64.decodestring(state["content"])
|
||||||
|
self.timestamp = state["timestamp"]
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return dict(
|
return dict(
|
||||||
code = self.code,
|
code = self.code,
|
||||||
@ -325,12 +349,21 @@ class ClientConnect(controller.Msg):
|
|||||||
self.close = False
|
self.close = False
|
||||||
controller.Msg.__init__(self)
|
controller.Msg.__init__(self)
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.get_state() == other.get_state()
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
self.address = state
|
||||||
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return list(self.address) if self.address else None
|
return list(self.address) if self.address else None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(klass, state):
|
def from_state(klass, state):
|
||||||
|
if state:
|
||||||
return klass(state)
|
return klass(state)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return copy.copy(self)
|
return copy.copy(self)
|
||||||
@ -342,6 +375,10 @@ class Error(controller.Msg):
|
|||||||
self.timestamp = timestamp or time.time()
|
self.timestamp = timestamp or time.time()
|
||||||
controller.Msg.__init__(self)
|
controller.Msg.__init__(self)
|
||||||
|
|
||||||
|
def load_state(self, state):
|
||||||
|
self.msg = state["msg"]
|
||||||
|
self.timestamp = state["timestamp"]
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
return copy.copy(self)
|
return copy.copy(self)
|
||||||
|
|
||||||
|
@ -234,6 +234,11 @@ class uRequest(libpry.AutoTree):
|
|||||||
state = r.get_state()
|
state = r.get_state()
|
||||||
assert proxy.Request.from_state(state) == r
|
assert proxy.Request.from_state(state) == r
|
||||||
|
|
||||||
|
r2 = proxy.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test")
|
||||||
|
assert not r == r2
|
||||||
|
r.load_state(r2.get_state())
|
||||||
|
assert r == r2
|
||||||
|
|
||||||
|
|
||||||
class uResponse(libpry.AutoTree):
|
class uResponse(libpry.AutoTree):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
@ -256,6 +261,11 @@ class uResponse(libpry.AutoTree):
|
|||||||
state = resp.get_state()
|
state = resp.get_state()
|
||||||
assert proxy.Response.from_state(req, state) == resp
|
assert proxy.Response.from_state(req, state) == resp
|
||||||
|
|
||||||
|
resp2 = proxy.Response(req, 220, "foo", h.copy(), "test")
|
||||||
|
assert not resp == resp2
|
||||||
|
resp.load_state(resp2.get_state())
|
||||||
|
assert resp == resp2
|
||||||
|
|
||||||
|
|
||||||
class uError(libpry.AutoTree):
|
class uError(libpry.AutoTree):
|
||||||
def test_getset_state(self):
|
def test_getset_state(self):
|
||||||
@ -265,6 +275,12 @@ class uError(libpry.AutoTree):
|
|||||||
|
|
||||||
assert e.copy()
|
assert e.copy()
|
||||||
|
|
||||||
|
e2 = proxy.Error(None, "bar")
|
||||||
|
assert not e == e2
|
||||||
|
e.load_state(e2.get_state())
|
||||||
|
assert e == e2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class uProxyError(libpry.AutoTree):
|
class uProxyError(libpry.AutoTree):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
@ -272,6 +288,18 @@ class uProxyError(libpry.AutoTree):
|
|||||||
assert repr(p)
|
assert repr(p)
|
||||||
|
|
||||||
|
|
||||||
|
class uClientConnect(libpry.AutoTree):
|
||||||
|
def test_state(self):
|
||||||
|
c = proxy.ClientConnect(("a", 22))
|
||||||
|
assert proxy.ClientConnect.from_state(c.get_state()) == c
|
||||||
|
|
||||||
|
c2 = proxy.ClientConnect(("a", 25))
|
||||||
|
assert not c == c2
|
||||||
|
|
||||||
|
c.load_state(c2.get_state())
|
||||||
|
assert c == c2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
tests = [
|
tests = [
|
||||||
uProxyError(),
|
uProxyError(),
|
||||||
@ -281,8 +309,9 @@ tests = [
|
|||||||
u_parse_request_line(),
|
u_parse_request_line(),
|
||||||
u_parse_url(),
|
u_parse_url(),
|
||||||
uError(),
|
uError(),
|
||||||
|
uClientConnect(),
|
||||||
_TestServers(), [
|
_TestServers(), [
|
||||||
uSanity(),
|
uSanity(),
|
||||||
uProxy(),
|
uProxy(),
|
||||||
]
|
],
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user