Significantly refactor the master/slave message passing interface.

This commit is contained in:
Aldo Cortesi 2013-02-17 12:42:48 +13:00
parent 1ccb2c5dea
commit aaf892e3af
10 changed files with 158 additions and 107 deletions

View File

@ -580,7 +580,7 @@ class ConsoleMaster(flow.FlowMaster):
self.view_flowlist()
self.server.start_slave(controller.Slave, self.masterq)
self.server.start_slave(controller.Slave, controller.Channel(self.masterq))
if self.options.rfile:
ret = self.load_flows(self.options.rfile)
@ -1002,7 +1002,7 @@ class ConsoleMaster(flow.FlowMaster):
if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay():
f.intercept()
else:
r._ack()
r.reply()
self.sync_list_view()
self.refresh_flow(f)
@ -1023,7 +1023,7 @@ class ConsoleMaster(flow.FlowMaster):
# Handlers
def handle_log(self, l):
self.add_event(l.msg)
l._ack()
l.reply()
def handle_error(self, r):
f = flow.FlowMaster.handle_error(self, r)

View File

@ -184,7 +184,7 @@ def format_flow(f, focus, extended=False, padding=2):
req_timestamp = f.request.timestamp_start,
req_is_replay = f.request.is_replay(),
req_method = f.request.method,
req_acked = f.request.acked,
req_acked = f.request.reply.acked,
req_url = f.request.get_url(),
err_msg = f.error.msg if f.error else None,
@ -200,7 +200,7 @@ def format_flow(f, focus, extended=False, padding=2):
d.update(dict(
resp_code = f.response.code,
resp_is_replay = f.response.is_replay(),
resp_acked = f.response.acked,
resp_acked = f.response.reply.acked,
resp_clen = contentdesc
))
t = f.response.headers["content-type"]

View File

