Merge pull request #921 from mitmproxy/model-cleanup

Model Cleanup
This commit is contained in:
Thomas Kriechbaumer 2016-02-08 11:41:30 +01:00
commit ec087a1960
10 changed files with 143 additions and 179 deletions

View File

@ -21,9 +21,23 @@ def convert_014_015(data):
return data
def convert_015_016(data):
for m in ("request", "response"):
if "body" in data[m]:
data[m]["content"] = data[m].pop("body")
if "httpversion" in data[m]:
data[m]["http_version"] = data[m].pop("httpversion")
if "msg" in data["response"]:
data["response"]["reason"] = data["response"].pop("msg")
data["request"].pop("form_out", None)
data["version"] = (0, 16)
return data
converters = {
(0, 13): convert_013_014,
(0, 14): convert_014_015,
(0, 15): convert_015_016,
}

View File

@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return self.ssl_established
_stateobject_attributes = dict(
address=tcp.Address,
clientcert=certutils.SSLCert,
ssl_established=bool,
timestamp_start=float,
timestamp_end=float,
timestamp_ssl_setup=float
)
def get_state(self, short=False):
d = super(ClientConnection, self).get_state(short)
d.update(
address=({
"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
clientcert=self.cert.to_pem() if self.clientcert else None)
return d
def load_state(self, state):
super(ClientConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(
state["clientcert"]) if state["clientcert"] else None
def copy(self):
return copy.copy(self)
@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod
def from_state(cls, state):
f = cls(None, tuple(), None)
f.load_state(state)
f.set_state(state)
return f
def convert_to_ssl(self, *args, **kwargs):
@ -130,33 +116,11 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
ssl_established=bool,
sni=str
)
_stateobject_long_attributes = {"cert"}
def get_state(self, short=False):
d = super(ServerConnection, self).get_state(short)
d.update(
address=({"address": self.address(),
"use_ipv6": self.address.use_ipv6} if self.address else {}),
source_address=({"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
cert=self.cert.to_pem() if self.cert else None
)
return d
def load_state(self, state):
super(ServerConnection, self).load_state(state)
self.address = tcp.Address(
**state["address"]) if state["address"] else None
self.source_address = tcp.Address(
**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(
state["cert"]) if state["cert"] else None
@classmethod
def from_state(cls, state):
f = cls(tuple())
f.load_state(state)
f.set_state(state)
return f
def copy(self):

View File

@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
# the default implementation assumes an empty constructor. Override
# accordingly.
f = cls(None)
f.load_state(state)
f.set_state(state)
return f
def copy(self):
@ -86,16 +86,19 @@ class Flow(stateobject.StateObject):
intercepted=bool
)
def get_state(self, short=False):
d = super(Flow, self).get_state(short)
def get_state(self):
d = super(Flow, self).get_state()
d.update(version=version.IVERSION)
if self._backup and self._backup != d:
if short:
d.update(modified=True)
else:
d.update(backup=self._backup)
d.update(backup=self._backup)
return d
def set_state(self, state):
state.pop("version")
if "backup" in state:
self._backup = state.pop("backup")
super(Flow, self).set_state(state)
def __eq__(self, other):
return self is other
@ -133,7 +136,7 @@ class Flow(stateobject.StateObject):
Revert to the last backed up state.
"""
if self._backup:
self.load_state(self._backup)
self.set_state(self._backup)
self._backup = None
def kill(self, master):

View File

@ -1,41 +1,20 @@
from __future__ import (absolute_import, print_function, division)
import Cookie
import copy
import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
from libmproxy import utils
from netlib import encoding
from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded
from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address
from .. import version, stateobject
from .. import version
from .flow import Flow
from collections import OrderedDict
class MessageMixin(stateobject.StateObject):
# The restoration order is important currently, e.g. because
# of .content setting .headers["content-length"] automatically.
# Using OrderedDict is the short term fix, restoring state should
# be implemented without side-effects again.
_stateobject_attributes = OrderedDict(
http_version=bytes,
headers=Headers,
timestamp_start=float,
timestamp_end=float
)
_stateobject_long_attributes = {"body"}
def get_state(self, short=False):
ret = super(MessageMixin, self).get_state(short)
if short:
if self.content:
ret["contentLength"] = len(self.content)
elif self.content == CONTENT_MISSING:
ret["contentLength"] = None
else:
ret["contentLength"] = 0
return ret
class MessageMixin(object):
def get_decoded_content(self):
"""
@ -141,6 +120,9 @@ class HTTPRequest(MessageMixin, Request):
timestamp_start=None,
timestamp_end=None,
form_out=None,
is_replay=False,
stickycookie=False,
stickyauth=False,
):
Request.__init__(
self,
@ -159,51 +141,26 @@ class HTTPRequest(MessageMixin, Request):
self.form_out = form_out or first_line_format # FIXME remove
# Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False
self.stickyauth = False
self.stickycookie = stickycookie
self.stickyauth = stickyauth
# Is this request replayed?
self.is_replay = False
self.is_replay = is_replay
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
_stateobject_attributes.update(
content=bytes,
first_line_format=str,
method=bytes,
scheme=bytes,
host=bytes,
port=int,
path=bytes,
form_out=str,
is_replay=bool
)
def get_state(self):
state = super(HTTPRequest, self).get_state()
state.update(
stickycookie = self.stickycookie,
stickyauth = self.stickyauth,
is_replay = self.is_replay,
)
return state
@classmethod
def from_state(cls, state):
f = cls(
None,
b"",
None,
None,
None,
None,
None,
None,
None,
None,
None)
f.load_state(state)
return f
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
req = protocol.read_request(*args, **kwargs)
return self.wrap(req)
def set_state(self, state):
self.stickycookie = state.pop("stickycookie")
self.stickyauth = state.pop("stickyauth")
self.is_replay = state.pop("is_replay")
super(HTTPRequest, self).set_state(state)
@classmethod
def wrap(self, request):
@ -223,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
)
return req
@property
def form_out(self):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
return self.first_line_format
@form_out.setter
def form_out(self, value):
warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
def __hash__(self):
return id(self)
@ -275,6 +241,7 @@ class HTTPResponse(MessageMixin, Response):
content,
timestamp_start=None,
timestamp_end=None,
is_replay = False
):
Response.__init__(
self,
@ -288,32 +255,9 @@ class HTTPResponse(MessageMixin, Response):
)
# Is this request replayed?
self.is_replay = False
self.is_replay = is_replay
self.stream = False
_stateobject_attributes = MessageMixin._stateobject_attributes.copy()
_stateobject_attributes.update(
body=bytes,
status_code=int,
msg=bytes
)
@classmethod
def from_state(cls, state):
f = cls(None, None, None, None, None)
f.load_state(state)
return f
@classmethod
def from_protocol(
self,
protocol,
*args,
**kwargs
):
resp = protocol.read_response(*args, **kwargs)
return self.wrap(resp)
@classmethod
def wrap(self, response):
resp = HTTPResponse(
@ -424,7 +368,7 @@ class HTTPFlow(Flow):
@classmethod
def from_state(cls, state):
f = cls(None, None)
f.load_state(state)
f.set_state(state)
return f
def __repr__(self):

View File

@ -1,52 +1,51 @@
from __future__ import absolute_import
from netlib.utils import Serializable
class StateObject(object):
class StateObject(Serializable):
"""
An object with serializable state.
An object with serializable state.
State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves.
State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves.
"""
# An attribute-name -> class-or-type dict containing all attributes that
# should be serialized. If the attribute is a class, it must implement the
# StateObject protocol.
_stateobject_attributes = None
# A set() of attributes that should be ignored for short state
_stateobject_long_attributes = frozenset([])
"""
An attribute-name -> class-or-type dict containing all attributes that
should be serialized. If the attribute is a class, it must implement the
Serializable protocol.
"""
def from_state(self, state):
raise NotImplementedError()
def get_state(self, short=False):
def get_state(self):
"""
Retrieve object state. If short is true, return an abbreviated
format with long data elided.
Retrieve object state.
"""
state = {}
for attr, cls in self._stateobject_attributes.iteritems():
if short and attr in self._stateobject_long_attributes:
continue
val = getattr(self, attr)
if hasattr(val, "get_state"):
state[attr] = val.get_state(short)
state[attr] = val.get_state()
else:
state[attr] = val
return state
def load_state(self, state):
def set_state(self, state):
"""
Load object state from data returned by a get_state call.
Load object state from data returned by a get_state call.
"""
state = state.copy()
for attr, cls in self._stateobject_attributes.iteritems():
if state.get(attr, None) is None:
setattr(self, attr, None)
if state.get(attr) is None:
setattr(self, attr, state.pop(attr))
else:
curr = getattr(self, attr)
if hasattr(curr, "load_state"):
curr.load_state(state[attr])
if hasattr(curr, "set_state"):
curr.set_state(state.pop(attr))
elif hasattr(cls, "from_state"):
setattr(self, attr, cls.from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
obj = cls.from_state(state.pop(attr))
setattr(self, attr, obj)
else: # primitive types such as int, str, ...
setattr(self, attr, cls(state.pop(attr)))
if state:
raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))

View File

@ -1,6 +1,6 @@
from __future__ import (absolute_import, print_function, division)
IVERSION = (0, 15)
IVERSION = (0, 16)
VERSION = ".".join(str(i) for i in IVERSION)
MINORVERSION = ".".join(str(i) for i in IVERSION[:2])
NAME = "mitmproxy"

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import collections
import tornado.ioloop
import tornado.httpserver
from .. import controller, flow
from . import app
@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast(
type="flows",
cmd="add",
data=f.get_state(short=True)
data=app._strip_content(f.get_state())
)
def _update(self, f):
@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast(
type="flows",
cmd="update",
data=f.get_state(short=True)
data=app._strip_content(f.get_state())
)
def _remove(self, f):

View File

@ -4,9 +4,38 @@ import tornado.web
import tornado.websocket
import logging
import json
from netlib.http import CONTENT_MISSING
from .. import version, filt
def _strip_content(flow_state):
"""
Remove flow message content and cert to save transmission space.
Args:
flow_state: The original flow state. Will be left unmodified
"""
for attr in ("request", "response"):
if attr in flow_state:
message = flow_state[attr]
if message["content"]:
message["contentLength"] = len(message["content"])
elif message["content"] == CONTENT_MISSING:
message["contentLength"] = None
else:
message["contentLength"] = 0
del message["content"]
if "backup" in flow_state:
del flow_state["backup"]
flow_state["modified"] = True
flow_state.get("server_conn", {}).pop("cert", None)
return flow_state
class APIError(tornado.web.HTTPError):
pass
@ -100,7 +129,7 @@ class Flows(RequestHandler):
def get(self):
self.write(dict(
data=[f.get_state(short=True) for f in self.state.flows]
data=[_strip_content(f.get_state()) for f in self.state.flows]
))
@ -141,7 +170,7 @@ class FlowHandler(RequestHandler):
elif k == "port":
request.port = int(v)
elif k == "headers":
request.headers.load_state(v)
request.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
@ -155,7 +184,7 @@ class FlowHandler(RequestHandler):
elif k == "http_version":
response.http_version = str(v)
elif k == "headers":
response.headers.load_state(v)
response.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
else:

View File

@ -422,7 +422,7 @@ class TestFlow(object):
assert not f == f2
f2.error = Error("e2")
assert not f == f2
f.load_state(f2.get_state())
f.set_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
@ -1204,7 +1204,7 @@ class TestError:
e2 = Error("bar")
assert not e == e2
e.load_state(e2.get_state())
e.set_state(e2.get_state())
assert e.get_state() == e2.get_state()
e3 = e.copy()
@ -1224,7 +1224,7 @@ class TestClientConnection:
assert not c == c2
c2.timestamp_start = 42
c.load_state(c2.get_state())
c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()

View File

@ -76,7 +76,11 @@ def tclient_conn():
"""
c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
clientcert=None
clientcert=None,
ssl_established=False,
timestamp_start=1,
timestamp_ssl_setup=2,
timestamp_end=3,
))
c.reply = controller.DummyReply()
return c
@ -88,9 +92,15 @@ def tserver_conn():
"""
c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
state=[],
source_address=dict(address=("address", 22), use_ipv6=True),
cert=None
cert=None,
timestamp_start=1,
timestamp_tcp_setup=2,
timestamp_ssl_setup=3,
timestamp_end=4,
ssl_established=False,
sni="address",
via=None
))
c.reply = controller.DummyReply()
return c