From 45eab17e0c35b9527dd8f68364fa577c61f33551 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 4 Jan 2014 14:42:32 +1300 Subject: [PATCH] Decouple message type from message class name. --- libmproxy/controller.py | 18 +++++++++--------- libmproxy/proxy.py | 20 ++++++++++---------- test/test_controller.py | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/libmproxy/controller.py b/libmproxy/controller.py index 62d1dbbcb..b662b6d5e 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -39,13 +39,13 @@ class Channel: def __init__(self, q): self.q = q - def ask(self, m): + def ask(self, mtype, m): """ Decorate a message with a reply attribute, and send it to the master. then wait for a response. """ m.reply = Reply(m) - self.q.put(m) + self.q.put((mtype, m)) while not should_exit: try: # The timeout is here so we can handle a should_exit event. @@ -54,13 +54,13 @@ class Channel: continue return g - def tell(self, m): + def tell(self, mtype, m): """ Decorate a message with a dummy reply attribute, send it to the master, then return immediately. """ m.reply = DummyReply() - self.q.put(m) + self.q.put((mtype, m)) class Slave(threading.Thread): @@ -98,7 +98,7 @@ class Master: while True: # Small timeout to prevent pegging the CPU msg = q.get(timeout=0.01) - self.handle(msg) + self.handle(*msg) changed = True except Queue.Empty: pass @@ -112,13 +112,13 @@ class Master: self.tick(self.masterq) self.shutdown() - def handle(self, msg): - c = "handle_" + msg.__class__.__name__.lower() + def handle(self, mtype, obj): + c = "handle_" + mtype m = getattr(self, c, None) if m: - m(msg) + m(obj) else: - msg.reply() + obj.reply() def shutdown(self): global should_exit diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 9b300aa18..7f39a5c53 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -97,10 +97,10 @@ class RequestReplayThread(threading.Thread): self.flow.request, httpversion, code, msg, headers, content, server.cert, server.rfile.first_byte_timestamp ) - self.channel.ask(response) + self.channel.ask("response", response) except (ProxyError, http.HttpError, tcp.NetLibError), v: err = flow.Error(self.flow.request, str(v)) - self.channel.ask(err) + self.channel.ask("error", err) class HandleSNI: @@ -173,7 +173,7 @@ class ProxyHandler(tcp.BaseHandler): self.server_conn.require_request = False self.server_conn.conn_info = conn_info - self.channel.ask(self.server_conn) + self.channel.ask("serverconnect", self.server_conn) self.server_conn.connect() except tcp.NetLibError, v: raise ProxyError(502, v) @@ -187,7 +187,7 @@ class ProxyHandler(tcp.BaseHandler): def handle(self): cc = flow.ClientConnect(self.client_address) self.log(cc, "connect") - self.channel.ask(cc) + self.channel.ask("clientconnect", cc) while self.handle_request(cc) and not cc.close: pass cc.close = True @@ -199,7 +199,7 @@ class ProxyHandler(tcp.BaseHandler): [ "handled %s requests"%cc.requestcount] ) - self.channel.tell(cd) + self.channel.tell("clientdisconnect", cd) def handle_request(self, cc): try: @@ -209,13 +209,13 @@ class ProxyHandler(tcp.BaseHandler): return cc.requestcount += 1 - request_reply = self.channel.ask(request) + request_reply = self.channel.ask("request", request) if request_reply is None or request_reply == KILL: return elif isinstance(request_reply, flow.Response): request = False response = request_reply - response_reply = self.channel.ask(response) + response_reply = self.channel.ask("response", response) else: request = request_reply if self.config.reverse_proxy: @@ -261,7 +261,7 @@ class ProxyHandler(tcp.BaseHandler): request, httpversion, code, msg, headers, content, sc.cert, sc.rfile.first_byte_timestamp ) - response_reply = self.channel.ask(response) + response_reply = self.channel.ask("response", response) # Not replying to the server invalidates the server # connection, so we terminate. if response_reply == KILL: @@ -288,7 +288,7 @@ class ProxyHandler(tcp.BaseHandler): if request: err = flow.Error(request, cc.error) - self.channel.ask(err) + self.channel.ask("error", err) self.log( cc, cc.error, ["url: %s"%request.get_url()] @@ -308,7 +308,7 @@ class ProxyHandler(tcp.BaseHandler): msg.append(" -> "+i) msg = "\n".join(msg) l = Log(msg) - self.channel.tell(l) + self.channel.tell("log", l) def find_cert(self, cc, host, port, sni): if self.config.certfile: diff --git a/test/test_controller.py b/test/test_controller.py index f6d6b5eb3..e71a148eb 100644 --- a/test/test_controller.py +++ b/test/test_controller.py @@ -6,7 +6,7 @@ class TestMaster: def test_default_handler(self): m = controller.Master(None) msg = mock.MagicMock() - m.handle(msg) + m.handle("type", msg) assert msg.reply.call_count == 1