mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Sketch out a more solid core
- Decorator for handler methods - Stricter checking for double-acks and non-acks
This commit is contained in:
parent
f7e77d543b
commit
23efee9813
@ -16,7 +16,7 @@ import weakref
|
|||||||
|
|
||||||
from netlib import tcp
|
from netlib import tcp
|
||||||
|
|
||||||
from .. import flow, script, contentviews
|
from .. import flow, script, contentviews, controller
|
||||||
from . import flowlist, flowview, help, window, signals, options
|
from . import flowlist, flowview, help, window, signals, options
|
||||||
from . import grideditor, palettes, statusbar, palettepicker
|
from . import grideditor, palettes, statusbar, palettepicker
|
||||||
from ..exceptions import FlowReadException, ScriptException
|
from ..exceptions import FlowReadException, ScriptException
|
||||||
|
@ -1,8 +1,14 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from six.moves import queue
|
from six.moves import queue
|
||||||
import threading
|
import threading
|
||||||
|
import functools
|
||||||
|
import sys
|
||||||
|
|
||||||
from .exceptions import Kill
|
from . import exceptions
|
||||||
|
|
||||||
|
|
||||||
|
class ControlError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Master(object):
|
class Master(object):
|
||||||
@ -36,6 +42,12 @@ class Master(object):
|
|||||||
while True:
|
while True:
|
||||||
mtype, obj = self.event_queue.get(timeout=timeout)
|
mtype, obj = self.event_queue.get(timeout=timeout)
|
||||||
handle_func = getattr(self, "handle_" + mtype)
|
handle_func = getattr(self, "handle_" + mtype)
|
||||||
|
# if not handle_func.func_dict.get("handler"):
|
||||||
|
# raise ControlError(
|
||||||
|
# "Handler function %s is not decorated with controller.handler"%(
|
||||||
|
# handle_func
|
||||||
|
# )
|
||||||
|
# )
|
||||||
handle_func(obj)
|
handle_func(obj)
|
||||||
self.event_queue.task_done()
|
self.event_queue.task_done()
|
||||||
changed = True
|
changed = True
|
||||||
@ -100,7 +112,7 @@ class Channel(object):
|
|||||||
master. Then wait for a response.
|
master. Then wait for a response.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
Kill: All connections should be closed immediately.
|
exceptions.Kill: All connections should be closed immediately.
|
||||||
"""
|
"""
|
||||||
m.reply = Reply(m)
|
m.reply = Reply(m)
|
||||||
self.q.put((mtype, m))
|
self.q.put((mtype, m))
|
||||||
@ -110,11 +122,11 @@ class Channel(object):
|
|||||||
g = m.reply.q.get(timeout=0.5)
|
g = m.reply.q.get(timeout=0.5)
|
||||||
except queue.Empty: # pragma: no cover
|
except queue.Empty: # pragma: no cover
|
||||||
continue
|
continue
|
||||||
if g == Kill:
|
if g == exceptions.Kill:
|
||||||
raise Kill()
|
raise exceptions.Kill()
|
||||||
return g
|
return g
|
||||||
|
|
||||||
raise Kill()
|
raise exceptions.Kill()
|
||||||
|
|
||||||
def tell(self, mtype, m):
|
def tell(self, mtype, m):
|
||||||
"""
|
"""
|
||||||
@ -133,6 +145,10 @@ class DummyReply(object):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.acked = False
|
self.acked = False
|
||||||
|
self.taken = False
|
||||||
|
|
||||||
|
def take(self):
|
||||||
|
self.taken = True
|
||||||
|
|
||||||
def __call__(self, msg=False):
|
def __call__(self, msg=False):
|
||||||
self.acked = True
|
self.acked = True
|
||||||
@ -142,22 +158,44 @@ class DummyReply(object):
|
|||||||
NO_REPLY = object()
|
NO_REPLY = object()
|
||||||
|
|
||||||
|
|
||||||
|
def handler(f):
|
||||||
|
@functools.wraps(f)
|
||||||
|
def wrapper(obj, message, *args, **kwargs):
|
||||||
|
if not hasattr(message, "reply"):
|
||||||
|
raise ControlError("Message %s has no reply attribute"%message)
|
||||||
|
ret = f(obj, message, *args, **kwargs)
|
||||||
|
if not message.reply.acked and not message.reply.taken:
|
||||||
|
message.reply()
|
||||||
|
return ret
|
||||||
|
wrapper.func_dict["handler"] = True
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Reply(object):
|
class Reply(object):
|
||||||
"""
|
"""
|
||||||
Messages sent through a channel are decorated with a "reply" attribute.
|
Messages sent through a channel are decorated with a "reply" attribute.
|
||||||
This object is used to respond to the message through the return
|
This object is used to respond to the message through the return
|
||||||
channel.
|
channel.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, obj):
|
def __init__(self, obj):
|
||||||
self.obj = obj
|
self.obj = obj
|
||||||
self.q = queue.Queue()
|
self.q = queue.Queue()
|
||||||
self.acked = False
|
self.acked = False
|
||||||
|
self.taken = False
|
||||||
|
|
||||||
|
def take(self):
|
||||||
|
self.taken = True
|
||||||
|
|
||||||
def __call__(self, msg=NO_REPLY):
|
def __call__(self, msg=NO_REPLY):
|
||||||
if not self.acked:
|
if self.acked:
|
||||||
|
raise ControlError("Message already acked.")
|
||||||
self.acked = True
|
self.acked = True
|
||||||
if msg is NO_REPLY:
|
if msg is NO_REPLY:
|
||||||
self.q.put(self.obj)
|
self.q.put(self.obj)
|
||||||
else:
|
else:
|
||||||
self.q.put(msg)
|
self.q.put(msg)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if not self.acked:
|
||||||
|
# This will be ignored by the interpreter, but emit a warning
|
||||||
|
raise ControlError("Un-acked message")
|
||||||
|
@ -985,36 +985,36 @@ class FlowMaster(controller.ServerMaster):
|
|||||||
if block:
|
if block:
|
||||||
rt.join()
|
rt.join()
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_log(self, l):
|
def handle_log(self, l):
|
||||||
self.add_event(l.msg, l.level)
|
self.add_event(l.msg, l.level)
|
||||||
l.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_clientconnect(self, root_layer):
|
def handle_clientconnect(self, root_layer):
|
||||||
self.run_script_hook("clientconnect", root_layer)
|
self.run_script_hook("clientconnect", root_layer)
|
||||||
root_layer.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_clientdisconnect(self, root_layer):
|
def handle_clientdisconnect(self, root_layer):
|
||||||
self.run_script_hook("clientdisconnect", root_layer)
|
self.run_script_hook("clientdisconnect", root_layer)
|
||||||
root_layer.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_serverconnect(self, server_conn):
|
def handle_serverconnect(self, server_conn):
|
||||||
self.run_script_hook("serverconnect", server_conn)
|
self.run_script_hook("serverconnect", server_conn)
|
||||||
server_conn.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_serverdisconnect(self, server_conn):
|
def handle_serverdisconnect(self, server_conn):
|
||||||
self.run_script_hook("serverdisconnect", server_conn)
|
self.run_script_hook("serverdisconnect", server_conn)
|
||||||
server_conn.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_next_layer(self, top_layer):
|
def handle_next_layer(self, top_layer):
|
||||||
self.run_script_hook("next_layer", top_layer)
|
self.run_script_hook("next_layer", top_layer)
|
||||||
top_layer.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_error(self, f):
|
def handle_error(self, f):
|
||||||
self.state.update_flow(f)
|
self.state.update_flow(f)
|
||||||
self.run_script_hook("error", f)
|
self.run_script_hook("error", f)
|
||||||
if self.client_playback:
|
if self.client_playback:
|
||||||
self.client_playback.clear(f)
|
self.client_playback.clear(f)
|
||||||
f.reply()
|
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def handle_request(self, f):
|
def handle_request(self, f):
|
||||||
|
@ -900,7 +900,7 @@ class TCPServer(object):
|
|||||||
"""
|
"""
|
||||||
# If a thread has persisted after interpreter exit, the module might be
|
# If a thread has persisted after interpreter exit, the module might be
|
||||||
# none.
|
# none.
|
||||||
if traceback:
|
if traceback and six:
|
||||||
exc = six.text_type(traceback.format_exc())
|
exc = six.text_type(traceback.format_exc())
|
||||||
print(u'-' * 40, file=fp)
|
print(u'-' * 40, file=fp)
|
||||||
print(
|
print(
|
||||||
|
@ -2,7 +2,7 @@ from threading import Thread, Event
|
|||||||
|
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
|
|
||||||
from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master
|
from mitmproxy import controller
|
||||||
from six.moves import queue
|
from six.moves import queue
|
||||||
|
|
||||||
from mitmproxy.exceptions import Kill
|
from mitmproxy.exceptions import Kill
|
||||||
@ -10,10 +10,15 @@ from mitmproxy.proxy import DummyServer
|
|||||||
from netlib.tutils import raises
|
from netlib.tutils import raises
|
||||||
|
|
||||||
|
|
||||||
|
class TMsg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestMaster(object):
|
class TestMaster(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
|
||||||
class DummyMaster(Master):
|
class DummyMaster(controller.Master):
|
||||||
|
@controller.handler
|
||||||
def handle_panic(self, _):
|
def handle_panic(self, _):
|
||||||
m.should_exit.set()
|
m.should_exit.set()
|
||||||
|
|
||||||
@ -23,14 +28,16 @@ class TestMaster(object):
|
|||||||
|
|
||||||
m = DummyMaster()
|
m = DummyMaster()
|
||||||
assert not m.should_exit.is_set()
|
assert not m.should_exit.is_set()
|
||||||
m.event_queue.put(("panic", 42))
|
msg = TMsg()
|
||||||
|
msg.reply = controller.DummyReply()
|
||||||
|
m.event_queue.put(("panic", msg))
|
||||||
m.run()
|
m.run()
|
||||||
assert m.should_exit.is_set()
|
assert m.should_exit.is_set()
|
||||||
|
|
||||||
|
|
||||||
class TestServerMaster(object):
|
class TestServerMaster(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
m = ServerMaster()
|
m = controller.ServerMaster()
|
||||||
s = DummyServer(None)
|
s = DummyServer(None)
|
||||||
m.add_server(s)
|
m.add_server(s)
|
||||||
m.start()
|
m.start()
|
||||||
@ -42,7 +49,7 @@ class TestServerMaster(object):
|
|||||||
class TestServerThread(object):
|
class TestServerThread(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
m = Mock()
|
m = Mock()
|
||||||
t = ServerThread(m)
|
t = controller.ServerThread(m)
|
||||||
t.run()
|
t.run()
|
||||||
assert m.serve_forever.called
|
assert m.serve_forever.called
|
||||||
|
|
||||||
@ -50,7 +57,7 @@ class TestServerThread(object):
|
|||||||
class TestChannel(object):
|
class TestChannel(object):
|
||||||
def test_tell(self):
|
def test_tell(self):
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
channel = Channel(q, Event())
|
channel = controller.Channel(q, Event())
|
||||||
m = Mock()
|
m = Mock()
|
||||||
channel.tell("test", m)
|
channel.tell("test", m)
|
||||||
assert q.get() == ("test", m)
|
assert q.get() == ("test", m)
|
||||||
@ -66,21 +73,21 @@ class TestChannel(object):
|
|||||||
|
|
||||||
Thread(target=reply).start()
|
Thread(target=reply).start()
|
||||||
|
|
||||||
channel = Channel(q, Event())
|
channel = controller.Channel(q, Event())
|
||||||
assert channel.ask("test", Mock()) == 42
|
assert channel.ask("test", Mock()) == 42
|
||||||
|
|
||||||
def test_ask_shutdown(self):
|
def test_ask_shutdown(self):
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
done = Event()
|
done = Event()
|
||||||
done.set()
|
done.set()
|
||||||
channel = Channel(q, done)
|
channel = controller.Channel(q, done)
|
||||||
with raises(Kill):
|
with raises(Kill):
|
||||||
channel.ask("test", Mock())
|
channel.ask("test", Mock())
|
||||||
|
|
||||||
|
|
||||||
class TestDummyReply(object):
|
class TestDummyReply(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
reply = DummyReply()
|
reply = controller.DummyReply()
|
||||||
assert not reply.acked
|
assert not reply.acked
|
||||||
reply()
|
reply()
|
||||||
assert reply.acked
|
assert reply.acked
|
||||||
@ -88,18 +95,18 @@ class TestDummyReply(object):
|
|||||||
|
|
||||||
class TestReply(object):
|
class TestReply(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
reply = Reply(42)
|
reply = controller.Reply(42)
|
||||||
assert not reply.acked
|
assert not reply.acked
|
||||||
reply("foo")
|
reply("foo")
|
||||||
assert reply.acked
|
assert reply.acked
|
||||||
assert reply.q.get() == "foo"
|
assert reply.q.get() == "foo"
|
||||||
|
|
||||||
def test_default(self):
|
def test_default(self):
|
||||||
reply = Reply(42)
|
reply = controller.Reply(42)
|
||||||
reply()
|
reply()
|
||||||
assert reply.q.get() == 42
|
assert reply.q.get() == 42
|
||||||
|
|
||||||
def test_reply_none(self):
|
def test_reply_none(self):
|
||||||
reply = Reply(42)
|
reply = controller.Reply(42)
|
||||||
reply(None)
|
reply(None)
|
||||||
assert reply.q.get() is None
|
assert reply.q.get() is None
|
||||||
|
@ -866,7 +866,6 @@ class TestFlowMaster:
|
|||||||
|
|
||||||
f.response = HTTPResponse.wrap(netlib.tutils.tresp())
|
f.response = HTTPResponse.wrap(netlib.tutils.tresp())
|
||||||
fm.handle_response(f)
|
fm.handle_response(f)
|
||||||
assert not fm.handle_response(None)
|
|
||||||
assert s.flow_count() == 1
|
assert s.flow_count() == 1
|
||||||
|
|
||||||
fm.handle_clientdisconnect(f.client_conn)
|
fm.handle_clientdisconnect(f.client_conn)
|
||||||
|
@ -39,13 +39,13 @@ class TestMaster(flow.FlowMaster):
|
|||||||
self.apps.add(errapp, "errapp", 80)
|
self.apps.add(errapp, "errapp", 80)
|
||||||
self.clear_log()
|
self.clear_log()
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_request(self, f):
|
def handle_request(self, f):
|
||||||
flow.FlowMaster.handle_request(self, f)
|
flow.FlowMaster.handle_request(self, f)
|
||||||
f.reply()
|
|
||||||
|
|
||||||
|
@controller.handler
|
||||||
def handle_response(self, f):
|
def handle_response(self, f):
|
||||||
flow.FlowMaster.handle_response(self, f)
|
flow.FlowMaster.handle_response(self, f)
|
||||||
f.reply()
|
|
||||||
|
|
||||||
def clear_log(self):
|
def clear_log(self):
|
||||||
self.log = []
|
self.log = []
|
||||||
|
Loading…
Reference in New Issue
Block a user