mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 10:16:27 +00:00
Significantly refactor the master/slave message passing interface.
This commit is contained in:
parent
1ccb2c5dea
commit
aaf892e3af
@ -580,7 +580,7 @@ class ConsoleMaster(flow.FlowMaster):
|
||||
|
||||
self.view_flowlist()
|
||||
|
||||
self.server.start_slave(controller.Slave, self.masterq)
|
||||
self.server.start_slave(controller.Slave, controller.Channel(self.masterq))
|
||||
|
||||
if self.options.rfile:
|
||||
ret = self.load_flows(self.options.rfile)
|
||||
@ -1002,7 +1002,7 @@ class ConsoleMaster(flow.FlowMaster):
|
||||
if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay():
|
||||
f.intercept()
|
||||
else:
|
||||
r._ack()
|
||||
r.reply()
|
||||
self.sync_list_view()
|
||||
self.refresh_flow(f)
|
||||
|
||||
@ -1023,7 +1023,7 @@ class ConsoleMaster(flow.FlowMaster):
|
||||
# Handlers
|
||||
def handle_log(self, l):
|
||||
self.add_event(l.msg)
|
||||
l._ack()
|
||||
l.reply()
|
||||
|
||||
def handle_error(self, r):
|
||||
f = flow.FlowMaster.handle_error(self, r)
|
||||
|
@ -184,7 +184,7 @@ def format_flow(f, focus, extended=False, padding=2):
|
||||
req_timestamp = f.request.timestamp_start,
|
||||
req_is_replay = f.request.is_replay(),
|
||||
req_method = f.request.method,
|
||||
req_acked = f.request.acked,
|
||||
req_acked = f.request.reply.acked,
|
||||
req_url = f.request.get_url(),
|
||||
|
||||
err_msg = f.error.msg if f.error else None,
|
||||
@ -200,7 +200,7 @@ def format_flow(f, focus, extended=False, padding=2):
|
||||
d.update(dict(
|
||||
resp_code = f.response.code,
|
||||
resp_is_replay = f.response.is_replay(),
|
||||
resp_acked = f.response.acked,
|
||||
resp_acked = f.response.reply.acked,
|
||||
resp_clen = contentdesc
|
||||
))
|
||||
t = f.response.headers["content-type"]
|
||||
|
@ -17,37 +17,73 @@ import Queue, threading
|
||||
|
||||
should_exit = False
|
||||
|
||||
class Msg:
|
||||
|
||||
class DummyReply:
|
||||
"""
|
||||
A reply object that does nothing. Useful when we need an object to seem
|
||||
like it has a channel, and during testing.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.acked = False
|
||||
|
||||
def __call__(self, msg=False):
|
||||
self.acked = True
|
||||
|
||||
|
||||
class Reply:
|
||||
"""
|
||||
Messages sent through a channel are decorated with a "reply" attribute.
|
||||
This object is used to respond to the message through the return
|
||||
channel.
|
||||
"""
|
||||
def __init__(self, obj):
|
||||
self.obj = obj
|
||||
self.q = Queue.Queue()
|
||||
self.acked = False
|
||||
|
||||
def _ack(self, data=False):
|
||||
def __call__(self, msg=False):
|
||||
if not self.acked:
|
||||
self.acked = True
|
||||
if data is None:
|
||||
self.q.put(data)
|
||||
if msg is None:
|
||||
self.q.put(msg)
|
||||
else:
|
||||
self.q.put(data or self)
|
||||
self.q.put(msg or self.obj)
|
||||
|
||||
def _send(self, masterq):
|
||||
self.acked = False
|
||||
try:
|
||||
masterq.put(self, timeout=3)
|
||||
while not should_exit: # pragma: no cover
|
||||
try:
|
||||
g = self.q.get(timeout=0.5)
|
||||
except Queue.Empty:
|
||||
continue
|
||||
return g
|
||||
except (Queue.Empty, Queue.Full): # pragma: no cover
|
||||
return None
|
||||
|
||||
class Channel:
|
||||
def __init__(self, q):
|
||||
self.q = q
|
||||
|
||||
def ask(self, m):
|
||||
"""
|
||||
Send a message to the master, and wait for a response.
|
||||
"""
|
||||
m.reply = Reply(m)
|
||||
self.q.put(m)
|
||||
while not should_exit:
|
||||
try:
|
||||
# The timeout is here so we can handle a should_exit event.
|
||||
g = m.reply.q.get(timeout=0.5)
|
||||
except Queue.Empty:
|
||||
continue
|
||||
return g
|
||||
|
||||
def tell(self, m):
|
||||
"""
|
||||
Send a message to the master, and keep going.
|
||||
"""
|
||||
m.reply = None
|
||||
self.q.put(m)
|
||||
|
||||
|
||||
class Slave(threading.Thread):
|
||||
def __init__(self, masterq, server):
|
||||
self.masterq, self.server = masterq, server
|
||||
self.server.set_mqueue(masterq)
|
||||
"""
|
||||
Slaves get a channel end-point through which they can send messages to
|
||||
the master.
|
||||
"""
|
||||
def __init__(self, channel, server):
|
||||
self.channel, self.server = channel, server
|
||||
self.server.set_channel(channel)
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
@ -55,6 +91,9 @@ class Slave(threading.Thread):
|
||||
|
||||
|
||||
class Master:
|
||||
"""
|
||||
Masters get and respond to messages from slaves.
|
||||
"""
|
||||
def __init__(self, server):
|
||||
"""
|
||||
server may be None if no server is needed.
|
||||
@ -81,18 +120,18 @@ class Master:
|
||||
def run(self):
|
||||
global should_exit
|
||||
should_exit = False
|
||||
self.server.start_slave(Slave, self.masterq)
|
||||
self.server.start_slave(Slave, Channel(self.masterq))
|
||||
while not should_exit:
|
||||
self.tick(self.masterq)
|
||||
self.shutdown()
|
||||
|
||||
def handle(self, msg): # pragma: no cover
|
||||
def handle(self, msg):
|
||||
c = "handle_" + msg.__class__.__name__.lower()
|
||||
m = getattr(self, c, None)
|
||||
if m:
|
||||
m(msg)
|
||||
else:
|
||||
msg._ack()
|
||||
msg.reply()
|
||||
|
||||
def shutdown(self):
|
||||
global should_exit
|
||||
|
@ -150,16 +150,6 @@ class DumpMaster(flow.FlowMaster):
|
||||
print >> self.outfile, e
|
||||
self.outfile.flush()
|
||||
|
||||
def handle_log(self, l):
|
||||
self.add_event(l.msg)
|
||||
l._ack()
|
||||
|
||||
def handle_request(self, r):
|
||||
f = flow.FlowMaster.handle_request(self, r)
|
||||
if f:
|
||||
r._ack()
|
||||
return f
|
||||
|
||||
def indent(self, n, t):
|
||||
l = str(t).strip().split("\n")
|
||||
return "\n".join(" "*n + i for i in l)
|
||||
@ -210,10 +200,20 @@ class DumpMaster(flow.FlowMaster):
|
||||
self.outfile.flush()
|
||||
self.state.delete_flow(f)
|
||||
|
||||
def handle_log(self, l):
|
||||
self.add_event(l.msg)
|
||||
l.reply()
|
||||
|
||||
def handle_request(self, r):
|
||||
f = flow.FlowMaster.handle_request(self, r)
|
||||
if f:
|
||||
r.reply()
|
||||
return f
|
||||
|
||||
def handle_response(self, msg):
|
||||
f = flow.FlowMaster.handle_response(self, msg)
|
||||
if f:
|
||||
msg._ack()
|
||||
msg.reply()
|
||||
self._process_flow(f)
|
||||
return f
|
||||
|
||||
|
@ -196,7 +196,7 @@ class decoded(object):
|
||||
self.o.encode(self.ce)
|
||||
|
||||
|
||||
class HTTPMsg(controller.Msg):
|
||||
class HTTPMsg:
|
||||
def get_decoded_content(self):
|
||||
"""
|
||||
Returns the decoded content based on the current Content-Encoding header.
|
||||
@ -252,6 +252,7 @@ class HTTPMsg(controller.Msg):
|
||||
return 0
|
||||
return len(self.content)
|
||||
|
||||
|
||||
class Request(HTTPMsg):
|
||||
"""
|
||||
An HTTP request.
|
||||
@ -289,7 +290,6 @@ class Request(HTTPMsg):
|
||||
self.timestamp_start = timestamp_start or utils.timestamp()
|
||||
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
|
||||
self.close = False
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
# Have this request's cookies been modified by sticky cookies or auth?
|
||||
self.stickycookie = False
|
||||
@ -396,7 +396,6 @@ class Request(HTTPMsg):
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
c = copy.copy(self)
|
||||
c.acked = True
|
||||
c.headers = self.headers.copy()
|
||||
return c
|
||||
|
||||
@ -603,7 +602,6 @@ class Response(HTTPMsg):
|
||||
self.cert = cert
|
||||
self.timestamp_start = timestamp_start or utils.timestamp()
|
||||
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
|
||||
controller.Msg.__init__(self)
|
||||
self.replay = False
|
||||
|
||||
def _refresh_cookie(self, c, delta):
|
||||
@ -708,7 +706,6 @@ class Response(HTTPMsg):
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
c = copy.copy(self)
|
||||
c.acked = True
|
||||
c.headers = self.headers.copy()
|
||||
return c
|
||||
|
||||
@ -773,7 +770,7 @@ class Response(HTTPMsg):
|
||||
cookies.append((cookie_name, (cookie_value, cookie_parameters)))
|
||||
return dict(cookies)
|
||||
|
||||
class ClientDisconnect(controller.Msg):
|
||||
class ClientDisconnect:
|
||||
"""
|
||||
A client disconnection event.
|
||||
|
||||
@ -782,11 +779,10 @@ class ClientDisconnect(controller.Msg):
|
||||
client_conn: ClientConnect object.
|
||||
"""
|
||||
def __init__(self, client_conn):
|
||||
controller.Msg.__init__(self)
|
||||
self.client_conn = client_conn
|
||||
|
||||
|
||||
class ClientConnect(controller.Msg):
|
||||
class ClientConnect:
|
||||
"""
|
||||
A single client connection. Each connection can result in multiple HTTP
|
||||
Requests.
|
||||
@ -807,7 +803,6 @@ class ClientConnect(controller.Msg):
|
||||
self.close = False
|
||||
self.requestcount = 0
|
||||
self.error = None
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self._get_state() == other._get_state()
|
||||
@ -838,11 +833,10 @@ class ClientConnect(controller.Msg):
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
c = copy.copy(self)
|
||||
c.acked = True
|
||||
return c
|
||||
|
||||
|
||||
class Error(controller.Msg):
|
||||
class Error:
|
||||
"""
|
||||
An Error.
|
||||
|
||||
@ -860,7 +854,6 @@ class Error(controller.Msg):
|
||||
def __init__(self, request, msg, timestamp=None):
|
||||
self.request, self.msg = request, msg
|
||||
self.timestamp = timestamp or utils.timestamp()
|
||||
controller.Msg.__init__(self)
|
||||
|
||||
def _load_state(self, state):
|
||||
self.msg = state["msg"]
|
||||
@ -871,7 +864,6 @@ class Error(controller.Msg):
|
||||
Returns a copy of this object.
|
||||
"""
|
||||
c = copy.copy(self)
|
||||
c.acked = True
|
||||
return c
|
||||
|
||||
def _get_state(self):
|
||||
@ -1180,10 +1172,11 @@ class Flow:
|
||||
Kill this request.
|
||||
"""
|
||||
self.error = Error(self.request, "Connection killed")
|
||||
if self.request and not self.request.acked:
|
||||
self.request._ack(None)
|
||||
elif self.response and not self.response.acked:
|
||||
self.response._ack(None)
|
||||
self.error.reply = controller.DummyReply()
|
||||
if self.request and not self.request.reply.acked:
|
||||
self.request.reply(None)
|
||||
elif self.response and not self.response.reply.acked:
|
||||
self.response.reply(None)
|
||||
master.handle_error(self.error)
|
||||
self.intercepting = False
|
||||
|
||||
@ -1199,10 +1192,10 @@ class Flow:
|
||||
Continue with the flow - called after an intercept().
|
||||
"""
|
||||
if self.request:
|
||||
if not self.request.acked:
|
||||
self.request._ack()
|
||||
elif self.response and not self.response.acked:
|
||||
self.response._ack()
|
||||
if not self.request.reply.acked:
|
||||
self.request.reply()
|
||||
elif self.response and not self.response.reply.acked:
|
||||
self.response.reply()
|
||||
self.intercepting = False
|
||||
|
||||
def replace(self, pattern, repl, *args, **kwargs):
|
||||
@ -1464,7 +1457,7 @@ class FlowMaster(controller.Master):
|
||||
flow.response = response
|
||||
if self.refresh_server_playback:
|
||||
response.refresh()
|
||||
flow.request._ack(response)
|
||||
flow.request.reply(response)
|
||||
if self.server_playback.count() == 0:
|
||||
self.stop_server_playback()
|
||||
return True
|
||||
@ -1491,10 +1484,13 @@ class FlowMaster(controller.Master):
|
||||
Loads a flow, and returns a new flow object.
|
||||
"""
|
||||
if f.request:
|
||||
f.request.reply = controller.DummyReply()
|
||||
fr = self.handle_request(f.request)
|
||||
if f.response:
|
||||
f.response.reply = controller.DummyReply()
|
||||
self.handle_response(f.response)
|
||||
if f.error:
|
||||
f.error.reply = controller.DummyReply()
|
||||
self.handle_error(f.error)
|
||||
return fr
|
||||
|
||||
@ -1522,7 +1518,7 @@ class FlowMaster(controller.Master):
|
||||
if self.kill_nonreplay:
|
||||
f.kill(self)
|
||||
else:
|
||||
f.request._ack()
|
||||
f.request.reply()
|
||||
|
||||
def process_new_response(self, f):
|
||||
if self.stickycookie_state:
|
||||
@ -1561,11 +1557,11 @@ class FlowMaster(controller.Master):
|
||||
|
||||
def handle_clientconnect(self, cc):
|
||||
self.run_script_hook("clientconnect", cc)
|
||||
cc._ack()
|
||||
cc.reply()
|
||||
|
||||
def handle_clientdisconnect(self, r):
|
||||
self.run_script_hook("clientdisconnect", r)
|
||||
r._ack()
|
||||
r.reply()
|
||||
|
||||
def handle_error(self, r):
|
||||
f = self.state.add_error(r)
|
||||
@ -1573,7 +1569,7 @@ class FlowMaster(controller.Master):
|
||||
self.run_script_hook("error", f)
|
||||
if self.client_playback:
|
||||
self.client_playback.clear(f)
|
||||
r._ack()
|
||||
r.reply()
|
||||
return f
|
||||
|
||||
def handle_request(self, r):
|
||||
@ -1596,7 +1592,7 @@ class FlowMaster(controller.Master):
|
||||
if self.stream:
|
||||
self.stream.add(f)
|
||||
else:
|
||||
r._ack()
|
||||
r.reply()
|
||||
return f
|
||||
|
||||
def shutdown(self):
|
||||
|
@ -29,9 +29,8 @@ class ProxyError(Exception):
|
||||
return "ProxyError(%s, %s)"%(self.code, self.msg)
|
||||
|
||||
|
||||
class Log(controller.Msg):
|
||||
class Log:
|
||||
def __init__(self, msg):
|
||||
controller.Msg.__init__(self)
|
||||
self.msg = msg
|
||||
|
||||
|
||||
@ -51,7 +50,7 @@ class ProxyConfig:
|
||||
|
||||
class RequestReplayThread(threading.Thread):
|
||||
def __init__(self, config, flow, masterq):
|
||||
self.config, self.flow, self.masterq = config, flow, masterq
|
||||
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
@ -66,10 +65,10 @@ class RequestReplayThread(threading.Thread):
|
||||
response = flow.Response(
|
||||
self.flow.request, httpversion, code, msg, headers, content, server.cert
|
||||
)
|
||||
response._send(self.masterq)
|
||||
self.channel.ask(response)
|
||||
except (ProxyError, http.HttpError, tcp.NetLibError), v:
|
||||
err = flow.Error(self.flow.request, str(v))
|
||||
err._send(self.masterq)
|
||||
self.channel.ask(err)
|
||||
|
||||
|
||||
class ServerConnection(tcp.TCPClient):
|
||||
@ -128,8 +127,8 @@ class ServerConnectionPool:
|
||||
|
||||
|
||||
class ProxyHandler(tcp.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, mqueue, server_version):
|
||||
self.mqueue, self.server_version = mqueue, server_version
|
||||
def __init__(self, config, connection, client_address, server, channel, server_version):
|
||||
self.channel, self.server_version = channel, server_version
|
||||
self.config = config
|
||||
self.server_conn_pool = ServerConnectionPool(config)
|
||||
self.proxy_connect_state = None
|
||||
@ -139,18 +138,18 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
self.log(cc, "connect")
|
||||
cc._send(self.mqueue)
|
||||
self.channel.ask(cc)
|
||||
while self.handle_request(cc) and not cc.close:
|
||||
pass
|
||||
cc.close = True
|
||||
cd = flow.ClientDisconnect(cc)
|
||||
|
||||
cd = flow.ClientDisconnect(cc)
|
||||
self.log(
|
||||
cc, "disconnect",
|
||||
[
|
||||
"handled %s requests"%cc.requestcount]
|
||||
)
|
||||
cd._send(self.mqueue)
|
||||
self.channel.ask(cd)
|
||||
|
||||
def handle_request(self, cc):
|
||||
try:
|
||||
@ -167,14 +166,14 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
self.log(cc, "Error in wsgi app.", err.split("\n"))
|
||||
return
|
||||
else:
|
||||
request = request._send(self.mqueue)
|
||||
request = self.channel.ask(request)
|
||||
if request is None:
|
||||
return
|
||||
|
||||
if isinstance(request, flow.Response):
|
||||
response = request
|
||||
request = False
|
||||
response = response._send(self.mqueue)
|
||||
response = self.channel.ask(response)
|
||||
else:
|
||||
if self.config.reverse_proxy:
|
||||
scheme, host, port = self.config.reverse_proxy
|
||||
@ -192,7 +191,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
request, httpversion, code, msg, headers, content, sc.cert,
|
||||
sc.rfile.first_byte_timestamp, utils.timestamp()
|
||||
)
|
||||
response = response._send(self.mqueue)
|
||||
response = self.channel.ask(response)
|
||||
if response is None:
|
||||
sc.terminate()
|
||||
if response is None:
|
||||
@ -214,7 +213,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
|
||||
if request:
|
||||
err = flow.Error(request, cc.error)
|
||||
err._send(self.mqueue)
|
||||
self.channel.ask(err)
|
||||
self.log(
|
||||
cc, cc.error,
|
||||
["url: %s"%request.get_url()]
|
||||
@ -235,7 +234,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
msg.append(" -> "+i)
|
||||
msg = "\n".join(msg)
|
||||
l = Log(msg)
|
||||
l._send(self.mqueue)
|
||||
self.channel.ask(l)
|
||||
|
||||
def find_cert(self, host, port, sni):
|
||||
if self.config.certfile:
|
||||
@ -438,18 +437,18 @@ class ProxyServer(tcp.TCPServer):
|
||||
tcp.TCPServer.__init__(self, (address, port))
|
||||
except socket.error, v:
|
||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||
self.masterq = None
|
||||
self.channel = None
|
||||
self.apps = AppRegistry()
|
||||
|
||||
def start_slave(self, klass, masterq):
|
||||
slave = klass(masterq, self)
|
||||
def start_slave(self, klass, channel):
|
||||
slave = klass(channel, self)
|
||||
slave.start()
|
||||
|
||||
def set_mqueue(self, q):
|
||||
self.masterq = q
|
||||
def set_channel(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def handle_connection(self, request, client_address):
|
||||
h = ProxyHandler(self.config, request, client_address, self, self.masterq, self.server_version)
|
||||
h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version)
|
||||
h.handle()
|
||||
try:
|
||||
h.finish()
|
||||
@ -487,7 +486,7 @@ class DummyServer:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def start_slave(self, klass, masterq):
|
||||
def start_slave(self, klass, channel):
|
||||
pass
|
||||
|
||||
def shutdown(self):
|
||||
|
@ -3,6 +3,7 @@ from cStringIO import StringIO
|
||||
import libpry
|
||||
from libmproxy import dump, flow, proxy
|
||||
import tutils
|
||||
import mock
|
||||
|
||||
def test_strfuncs():
|
||||
t = tutils.tresp()
|
||||
@ -21,6 +22,7 @@ class TestDumpMaster:
|
||||
req = tutils.treq()
|
||||
req.content = content
|
||||
l = proxy.Log("connect")
|
||||
l.reply = mock.MagicMock()
|
||||
m.handle_log(l)
|
||||
cc = req.client_conn
|
||||
cc.connection_error = "error"
|
||||
@ -29,7 +31,9 @@ class TestDumpMaster:
|
||||
m.handle_clientconnect(cc)
|
||||
m.handle_request(req)
|
||||
f = m.handle_response(resp)
|
||||
m.handle_clientdisconnect(flow.ClientDisconnect(cc))
|
||||
cd = flow.ClientDisconnect(cc)
|
||||
cd.reply = mock.MagicMock()
|
||||
m.handle_clientdisconnect(cd)
|
||||
return f
|
||||
|
||||
def _dummy_cycle(self, n, filt, content, **options):
|
||||
|
@ -223,16 +223,16 @@ class TestFlow:
|
||||
f = tutils.tflow()
|
||||
f.request = tutils.treq()
|
||||
f.intercept()
|
||||
assert not f.request.acked
|
||||
assert not f.request.reply.acked
|
||||
f.kill(fm)
|
||||
assert f.request.acked
|
||||
assert f.request.reply.acked
|
||||
f.intercept()
|
||||
f.response = tutils.tresp()
|
||||
f.request = f.response.request
|
||||
f.request._ack()
|
||||
assert not f.response.acked
|
||||
f.request.reply()
|
||||
assert not f.response.reply.acked
|
||||
f.kill(fm)
|
||||
assert f.response.acked
|
||||
assert f.response.reply.acked
|
||||
|
||||
def test_killall(self):
|
||||
s = flow.State()
|
||||
@ -245,25 +245,25 @@ class TestFlow:
|
||||
fm.handle_request(r)
|
||||
|
||||
for i in s.view:
|
||||
assert not i.request.acked
|
||||
assert not i.request.reply.acked
|
||||
s.killall(fm)
|
||||
for i in s.view:
|
||||
assert i.request.acked
|
||||
assert i.request.reply.acked
|
||||
|
||||
def test_accept_intercept(self):
|
||||
f = tutils.tflow()
|
||||
f.request = tutils.treq()
|
||||
f.intercept()
|
||||
assert not f.request.acked
|
||||
assert not f.request.reply.acked
|
||||
f.accept_intercept()
|
||||
assert f.request.acked
|
||||
assert f.request.reply.acked
|
||||
f.response = tutils.tresp()
|
||||
f.request = f.response.request
|
||||
f.intercept()
|
||||
f.request._ack()
|
||||
assert not f.response.acked
|
||||
f.request.reply()
|
||||
assert not f.response.reply.acked
|
||||
f.accept_intercept()
|
||||
assert f.response.acked
|
||||
assert f.response.reply.acked
|
||||
|
||||
def test_serialization(self):
|
||||
f = flow.Flow(None)
|
||||
@ -562,9 +562,11 @@ class TestFlowMaster:
|
||||
fm.handle_response(resp)
|
||||
assert fm.script.ns["log"][-1] == "response"
|
||||
dc = flow.ClientDisconnect(req.client_conn)
|
||||
dc.reply = controller.DummyReply()
|
||||
fm.handle_clientdisconnect(dc)
|
||||
assert fm.script.ns["log"][-1] == "clientdisconnect"
|
||||
err = flow.Error(f.request, "msg")
|
||||
err.reply = controller.DummyReply()
|
||||
fm.handle_error(err)
|
||||
assert fm.script.ns["log"][-1] == "error"
|
||||
|
||||
@ -598,10 +600,12 @@ class TestFlowMaster:
|
||||
assert not fm.handle_response(rx)
|
||||
|
||||
dc = flow.ClientDisconnect(req.client_conn)
|
||||
dc.reply = controller.DummyReply()
|
||||
req.client_conn.requestcount = 1
|
||||
fm.handle_clientdisconnect(dc)
|
||||
|
||||
err = flow.Error(f.request, "msg")
|
||||
err.reply = controller.DummyReply()
|
||||
fm.handle_error(err)
|
||||
|
||||
fm.load_script(tutils.test_data.path("scripts/a.py"))
|
||||
@ -621,7 +625,9 @@ class TestFlowMaster:
|
||||
fm.tick(q)
|
||||
assert fm.state.flow_count()
|
||||
|
||||
fm.handle_error(flow.Error(f.request, "error"))
|
||||
err = flow.Error(f.request, "error")
|
||||
err.reply = controller.DummyReply()
|
||||
fm.handle_error(err)
|
||||
|
||||
def test_server_playback(self):
|
||||
controller.should_exit = False
|
||||
|
@ -31,7 +31,7 @@ class TestMaster(flow.FlowMaster):
|
||||
|
||||
def handle(self, m):
|
||||
flow.FlowMaster.handle(self, m)
|
||||
m._ack()
|
||||
m.reply()
|
||||
|
||||
|
||||
class ProxyThread(threading.Thread):
|
||||
|
@ -1,15 +1,18 @@
|
||||
import os, shutil, tempfile
|
||||
from contextlib import contextmanager
|
||||
from libmproxy import flow, utils
|
||||
from libmproxy import flow, utils, controller
|
||||
from netlib import certutils
|
||||
|
||||
import mock
|
||||
|
||||
def treq(conn=None):
|
||||
if not conn:
|
||||
conn = flow.ClientConnect(("address", 22))
|
||||
conn.reply = controller.DummyReply()
|
||||
headers = flow.ODictCaseless()
|
||||
headers["header"] = ["qvalue"]
|
||||
return flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
|
||||
r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
|
||||
r.reply = controller.DummyReply()
|
||||
return r
|
||||
|
||||
|
||||
def tresp(req=None):
|
||||
@ -18,7 +21,9 @@ def tresp(req=None):
|
||||
headers = flow.ODictCaseless()
|
||||
headers["header_response"] = ["svalue"]
|
||||
cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert")).read())
|
||||
return flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
|
||||
resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
|
||||
resp.reply = controller.DummyReply()
|
||||
return resp
|
||||
|
||||
|
||||
def tflow():
|
||||
@ -37,9 +42,11 @@ def tflow_err():
|
||||
r = treq()
|
||||
f = flow.Flow(r)
|
||||
f.error = flow.Error(r, "error")
|
||||
f.error.reply = controller.DummyReply()
|
||||
return f
|
||||
|
||||
|
||||
|
||||
@contextmanager
|
||||
def tmpdir(*args, **kwargs):
|
||||
orig_workdir = os.getcwd()
|
||||
|
Loading…
Reference in New Issue
Block a user