Sketch out a more solid core

- Decorator for handler methods
- Stricter checking for double-acks and non-acks
This commit is contained in:
Aldo Cortesi 2016-05-26 12:31:29 +12:00
parent f7e77d543b
commit 23efee9813
7 changed files with 79 additions and 35 deletions

View File

@ -16,7 +16,7 @@ import weakref
from netlib import tcp
from .. import flow, script, contentviews
from .. import flow, script, contentviews, controller
from . import flowlist, flowview, help, window, signals, options
from . import grideditor, palettes, statusbar, palettepicker
from ..exceptions import FlowReadException, ScriptException

View File

@ -1,8 +1,14 @@
from __future__ import absolute_import
from six.moves import queue
import threading
import functools
import sys
from .exceptions import Kill
from . import exceptions
class ControlError(Exception):
pass
class Master(object):
@ -36,6 +42,12 @@ class Master(object):
while True:
mtype, obj = self.event_queue.get(timeout=timeout)
handle_func = getattr(self, "handle_" + mtype)
# if not handle_func.func_dict.get("handler"):
# raise ControlError(
# "Handler function %s is not decorated with controller.handler"%(
# handle_func
# )
# )
handle_func(obj)
self.event_queue.task_done()
changed = True
@ -100,7 +112,7 @@ class Channel(object):
master. Then wait for a response.
Raises:
Kill: All connections should be closed immediately.
exceptions.Kill: All connections should be closed immediately.
"""
m.reply = Reply(m)
self.q.put((mtype, m))
@ -110,11 +122,11 @@ class Channel(object):
g = m.reply.q.get(timeout=0.5)
except queue.Empty: # pragma: no cover
continue
if g == Kill:
raise Kill()
if g == exceptions.Kill:
raise exceptions.Kill()
return g
raise Kill()
raise exceptions.Kill()
def tell(self, mtype, m):
"""
@ -133,6 +145,10 @@ class DummyReply(object):
def __init__(self):
self.acked = False
self.taken = False
def take(self):
self.taken = True
def __call__(self, msg=False):
self.acked = True
@ -142,22 +158,44 @@ class DummyReply(object):
NO_REPLY = object()
def handler(f):
@functools.wraps(f)
def wrapper(obj, message, *args, **kwargs):
if not hasattr(message, "reply"):
raise ControlError("Message %s has no reply attribute"%message)
ret = f(obj, message, *args, **kwargs)
if not message.reply.acked and not message.reply.taken:
message.reply()
return ret
wrapper.func_dict["handler"] = True
return wrapper
class Reply(object):
"""
Messages sent through a channel are decorated with a "reply" attribute.
This object is used to respond to the message through the return
channel.
"""
def __init__(self, obj):
self.obj = obj
self.q = queue.Queue()
self.acked = False
self.taken = False
def take(self):
self.taken = True
def __call__(self, msg=NO_REPLY):
if self.acked:
raise ControlError("Message already acked.")
self.acked = True
if msg is NO_REPLY:
self.q.put(self.obj)
else:
self.q.put(msg)
def __del__(self):
if not self.acked:
self.acked = True
if msg is NO_REPLY:
self.q.put(self.obj)
else:
self.q.put(msg)
# This will be ignored by the interpreter, but emit a warning
raise ControlError("Un-acked message")

View File

@ -985,36 +985,36 @@ class FlowMaster(controller.ServerMaster):
if block:
rt.join()
@controller.handler
def handle_log(self, l):
self.add_event(l.msg, l.level)
l.reply()
@controller.handler
def handle_clientconnect(self, root_layer):
self.run_script_hook("clientconnect", root_layer)
root_layer.reply()
@controller.handler
def handle_clientdisconnect(self, root_layer):
self.run_script_hook("clientdisconnect", root_layer)
root_layer.reply()
@controller.handler
def handle_serverconnect(self, server_conn):
self.run_script_hook("serverconnect", server_conn)
server_conn.reply()
@controller.handler
def handle_serverdisconnect(self, server_conn):
self.run_script_hook("serverdisconnect", server_conn)
server_conn.reply()
@controller.handler
def handle_next_layer(self, top_layer):
self.run_script_hook("next_layer", top_layer)
top_layer.reply()
@controller.handler
def handle_error(self, f):
self.state.update_flow(f)
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
f.reply()
return f
def handle_request(self, f):

