Clean up and clarify StateObject

- Flatten the class hierarchy
- get_state, load_state, from_state are public
- Simplify code
- Remove __eq__ and __neq__. This fundamentally changes the semantics of
inherited objects in a way that's not part of the core function of the
class
This commit is contained in:
Aldo Cortesi 2014-09-17 11:35:14 +12:00
parent b9531ac89b
commit d998790c2f
10 changed files with 123 additions and 144 deletions

View File

@ -12,7 +12,7 @@ with open("logfile", "rb") as logfile:
for f in freader.stream():
print(f)
print(f.request.host)
json.dump(f._get_state(), sys.stdout, indent=4)
json.dump(f.get_state(), sys.stdout, indent=4)
print ""
except flow.FlowReadError, v:
print "Flow file corrupted. Stopped loading."

View File

@ -561,7 +561,7 @@ class FlowMaster(controller.Master):
rflow = self.server_playback.next_flow(flow)
if not rflow:
return None
response = http.HTTPResponse._from_state(rflow.response._get_state())
response = http.HTTPResponse.from_state(rflow.response.get_state())
response.is_replay = True
if self.refresh_server_playback:
response.refresh()
@ -740,7 +740,7 @@ class FlowWriter:
self.fo = fo
def add(self, flow):
d = flow._get_state()
d = flow.get_state()
tnetstring.dump(d, self.fo)
@ -766,7 +766,7 @@ class FlowReader:
v = ".".join(str(i) for i in data["version"])
raise FlowReadError("Incompatible serialized data version: %s"%v)
off = self.fo.tell()
yield handle.protocols[data["conntype"]]["flow"]._from_state(data)
yield handle.protocols[data["conntype"]]["flow"].from_state(data)
except ValueError, v:
# Error is due to EOF
if self.fo.tell() == off and self.fo.read() == '':
@ -782,5 +782,5 @@ class FilteredFlowWriter:
def add(self, f):
if self.filt and not f.match(self.filt):
return
d = f._get_state()
d = f.get_state()
tnetstring.dump(d, self.fo)

View File

@ -85,7 +85,7 @@ class decoded(object):
self.o.encode(self.ce)
class HTTPMessage(stateobject.SimpleStateObject):
class HTTPMessage(stateobject.StateObject):
"""
Base class for HTTPRequest and HTTPResponse
"""
@ -275,9 +275,9 @@ class HTTPRequest(HTTPMessage):
)
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
f = cls(None, None, None, None, None, None, None, None, None, None, None)
f._load_state(state)
f.load_state(state)
return f
def __repr__(self):
@ -626,9 +626,9 @@ class HTTPResponse(HTTPMessage):
)
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
f = cls(None, None, None, None, None)
f._load_state(state)
f.load_state(state)
return f
def __repr__(self):
@ -814,9 +814,9 @@ class HTTPFlow(Flow):
)
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
f = cls(None, None)
f._load_state(state)
f.load_state(state)
return f
def __repr__(self):

View File

@ -8,7 +8,7 @@ from ..proxy.connection import ClientConnection, ServerConnection
KILL = 0 # const for killed requests
class Error(stateobject.SimpleStateObject):
class Error(stateobject.StateObject):
"""
An Error.
@ -41,11 +41,11 @@ class Error(stateobject.SimpleStateObject):
return self.msg
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
# the default implementation assumes an empty constructor. Override
# accordingly.
f = cls(None)
f._load_state(state)
f.load_state(state)
return f
def copy(self):
@ -53,7 +53,7 @@ class Error(stateobject.SimpleStateObject):
return c
class Flow(stateobject.SimpleStateObject):
class Flow(stateobject.StateObject):
"""
A Flow is a collection of objects representing a single transaction.
This class is usually subclassed for each protocol, e.g. HTTPFlow.
@ -78,8 +78,8 @@ class Flow(stateobject.SimpleStateObject):
conntype=str
)
def _get_state(self):
d = super(Flow, self)._get_state()
def get_state(self):
d = super(Flow, self).get_state()
d.update(version=version.IVERSION)
return d
@ -101,7 +101,7 @@ class Flow(stateobject.SimpleStateObject):
Has this Flow been modified?
"""
if self._backup:
return self._backup != self._get_state()
return self._backup != self.get_state()
else:
return False
@ -111,14 +111,14 @@ class Flow(stateobject.SimpleStateObject):
call to .revert().
"""
if not self._backup:
self._backup = self._get_state()
self._backup = self.get_state()
def revert(self):
"""
Revert to the last backed up state.
"""
if self._backup:
self._load_state(self._backup)
self.load_state(self._backup)
self._backup = None

