Merge pull request #921 from mitmproxy/model-cleanup

Model Cleanup
This commit is contained in:
Thomas Kriechbaumer 2016-02-08 11:41:30 +01:00
commit ec087a1960
10 changed files with 143 additions and 179 deletions

View File

@ -21,9 +21,23 @@ def convert_014_015(data):
return data return data
def convert_015_016(data):
for m in ("request", "response"):
if "body" in data[m]:
data[m]["content"] = data[m].pop("body")
if "httpversion" in data[m]:
data[m]["http_version"] = data[m].pop("httpversion")
if "msg" in data["response"]:
data["response"]["reason"] = data["response"].pop("msg")
data["request"].pop("form_out", None)
data["version"] = (0, 16)
return data
converters = { converters = {
(0, 13): convert_013_014, (0, 13): convert_013_014,
(0, 14): convert_014_015, (0, 14): convert_014_015,
(0, 15): convert_015_016,
} }

View File

@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return self.ssl_established return self.ssl_established
_stateobject_attributes = dict( _stateobject_attributes = dict(
address=tcp.Address,
clientcert=certutils.SSLCert,
ssl_established=bool, ssl_established=bool,
timestamp_start=float, timestamp_start=float,
timestamp_end=float, timestamp_end=float,
timestamp_ssl_setup=float timestamp_ssl_setup=float
) )
def get_state(self, short=False):
d = super(ClientConnection, self).get_state(short)
d.update(
address=({
"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
clientcert=self.cert.to_pem() if self.clientcert else None)
return d
def load_state(self, state):
super(ClientConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(
state["clientcert"]) if state["clientcert"] else None
def copy(self): def copy(self):
return copy.copy(self) return copy.copy(self)
@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, tuple(), None) f = cls(None, tuple(), None)
f.load_state(state) f.set_state(state)
return f return f
def convert_to_ssl(self, *args, **kwargs): def convert_to_ssl(self, *args, **kwargs):
@ -130,33 +116,11 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
ssl_established=bool, ssl_established=bool,
sni=str sni=str
) )
_stateobject_long_attributes = {"cert"}
def get_state(self, short=False):
d = super(ServerConnection, self).get_state(short)
d.update(
address=({"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
source_address=({"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
cert=self.cert.to_pem() if self.cert else None
)
return d
def load_state(self, state):
super(ServerConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.source_address = tcp.Address(
**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(
state["cert"]) if state["cert"] else None
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(tuple()) f = cls(tuple())
f.load_state(state) f.set_state(state)
return f return f
def copy(self): def copy(self):

View File

@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
# the default implementation assumes an empty constructor. Override # the default implementation assumes an empty constructor. Override
# accordingly. # accordingly.
f = cls(None) f = cls(None)
f.load_state(state) f.set_state(state)
return f return f
def copy(self): def copy(self):
@ -86,16 +86,19 @@ class Flow(stateobject.StateObject):
intercepted=bool intercepted=bool
) )
def get_state(self, short=False): def get_state(self):
d = super(Flow, self).get_state(short) d = super(Flow, self).get_state()
d.update(version=version.IVERSION) d.update(version=version.IVERSION)
if self._backup and self._backup != d: if self._backup and self._backup != d:
if short: d.update(backup=self._backup)
d.update(modified=True)
else:
d.update(backup=self._backup)
return d return d
def set_state(self, state):
state.pop("version")
if "backup" in state:
self._backup = state.pop("backup")
super(Flow, self).set_state(state)
def __eq__(self, other): def __eq__(self, other):
return self is other return self is other
@ -133,7 +136,7 @@ class Flow(stateobject.StateObject):
Revert to the last backed up state. Revert to the last backed up state.
""" """
if self._backup: if self._backup:
self.load_state(self._backup) self.set_state(self._backup)
self._backup = None self._backup = None
def kill(self, master): def kill(self, master):

View File

@ -1,41 +1,20 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
import Cookie import Cookie
import copy import copy
import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz from email.utils import parsedate_tz, formatdate, mktime_tz
import time import time
from libmproxy import utils from libmproxy import utils
from netlib import encoding from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address from netlib.tcp import Address
from .. import version, stateobject from .. import version
from .flow import Flow from .flow import Flow
from collections import OrderedDict
class MessageMixin(stateobject.StateObject):
# The restoration order is important currently, e.g. because
# of .content setting .headers["content-length"] automatically.
# Using OrderedDict is the short term fix, restoring state should
# be implemented without side-effects again.
_stateobject_attributes = OrderedDict(
http_version=bytes,
headers=Headers,
timestamp_start=float,
timestamp_end=float
)
_stateobject_long_attributes = {"body"}
def get_state(self, short=False): class MessageMixin(object):
ret = super(MessageMixin, self).get_state(short)
if short:
if self.content:
ret["contentLength"] = len(self.content)
elif self.content == CONTENT_MISSING:
ret["contentLength"] = None
else:
ret["contentLength"] = 0
return ret
def get_decoded_content(self): def get_decoded_content(self):
""" """
@ -141,6 +120,9 @@ class HTTPRequest(MessageMixin, Request):
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
form_out=None, form_out=None,
is_replay=False,
stickycookie=False,
stickyauth=False,
): ):
Request.__init__( Request.__init__(
self, self,
@ -159,51 +141,26 @@ class HTTPRequest(MessageMixin, Request):
self.form_out = form_out or first_line_format # FIXME remove self.form_out = form_out or first_line_format # FIXME remove
# Have this request's cookies been modified by sticky cookies or auth? # Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False self.stickycookie = stickycookie
self.stickyauth = False self.stickyauth = stickyauth
# Is this request replayed? # Is this request replayed?
self.is_replay = False self.is_replay = is_replay
_stateobject_attributes = MessageMixin._stateobject_attributes.copy() def get_state(self):
_stateobject_attributes.update( state = super(HTTPRequest, self).get_state()
content=bytes, state.update(
first_line_format=str, stickycookie = self.stickycookie,
method=bytes, stickyauth = self.stickyauth,
scheme=bytes, is_replay = self.is_replay,
host=bytes, )
port=int, return state
path=bytes,
form_out=str,
is_replay=bool
)
@classmethod def set_state(self, state):
def from_state(cls, state): self.stickycookie = state.pop("stickycookie")
f = cls( self.stickyauth = state.pop("stickyauth")
None, self.is_replay = state.pop("is_replay")
b"", super(HTTPRequest, self).set_state(state)
None,
None,
None,
None,
None,
None,
None,
None,
None)
f.load_state(state)
return f
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
req = protocol.read_request(*args, **kwargs)
return self.wrap(req)
@classmethod @classmethod
def wrap(self, request): def wrap(self, request):
@ -223,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
) )
return req return req
@property
def form_out(self):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_out.setter
def form_out(self, value):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
def __hash__(self): def __hash__(self):
return id(self) return id(self)
@ -275,6 +241,7 @@ class HTTPResponse(MessageMixin, Response):
content, content,
timestamp_start=None, timestamp_start=None,
timestamp_end=None, timestamp_end=None,
is_replay = False
): ):
Response.__init__( Response.__init__(
self, self,
@ -288,32 +255,9 @@ class HTTPResponse(MessageMixin, Response):
) )
# Is this request replayed? # Is this request replayed?
self.is_replay = False self.is_replay = is_replay
self.stream = False self.stream = False
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
_stateobject_attributes.update(
body=bytes,
status_code=int,
msg=bytes
)
@classmethod
def from_state(cls, state):
f = cls(None, None, None, None, None)
f.load_state(state)
return f
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
resp = protocol.read_response(*args, **kwargs)
return self.wrap(resp)
@classmethod @classmethod
def wrap(self, response): def wrap(self, response):
resp = HTTPResponse( resp = HTTPResponse(
@ -424,7 +368,7 @@ class HTTPFlow(Flow):
@classmethod @classmethod
def from_state(cls, state): def from_state(cls, state):
f = cls(None, None) f = cls(None, None)
f.load_state(state) f.set_state(state)
return f return f
def __repr__(self): def __repr__(self):

View File

@ -1,52 +1,51 @@
from __future__ import absolute_import from __future__ import absolute_import
from netlib.utils import Serializable
class StateObject(object): class StateObject(Serializable):
""" """
An object with serializable state. An object with serializable state.
State attributes can either be serializable types(str, tuple, bool, ...) State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves. or StateObject instances themselves.
""" """
# An attribute-name -> class-or-type dict containing all attributes that
# should be serialized. If the attribute is a class, it must implement the
# StateObject protocol.
_stateobject_attributes = None _stateobject_attributes = None
# A set() of attributes that should be ignored for short state """
_stateobject_long_attributes = frozenset([]) An attribute-name -> class-or-type dict containing all attributes that
should be serialized. If the attribute is a class, it must implement the
Serializable protocol.
"""
def from_state(self, state): def get_state(self):
raise NotImplementedError()
def get_state(self, short=False):
""" """
Retrieve object state. If short is true, return an abbreviated Retrieve object state.
format with long data elided.
""" """
state = {} state = {}
for attr, cls in self._stateobject_attributes.iteritems(): for attr, cls in self._stateobject_attributes.iteritems():
if short and attr in self._stateobject_long_attributes:
continue
val = getattr(self, attr) val = getattr(self, attr)
if hasattr(val, "get_state"): if hasattr(val, "get_state"):
state[attr] = val.get_state(short) state[attr] = val.get_state()
else: else:
state[attr] = val state[attr] = val
return state return state
def load_state(self, state): def set_state(self, state):
""" """
Load object state from data returned by a get_state call. Load object state from data returned by a get_state call.
""" """
state = state.copy()
for attr, cls in self._stateobject_attributes.iteritems(): for attr, cls in self._stateobject_attributes.iteritems():
if state.get(attr, None) is None: if state.get(attr) is None:
setattr(self, attr, None) setattr(self, attr, state.pop(attr))
else: else:
curr = getattr(self, attr) curr = getattr(self, attr)
if hasattr(curr, "load_state"): if hasattr(curr, "set_state"):
curr.load_state(state[attr]) curr.set_state(state.pop(attr))
elif hasattr(cls, "from_state"): elif hasattr(cls, "from_state"):
setattr(self, attr, cls.from_state(state[attr])) obj = cls.from_state(state.pop(attr))
else: setattr(self, attr, obj)
setattr(self, attr, cls(state[attr])) else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))

View File

@ -1,6 +1,6 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
IVERSION = (0, 15) IVERSION = (0, 16)
VERSION = ".".join(str(i) for i in IVERSION) VERSION = ".".join(str(i) for i in IVERSION)
MINORVERSION = ".".join(str(i) for i in IVERSION[:2]) MINORVERSION = ".".join(str(i) for i in IVERSION[:2])
NAME = "mitmproxy" NAME = "mitmproxy"

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import collections import collections
import tornado.ioloop import tornado.ioloop
import tornado.httpserver import tornado.httpserver
from .. import controller, flow from .. import controller, flow
from . import app from . import app
@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast( app.ClientConnection.broadcast(
type="flows", type="flows",
cmd="add", cmd="add",
data=f.get_state(short=True) data=app._strip_content(f.get_state())
) )
def _update(self, f): def _update(self, f):
@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast( app.ClientConnection.broadcast(
type="flows", type="flows",
cmd="update", cmd="update",
data=f.get_state(short=True) data=app._strip_content(f.get_state())
) )
def _remove(self, f): def _remove(self, f):

View File

@ -4,9 +4,38 @@ import tornado.web
import tornado.websocket import tornado.websocket
import logging import logging
import json import json
from netlib.http import CONTENT_MISSING
from .. import version, filt from .. import version, filt
def _strip_content(flow_state):
"""
Remove flow message content and cert to save transmission space.
Args:
flow_state: The original flow state. Will be left unmodified
"""
for attr in ("request", "response"):
if attr in flow_state:
message = flow_state[attr]
if message["content"]:
message["contentLength"] = len(message["content"])
elif message["content"] == CONTENT_MISSING:
message["contentLength"] = None
else:
message["contentLength"] = 0
del message["content"]
if "backup" in flow_state:
del flow_state["backup"]
flow_state["modified"] = True
flow_state.get("server_conn", {}).pop("cert", None)
return flow_state
class APIError(tornado.web.HTTPError): class APIError(tornado.web.HTTPError):
pass pass
@ -100,7 +129,7 @@ class Flows(RequestHandler):
def get(self): def get(self):
self.write(dict( self.write(dict(
data=[f.get_state(short=True) for f in self.state.flows] data=[_strip_content(f.get_state()) for f in self.state.flows]
)) ))
@ -141,7 +170,7 @@ class FlowHandler(RequestHandler):
elif k == "port": elif k == "port":
request.port = int(v) request.port = int(v)
elif k == "headers": elif k == "headers":
request.headers.load_state(v) request.headers.set_state(v)
else: else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v) print "Warning: Unknown update {}.{}: {}".format(a, k, v)
@ -155,7 +184,7 @@ class FlowHandler(RequestHandler):
elif k == "http_version": elif k == "http_version":
response.http_version = str(v) response.http_version = str(v)
elif k == "headers": elif k == "headers":
response.headers.load_state(v) response.headers.set_state(v)
else: else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v) print "Warning: Unknown update {}.{}: {}".format(a, k, v)
else: else:

View File

@ -422,7 +422,7 @@ class TestFlow(object):
assert not f == f2 assert not f == f2
f2.error = Error("e2") f2.error = Error("e2")
assert not f == f2 assert not f == f2
f.load_state(f2.get_state()) f.set_state(f2.get_state())
assert f.get_state() == f2.get_state() assert f.get_state() == f2.get_state()
def test_kill(self): def test_kill(self):
@ -1204,7 +1204,7 @@ class TestError:
e2 = Error("bar") e2 = Error("bar")
assert not e == e2 assert not e == e2
e.load_state(e2.get_state()) e.set_state(e2.get_state())
assert e.get_state() == e2.get_state() assert e.get_state() == e2.get_state()
e3 = e.copy() e3 = e.copy()
@ -1224,7 +1224,7 @@ class TestClientConnection:
assert not c == c2 assert not c == c2
c2.timestamp_start = 42 c2.timestamp_start = 42
c.load_state(c2.get_state()) c.set_state(c2.get_state())
assert c.timestamp_start == 42 assert c.timestamp_start == 42
c3 = c.copy() c3 = c.copy()

View File

@ -76,7 +76,11 @@ def tclient_conn():
""" """
c = ClientConnection.from_state(dict( c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=dict(address=("address", 22), use_ipv6=True),
clientcert=None clientcert=None,
ssl_established=False,
timestamp_start=1,
timestamp_ssl_setup=2,
timestamp_end=3,
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
return c return c
@ -88,9 +92,15 @@ def tserver_conn():
""" """
c = ServerConnection.from_state(dict( c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True), address=dict(address=("address", 22), use_ipv6=True),
state=[],
source_address=dict(address=("address", 22), use_ipv6=True), source_address=dict(address=("address", 22), use_ipv6=True),
cert=None cert=None,
timestamp_start=1,
timestamp_tcp_setup=2,
timestamp_ssl_setup=3,
timestamp_end=4,
ssl_established=False,
sni="address",
via=None
)) ))
c.reply = controller.DummyReply() c.reply = controller.DummyReply()
return c return c