View File

@ -900,7 +900,7 @@ class TCPServer(object):
"""
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
if traceback and six:
exc = six.text_type(traceback.format_exc())
print(u'-' * 40, file=fp)
print(

View File

@ -2,7 +2,7 @@ from threading import Thread, Event
from mock import Mock
from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master
from mitmproxy import controller
from six.moves import queue
from mitmproxy.exceptions import Kill
@ -10,10 +10,15 @@ from mitmproxy.proxy import DummyServer
from netlib.tutils import raises
class TMsg:
pass
class TestMaster(object):
def test_simple(self):
class DummyMaster(Master):
class DummyMaster(controller.Master):
@controller.handler
def handle_panic(self, _):
m.should_exit.set()
@ -23,14 +28,16 @@ class TestMaster(object):
m = DummyMaster()
assert not m.should_exit.is_set()
m.event_queue.put(("panic", 42))
msg = TMsg()
msg.reply = controller.DummyReply()
m.event_queue.put(("panic", msg))
m.run()
assert m.should_exit.is_set()
class TestServerMaster(object):
def test_simple(self):
m = ServerMaster()
m = controller.ServerMaster()
s = DummyServer(None)
m.add_server(s)
m.start()
@ -42,7 +49,7 @@ class TestServerMaster(object):
class TestServerThread(object):
def test_simple(self):
m = Mock()
t = ServerThread(m)
t = controller.ServerThread(m)
t.run()
assert m.serve_forever.called
@ -50,7 +57,7 @@ class TestServerThread(object):
class TestChannel(object):
def test_tell(self):
q = queue.Queue()
channel = Channel(q, Event())
channel = controller.Channel(q, Event())
m = Mock()
channel.tell("test", m)
assert q.get() == ("test", m)
@ -66,21 +73,21 @@ class TestChannel(object):
Thread(target=reply).start()
channel = Channel(q, Event())
channel = controller.Channel(q, Event())
assert channel.ask("test", Mock()) == 42
def test_ask_shutdown(self):
q = queue.Queue()
done = Event()
done.set()
channel = Channel(q, done)
channel = controller.Channel(q, done)
with raises(Kill):
channel.ask("test", Mock())
class TestDummyReply(object):
def test_simple(self):
reply = DummyReply()
reply = controller.DummyReply()
assert not reply.acked
reply()
assert reply.acked
@ -88,18 +95,18 @@ class TestDummyReply(object):
class TestReply(object):
def test_simple(self):
reply = Reply(42)
reply = controller.Reply(42)
assert not reply.acked
reply("foo")
assert reply.acked
assert reply.q.get() == "foo"
def test_default(self):
reply = Reply(42)
reply = controller.Reply(42)
reply()
assert reply.q.get() == 42
def test_reply_none(self):
reply = Reply(42)
reply = controller.Reply(42)
reply(None)
assert reply.q.get() is None

View File

@ -866,7 +866,6 @@ class TestFlowMaster:
f.response = HTTPResponse.wrap(netlib.tutils.tresp())
fm.handle_response(f)
assert not fm.handle_response(None)
assert s.flow_count() == 1
fm.handle_clientdisconnect(f.client_conn)

View File

@ -39,13 +39,13 @@ class TestMaster(flow.FlowMaster):
self.apps.add(errapp, "errapp", 80)
self.clear_log()
@controller.handler
def handle_request(self, f):
flow.FlowMaster.handle_request(self, f)
f.reply()
@controller.handler
def handle_response(self, f):
flow.FlowMaster.handle_response(self, f)
f.reply()
def clear_log(self):
self.log = []