View File

@ -1,8 +1,10 @@
from __future__ import absolute_import
import select, socket
import select
import socket
from .primitives import ProtocolHandler
from netlib.utils import cleanBin
class TCPHandler(ProtocolHandler):
"""
TCPHandler acts as a generic TCP forwarder.
@ -34,7 +36,9 @@ class TCPHandler(ProtocolHandler):
closed = False
if src.ssl_established:
# Unfortunately, pyOpenSSL lacks a recv_into function.
contents = src.rfile.read(1) # We need to read a single byte before .pending() becomes usable
# We need to read a single byte before .pending()
# becomes usable
contents = src.rfile.read(1)
contents += src.rfile.read(src.connection.pending())
if not contents:
closed = True
@ -56,14 +60,29 @@ class TCPHandler(ProtocolHandler):
continue
if src.ssl_established or dst.ssl_established:
# if one of the peers is over SSL, we need to send bytes/strings
if not src.ssl_established: # only ssl to dst, i.e. we revc'd into buf but need bytes/string now.
# if one of the peers is over SSL, we need to send
# bytes/strings
if not src.ssl_established:
# only ssl to dst, i.e. we revc'd into buf but need
# bytes/string now.
contents = buf[:size].tobytes()
self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(contents)), "debug")
self.c.log(
"%s %s\r\n%s" % (
direction, dst_str, cleanBin(contents)
),
"debug"
)
dst.connection.send(contents)
else:
# socket.socket.send supports raw bytearrays/memoryviews
self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(buf.tobytes())), "debug")
self.c.log(
"%s %s\r\n%s" % (
direction,
dst_str,
cleanBin(buf.tobytes())
),
"debug"
)
dst.connection.send(buf[:size])
except socket.error as e:
self.c.log("TCP connection closed unexpectedly.", "debug")

View File

@ -5,7 +5,7 @@ from netlib import tcp, certutils
from .. import stateobject, utils
class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
def __init__(self, client_connection, address, server):
if client_connection: # Eventually, this object is restored from state. We don't have a connection then.
tcp.BaseHandler.__init__(self, client_connection, address, server)
@ -36,16 +36,16 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
timestamp_ssl_setup=float
)
def _get_state(self):
d = super(ClientConnection, self)._get_state()
def get_state(self):
d = super(ClientConnection, self).get_state()
d.update(
address={"address": self.address(), "use_ipv6": self.address.use_ipv6},
clientcert=self.cert.to_pem() if self.clientcert else None
)
return d
def _load_state(self, state):
super(ClientConnection, self)._load_state(state)
def load_state(self, state):
super(ClientConnection, self).load_state(state)
self.address = tcp.Address(**state["address"]) if state["address"] else None
self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None
@ -57,9 +57,9 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
self.wfile.flush()
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
f = cls(None, tuple(), None)
f._load_state(state)
f.load_state(state)
return f
def convert_to_ssl(self, *args, **kwargs):
@ -71,7 +71,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
self.timestamp_end = utils.timestamp()
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
class ServerConnection(tcp.TCPClient, stateobject.StateObject):
def __init__(self, address):
tcp.TCPClient.__init__(self, address)
@ -107,8 +107,8 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
sni=str
)
def _get_state(self):
d = super(ServerConnection, self)._get_state()
def get_state(self):
d = super(ServerConnection, self).get_state()
d.update(
address={"address": self.address(),
"use_ipv6": self.address.use_ipv6},
@ -118,17 +118,17 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
)
return d
def _load_state(self, state):
super(ServerConnection, self)._load_state(state)
def load_state(self, state):
super(ServerConnection, self).load_state(state)
self.address = tcp.Address(**state["address"]) if state["address"] else None
self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None
self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None
@classmethod
def _from_state(cls, state):
def from_state(cls, state):
f = cls(tuple())
f._load_state(state)
f.load_state(state)
return f
def copy(self):

View File

@ -2,82 +2,42 @@ from __future__ import absolute_import
class StateObject(object):
def _get_state(self):
raise NotImplementedError # pragma: nocover
def _load_state(self, state):
raise NotImplementedError # pragma: nocover
@classmethod
def _from_state(cls, state):
raise NotImplementedError # pragma: nocover
# Usually, this function roughly equals to the following code:
# f = cls()
# f._load_state(state)
# return f
def __eq__(self, other):
try:
return self._get_state() == other._get_state()
except AttributeError:
# we may compare with something that's not a StateObject
return False
def __ne__(self, other):
return not self.__eq__(other)
class SimpleStateObject(StateObject):
"""
A StateObject with opionated conventions that tries to keep everything DRY.
An object with serializable state.
Simply put, you agree on a list of attributes and their type. Attributes can
either be primitive types(str, tuple, bool, ...) or StateObject instances
themselves. SimpleStateObject uses this information for the default
_get_state(), _from_state(s) and _load_state(s) methods. Overriding
_get_state or _load_state to add custom adjustments is always possible.
State attributes can either be serializable types(str, tuple, bool, ...)
or StateObject instances themselves.
"""
_stateobject_attributes = None # none by default to raise an exception if definition was forgotten
"""
An attribute-name -> class-or-type dict containing all attributes that
should be serialized If the attribute is a class, this class must be a
subclass of StateObject.
"""
def _get_state(self):
return {attr: self._get_state_attr(attr, cls)
for attr, cls in self._stateobject_attributes.iteritems()}
# An attribute-name -> class-or-type dict containing all attributes that
# should be serialized. If the attribute is a class, it must be a subclass
# of StateObject.
_stateobject_attributes = None
def _get_state_attr(self, attr, cls):
"""
helper for _get_state.
returns the value of the given attribute
"""
val = getattr(self, attr)
if hasattr(val, "_get_state"):
return val._get_state()
if hasattr(val, "get_state"):
return val.get_state()
else:
return val
def _load_state(self, state):
def from_state(self):
raise NotImplementedError
def get_state(self):
state = {}
for attr, cls in self._stateobject_attributes.iteritems():
self._load_state_attr(attr, cls, state)
def _load_state_attr(self, attr, cls, state):
"""
helper for _load_state.
loads the given attribute from the state.
"""
if state.get(attr, None) is None:
setattr(self, attr, None)
return
curr = getattr(self, attr)
if hasattr(curr, "_load_state"):
curr._load_state(state[attr])
elif hasattr(cls, "_from_state"):
setattr(self, attr, cls._from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))
state[attr] = self._get_state_attr(attr, cls)
return state
def load_state(self, state):
for attr, cls in self._stateobject_attributes.iteritems():
if state.get(attr, None) is None:
setattr(self, attr, None)
else:
curr = getattr(self, attr)
if hasattr(curr, "load_state"):
curr.load_state(state[attr])
elif hasattr(cls, "from_state"):
setattr(self, attr, cls.from_state(state[attr]))
else:
setattr(self, attr, cls(state[attr]))

View File

@ -1,8 +1,8 @@
import tornado.ioloop
import tornado.httpserver
from .. import controller, utils, flow, script, proxy
import app
import pprint
class Stop(Exception):
@ -81,7 +81,7 @@ class WebMaster(flow.FlowMaster):
self.shutdown()
def handle_request(self, f):
print f
pprint.pprint(f.get_state())
flow.FlowMaster.handle_request(self, f)
if f:
f.reply()

View File

@ -175,18 +175,18 @@ class TestServerPlaybackState:
class TestFlow:
def test_copy(self):
f = tutils.tflow(resp=True)
a0 = f._get_state()
a0 = f.get_state()
f2 = f.copy()
a = f._get_state()
b = f2._get_state()
assert f._get_state() == f2._get_state()
a = f.get_state()
b = f2.get_state()
assert f.get_state() == f2.get_state()
assert not f == f2
assert not f is f2
assert f.request == f2.request
assert f.request.get_state() == f2.request.get_state()
assert not f.request is f2.request
assert f.request.headers == f2.request.headers
assert not f.request.headers is f2.request.headers
assert f.response == f2.response
assert f.response.get_state() == f2.response.get_state()
assert not f.response is f2.response
f = tutils.tflow(err=True)
@ -195,7 +195,7 @@ class TestFlow:
assert not f.request is f2.request
assert f.request.headers == f2.request.headers
assert not f.request.headers is f2.request.headers
assert f.error == f2.error
assert f.error.get_state() == f2.error.get_state()
assert not f.error is f2.error
def test_match(self):
@ -229,21 +229,21 @@ class TestFlow:
def test_getset_state(self):
f = tutils.tflow(resp=True)
state = f._get_state()
assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state()
state = f.get_state()
assert f.get_state() == protocol.http.HTTPFlow.from_state(state).get_state()
f.response = None
f.error = Error("error")
state = f._get_state()
assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state()
state = f.get_state()
assert f.get_state() == protocol.http.HTTPFlow.from_state(state).get_state()
f2 = f.copy()
assert f._get_state() == f2._get_state()
assert f.get_state() == f2.get_state()
assert not f == f2
f2.error = Error("e2")
assert not f == f2
f._load_state(f2._get_state())
assert f._get_state() == f2._get_state()
f.load_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
s = flow.State()
@ -481,7 +481,7 @@ class TestSerialize:
assert len(l) == 1
f2 = l[0]
assert f2._get_state() == f._get_state()
assert f2.get_state() == f.get_state()
assert f2.request.assemble() == f.request.assemble()
def test_load_flows(self):
@ -521,7 +521,7 @@ class TestSerialize:
def test_versioncheck(self):
f = tutils.tflow()
d = f._get_state()
d = f.get_state()
d["version"] = (0, 0)
sio = StringIO()
tnetstring.dump(d, sio)
@ -770,7 +770,7 @@ class TestRequest:
assert r.size() == len(r.assemble())
r2 = r.copy()
assert r == r2
assert r.get_state() == r2.get_state()
r.content = None
assert r.assemble()
@ -979,7 +979,7 @@ class TestResponse:
assert resp.size() == len(resp.assemble())
resp2 = resp.copy()
assert resp2 == resp
assert resp2.get_state() == resp.get_state()
resp.content = None
assert resp.assemble()
@ -1122,37 +1122,37 @@ class TestResponse:
class TestError:
def test_getset_state(self):
e = Error("Error")
state = e._get_state()
assert Error._from_state(state) == e
state = e.get_state()
assert Error.from_state(state).get_state() == e.get_state()
assert e.copy()
e2 = Error("bar")
assert not e == e2
e._load_state(e2._get_state())
assert e == e2
e.load_state(e2.get_state())
assert e.get_state() == e2.get_state()
e3 = e.copy()
assert e3 == e
assert e3.get_state() == e.get_state()
class TestClientConnection:
def test_state(self):
c = tutils.tclient_conn()
assert ClientConnection._from_state(c._get_state()) == c
assert ClientConnection.from_state(c.get_state()).get_state() ==\
c.get_state()
c2 = tutils.tclient_conn()
c2.address.address = (c2.address.host, 4242)
assert not c == c2
c2.timestamp_start = 42
c._load_state(c2._get_state())
c.load_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()
assert c3 == c
assert c3.get_state() == c.get_state()
assert str(c)

View File

@ -55,7 +55,7 @@ def tclient_conn():
"""
@return: libmproxy.proxy.connection.ClientConnection
"""
c = ClientConnection._from_state(dict(
c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
clientcert=None
))
@ -67,7 +67,7 @@ def tserver_conn():
"""
@return: libmproxy.proxy.connection.ServerConnection
"""
c = ServerConnection._from_state(dict(
c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
state=[],
source_address=dict(address=("address", 22), use_ipv6=True),