mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
simplify state management
This commit is contained in:
parent
cd744592f6
commit
e9934cc008
@ -21,9 +21,22 @@ def convert_014_015(data):
|
||||
return data
|
||||
|
||||
|
||||
def convert_015_016(data):
|
||||
for m in ("request", "response"):
|
||||
if "body" in data[m]:
|
||||
data[m]["content"] = data[m].pop("body")
|
||||
if "httpversion" in data[m]:
|
||||
data[m]["http_version"] = data[m].pop("httpversion")
|
||||
if "msg" in data["response"]:
|
||||
data["response"]["reason"] = data["response"].pop("msg")
|
||||
data["version"] = (0, 16)
|
||||
return data
|
||||
|
||||
|
||||
converters = {
|
||||
(0, 13): convert_013_014,
|
||||
(0, 14): convert_014_015,
|
||||
(0, 15): convert_015_016,
|
||||
}
|
||||
|
||||
|
||||
|
@ -48,8 +48,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
|
||||
timestamp_ssl_setup=float
|
||||
)
|
||||
|
||||
def get_state(self, short=False):
|
||||
d = super(ClientConnection, self).get_state(short)
|
||||
def get_state(self):
|
||||
d = super(ClientConnection, self).get_state()
|
||||
d.update(
|
||||
address=({
|
||||
"address": self.address(),
|
||||
@ -130,10 +130,9 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
|
||||
ssl_established=bool,
|
||||
sni=str
|
||||
)
|
||||
_stateobject_long_attributes = {"cert"}
|
||||
|
||||
def get_state(self, short=False):
|
||||
d = super(ServerConnection, self).get_state(short)
|
||||
def get_state(self):
|
||||
d = super(ServerConnection, self).get_state()
|
||||
d.update(
|
||||
address=({"address": self.address(),
|
||||
"use_ipv6": self.address.use_ipv6} if self.address else {}),
|
||||
|
@ -86,13 +86,10 @@ class Flow(stateobject.StateObject):
|
||||
intercepted=bool
|
||||
)
|
||||
|
||||
def get_state(self, short=False):
|
||||
d = super(Flow, self).get_state(short)
|
||||
def get_state(self):
|
||||
d = super(Flow, self).get_state()
|
||||
d.update(version=version.IVERSION)
|
||||
if self._backup and self._backup != d:
|
||||
if short:
|
||||
d.update(modified=True)
|
||||
else:
|
||||
d.update(backup=self._backup)
|
||||
return d
|
||||
|
||||
|
@ -6,36 +6,30 @@ import time
|
||||
|
||||
from libmproxy import utils
|
||||
from netlib import encoding
|
||||
from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded
|
||||
from netlib.http import status_codes, Headers, Request, Response, decoded
|
||||
from netlib.tcp import Address
|
||||
from .. import version, stateobject
|
||||
from .flow import Flow
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class MessageMixin(stateobject.StateObject):
|
||||
# The restoration order is important currently, e.g. because
|
||||
# of .content setting .headers["content-length"] automatically.
|
||||
# Using OrderedDict is the short term fix, restoring state should
|
||||
# be implemented without side-effects again.
|
||||
_stateobject_attributes = OrderedDict(
|
||||
http_version=bytes,
|
||||
headers=Headers,
|
||||
timestamp_start=float,
|
||||
timestamp_end=float
|
||||
)
|
||||
_stateobject_long_attributes = {"body"}
|
||||
|
||||
def get_state(self, short=False):
|
||||
ret = super(MessageMixin, self).get_state(short)
|
||||
if short:
|
||||
if self.content:
|
||||
ret["contentLength"] = len(self.content)
|
||||
elif self.content == CONTENT_MISSING:
|
||||
ret["contentLength"] = None
|
||||
else:
|
||||
ret["contentLength"] = 0
|
||||
return ret
|
||||
def get_state(self):
|
||||
state = vars(self.data).copy()
|
||||
state["headers"] = state["headers"].get_state()
|
||||
return state
|
||||
|
||||
def load_state(self, state):
|
||||
for k, v in state.items():
|
||||
if k == "headers":
|
||||
v = Headers.from_state(v)
|
||||
setattr(self.data, k, v)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
state["headers"] = Headers.from_state(state["headers"])
|
||||
return cls(**state)
|
||||
|
||||
def get_decoded_content(self):
|
||||
"""
|
||||
@ -141,6 +135,7 @@ class HTTPRequest(MessageMixin, Request):
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
form_out=None,
|
||||
is_replay=False,
|
||||
):
|
||||
Request.__init__(
|
||||
self,
|
||||
@ -163,37 +158,7 @@ class HTTPRequest(MessageMixin, Request):
|
||||
self.stickyauth = False
|
||||
|
||||
# Is this request replayed?
|
||||
self.is_replay = False
|
||||
|
||||
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
|
||||
_stateobject_attributes.update(
|
||||
content=bytes,
|
||||
first_line_format=str,
|
||||
method=bytes,
|
||||
scheme=bytes,
|
||||
host=bytes,
|
||||
port=int,
|
||||
path=bytes,
|
||||
form_out=str,
|
||||
is_replay=bool
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
f = cls(
|
||||
None,
|
||||
b"",
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None)
|
||||
f.load_state(state)
|
||||
return f
|
||||
self.is_replay = is_replay
|
||||
|
||||
@classmethod
|
||||
def from_protocol(
|
||||
@ -275,6 +240,7 @@ class HTTPResponse(MessageMixin, Response):
|
||||
content,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
is_replay = False
|
||||
):
|
||||
Response.__init__(
|
||||
self,
|
||||
@ -288,22 +254,9 @@ class HTTPResponse(MessageMixin, Response):
|
||||
)
|
||||
|
||||
# Is this request replayed?
|
||||
self.is_replay = False
|
||||
self.is_replay = is_replay
|
||||
self.stream = False
|
||||
|
||||
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
|
||||
_stateobject_attributes.update(
|
||||
body=bytes,
|
||||
status_code=int,
|
||||
msg=bytes
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
f = cls(None, None, None, None, None)
|
||||
f.load_state(state)
|
||||
return f
|
||||
|
||||
@classmethod
|
||||
def from_protocol(
|
||||
self,
|
||||
|
@ -13,24 +13,20 @@ class StateObject(object):
|
||||
# should be serialized. If the attribute is a class, it must implement the
|
||||
# StateObject protocol.
|
||||
_stateobject_attributes = None
|
||||
# A set() of attributes that should be ignored for short state
|
||||
_stateobject_long_attributes = frozenset([])
|
||||
|
||||
def from_state(self, state):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_state(self, short=False):
|
||||
def get_state(self):
|
||||
"""
|
||||
Retrieve object state. If short is true, return an abbreviated
|
||||
format with long data elided.
|
||||
"""
|
||||
state = {}
|
||||
for attr, cls in self._stateobject_attributes.iteritems():
|
||||
if short and attr in self._stateobject_long_attributes:
|
||||
continue
|
||||
val = getattr(self, attr)
|
||||
if hasattr(val, "get_state"):
|
||||
state[attr] = val.get_state(short)
|
||||
state[attr] = val.get_state()
|
||||
else:
|
||||
state[attr] = val
|
||||
return state
|
||||
|
@ -1,6 +1,6 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
|
||||
IVERSION = (0, 15)
|
||||
IVERSION = (0, 16)
|
||||
VERSION = ".".join(str(i) for i in IVERSION)
|
||||
MINORVERSION = ".".join(str(i) for i in IVERSION[:2])
|
||||
NAME = "mitmproxy"
|
||||
|
@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
|
||||
import collections
|
||||
import tornado.ioloop
|
||||
import tornado.httpserver
|
||||
|
||||
from .. import controller, flow
|
||||
from . import app
|
||||
|
||||
@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView):
|
||||
app.ClientConnection.broadcast(
|
||||
type="flows",
|
||||
cmd="add",
|
||||
data=f.get_state(short=True)
|
||||
data=app._strip_content(f.get_state())
|
||||
)
|
||||
|
||||
def _update(self, f):
|
||||
@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView):
|
||||
app.ClientConnection.broadcast(
|
||||
type="flows",
|
||||
cmd="update",
|
||||
data=f.get_state(short=True)
|
||||
data=app._strip_content(f.get_state())
|
||||
)
|
||||
|
||||
def _remove(self, f):
|
||||
|
@ -4,9 +4,38 @@ import tornado.web
|
||||
import tornado.websocket
|
||||
import logging
|
||||
import json
|
||||
|
||||
from netlib.http import CONTENT_MISSING
|
||||
from .. import version, filt
|
||||
|
||||
|
||||
def _strip_content(flow_state):
|
||||
"""
|
||||
Remove flow message content and cert to save transmission space.
|
||||
|
||||
Args:
|
||||
flow_state: The original flow state. Will be left unmodified
|
||||
"""
|
||||
for attr in ("request", "response"):
|
||||
if attr in flow_state:
|
||||
message = flow_state[attr]
|
||||
if message["content"]:
|
||||
message["contentLength"] = len(message["content"])
|
||||
elif message["content"] == CONTENT_MISSING:
|
||||
message["contentLength"] = None
|
||||
else:
|
||||
message["contentLength"] = 0
|
||||
del message["content"]
|
||||
|
||||
if "backup" in flow_state:
|
||||
del flow_state["backup"]
|
||||
flow_state["modified"] = True
|
||||
|
||||
flow_state.get("server_conn", {}).pop("cert", None)
|
||||
|
||||
return flow_state
|
||||
|
||||
|
||||
class APIError(tornado.web.HTTPError):
|
||||
pass
|
||||
|
||||
@ -100,7 +129,7 @@ class Flows(RequestHandler):
|
||||
|
||||
def get(self):
|
||||
self.write(dict(
|
||||
data=[f.get_state(short=True) for f in self.state.flows]
|
||||
data=[_strip_content(f.get_state()) for f in self.state.flows]
|
||||
))
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user