@ -17,37 +17,73 @@ import Queue, threading
should_exit = False
class Msg:
class DummyReply:
"""
A reply object that does nothing. Useful when we need an object to seem
like it has a channel, and during testing.
"""
def __init__(self):
self.acked = False
def __call__(self, msg=False):
self.acked = True
class Reply:
"""
Messages sent through a channel are decorated with a "reply" attribute.
This object is used to respond to the message through the return
channel.
"""
def __init__(self, obj):
self.obj = obj
self.q = Queue.Queue()
self.acked = False
def _ack(self, data=False):
def __call__(self, msg=False):
if not self.acked:
self.acked = True
if data is None:
self.q.put(data)
if msg is None:
self.q.put(msg)
else:
self.q.put(data or self)
self.q.put(msg or self.obj)
def _send(self, masterq):
self.acked = False
try:
masterq.put(self, timeout=3)
while not should_exit: # pragma: no cover
try:
g = self.q.get(timeout=0.5)
except Queue.Empty:
continue
return g
except (Queue.Empty, Queue.Full): # pragma: no cover
return None
class Channel:
def __init__(self, q):
self.q = q
def ask(self, m):
"""
Send a message to the master, and wait for a response.
"""
m.reply = Reply(m)
self.q.put(m)
while not should_exit:
try:
# The timeout is here so we can handle a should_exit event.
g = m.reply.q.get(timeout=0.5)
except Queue.Empty:
continue
return g
def tell(self, m):
"""
Send a message to the master, and keep going.
"""
m.reply = None
self.q.put(m)
class Slave(threading.Thread):
def __init__(self, masterq, server):
self.masterq, self.server = masterq, server
self.server.set_mqueue(masterq)
"""
Slaves get a channel end-point through which they can send messages to
the master.
"""
def __init__(self, channel, server):
self.channel, self.server = channel, server
self.server.set_channel(channel)
threading.Thread.__init__(self)
def run(self):
@ -55,6 +91,9 @@ class Slave(threading.Thread):
class Master:
"""
Masters get and respond to messages from slaves.
"""
def __init__(self, server):
"""
server may be None if no server is needed.
@ -81,18 +120,18 @@ class Master:
def run(self):
global should_exit
should_exit = False
self.server.start_slave(Slave, self.masterq)
self.server.start_slave(Slave, Channel(self.masterq))
while not should_exit:
self.tick(self.masterq)
self.shutdown()
def handle(self, msg): # pragma: no cover
def handle(self, msg):
c = "handle_" + msg.__class__.__name__.lower()
m = getattr(self, c, None)
if m:
m(msg)
else:
msg._ack()
msg.reply()
def shutdown(self):
global should_exit

View File

@ -150,16 +150,6 @@ class DumpMaster(flow.FlowMaster):
print >> self.outfile, e
self.outfile.flush()
def handle_log(self, l):
self.add_event(l.msg)
l._ack()
def handle_request(self, r):
f = flow.FlowMaster.handle_request(self, r)
if f:
r._ack()
return f
def indent(self, n, t):
l = str(t).strip().split("\n")
return "\n".join(" "*n + i for i in l)
@ -210,10 +200,20 @@ class DumpMaster(flow.FlowMaster):
self.outfile.flush()
self.state.delete_flow(f)
def handle_log(self, l):
self.add_event(l.msg)
l.reply()
def handle_request(self, r):
f = flow.FlowMaster.handle_request(self, r)
if f:
r.reply()
return f
def handle_response(self, msg):
f = flow.FlowMaster.handle_response(self, msg)
if f:
msg._ack()
msg.reply()
self._process_flow(f)
return f

View File

@ -196,7 +196,7 @@ class decoded(object):
self.o.encode(self.ce)
class HTTPMsg(controller.Msg):
class HTTPMsg:
def get_decoded_content(self):
"""
Returns the decoded content based on the current Content-Encoding header.
@ -252,6 +252,7 @@ class HTTPMsg(controller.Msg):
return 0
return len(self.content)
class Request(HTTPMsg):
"""
An HTTP request.
@ -289,7 +290,6 @@ class Request(HTTPMsg):
self.timestamp_start = timestamp_start or utils.timestamp()
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
self.close = False
controller.Msg.__init__(self)
# Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False
@ -396,7 +396,6 @@ class Request(HTTPMsg):
Returns a copy of this object.
"""
c = copy.copy(self)
c.acked = True
c.headers = self.headers.copy()
return c
@ -603,7 +602,6 @@ class Response(HTTPMsg):
self.cert = cert
self.timestamp_start = timestamp_start or utils.timestamp()
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
controller.Msg.__init__(self)
self.replay = False
def _refresh_cookie(self, c, delta):
@ -708,7 +706,6 @@ class Response(HTTPMsg):
Returns a copy of this object.
"""
c = copy.copy(self)
c.acked = True
c.headers = self.headers.copy()
return c
@ -773,7 +770,7 @@ class Response(HTTPMsg):
cookies.append((cookie_name, (cookie_value, cookie_parameters)))
return dict(cookies)
class ClientDisconnect(controller.Msg):
class ClientDisconnect:
"""
A client disconnection event.
@ -782,11 +779,10 @@ class ClientDisconnect(controller.Msg):
client_conn: ClientConnect object.
"""
def __init__(self, client_conn):
controller.Msg.__init__(self)
self.client_conn = client_conn
class ClientConnect(controller.Msg):
class ClientConnect:
"""
A single client connection. Each connection can result in multiple HTTP
Requests.
@ -807,7 +803,6 @@ class ClientConnect(controller.Msg):
self.close = False
self.requestcount = 0
self.error = None
controller.Msg.__init__(self)
def __eq__(self, other):
return self._get_state() == other._get_state()
@ -838,11 +833,10 @@ class ClientConnect(controller.Msg):
Returns a copy of this object.
"""
c = copy.copy(self)
c.acked = True
return c
class Error(controller.Msg):
class Error:
"""
An Error.
@ -860,7 +854,6 @@ class Error(controller.Msg):
def __init__(self, request, msg, timestamp=None):
self.request, self.msg = request, msg
self.timestamp = timestamp or utils.timestamp()
controller.Msg.__init__(self)
def _load_state(self, state):
self.msg = state["msg"]
@ -871,7 +864,6 @@ class Error(controller.Msg):
Returns a copy of this object.
"""
c = copy.copy(self)
c.acked = True
return c
def _get_state(self):
@ -1180,10 +1172,11 @@ class Flow:
Kill this request.
"""
self.error = Error(self.request, "Connection killed")
if self.request and not self.request.acked:
self.request._ack(None)
elif self.response and not self.response.acked:
self.response._ack(None)
self.error.reply = controller.DummyReply()
if self.request and not self.request.reply.acked:
self.request.reply(None)
elif self.response and not self.response.reply.acked:
self.response.reply(None)
master.handle_error(self.error)
self.intercepting = False
@ -1199,10 +1192,10 @@ class Flow:
Continue with the flow - called after an intercept().
"""
if self.request:
if not self.request.acked:
self.request._ack()
elif self.response and not self.response.acked:
self.response._ack()
if not self.request.reply.acked:
self.request.reply()
elif self.response and not self.response.reply.acked:
self.response.reply()
self.intercepting = False
def replace(self, pattern, repl, *args, **kwargs):
@ -1464,7 +1457,7 @@ class FlowMaster(controller.Master):
flow.response = response
if self.refresh_server_playback:
response.refresh()
flow.request._ack(response)
flow.request.reply(response)
if self.server_playback.count() == 0:
self.stop_server_playback()
return True
@ -1491,10 +1484,13 @@ class FlowMaster(controller.Master):
Loads a flow, and returns a new flow object.
"""
if f.request:
f.request.reply = controller.DummyReply()
fr = self.handle_request(f.request)
if f.response:
f.response.reply = controller.DummyReply()
self.handle_response(f.response)
if f.error:
f.error.reply = controller.DummyReply()
self.handle_error(f.error)
return fr
@ -1522,7 +1518,7 @@ class FlowMaster(controller.Master):
if self.kill_nonreplay:
f.kill(self)
else:
f.request._ack()
f.request.reply()
def process_new_response(self, f):
if self.stickycookie_state:
@ -1561,11 +1557,11 @@ class FlowMaster(controller.Master):
def handle_clientconnect(self, cc):
self.run_script_hook("clientconnect", cc)
cc._ack()
cc.reply()
def handle_clientdisconnect(self, r):
self.run_script_hook("clientdisconnect", r)
r._ack()
r.reply()
def handle_error(self, r):
f = self.state.add_error(r)
@ -1573,7 +1569,7 @@ class FlowMaster(controller.Master):
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
r._ack()
r.reply()
return f
def handle_request(self, r):
@ -1596,7 +1592,7 @@ class FlowMaster(controller.Master):
if self.stream:
self.stream.add(f)
else:
r._ack()
r.reply()
return f
def shutdown(self):

