simplify state management

This commit is contained in:
Maximilian Hils 2016-02-08 02:10:10 +01:00
parent cd744592f6
commit e9934cc008
8 changed files with 77 additions and 89 deletions

View File

@ -21,9 +21,22 @@ def convert_014_015(data):
return data return data
def convert_015_016(data):
for m in ("request", "response"):
if "body" in data[m]:
data[m]["content"] = data[m].pop("body")
if "httpversion" in data[m]:
data[m]["http_version"] = data[m].pop("httpversion")
if "msg" in data["response"]:
data["response"]["reason"] = data["response"].pop("msg")
data["version"] = (0, 16)
return data
converters = { converters = {
(0, 13): convert_013_014, (0, 13): convert_013_014,
(0, 14): convert_014_015, (0, 14): convert_014_015,
(0, 15): convert_015_016,
} }

View File

@ -48,8 +48,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
timestamp_ssl_setup=float timestamp_ssl_setup=float
) )
def get_state(self, short=False): def get_state(self):
d = super(ClientConnection, self).get_state(short) d = super(ClientConnection, self).get_state()
d.update( d.update(
address=({ address=({
"address": self.address(), "address": self.address(),
@ -130,10 +130,9 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
ssl_established=bool, ssl_established=bool,
sni=str sni=str
) )
_stateobject_long_attributes = {"cert"}
def get_state(self, short=False): def get_state(self):
d = super(ServerConnection, self).get_state(short) d = super(ServerConnection, self).get_state()
d.update( d.update(
address=({"address": self.address(), address=({"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}), "use_ipv6": self.address.use_ipv6} if self.address else {}),

View File

@ -86,13 +86,10 @@ class Flow(stateobject.StateObject):
intercepted=bool intercepted=bool
) )
def get_state(self, short=False): def get_state(self):
d = super(Flow, self).get_state(short) d = super(Flow, self).get_state()
d.update(version=version.IVERSION) d.update(version=version.IVERSION)
if self._backup and self._backup != d: if self._backup and self._backup != d:
if short:
d.update(modified=True)
else:
d.update(backup=self._backup) d.update(backup=self._backup)
return d return d

View File

@ -6,36 +6,30 @@ import time
from libmproxy import utils from libmproxy import utils
from netlib import encoding from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address from netlib.tcp import Address
from .. import version, stateobject from .. import version, stateobject
from .flow import Flow from .flow import Flow
from collections import OrderedDict
class MessageMixin(stateobject.StateObject): class MessageMixin(stateobject.StateObject):
# The restoration order is important currently, e.g. because
# of .content setting .headers["content-length"] automatically.
# Using OrderedDict is the short term fix, restoring state should
# be implemented without side-effects again.
_stateobject_attributes = OrderedDict(
http_version=bytes,
headers=Headers,
timestamp_start=float,
timestamp_end=float
)
_stateobject_long_attributes = {"body"}
def get_state(self, short=False): def get_state(self):
ret = super(MessageMixin, self).get_state(short) state = vars(self.data).copy()
if short: state["headers"] = state["headers"].get_state()
if self.content: return state
ret["contentLength"] = len(self.content)
elif self.content == CONTENT_MISSING: def load_state(self, state):
ret["contentLength"] = None for k, v in state.items():
else: if k == "headers":
ret["contentLength"] = 0 v = Headers.from_state(v)
return ret setattr(self.data, k, v)
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
return cls(**state)
def get_decoded_content(self): def get_decoded_content(self):
""" """
@ -141,6 +135,7 @@ class HTTPRequest(MessageMixin, Request):
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
form_out=None, form_out=None,
is_replay=False,
): ):
Request.__init__( Request.__init__(
self, self,
@ -163,37 +158,7 @@ class HTTPRequest(MessageMixin, Request):
self.stickyauth = False self.stickyauth = False
# Is this request replayed? # Is this request replayed?
self.is_replay = False self.is_replay = is_replay
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
_stateobject_attributes.update(
content=bytes,
first_line_format=str,
method=bytes,
scheme=bytes,
host=bytes,
port=int,
path=bytes,
form_out=str,
is_replay=bool
)
@classmethod
def from_state(cls, state):
f = cls(
None,
b"",
None,
None,
None,
None,
None,
None,
None,
None,
None)
f.load_state(state)
return f
@classmethod @classmethod
def from_protocol( def from_protocol(
@ -275,6 +240,7 @@ class HTTPResponse(MessageMixin, Response):
content, content,
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
is_replay = False
): ):
Response.__init__( Response.__init__(
self, self,
@ -288,22 +254,9 @@ class HTTPResponse(MessageMixin, Response):
) )
# Is this request replayed? # Is this request replayed?
self.is_replay = False self.is_replay = is_replay
self.stream = False self.stream = False
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
_stateobject_attributes.update(
body=bytes,
status_code=int,
msg=bytes
)
@classmethod
def from_state(cls, state):
f = cls(None, None, None, None, None)
f.load_state(state)
return f
@classmethod @classmethod
def from_protocol( def from_protocol(
self, self,

View File

@ -13,24 +13,20 @@ class StateObject(object):
# should be serialized. If the attribute is a class, it must implement the # should be serialized. If the attribute is a class, it must implement the
# StateObject protocol. # StateObject protocol.
_stateobject_attributes = None _stateobject_attributes = None
# A set() of attributes that should be ignored for short state
_stateobject_long_attributes = frozenset([])
def from_state(self, state): def from_state(self, state):
raise NotImplementedError() raise NotImplementedError()
def get_state(self, short=False): def get_state(self):
""" """
Retrieve object state. If short is true, return an abbreviated Retrieve object state. If short is true, return an abbreviated
format with long data elided. format with long data elided.
""" """
state = {} state = {}
for attr, cls in self._stateobject_attributes.iteritems(): for attr, cls in self._stateobject_attributes.iteritems():
if short and attr in self._stateobject_long_attributes:
continue
val = getattr(self, attr) val = getattr(self, attr)
if hasattr(val, "get_state"): if hasattr(val, "get_state"):
state[attr] = val.get_state(short) state[attr] = val.get_state()
else: else:
state[attr] = val state[attr] = val
return state return state

View File

@ -1,6 +1,6 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
IVERSION = (0, 15) IVERSION = (0, 16)
VERSION = ".".join(str(i) for i in IVERSION) VERSION = ".".join(str(i) for i in IVERSION)
MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) MINORVERSION = ".".join(str(i) for i in IVERSION[:2])
NAME = "mitmproxy" NAME = "mitmproxy"

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import collections import collections
import tornado.ioloop import tornado.ioloop
import tornado.httpserver import tornado.httpserver
from .. import controller, flow from .. import controller, flow
from . import app from . import app
@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast( app.ClientConnection.broadcast(
type="flows", type="flows",
cmd="add", cmd="add",
data=f.get_state(short=True) data=app._strip_content(f.get_state())
) )
def _update(self, f): def _update(self, f):
@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast( app.ClientConnection.broadcast(
type="flows", type="flows",
cmd="update", cmd="update",
data=f.get_state(short=True) data=app._strip_content(f.get_state())
) )
def _remove(self, f): def _remove(self, f):

View File

@ -4,9 +4,38 @@ import tornado.web
import tornado.websocket import tornado.websocket
import logging import logging
import json import json
from netlib.http import CONTENT_MISSING
from .. import version, filt from .. import version, filt
def _strip_content(flow_state):
"""
Remove flow message content and cert to save transmission space.
Args:
flow_state: The original flow state. Will be left unmodified
"""
for attr in ("request", "response"):
if attr in flow_state:
message = flow_state[attr]
if message["content"]:
message["contentLength"] = len(message["content"])
elif message["content"] == CONTENT_MISSING:
message["contentLength"] = None
else:
message["contentLength"] = 0
del message["content"]
if "backup" in flow_state:
del flow_state["backup"]
flow_state["modified"] = True
flow_state.get("server_conn", {}).pop("cert", None)
return flow_state
class APIError(tornado.web.HTTPError): class APIError(tornado.web.HTTPError):
pass pass
@ -100,7 +129,7 @@ class Flows(RequestHandler):
def get(self): def get(self):
self.write(dict( self.write(dict(
data=[f.get_state(short=True) for f in self.state.flows] data=[_strip_content(f.get_state()) for f in self.state.flows]
)) ))