Sketch out a more solid core

- Decorator for handler methods
- Stricter checking for double-acks and non-acks
This commit is contained in:
Aldo Cortesi 2016-05-26 12:31:29 +12:00
parent f7e77d543b
commit 23efee9813
7 changed files with 79 additions and 35 deletions

View File

@ -16,7 +16,7 @@ import weakref
from netlib import tcp from netlib import tcp
from .. import flow, script, contentviews from .. import flow, script, contentviews, controller
from . import flowlist, flowview, help, window, signals, options from . import flowlist, flowview, help, window, signals, options
from . import grideditor, palettes, statusbar, palettepicker from . import grideditor, palettes, statusbar, palettepicker
from ..exceptions import FlowReadException, ScriptException from ..exceptions import FlowReadException, ScriptException

View File

@ -1,8 +1,14 @@
from __future__ import absolute_import from __future__ import absolute_import
from six.moves import queue from six.moves import queue
import threading import threading
import functools
import sys
from .exceptions import Kill from . import exceptions
class ControlError(Exception):
pass
class Master(object): class Master(object):
@ -36,6 +42,12 @@ class Master(object):
while True: while True:
mtype, obj = self.event_queue.get(timeout=timeout) mtype, obj = self.event_queue.get(timeout=timeout)
handle_func = getattr(self, "handle_" + mtype) handle_func = getattr(self, "handle_" + mtype)
# if not handle_func.func_dict.get("handler"):
# raise ControlError(
# "Handler function %s is not decorated with controller.handler"%(
# handle_func
# )
# )
handle_func(obj) handle_func(obj)
self.event_queue.task_done() self.event_queue.task_done()
changed = True changed = True
@ -100,7 +112,7 @@ class Channel(object):
master. Then wait for a response. master. Then wait for a response.
Raises: Raises:
Kill: All connections should be closed immediately. exceptions.Kill: All connections should be closed immediately.
""" """
m.reply = Reply(m) m.reply = Reply(m)
self.q.put((mtype, m)) self.q.put((mtype, m))
@ -110,11 +122,11 @@ class Channel(object):
g = m.reply.q.get(timeout=0.5) g = m.reply.q.get(timeout=0.5)
except queue.Empty: # pragma: no cover except queue.Empty: # pragma: no cover
continue continue
if g == Kill: if g == exceptions.Kill:
raise Kill() raise exceptions.Kill()
return g return g
raise Kill() raise exceptions.Kill()
def tell(self, mtype, m): def tell(self, mtype, m):
""" """
@ -133,6 +145,10 @@ class DummyReply(object):
def __init__(self): def __init__(self):
self.acked = False self.acked = False
self.taken = False
def take(self):
self.taken = True
def __call__(self, msg=False): def __call__(self, msg=False):
self.acked = True self.acked = True
@ -142,22 +158,44 @@ class DummyReply(object):
NO_REPLY = object() NO_REPLY = object()
def handler(f):
@functools.wraps(f)
def wrapper(obj, message, *args, **kwargs):
if not hasattr(message, "reply"):
raise ControlError("Message %s has no reply attribute"%message)
ret = f(obj, message, *args, **kwargs)
if not message.reply.acked and not message.reply.taken:
message.reply()
return ret
wrapper.func_dict["handler"] = True
return wrapper
class Reply(object): class Reply(object):
""" """
Messages sent through a channel are decorated with a "reply" attribute. Messages sent through a channel are decorated with a "reply" attribute.
This object is used to respond to the message through the return This object is used to respond to the message through the return
channel. channel.
""" """
def __init__(self, obj): def __init__(self, obj):
self.obj = obj self.obj = obj
self.q = queue.Queue() self.q = queue.Queue()
self.acked = False self.acked = False
self.taken = False
def take(self):
self.taken = True
def __call__(self, msg=NO_REPLY): def __call__(self, msg=NO_REPLY):
if not self.acked: if self.acked:
raise ControlError("Message already acked.")
self.acked = True self.acked = True
if msg is NO_REPLY: if msg is NO_REPLY:
self.q.put(self.obj) self.q.put(self.obj)
else: else:
self.q.put(msg) self.q.put(msg)
def __del__(self):
if not self.acked:
# This will be ignored by the interpreter, but emit a warning
raise ControlError("Un-acked message")

View File

@ -985,36 +985,36 @@ class FlowMaster(controller.ServerMaster):
if block: if block:
rt.join() rt.join()
@controller.handler
def handle_log(self, l): def handle_log(self, l):
self.add_event(l.msg, l.level) self.add_event(l.msg, l.level)
l.reply()
@controller.handler
def handle_clientconnect(self, root_layer): def handle_clientconnect(self, root_layer):
self.run_script_hook("clientconnect", root_layer) self.run_script_hook("clientconnect", root_layer)
root_layer.reply()
@controller.handler
def handle_clientdisconnect(self, root_layer): def handle_clientdisconnect(self, root_layer):
self.run_script_hook("clientdisconnect", root_layer) self.run_script_hook("clientdisconnect", root_layer)
root_layer.reply()
@controller.handler
def handle_serverconnect(self, server_conn): def handle_serverconnect(self, server_conn):
self.run_script_hook("serverconnect", server_conn) self.run_script_hook("serverconnect", server_conn)
server_conn.reply()
@controller.handler
def handle_serverdisconnect(self, server_conn): def handle_serverdisconnect(self, server_conn):
self.run_script_hook("serverdisconnect", server_conn) self.run_script_hook("serverdisconnect", server_conn)
server_conn.reply()
@controller.handler
def handle_next_layer(self, top_layer): def handle_next_layer(self, top_layer):
self.run_script_hook("next_layer", top_layer) self.run_script_hook("next_layer", top_layer)
top_layer.reply()
@controller.handler
def handle_error(self, f): def handle_error(self, f):
self.state.update_flow(f) self.state.update_flow(f)
self.run_script_hook("error", f) self.run_script_hook("error", f)
if self.client_playback: if self.client_playback:
self.client_playback.clear(f) self.client_playback.clear(f)
f.reply()
return f return f
def handle_request(self, f): def handle_request(self, f):

