A new interface for reply

Reply is now explicit - it's no longer a callable itself. Instead, we have:

    reply.kill()            - kill the flow
    reply.ack()             - ack, but don't send anything
    reply.send(message)     - send a response

This is part of an incremental move to detach reply from our flow objects,
and unify the script and handler interfaces.
This commit is contained in:
Aldo Cortesi 2016-06-08 10:44:20 +12:00
parent 982077ec31
commit a388ddfd78
6 changed files with 28 additions and 40 deletions

View File

@ -16,7 +16,7 @@ def request(context, flow):
"HTTP/1.1", 200, "OK", "HTTP/1.1", 200, "OK",
Headers(Content_Type="text/html"), Headers(Content_Type="text/html"),
"helloworld") "helloworld")
flow.reply(resp) flow.reply.send(resp)
# Method 2: Redirect the request to a different server # Method 2: Redirect the request to a different server
if flow.request.pretty_host.endswith("example.org"): if flow.request.pretty_host.endswith("example.org"):

View File

@ -134,5 +134,5 @@ def next_layer(context, next_layer):
# We don't intercept - reply with a pass-through layer and add a "skipped" entry. # We don't intercept - reply with a pass-through layer and add a "skipped" entry.
context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info")
next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False)
next_layer.reply(next_layer_replacement) next_layer.reply.send(next_layer_replacement)
context.tls_strategy.record_skipped(server_address) context.tls_strategy.record_skipped(server_address)

View File

@ -145,10 +145,6 @@ class Channel(object):
self.q.put((mtype, m)) self.q.put((mtype, m))
# Special value to distinguish the case where no reply was sent
NO_REPLY = object()
def handler(f): def handler(f):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
@ -199,21 +195,18 @@ class Reply(object):
self.handled = False self.handled = False
def ack(self): def ack(self):
self(NO_REPLY) self.send(self.obj)
def kill(self): def kill(self):
self(exceptions.Kill) self.send(exceptions.Kill)
def take(self): def take(self):
self.taken = True self.taken = True
def __call__(self, msg=NO_REPLY): def send(self, msg):
if self.acked: if self.acked:
raise exceptions.ControlException("Message already acked.") raise exceptions.ControlException("Message already acked.")
self.acked = True self.acked = True
if msg is NO_REPLY:
self.q.put(self.obj)
else:
self.q.put(msg) self.q.put(msg)
def __del__(self): def __del__(self):
@ -233,13 +226,13 @@ class DummyReply(object):
self.handled = False self.handled = False
def kill(self): def kill(self):
self() self.send(None)
def ack(self): def ack(self):
self() self.send(None)
def take(self): def take(self):
self.taken = True self.taken = True
def __call__(self, msg=False): def send(self, msg):
self.acked = True self.acked = True

View File

@ -4,6 +4,7 @@ offload computations from mitmproxy's main master thread.
""" """
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from mitmproxy import controller
import threading import threading
@ -14,15 +15,15 @@ class ReplyProxy(object):
self.script_thread = script_thread self.script_thread = script_thread
self.master_reply = None self.master_reply = None
def __call__(self, *args): def send(self, message):
if self.master_reply is None: if self.master_reply is None:
self.master_reply = args self.master_reply = message
self.script_thread.start() self.script_thread.start()
return return
self.reply_func(*args) self.reply_func(message)
def done(self): def done(self):
self.reply_func(*self.master_reply) self.reply_func.send(self.master_reply)
def __getattr__(self, k): def __getattr__(self, k):
return getattr(self.reply_func, k) return getattr(self.reply_func, k)
@ -49,17 +50,11 @@ class ScriptThread(threading.Thread):
def concurrent(fn): def concurrent(fn):
if fn.__name__ in ( if fn.__name__ not in controller.Events:
"request", raise NotImplementedError(
"response", "Concurrent decorator not supported for '%s' method." % fn.__name__
"error", )
"clientconnect",
"serverconnect",
"clientdisconnect",
"next_layer"):
def _concurrent(ctx, obj): def _concurrent(ctx, obj):
_handle_concurrent_reply(fn, obj, ctx, obj) _handle_concurrent_reply(fn, obj, ctx, obj)
return _concurrent return _concurrent
raise NotImplementedError(
"Concurrent decorator not supported for '%s' method." % fn.__name__)

View File

@ -66,7 +66,7 @@ class TestChannel(object):
def reply(): def reply():
m, obj = q.get() m, obj = q.get()
assert m == "test" assert m == "test"
obj.reply(42) obj.reply.send(42)
Thread(target=reply).start() Thread(target=reply).start()
@ -86,7 +86,7 @@ class TestDummyReply(object):
def test_simple(self): def test_simple(self):
reply = controller.DummyReply() reply = controller.DummyReply()
assert not reply.acked assert not reply.acked
reply() reply.ack()
assert reply.acked assert reply.acked
@ -94,16 +94,16 @@ class TestReply(object):
def test_simple(self): def test_simple(self):
reply = controller.Reply(42) reply = controller.Reply(42)
assert not reply.acked assert not reply.acked
reply("foo") reply.send("foo")
assert reply.acked assert reply.acked
assert reply.q.get() == "foo" assert reply.q.get() == "foo"
def test_default(self): def test_default(self):
reply = controller.Reply(42) reply = controller.Reply(42)
reply() reply.ack()
assert reply.q.get() == 42 assert reply.q.get() == 42
def test_reply_none(self): def test_reply_none(self):
reply = controller.Reply(42) reply = controller.Reply(42)
reply(None) reply.send(None)
assert reply.q.get() is None assert reply.q.get() is None

View File

@ -743,7 +743,7 @@ class MasterFakeResponse(tservers.TestMaster):
@controller.handler @controller.handler
def request(self, f): def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp = HTTPResponse.wrap(netlib.tutils.tresp())
f.reply(resp) f.reply.send(resp)
class TestFakeResponse(tservers.HTTPProxyTest): class TestFakeResponse(tservers.HTTPProxyTest):
@ -819,7 +819,7 @@ class MasterIncomplete(tservers.TestMaster):
def request(self, f): def request(self, f):
resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp = HTTPResponse.wrap(netlib.tutils.tresp())
resp.content = None resp.content = None
f.reply(resp) f.reply.send(resp)
class TestIncompleteResponse(tservers.HTTPProxyTest): class TestIncompleteResponse(tservers.HTTPProxyTest):