View File

@ -29,9 +29,8 @@ class ProxyError(Exception):
return "ProxyError(%s, %s)"%(self.code, self.msg)
class Log(controller.Msg):
class Log:
def __init__(self, msg):
controller.Msg.__init__(self)
self.msg = msg
@ -51,7 +50,7 @@ class ProxyConfig:
class RequestReplayThread(threading.Thread):
def __init__(self, config, flow, masterq):
self.config, self.flow, self.masterq = config, flow, masterq
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
threading.Thread.__init__(self)
def run(self):
@ -66,10 +65,10 @@ class RequestReplayThread(threading.Thread):
response = flow.Response(
self.flow.request, httpversion, code, msg, headers, content, server.cert
)
response._send(self.masterq)
self.channel.ask(response)
except (ProxyError, http.HttpError, tcp.NetLibError), v:
err = flow.Error(self.flow.request, str(v))
err._send(self.masterq)
self.channel.ask(err)
class ServerConnection(tcp.TCPClient):
@ -128,8 +127,8 @@ class ServerConnectionPool:
class ProxyHandler(tcp.BaseHandler):
def __init__(self, config, connection, client_address, server, mqueue, server_version):
self.mqueue, self.server_version = mqueue, server_version
def __init__(self, config, connection, client_address, server, channel, server_version):
self.channel, self.server_version = channel, server_version
self.config = config
self.server_conn_pool = ServerConnectionPool(config)
self.proxy_connect_state = None
@ -139,18 +138,18 @@ class ProxyHandler(tcp.BaseHandler):
def handle(self):
cc = flow.ClientConnect(self.client_address)
self.log(cc, "connect")
cc._send(self.mqueue)
self.channel.ask(cc)
while self.handle_request(cc) and not cc.close:
pass
cc.close = True
cd = flow.ClientDisconnect(cc)
cd = flow.ClientDisconnect(cc)
self.log(
cc, "disconnect",
[
"handled %s requests"%cc.requestcount]
)
cd._send(self.mqueue)
self.channel.ask(cd)
def handle_request(self, cc):
try:
@ -167,14 +166,14 @@ class ProxyHandler(tcp.BaseHandler):
self.log(cc, "Error in wsgi app.", err.split("\n"))
return
else:
request = request._send(self.mqueue)
request = self.channel.ask(request)
if request is None:
return
if isinstance(request, flow.Response):
response = request
request = False
response = response._send(self.mqueue)
response = self.channel.ask(response)
else:
if self.config.reverse_proxy:
scheme, host, port = self.config.reverse_proxy
@ -192,7 +191,7 @@ class ProxyHandler(tcp.BaseHandler):
request, httpversion, code, msg, headers, content, sc.cert,
sc.rfile.first_byte_timestamp, utils.timestamp()
)
response = response._send(self.mqueue)
response = self.channel.ask(response)
if response is None:
sc.terminate()
if response is None:
@ -214,7 +213,7 @@ class ProxyHandler(tcp.BaseHandler):
if request:
err = flow.Error(request, cc.error)
err._send(self.mqueue)
self.channel.ask(err)
self.log(
cc, cc.error,
["url: %s"%request.get_url()]
@ -235,7 +234,7 @@ class ProxyHandler(tcp.BaseHandler):
msg.append(" -> "+i)
msg = "\n".join(msg)
l = Log(msg)
l._send(self.mqueue)
self.channel.ask(l)
def find_cert(self, host, port, sni):
if self.config.certfile:
@ -438,18 +437,18 @@ class ProxyServer(tcp.TCPServer):
tcp.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
self.channel = None
self.apps = AppRegistry()
def start_slave(self, klass, masterq):
slave = klass(masterq, self)
def start_slave(self, klass, channel):
slave = klass(channel, self)
slave.start()
def set_mqueue(self, q):
self.masterq = q
def set_channel(self, channel):
self.channel = channel
def handle_connection(self, request, client_address):
h = ProxyHandler(self.config, request, client_address, self, self.masterq, self.server_version)
h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version)
h.handle()
try:
h.finish()
@ -487,7 +486,7 @@ class DummyServer:
def __init__(self, config):
self.config = config
def start_slave(self, klass, masterq):
def start_slave(self, klass, channel):
pass
def shutdown(self):

