mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
make stateobject simpler and stricter
This commit is contained in:
parent
e9934cc008
commit
bdb763d9cf
@ -29,6 +29,7 @@ def convert_015_016(data):
|
||||
data[m]["http_version"] = data[m].pop("httpversion")
|
||||
if "msg" in data["response"]:
|
||||
data["response"]["reason"] = data["response"].pop("msg")
|
||||
data["request"].pop("form_out", None)
|
||||
data["version"] = (0, 16)
|
||||
return data
|
||||
|
||||
|
@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||
return self.ssl_established
|
||||
|
||||
_stateobject_attributes = dict(
|
||||
address=tcp.Address,
|
||||
clientcert=certutils.SSLCert,
|
||||
ssl_established=bool,
|
||||
timestamp_start=float,
|
||||
timestamp_end=float,
|
||||
timestamp_ssl_setup=float
|
||||
)
|
||||
|
||||
def get_state(self):
|
||||
d = super(ClientConnection, self).get_state()
|
||||
d.update(
|
||||
address=({
|
||||
"address": self.address(),
|
||||
"use_ipv6": self.address.use_ipv6} if self.address else {}),
|
||||
clientcert=self.cert.to_pem() if self.clientcert else None)
|
||||
return d
|
||||
|
||||
def load_state(self, state):
|
||||
super(ClientConnection, self).load_state(state)
|
||||
self.address = tcp.Address(
|
||||
**state["address"]) if state["address"] else None
|
||||
self.clientcert = certutils.SSLCert.from_pem(
|
||||
state["clientcert"]) if state["clientcert"] else None
|
||||
|
||||
def copy(self):
|
||||
return copy.copy(self)
|
||||
|
||||
@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
f = cls(None, tuple(), None)
|
||||
f.load_state(state)
|
||||
f.set_state(state)
|
||||
return f
|
||||
|
||||
def convert_to_ssl(self, *args, **kwargs):
|
||||
@ -131,31 +117,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||
sni=str
|
||||
)
|
||||
|
||||
def get_state(self):
|
||||
d = super(ServerConnection, self).get_state()
|
||||
d.update(
|
||||
address=({"address": self.address(),
|
||||
"use_ipv6": self.address.use_ipv6} if self.address else {}),
|
||||
source_address=({"address": self.source_address(),
|
||||
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
|
||||
cert=self.cert.to_pem() if self.cert else None
|
||||
)
|
||||
return d
|
||||
|
||||
def load_state(self, state):
|
||||
super(ServerConnection, self).load_state(state)
|
||||
|
||||
self.address = tcp.Address(
|
||||
**state["address"]) if state["address"] else None
|
||||
self.source_address = tcp.Address(
|
||||
**state["source_address"]) if state["source_address"] else None
|
||||
self.cert = certutils.SSLCert.from_pem(
|
||||
state["cert"]) if state["cert"] else None
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
f = cls(tuple())
|
||||
f.load_state(state)
|
||||
f.set_state(state)
|
||||
return f
|
||||
|
||||
def copy(self):
|
||||
|
@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
|
||||
# the default implementation assumes an empty constructor. Override
|
||||
# accordingly.
|
||||
f = cls(None)
|
||||
f.load_state(state)
|
||||
f.set_state(state)
|
||||
return f
|
||||
|
||||
def copy(self):
|
||||
@ -93,6 +93,12 @@ class Flow(stateobject.StateObject):
|
||||
d.update(backup=self._backup)
|
||||
return d
|
||||
|
||||
def set_state(self, state):
|
||||
state.pop("version")
|
||||
if "backup" in state:
|
||||
self._backup = state.pop("backup")
|
||||
super(Flow, self).set_state(state)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self is other
|
||||
|
||||
@ -130,7 +136,7 @@ class Flow(stateobject.StateObject):
|
||||
Revert to the last backed up state.
|
||||
"""
|
||||
if self._backup:
|
||||
self.load_state(self._backup)
|
||||
self.set_state(self._backup)
|
||||
self._backup = None
|
||||
|
||||
def kill(self, master):
|
||||
|
@ -1,6 +1,7 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import Cookie
|
||||
import copy
|
||||
import warnings
|
||||
from email.utils import parsedate_tz, formatdate, mktime_tz
|
||||
import time
|
||||
|
||||
@ -8,28 +9,12 @@ from libmproxy import utils
|
||||
from netlib import encoding
|
||||
from netlib.http import status_codes, Headers, Request, Response, decoded
|
||||
from netlib.tcp import Address
|
||||
from .. import version, stateobject
|
||||
from .. import version
|
||||
from .flow import Flow
|
||||
|
||||
|
||||
|
||||
class MessageMixin(stateobject.StateObject):
|
||||
|
||||
def get_state(self):
|
||||
state = vars(self.data).copy()
|
||||
state["headers"] = state["headers"].get_state()
|
||||
return state
|
||||
|
||||
def load_state(self, state):
|
||||
for k, v in state.items():
|
||||
if k == "headers":
|
||||
v = Headers.from_state(v)
|
||||
setattr(self.data, k, v)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
state["headers"] = Headers.from_state(state["headers"])
|
||||
return cls(**state)
|
||||
class MessageMixin(object):
|
||||
|
||||
def get_decoded_content(self):
|
||||
"""
|
||||
@ -136,6 +121,8 @@ class HTTPRequest(MessageMixin, Request):
|
||||
timestamp_end=None,
|
||||
form_out=None,
|
||||
is_replay=False,
|
||||
stickycookie=False,
|
||||
stickyauth=False,
|
||||
):
|
||||
Request.__init__(
|
||||
self,
|
||||
@ -154,21 +141,26 @@ class HTTPRequest(MessageMixin, Request):
|
||||
self.form_out = form_out or first_line_format # FIXME remove
|
||||
|
||||
# Have this request's cookies been modified by sticky cookies or auth?
|
||||
self.stickycookie = False
|
||||
self.stickyauth = False
|
||||
self.stickycookie = stickycookie
|
||||
self.stickyauth = stickyauth
|
||||
|
||||
# Is this request replayed?
|
||||
self.is_replay = is_replay
|
||||
|
||||
@classmethod
|
||||
def from_protocol(
|
||||
self,
|
||||
protocol,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
req = protocol.read_request(*args, **kwargs)
|
||||
return self.wrap(req)
|
||||
def get_state(self):
|
||||
state = super(HTTPRequest, self).get_state()
|
||||
state.update(
|
||||
stickycookie = self.stickycookie,
|
||||
stickyauth = self.stickyauth,
|
||||
is_replay = self.is_replay,
|
||||
)
|
||||
return state
|
||||
|
||||
def set_state(self, state):
|
||||
self.stickycookie = state.pop("stickycookie")
|
||||
self.stickyauth = state.pop("stickyauth")
|
||||
self.is_replay = state.pop("is_replay")
|
||||
super(HTTPRequest, self).set_state(state)
|
||||
|
||||
@classmethod
|
||||
def wrap(self, request):
|
||||
@ -188,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
|
||||
)
|
||||
return req
|
||||
|
||||
@property
|
||||
def form_out(self):
|
||||
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
|
||||
return self.first_line_format
|
||||
|
||||
@form_out.setter
|
||||
def form_out(self, value):
|
||||
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
|
||||
|
||||
def __hash__(self):
|
||||
return id(self)
|
||||
|
||||
@ -257,16 +258,6 @@ class HTTPResponse(MessageMixin, Response):
|
||||
self.is_replay = is_replay
|
||||
self.stream = False
|
||||
|
||||
@classmethod
|
||||
def from_protocol(
|
||||
self,
|
||||
protocol,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
resp = protocol.read_response(*args, **kwargs)
|
||||
return self.wrap(resp)
|
||||
|
||||
@classmethod
|
||||
def wrap(self, response):
|
||||
resp = HTTPResponse(
|
||||
@ -377,7 +368,7 @@ class HTTPFlow(Flow):
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
f = cls(None, None)
|
||||
f.load_state(state)
|
||||
f.set_state(state)
|
||||
return f
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -1,26 +1,25 @@
|
||||
from __future__ import absolute_import
|
||||
from netlib.utils import Serializable
|
||||
|
||||
|
||||
class StateObject(object):
|
||||
|
||||
class StateObject(Serializable):
|
||||
"""
|
||||
An object with serializable state.
|
||||
|
||||
State attributes can either be serializable types(str, tuple, bool, ...)
|
||||
or StateObject instances themselves.
|
||||
"""
|
||||
# An attribute-name -> class-or-type dict containing all attributes that
|
||||
# should be serialized. If the attribute is a class, it must implement the
|
||||
# StateObject protocol.
|
||||
_stateobject_attributes = None
|
||||
|
||||
def from_state(self, state):
|
||||
raise NotImplementedError()
|
||||
_stateobject_attributes = None
|
||||
"""
|
||||
An attribute-name -> class-or-type dict containing all attributes that
|
||||
should be serialized. If the attribute is a class, it must implement the
|
||||
Serializable protocol.
|
||||
"""
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
Retrieve object state. If short is true, return an abbreviated
|
||||
format with long data elided.
|
||||
Retrieve object state.
|
||||
"""
|
||||
state = {}
|
||||
for attr, cls in self._stateobject_attributes.iteritems():
|
||||
@ -31,18 +30,22 @@ class StateObject(object):
|
||||
state[attr] = val
|
||||
return state
|
||||
|
||||
def load_state(self, state):
|
||||
def set_state(self, state):
|
||||
"""
|
||||
Load object state from data returned by a get_state call.
|
||||
"""
|
||||
state = state.copy()
|
||||
for attr, cls in self._stateobject_attributes.iteritems():
|
||||
if state.get(attr, None) is None:
|
||||
setattr(self, attr, None)
|
||||
if state.get(attr) is None:
|
||||
setattr(self, attr, state.pop(attr))
|
||||
else:
|
||||
curr = getattr(self, attr)
|
||||
if hasattr(curr, "load_state"):
|
||||
curr.load_state(state[attr])
|
||||
if hasattr(curr, "set_state"):
|
||||
curr.set_state(state.pop(attr))
|
||||
elif hasattr(cls, "from_state"):
|
||||
setattr(self, attr, cls.from_state(state[attr]))
|
||||
else:
|
||||
setattr(self, attr, cls(state[attr]))
|
||||
obj = cls.from_state(state.pop(attr))
|
||||
setattr(self, attr, obj)
|
||||
else: # primitive types such as int, str, ...
|
||||
setattr(self, attr, cls(state.pop(attr)))
|
||||
if state:
|
||||
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))
|
||||
|
@ -170,7 +170,7 @@ class FlowHandler(RequestHandler):
|
||||
elif k == "port":
|
||||
request.port = int(v)
|
||||
elif k == "headers":
|
||||
request.headers.load_state(v)
|
||||
request.headers.set_state(v)
|
||||
else:
|
||||
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
|
||||
|
||||
@ -184,7 +184,7 @@ class FlowHandler(RequestHandler):
|
||||
elif k == "http_version":
|
||||
response.http_version = str(v)
|
||||
elif k == "headers":
|
||||
response.headers.load_state(v)
|
||||
response.headers.set_state(v)
|
||||
else:
|
||||
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
|
||||
else:
|
||||
|
@ -422,7 +422,7 @@ class TestFlow(object):
|
||||
assert not f == f2
|
||||
f2.error = Error("e2")
|
||||
assert not f == f2
|
||||
f.load_state(f2.get_state())
|
||||
f.set_state(f2.get_state())
|
||||
assert f.get_state() == f2.get_state()
|
||||
|
||||
def test_kill(self):
|
||||
@ -1204,7 +1204,7 @@ class TestError:
|
||||
|
||||
e2 = Error("bar")
|
||||
assert not e == e2
|
||||
e.load_state(e2.get_state())
|
||||
e.set_state(e2.get_state())
|
||||
assert e.get_state() == e2.get_state()
|
||||
|
||||
e3 = e.copy()
|
||||
@ -1224,7 +1224,7 @@ class TestClientConnection:
|
||||
assert not c == c2
|
||||
|
||||
c2.timestamp_start = 42
|
||||
c.load_state(c2.get_state())
|
||||
c.set_state(c2.get_state())
|
||||
assert c.timestamp_start == 42
|
||||
|
||||
c3 = c.copy()
|
||||
|
@ -76,7 +76,11 @@ def tclient_conn():
|
||||
"""
|
||||
c = ClientConnection.from_state(dict(
|
||||
address=dict(address=("address", 22), use_ipv6=True),
|
||||
clientcert=None
|
||||
clientcert=None,
|
||||
ssl_established=False,
|
||||
timestamp_start=1,
|
||||
timestamp_ssl_setup=2,
|
||||
timestamp_end=3,
|
||||
))
|
||||
c.reply = controller.DummyReply()
|
||||
return c
|
||||
@ -88,9 +92,15 @@ def tserver_conn():
|
||||
"""
|
||||
c = ServerConnection.from_state(dict(
|
||||
address=dict(address=("address", 22), use_ipv6=True),
|
||||
state=[],
|
||||
source_address=dict(address=("address", 22), use_ipv6=True),
|
||||
cert=None
|
||||
cert=None,
|
||||
timestamp_start=1,
|
||||
timestamp_tcp_setup=2,
|
||||
timestamp_ssl_setup=3,
|
||||
timestamp_end=4,
|
||||
ssl_established=False,
|
||||
sni="address",
|
||||
via=None
|
||||
))
|
||||
c.reply = controller.DummyReply()
|
||||
return c
|
||||
|
Loading…
Reference in New Issue
Block a user