mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
Sketch out a more solid core
- Decorator for handler methods - Stricter checking for double-acks and non-acks
This commit is contained in:
parent
f7e77d543b
commit
23efee9813
@ -16,7 +16,7 @@ import weakref
|
||||
|
||||
from netlib import tcp
|
||||
|
||||
from .. import flow, script, contentviews
|
||||
from .. import flow, script, contentviews, controller
|
||||
from . import flowlist, flowview, help, window, signals, options
|
||||
from . import grideditor, palettes, statusbar, palettepicker
|
||||
from ..exceptions import FlowReadException, ScriptException
|
||||
|
@ -1,8 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
from six.moves import queue
|
||||
import threading
|
||||
import functools
|
||||
import sys
|
||||
|
||||
from .exceptions import Kill
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class ControlError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Master(object):
|
||||
@ -36,6 +42,12 @@ class Master(object):
|
||||
while True:
|
||||
mtype, obj = self.event_queue.get(timeout=timeout)
|
||||
handle_func = getattr(self, "handle_" + mtype)
|
||||
# if not handle_func.func_dict.get("handler"):
|
||||
# raise ControlError(
|
||||
# "Handler function %s is not decorated with controller.handler"%(
|
||||
# handle_func
|
||||
# )
|
||||
# )
|
||||
handle_func(obj)
|
||||
self.event_queue.task_done()
|
||||
changed = True
|
||||
@ -100,7 +112,7 @@ class Channel(object):
|
||||
master. Then wait for a response.
|
||||
|
||||
Raises:
|
||||
Kill: All connections should be closed immediately.
|
||||
exceptions.Kill: All connections should be closed immediately.
|
||||
"""
|
||||
m.reply = Reply(m)
|
||||
self.q.put((mtype, m))
|
||||
@ -110,11 +122,11 @@ class Channel(object):
|
||||
g = m.reply.q.get(timeout=0.5)
|
||||
except queue.Empty: # pragma: no cover
|
||||
continue
|
||||
if g == Kill:
|
||||
raise Kill()
|
||||
if g == exceptions.Kill:
|
||||
raise exceptions.Kill()
|
||||
return g
|
||||
|
||||
raise Kill()
|
||||
raise exceptions.Kill()
|
||||
|
||||
def tell(self, mtype, m):
|
||||
"""
|
||||
@ -133,6 +145,10 @@ class DummyReply(object):
|
||||
|
||||
def __init__(self):
|
||||
self.acked = False
|
||||
self.taken = False
|
||||
|
||||
def take(self):
|
||||
self.taken = True
|
||||
|
||||
def __call__(self, msg=False):
|
||||
self.acked = True
|
||||
@ -142,22 +158,44 @@ class DummyReply(object):
|
||||
NO_REPLY = object()
|
||||
|
||||
|
||||
def handler(f):
|
||||
@functools.wraps(f)
|
||||
def wrapper(obj, message, *args, **kwargs):
|
||||
if not hasattr(message, "reply"):
|
||||
raise ControlError("Message %s has no reply attribute"%message)
|
||||
ret = f(obj, message, *args, **kwargs)
|
||||
if not message.reply.acked and not message.reply.taken:
|
||||
message.reply()
|
||||
return ret
|
||||
wrapper.func_dict["handler"] = True
|
||||
return wrapper
|
||||
|
||||
|
||||
class Reply(object):
|
||||
"""
|
||||
Messages sent through a channel are decorated with a "reply" attribute.
|
||||
This object is used to respond to the message through the return
|
||||
channel.
|
||||
"""
|
||||
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
self.q = queue.Queue()
|
||||
self.acked = False
|
||||
self.taken = False
|
||||
|
||||
def take(self):
|
||||
self.taken = True
|
||||
|
||||
def __call__(self, msg=NO_REPLY):
|
||||
if self.acked:
|
||||
raise ControlError("Message already acked.")
|
||||
self.acked = True
|
||||
if msg is NO_REPLY:
|
||||
self.q.put(self.obj)
|
||||
else:
|
||||
self.q.put(msg)
|
||||
|
||||
def __del__(self):
|
||||
if not self.acked:
|
||||
self.acked = True
|
||||
if msg is NO_REPLY:
|
||||
self.q.put(self.obj)
|
||||
else:
|
||||
self.q.put(msg)
|
||||
# This will be ignored by the interpreter, but emit a warning
|
||||
raise ControlError("Un-acked message")
|
||||
|
@ -985,36 +985,36 @@ class FlowMaster(controller.ServerMaster):
|
||||
if block:
|
||||
rt.join()
|
||||
|
||||
@controller.handler
|
||||
def handle_log(self, l):
|
||||
self.add_event(l.msg, l.level)
|
||||
l.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_clientconnect(self, root_layer):
|
||||
self.run_script_hook("clientconnect", root_layer)
|
||||
root_layer.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_clientdisconnect(self, root_layer):
|
||||
self.run_script_hook("clientdisconnect", root_layer)
|
||||
root_layer.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_serverconnect(self, server_conn):
|
||||
self.run_script_hook("serverconnect", server_conn)
|
||||
server_conn.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_serverdisconnect(self, server_conn):
|
||||
self.run_script_hook("serverdisconnect", server_conn)
|
||||
server_conn.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_next_layer(self, top_layer):
|
||||
self.run_script_hook("next_layer", top_layer)
|
||||
top_layer.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_error(self, f):
|
||||
self.state.update_flow(f)
|
||||
self.run_script_hook("error", f)
|
||||
if self.client_playback:
|
||||
self.client_playback.clear(f)
|
||||
f.reply()
|
||||
return f
|
||||
|
||||
def handle_request(self, f):
|
||||
|
@ -900,7 +900,7 @@ class TCPServer(object):
|
||||
"""
|
||||
# If a thread has persisted after interpreter exit, the module might be
|
||||
# none.
|
||||
if traceback:
|
||||
if traceback and six:
|
||||
exc = six.text_type(traceback.format_exc())
|
||||
print(u'-' * 40, file=fp)
|
||||
print(
|
||||
|
@ -2,7 +2,7 @@ from threading import Thread, Event
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from mitmproxy.controller import Reply, DummyReply, Channel, ServerThread, ServerMaster, Master
|
||||
from mitmproxy import controller
|
||||
from six.moves import queue
|
||||
|
||||
from mitmproxy.exceptions import Kill
|
||||
@ -10,10 +10,15 @@ from mitmproxy.proxy import DummyServer
|
||||
from netlib.tutils import raises
|
||||
|
||||
|
||||
class TMsg:
|
||||
pass
|
||||
|
||||
|
||||
class TestMaster(object):
|
||||
def test_simple(self):
|
||||
|
||||
class DummyMaster(Master):
|
||||
class DummyMaster(controller.Master):
|
||||
@controller.handler
|
||||
def handle_panic(self, _):
|
||||
m.should_exit.set()
|
||||
|
||||
@ -23,14 +28,16 @@ class TestMaster(object):
|
||||
|
||||
m = DummyMaster()
|
||||
assert not m.should_exit.is_set()
|
||||
m.event_queue.put(("panic", 42))
|
||||
msg = TMsg()
|
||||
msg.reply = controller.DummyReply()
|
||||
m.event_queue.put(("panic", msg))
|
||||
m.run()
|
||||
assert m.should_exit.is_set()
|
||||
|
||||
|
||||
class TestServerMaster(object):
|
||||
def test_simple(self):
|
||||
m = ServerMaster()
|
||||
m = controller.ServerMaster()
|
||||
s = DummyServer(None)
|
||||
m.add_server(s)
|
||||
m.start()
|
||||
@ -42,7 +49,7 @@ class TestServerMaster(object):
|
||||
class TestServerThread(object):
|
||||
def test_simple(self):
|
||||
m = Mock()
|
||||
t = ServerThread(m)
|
||||
t = controller.ServerThread(m)
|
||||
t.run()
|
||||
assert m.serve_forever.called
|
||||
|
||||
@ -50,7 +57,7 @@ class TestServerThread(object):
|
||||
class TestChannel(object):
|
||||
def test_tell(self):
|
||||
q = queue.Queue()
|
||||
channel = Channel(q, Event())
|
||||
channel = controller.Channel(q, Event())
|
||||
m = Mock()
|
||||
channel.tell("test", m)
|
||||
assert q.get() == ("test", m)
|
||||
@ -66,21 +73,21 @@ class TestChannel(object):
|
||||
|
||||
Thread(target=reply).start()
|
||||
|
||||
channel = Channel(q, Event())
|
||||
channel = controller.Channel(q, Event())
|
||||
assert channel.ask("test", Mock()) == 42
|
||||
|
||||
def test_ask_shutdown(self):
|
||||
q = queue.Queue()
|
||||
done = Event()
|
||||
done.set()
|
||||
channel = Channel(q, done)
|
||||
channel = controller.Channel(q, done)
|
||||
with raises(Kill):
|
||||
channel.ask("test", Mock())
|
||||
|
||||
|
||||
class TestDummyReply(object):
|
||||
def test_simple(self):
|
||||
reply = DummyReply()
|
||||
reply = controller.DummyReply()
|
||||
assert not reply.acked
|
||||
reply()
|
||||
assert reply.acked
|
||||
@ -88,18 +95,18 @@ class TestDummyReply(object):
|
||||
|
||||
class TestReply(object):
|
||||
def test_simple(self):
|
||||
reply = Reply(42)
|
||||
reply = controller.Reply(42)
|
||||
assert not reply.acked
|
||||
reply("foo")
|
||||
assert reply.acked
|
||||
assert reply.q.get() == "foo"
|
||||
|
||||
def test_default(self):
|
||||
reply = Reply(42)
|
||||
reply = controller.Reply(42)
|
||||
reply()
|
||||
assert reply.q.get() == 42
|
||||
|
||||
def test_reply_none(self):
|
||||
reply = Reply(42)
|
||||
reply = controller.Reply(42)
|
||||
reply(None)
|
||||
assert reply.q.get() is None
|
||||
|
@ -866,7 +866,6 @@ class TestFlowMaster:
|
||||
|
||||
f.response = HTTPResponse.wrap(netlib.tutils.tresp())
|
||||
fm.handle_response(f)
|
||||
assert not fm.handle_response(None)
|
||||
assert s.flow_count() == 1
|
||||
|
||||
fm.handle_clientdisconnect(f.client_conn)
|
||||
|
@ -39,13 +39,13 @@ class TestMaster(flow.FlowMaster):
|
||||
self.apps.add(errapp, "errapp", 80)
|
||||
self.clear_log()
|
||||
|
||||
@controller.handler
|
||||
def handle_request(self, f):
|
||||
flow.FlowMaster.handle_request(self, f)
|
||||
f.reply()
|
||||
|
||||
@controller.handler
|
||||
def handle_response(self, f):
|
||||
flow.FlowMaster.handle_response(self, f)
|
||||
f.reply()
|
||||
|
||||
def clear_log(self):
|
||||
self.log = []
|
||||
|
Loading…
Reference in New Issue
Block a user