Implement state loading that doesn't change object identity.

We need this to let us load state from copied Flows returned from scripts.
This commit is contained in:
Aldo Cortesi 2011-02-20 09:36:13 +13:00
parent 58fc0041fa
commit 9c5c3c2b1a
3 changed files with 87 additions and 6 deletions

View File

@ -99,11 +99,26 @@ class Flow:
def load_state(self, state): def load_state(self, state):
self._backup = state["backup"] self._backup = state["backup"]
if self.request:
self.request.load_state(state["request"])
else:
self.request = proxy.Request.from_state(state["request"]) self.request = proxy.Request.from_state(state["request"])
if state["response"]: if state["response"]:
if self.response:
self.response.load_state(state["response"])
else:
self.response = proxy.Response.from_state(self.request, state["response"]) self.response = proxy.Response.from_state(self.request, state["response"])
else:
self.response = None
if state["error"]: if state["error"]:
if self.error:
self.error.load_state(state["error"])
else:
self.error = proxy.Error.from_state(state["error"]) self.error = proxy.Error.from_state(state["error"])
else:
self.error = None
@classmethod @classmethod
def from_state(klass, state): def from_state(klass, state):

View File

@ -148,6 +148,23 @@ class Request(controller.Msg):
def is_cached(self): def is_cached(self):
return False return False
def load_state(self, state):
if state["client_conn"]:
if self.client_conn:
self.client_conn.load_state(state["client_conn"])
else:
self.client_conn = ClientConnect.from_state(state["client_conn"])
else:
self.client_conn = None
self.host = state["host"]
self.port = state["port"]
self.scheme = state["scheme"]
self.method = state["method"]
self.path = state["path"]
self.headers = utils.Headers.from_state(state["headers"])
self.content = base64.decodestring(state["content"])
self.timestamp = state["timestamp"]
def get_state(self): def get_state(self):
return dict( return dict(
client_conn = self.client_conn.get_state() if self.client_conn else None, client_conn = self.client_conn.get_state() if self.client_conn else None,
@ -164,7 +181,7 @@ class Request(controller.Msg):
@classmethod @classmethod
def from_state(klass, state): def from_state(klass, state):
return klass( return klass(
ClientConnect.from_state(state["client_conn"]) if state["client_conn"] else None, ClientConnect.from_state(state["client_conn"]),
state["host"], state["host"],
state["port"], state["port"],
state["scheme"], state["scheme"],
@ -249,6 +266,13 @@ class Response(controller.Msg):
self.cached = False self.cached = False
controller.Msg.__init__(self) controller.Msg.__init__(self)
def load_state(self, state):
self.code = state["code"]
self.msg = state["msg"]
self.headers = utils.Headers.from_state(state["headers"])
self.content = base64.decodestring(state["content"])
self.timestamp = state["timestamp"]
def get_state(self): def get_state(self):
return dict( return dict(
code = self.code, code = self.code,
@ -325,12 +349,21 @@ class ClientConnect(controller.Msg):
self.close = False self.close = False
controller.Msg.__init__(self) controller.Msg.__init__(self)
def __eq__(self, other):
return self.get_state() == other.get_state()
def load_state(self, state):
self.address = state
def get_state(self): def get_state(self):
return list(self.address) if self.address else None return list(self.address) if self.address else None
@classmethod @classmethod
def from_state(klass, state): def from_state(klass, state):
if state:
return klass(state) return klass(state)
else:
return None
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
@ -342,6 +375,10 @@ class Error(controller.Msg):
self.timestamp = timestamp or time.time() self.timestamp = timestamp or time.time()
controller.Msg.__init__(self) controller.Msg.__init__(self)
def load_state(self, state):
self.msg = state["msg"]
self.timestamp = state["timestamp"]
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)

View File

@ -234,6 +234,11 @@ class uRequest(libpry.AutoTree):
state = r.get_state() state = r.get_state()
assert proxy.Request.from_state(state) == r assert proxy.Request.from_state(state) == r
r2 = proxy.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test")
assert not r == r2
r.load_state(r2.get_state())
assert r == r2
class uResponse(libpry.AutoTree): class uResponse(libpry.AutoTree):
def test_simple(self): def test_simple(self):
@ -256,6 +261,11 @@ class uResponse(libpry.AutoTree):
state = resp.get_state() state = resp.get_state()
assert proxy.Response.from_state(req, state) == resp assert proxy.Response.from_state(req, state) == resp
resp2 = proxy.Response(req, 220, "foo", h.copy(), "test")
assert not resp == resp2
resp.load_state(resp2.get_state())
assert resp == resp2
class uError(libpry.AutoTree): class uError(libpry.AutoTree):
def test_getset_state(self): def test_getset_state(self):
@ -265,6 +275,12 @@ class uError(libpry.AutoTree):
assert e.copy() assert e.copy()
e2 = proxy.Error(None, "bar")
assert not e == e2
e.load_state(e2.get_state())
assert e == e2
class uProxyError(libpry.AutoTree): class uProxyError(libpry.AutoTree):
def test_simple(self): def test_simple(self):
@ -272,6 +288,18 @@ class uProxyError(libpry.AutoTree):
assert repr(p) assert repr(p)
class uClientConnect(libpry.AutoTree):
def test_state(self):
c = proxy.ClientConnect(("a", 22))
assert proxy.ClientConnect.from_state(c.get_state()) == c
c2 = proxy.ClientConnect(("a", 25))
assert not c == c2
c.load_state(c2.get_state())
assert c == c2
tests = [ tests = [
uProxyError(), uProxyError(),
@ -281,8 +309,9 @@ tests = [
u_parse_request_line(), u_parse_request_line(),
u_parse_url(), u_parse_url(),
uError(), uError(),
uClientConnect(),
_TestServers(), [ _TestServers(), [
uSanity(), uSanity(),
uProxy(), uProxy(),
] ],
] ]