View File

@ -3,6 +3,7 @@ from cStringIO import StringIO
import libpry
from libmproxy import dump, flow, proxy
import tutils
import mock
def test_strfuncs():
t = tutils.tresp()
@ -21,6 +22,7 @@ class TestDumpMaster:
req = tutils.treq()
req.content = content
l = proxy.Log("connect")
l.reply = mock.MagicMock()
m.handle_log(l)
cc = req.client_conn
cc.connection_error = "error"
@ -29,7 +31,9 @@ class TestDumpMaster:
m.handle_clientconnect(cc)
m.handle_request(req)
f = m.handle_response(resp)
m.handle_clientdisconnect(flow.ClientDisconnect(cc))
cd = flow.ClientDisconnect(cc)
cd.reply = mock.MagicMock()
m.handle_clientdisconnect(cd)
return f
def _dummy_cycle(self, n, filt, content, **options):

View File

@ -223,16 +223,16 @@ class TestFlow:
f = tutils.tflow()
f.request = tutils.treq()
f.intercept()
assert not f.request.acked
assert not f.request.reply.acked
f.kill(fm)
assert f.request.acked
assert f.request.reply.acked
f.intercept()
f.response = tutils.tresp()
f.request = f.response.request
f.request._ack()
assert not f.response.acked
f.request.reply()
assert not f.response.reply.acked
f.kill(fm)
assert f.response.acked
assert f.response.reply.acked
def test_killall(self):
s = flow.State()
@ -245,25 +245,25 @@ class TestFlow:
fm.handle_request(r)
for i in s.view:
assert not i.request.acked
assert not i.request.reply.acked
s.killall(fm)
for i in s.view:
assert i.request.acked
assert i.request.reply.acked
def test_accept_intercept(self):
f = tutils.tflow()
f.request = tutils.treq()
f.intercept()
assert not f.request.acked
assert not f.request.reply.acked
f.accept_intercept()
assert f.request.acked
assert f.request.reply.acked
f.response = tutils.tresp()
f.request = f.response.request
f.intercept()
f.request._ack()
assert not f.response.acked
f.request.reply()
assert not f.response.reply.acked
f.accept_intercept()
assert f.response.acked
assert f.response.reply.acked
def test_serialization(self):
f = flow.Flow(None)
@ -562,9 +562,11 @@ class TestFlowMaster:
fm.handle_response(resp)
assert fm.script.ns["log"][-1] == "response"
dc = flow.ClientDisconnect(req.client_conn)
dc.reply = controller.DummyReply()
fm.handle_clientdisconnect(dc)
assert fm.script.ns["log"][-1] == "clientdisconnect"
err = flow.Error(f.request, "msg")
err.reply = controller.DummyReply()
fm.handle_error(err)
assert fm.script.ns["log"][-1] == "error"
@ -598,10 +600,12 @@ class TestFlowMaster:
assert not fm.handle_response(rx)
dc = flow.ClientDisconnect(req.client_conn)
dc.reply = controller.DummyReply()
req.client_conn.requestcount = 1
fm.handle_clientdisconnect(dc)
err = flow.Error(f.request, "msg")
err.reply = controller.DummyReply()
fm.handle_error(err)
fm.load_script(tutils.test_data.path("scripts/a.py"))
@ -621,7 +625,9 @@ class TestFlowMaster:
fm.tick(q)
assert fm.state.flow_count()
fm.handle_error(flow.Error(f.request, "error"))
err = flow.Error(f.request, "error")
err.reply = controller.DummyReply()
fm.handle_error(err)
def test_server_playback(self):
controller.should_exit = False

View File

@ -31,7 +31,7 @@ class TestMaster(flow.FlowMaster):
def handle(self, m):
flow.FlowMaster.handle(self, m)
m._ack()
m.reply()
class ProxyThread(threading.Thread):

View File

@ -1,15 +1,18 @@
import os, shutil, tempfile
from contextlib import contextmanager
from libmproxy import flow, utils
from libmproxy import flow, utils, controller
from netlib import certutils
import mock
def treq(conn=None):
if not conn:
conn = flow.ClientConnect(("address", 22))
conn.reply = controller.DummyReply()
headers = flow.ODictCaseless()
headers["header"] = ["qvalue"]
return flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
r.reply = controller.DummyReply()
return r
def tresp(req=None):
@ -18,7 +21,9 @@ def tresp(req=None):
headers = flow.ODictCaseless()
headers["header_response"] = ["svalue"]
cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert")).read())
return flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
resp.reply = controller.DummyReply()
return resp
def tflow():
@ -37,9 +42,11 @@ def tflow_err():
r = treq()
f = flow.Flow(r)
f.error = flow.Error(r, "error")
f.error.reply = controller.DummyReply()
return f
@contextmanager
def tmpdir(*args, **kwargs):
orig_workdir = os.getcwd()