View File

@ -900,7 +900,7 @@ class TCPServer(object):
""" """
# If a thread has persisted after interpreter exit, the module might be # If a thread has persisted after interpreter exit, the module might be
# none. # none.
if traceback: if traceback and six:
exc = six.text_type(traceback.format_exc()) exc = six.text_type(traceback.format_exc())
print(u'-' * 40, file=fp) print(u'-' * 40, file=fp)
print( print(

View File

@ -2,7 +2,7 @@ from threading import Thread, Event
from mock import Mock from mock import Mock
from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master from mitmproxy import controller
from six.moves import queue from six.moves import queue
from mitmproxy.exceptions import Kill from mitmproxy.exceptions import Kill
@ -10,10 +10,15 @@ from mitmproxy.proxy import DummyServer
from netlib.tutils import raises from netlib.tutils import raises
class TMsg:
pass
class TestMaster(object): class TestMaster(object):
def test_simple(self): def test_simple(self):
class DummyMaster(Master): class DummyMaster(controller.Master):
@controller.handler
def handle_panic(self, _): def handle_panic(self, _):
m.should_exit.set() m.should_exit.set()
@ -23,14 +28,16 @@ class TestMaster(object):
m = DummyMaster() m = DummyMaster()
assert not m.should_exit.is_set() assert not m.should_exit.is_set()
m.event_queue.put(("panic", 42)) msg = TMsg()
msg.reply = controller.DummyReply()
m.event_queue.put(("panic", msg))
m.run() m.run()
assert m.should_exit.is_set() assert m.should_exit.is_set()
class TestServerMaster(object): class TestServerMaster(object):
def test_simple(self): def test_simple(self):
m = ServerMaster() m = controller.ServerMaster()
s = DummyServer(None) s = DummyServer(None)
m.add_server(s) m.add_server(s)
m.start() m.start()
@ -42,7 +49,7 @@ class TestServerMaster(object):
class TestServerThread(object): class TestServerThread(object):
def test_simple(self): def test_simple(self):
m = Mock() m = Mock()
t = ServerThread(m) t = controller.ServerThread(m)
t.run() t.run()
assert m.serve_forever.called assert m.serve_forever.called
@ -50,7 +57,7 @@ class TestServerThread(object):
class TestChannel(object): class TestChannel(object):
def test_tell(self): def test_tell(self):
q = queue.Queue() q = queue.Queue()
channel = Channel(q, Event()) channel = controller.Channel(q, Event())
m = Mock() m = Mock()
channel.tell("test", m) channel.tell("test", m)
assert q.get() == ("test", m) assert q.get() == ("test", m)
@ -66,21 +73,21 @@ class TestChannel(object):
Thread(target=reply).start() Thread(target=reply).start()
channel = Channel(q, Event()) channel = controller.Channel(q, Event())
assert channel.ask("test", Mock()) == 42 assert channel.ask("test", Mock()) == 42
def test_ask_shutdown(self): def test_ask_shutdown(self):
q = queue.Queue() q = queue.Queue()
done = Event() done = Event()
done.set() done.set()
channel = Channel(q, done) channel = controller.Channel(q, done)
with raises(Kill): with raises(Kill):
channel.ask("test", Mock()) channel.ask("test", Mock())
class TestDummyReply(object): class TestDummyReply(object):
def test_simple(self): def test_simple(self):
reply = DummyReply() reply = controller.DummyReply()
assert not reply.acked assert not reply.acked
reply() reply()
assert reply.acked assert reply.acked
@ -88,18 +95,18 @@ class TestDummyReply(object):
class TestReply(object): class TestReply(object):
def test_simple(self): def test_simple(self):
reply = Reply(42) reply = controller.Reply(42)
assert not reply.acked assert not reply.acked
reply("foo") reply("foo")
assert reply.acked assert reply.acked
assert reply.q.get() == "foo" assert reply.q.get() == "foo"
def test_default(self): def test_default(self):
reply = Reply(42) reply = controller.Reply(42)
reply() reply()
assert reply.q.get() == 42 assert reply.q.get() == 42
def test_reply_none(self): def test_reply_none(self):
reply = Reply(42) reply = controller.Reply(42)
reply(None) reply(None)
assert reply.q.get() is None assert reply.q.get() is None

View File

@ -866,7 +866,6 @@ class TestFlowMaster:
f.response = HTTPResponse.wrap(netlib.tutils.tresp()) f.response = HTTPResponse.wrap(netlib.tutils.tresp())
fm.handle_response(f) fm.handle_response(f)
assert not fm.handle_response(None)
assert s.flow_count() == 1 assert s.flow_count() == 1
fm.handle_clientdisconnect(f.client_conn) fm.handle_clientdisconnect(f.client_conn)

View File

@ -39,13 +39,13 @@ class TestMaster(flow.FlowMaster):
self.apps.add(errapp, "errapp", 80) self.apps.add(errapp, "errapp", 80)
self.clear_log() self.clear_log()
@controller.handler
def handle_request(self, f): def handle_request(self, f):
flow.FlowMaster.handle_request(self, f) flow.FlowMaster.handle_request(self, f)
f.reply()
@controller.handler
def handle_response(self, f): def handle_response(self, f):
flow.FlowMaster.handle_response(self, f) flow.FlowMaster.handle_response(self, f)
f.reply()
def clear_log(self): def clear_log(self):
self.log = [] self.log = []