diff --git a/examples/flowbasic b/examples/flowbasic index b8184262f..8dbe2f28b 100755 --- a/examples/flowbasic +++ b/examples/flowbasic @@ -16,16 +16,16 @@ class MyMaster(flow.FlowMaster): except KeyboardInterrupt: self.shutdown() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - r.reply() + f.reply() print f return f diff --git a/examples/proxapp b/examples/proxapp index 3a94cd558..9f299d259 100755 --- a/examples/proxapp +++ b/examples/proxapp @@ -20,16 +20,16 @@ class MyMaster(flow.FlowMaster): except KeyboardInterrupt: self.shutdown() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - r.reply() + f.reply() print f return f diff --git a/libmproxy/app.py b/libmproxy/app.py index 9941d6ea0..ed7ec72a5 100644 --- a/libmproxy/app.py +++ b/libmproxy/app.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import flask -import os.path, os -from . import proxy +import os +from .proxy import config mapp = flask.Flask(__name__) mapp.debug = True @@ -18,12 +18,12 @@ def index(): @mapp.route("/cert/pem") def certs_pem(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.pem") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.pem") return flask.Response(open(p, "rb").read(), mimetype='application/x-x509-ca-cert') @mapp.route("/cert/p12") def certs_p12(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.p12") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.p12") return flask.Response(open(p, "rb").read(), mimetype='application/x-pkcs12') diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index 1325aae59..a59209158 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -268,8 +268,8 @@ class ConsoleState(flow.State): d = self.flowsettings.get(flow, {}) return d.get(key, default) - def add_request(self, req): - f = flow.State.add_request(self, req) + def add_request(self, f): + flow.State.add_request(self, f) if self.focus is None: self.set_focus(0) elif self.follow_focus: @@ -996,11 +996,11 @@ class ConsoleMaster(flow.FlowMaster): if hasattr(self.statusbar, "refresh_flow"): self.statusbar.refresh_flow(c) - def process_flow(self, f, r): + def process_flow(self, f): if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: - r.reply() + f.reply() self.sync_list_view() self.refresh_flow(f) @@ -1022,20 +1022,20 @@ class ConsoleMaster(flow.FlowMaster): self.eventlist.set_focus(len(self.eventlist)-1) # Handlers - def handle_error(self, r): - f = flow.FlowMaster.handle_error(self, r) + def handle_error(self, f): + f = flow.FlowMaster.handle_error(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f diff --git a/libmproxy/dump.py b/libmproxy/dump.py index aeb34cc32..8ecd56e78 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -50,13 +50,13 @@ def str_response(resp): return r -def str_request(req, showhost): - if req.flow.client_conn: - c = req.flow.client_conn.address.host +def str_request(f, showhost): + if f.client_conn: + c = f.client_conn.address.host else: c = "[replay]" - r = "%s %s %s"%(c, req.method, req.get_url(showhost)) - if req.stickycookie: + r = "%s %s %s"%(c, f.request.method, f.request.get_url(showhost, f)) + if f.request.stickycookie: r = "[stickycookie] " + r return r @@ -185,16 +185,16 @@ class DumpMaster(flow.FlowMaster): result = " << %s"%f.error.msg if self.o.flow_detail == 1: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, result elif self.o.flow_detail == 2: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) print >> self.outfile print >> self.outfile, result print >> self.outfile, "\n" elif self.o.flow_detail >= 3: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) if utils.isBin(f.request.content): print >> self.outfile, self.indent(4, netlib.utils.hexdump(f.request.content)) @@ -206,21 +206,21 @@ class DumpMaster(flow.FlowMaster): if self.o.flow_detail: self.outfile.flush() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, msg): - f = flow.FlowMaster.handle_response(self, msg) + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) if f: - msg.reply() + f.reply() self._process_flow(f) return f - def handle_error(self, msg): - f = flow.FlowMaster.handle_error(self, msg) + def handle_error(self, f): + flow.FlowMaster.handle_error(self, f) if f: self._process_flow(f) return f diff --git a/libmproxy/filt.py b/libmproxy/filt.py index e17ed7353..925dbfbba 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -208,7 +208,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.get_host(), re.IGNORECASE)) + return bool(re.search(self.expr, f.request.get_host(False, f), re.IGNORECASE)) class FUrl(_Rex): @@ -222,7 +222,7 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.get_url()) + return re.search(self.expr, f.request.get_url(False, f)) class _Int(_Action): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 2540435ee..eb183d9f5 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -34,11 +34,11 @@ class AppRegistry: """ Returns an WSGIAdaptor instance if request matches an app, or None. """ - if (request.get_host(), request.get_port()) in self.apps: - return self.apps[(request.get_host(), request.get_port())] + if (request.host, request.port) in self.apps: + return self.apps[(request.host, request.port)] if "host" in request.headers: host = request.headers["host"][0] - return self.apps.get((host, request.get_port()), None) + return self.apps.get((host, request.port), None) class ReplaceHooks: @@ -185,11 +185,11 @@ class ClientPlaybackState: n = self.flows.pop(0) n.request.reply = controller.DummyReply() n.client_conn = None - self.current = master.handle_request(n.request) + self.current = master.handle_request(n) if not testing and not self.current.response: - master.replay_request(self.current) # pragma: no cover + master.replay_request(self.current) # pragma: no cover elif self.current.response: - master.handle_response(self.current.response) + master.handle_response(self.current) class ServerPlaybackState: @@ -260,8 +260,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.get_host(), - f.request.get_port(), + m["domain"] or f.request.get_host(False, f), + f.request.get_port(f), m["path"] or "/" ) @@ -279,7 +279,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.get_host(), k[0]): + if self.domain_match(f.request.get_host(False, f), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -287,8 +287,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.get_host(), i[0]), - f.request.get_port() == i[1], + self.domain_match(f.request.get_host(False, f), i[0]), + f.request.get_port(f) == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -307,7 +307,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - host = f.request.get_host() + host = f.request.get_host(False, f) if "authorization" in f.request.headers: self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): @@ -342,33 +342,30 @@ class State(object): c += 1 return c - def add_request(self, req): + def add_request(self, flow): """ Add a request to the state. Returns the matching flow. """ - f = req.flow - self._flow_list.append(f) - if f.match(self._limit): - self.view.append(f) - return f + self._flow_list.append(flow) + if flow.match(self._limit): + self.view.append(flow) + return flow - def add_response(self, resp): + def add_response(self, f): """ Add a response to the state. Returns the matching flow. """ - f = resp.flow if not f: return False if f.match(self._limit) and not f in self.view: self.view.append(f) return f - def add_error(self, err): + def add_error(self, f): """ Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = err.flow if not f: return None if f.match(self._limit) and not f in self.view: @@ -586,7 +583,7 @@ class FlowMaster(controller.Master): response.is_replay = True if self.refresh_server_playback: response.refresh() - flow.request.reply(response) + flow.reply(response) if self.server_playback.count() == 0: self.stop_server_playback() return True @@ -612,16 +609,14 @@ class FlowMaster(controller.Master): """ Loads a flow, and returns a new flow object. """ + f.reply = controller.DummyReply() if f.request: - f.request.reply = controller.DummyReply() - fr = self.handle_request(f.request) + self.handle_request(f) if f.response: - f.response.reply = controller.DummyReply() - self.handle_response(f.response) + self.handle_response(f) if f.error: - f.error.reply = controller.DummyReply() - self.handle_error(f.error) - return fr + self.handle_error(f) + return f def load_flows(self, fr): """ @@ -647,7 +642,7 @@ class FlowMaster(controller.Master): if self.kill_nonreplay: f.kill(self) else: - f.request.reply() + f.reply() def process_new_response(self, f): if self.stickycookie_state: @@ -694,54 +689,49 @@ class FlowMaster(controller.Master): self.run_script_hook("serverconnect", sc) sc.reply() - def handle_error(self, r): - f = self.state.add_error(r) - if f: - self.run_script_hook("error", f) + def handle_error(self, f): + self.state.add_error(f) + self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - r.reply() + f.reply() return f - def handle_request(self, r): - if r.flow.live: - app = self.apps.get(r) + def handle_request(self, f): + if f.live: + app = self.apps.get(f.request) if app: - err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) + err = app.serve(f, f.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") - r.reply(protocol.KILL) + f.reply(protocol.KILL) return - f = self.state.add_request(r) + self.state.add_request(f) self.replacehooks.run(f) self.setheaders.run(f) self.run_script_hook("request", f) self.process_new_request(f) return f - def handle_responseheaders(self, resp): - f = resp.flow + def handle_responseheaders(self, f): self.run_script_hook("responseheaders", f) if self.stream_large_bodies: self.stream_large_bodies.run(f, False) - resp.reply() + f.reply() return f - def handle_response(self, r): - f = self.state.add_response(r) - if f: - self.replacehooks.run(f) - self.setheaders.run(f) - self.run_script_hook("response", f) - if self.client_playback: - self.client_playback.clear(f) - self.process_new_response(f) - if self.stream: - self.stream.add(f) - else: - r.reply() + def handle_response(self, f): + self.state.add_response(f) + self.replacehooks.run(f) + self.setheaders.run(f) + self.run_script_hook("response", f) + if self.client_playback: + self.client_playback.clear(f) + self.process_new_response(f) + if self.stream: + self.stream.add(f) return f def shutdown(self): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 658c08ed9..3f9eecb32 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -77,9 +77,6 @@ class HTTPMessage(stateobject.SimpleStateObject): self.timestamp_start = timestamp_start if timestamp_start is not None else utils.timestamp() self.timestamp_end = timestamp_end if timestamp_end is not None else utils.timestamp() - self.flow = None # will usually be set by the flow backref mixin - """@type: HTTPFlow""" - _stateobject_attributes = dict( httpversion=tuple, headers=ODictCaseless, @@ -346,10 +343,10 @@ class HTTPRequest(HTTPMessage): del headers[k] if headers["Upgrade"] == ["h2c"]: # Suppress HTTP2 https://http2.github.io/http2-spec/index.html#discover-http del headers["Upgrade"] - if not 'host' in headers: + if not 'host' in headers and self.scheme and self.host and self.port: headers["Host"] = [utils.hostport(self.scheme, - self.host or self.flow.server_conn.address.host, - self.port or self.flow.server_conn.address.port)] + self.host, + self.port)] if self.content: headers["Content-Length"] = [str(len(self.content))] @@ -429,16 +426,16 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - def get_path_components(self): + def get_path_components(self, f): """ Returns the path components of the URL as a list of strings. Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) + _, _, path, _, _, _ = urlparse.urlparse(self.get_url(False, f)) return [urllib.unquote(i) for i in path.split("/") if i] - def set_path_components(self, lst): + def set_path_components(self, lst, f): """ Takes a list of strings, and sets the path component of the URL. @@ -446,27 +443,27 @@ class HTTPRequest(HTTPMessage): """ lst = [urllib.quote(i, safe="") for i in lst] path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url(False, f)) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_query(self): + def get_query(self, f): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) + _, _, _, _, query, _ = urlparse.urlparse(self.get_url(False, f)) if query: return ODict(utils.urldecode(query)) return ODict([]) - def set_query(self, odict): + def set_query(self, odict, f): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url(False, f)) query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_host(self, hostheader=False): + def get_host(self, hostheader, flow): """ Heuristic to get the host of the request. @@ -484,16 +481,16 @@ class HTTPRequest(HTTPMessage): if self.host: host = self.host else: - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1]["state"] == "connect": host = s[1]["host"] break if not host: - host = self.flow.server_conn.address.host + host = flow.server_conn.address.host host = host.encode("idna") return host - def get_scheme(self): + def get_scheme(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ @@ -501,20 +498,20 @@ class HTTPRequest(HTTPMessage): return self.scheme if self.form_out == "authority": # On SSLed connections, the original CONNECT request is still unencrypted. return "http" - return "https" if self.flow.server_conn.ssl_established else "http" + return "https" if flow.server_conn.ssl_established else "http" - def get_port(self): + def get_port(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ if self.port: return self.port - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1].get("state") == "connect": return s[1]["port"] - return self.flow.server_conn.address.port + return flow.server_conn.address.port - def get_url(self, hostheader=False): + def get_url(self, hostheader, flow): """ Returns a URL string, constructed from the Request's URL components. @@ -522,13 +519,13 @@ class HTTPRequest(HTTPMessage): Host header to construct the URL. """ if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.get_host(hostheader), self.get_port()) - return utils.unparse_url(self.get_scheme(), - self.get_host(hostheader), - self.get_port(), + return "%s:%s" % (self.get_host(hostheader, flow), self.get_port(flow)) + return utils.unparse_url(self.get_scheme(flow), + self.get_host(hostheader, flow), + self.get_port(flow), self.path).encode('ascii') - def set_url(self, url): + def set_url(self, url, flow): """ Parses a URL specification, and updates the Request's information accordingly. @@ -543,14 +540,14 @@ class HTTPRequest(HTTPMessage): self.path = path - if host != self.get_host() or port != self.get_port(): - if self.flow.live: - self.flow.live.change_server((host, port), ssl=is_ssl) + if host != self.get_host(False, flow) or port != self.get_port(flow): + if flow.live: + flow.live.change_server((host, port), ssl=is_ssl) else: # There's not live server connection, we're just changing the attributes here. - self.flow.server_conn = ServerConnection((host, port), + flow.server_conn = ServerConnection((host, port), proxy.AddressPriority.MANUALLY_CHANGED) - self.flow.server_conn.ssl_established = is_ssl + flow.server_conn.ssl_established = is_ssl # If this is an absolute request, replace the attributes on the request object as well. if self.host: @@ -802,8 +799,6 @@ class HTTPFlow(Flow): self.intercepting = False # FIXME: Should that rather be an attribute of Flow? - _backrefattr = Flow._backrefattr + ("request", "response") - _stateobject_attributes = Flow._stateobject_attributes.copy() _stateobject_attributes.update( request=HTTPRequest, @@ -855,13 +850,10 @@ class HTTPFlow(Flow): Kill this request. """ self.error = Error("Connection killed") - self.error.reply = controller.DummyReply() - if self.request and not self.request.reply.acked: - self.request.reply(KILL) - elif self.response and not self.response.reply.acked: - self.response.reply(KILL) - master.handle_error(self.error) self.intercepting = False + self.reply(KILL) + self.reply = controller.DummyReply() + master.handle_error(self) def intercept(self): """ @@ -874,12 +866,8 @@ class HTTPFlow(Flow): """ Continue with the flow - called after an intercept(). """ - if self.request: - if not self.request.reply.acked: - self.request.reply() - elif self.response and not self.response.reply.acked: - self.response.reply() - self.intercepting = False + self.intercepting = False + self.reply() def replace(self, pattern, repl, *args, **kwargs): """ @@ -961,7 +949,7 @@ class HTTPHandler(ProtocolHandler): # in an Error object that has an attached request that has not been # sent through to the Master. flow.request = req - request_reply = self.c.channel.ask("request", flow.request) + request_reply = self.c.channel.ask("request", flow) self.determine_server_address(flow, flow.request) # The inline script may have changed request.host flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow @@ -976,7 +964,7 @@ class HTTPHandler(ProtocolHandler): flow.response = self.get_response_from_server(flow.request, include_body=False) # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True - self.c.channel.ask("responseheaders", flow.response) + self.c.channel.ask("responseheaders", flow) # now get the rest of the request body, if body still needs to be read but not streaming this response if flow.response.stream: @@ -991,7 +979,7 @@ class HTTPHandler(ProtocolHandler): flow.server_conn = self.c.server_conn self.c.log("response", "debug", [flow.response._assemble_first_line()]) - response_reply = self.c.channel.ask("response", flow.response) + response_reply = self.c.channel.ask("response", flow) if response_reply is None or response_reply == KILL: return False @@ -1079,7 +1067,7 @@ class HTTPHandler(ProtocolHandler): # TODO: no flows without request or with both request and response at the moment. if flow.request and not flow.response: flow.error = Error(message) - self.c.channel.ask("error", flow.error) + self.c.channel.ask("error", flow) try: code = getattr(error, "code", 502) @@ -1204,12 +1192,12 @@ class RequestReplayThread(threading.Thread): except proxy.ProxyError: pass if not server_address: - server_address = (r.get_host(), r.get_port()) + server_address = (r.get_host(False, self.flow), r.get_port(self.flow)) server = ServerConnection(server_address, None) server.connect() - if server_ssl or r.get_scheme() == "https": + if server_ssl or r.get_scheme(self.flow) == "https": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT send_connect_request(server, r.get_host(), r.get_port()) r.form_out = "relative" @@ -1218,9 +1206,9 @@ class RequestReplayThread(threading.Thread): server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) - self.channel.ask("response", self.flow.response) + self.channel.ask("response", self.flow) except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v: self.flow.error = Error(repr(v)) - self.channel.ask("error", self.flow.error) + self.channel.ask("error", self.flow) finally: r.form_out = form_out_backup \ No newline at end of file diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index a227d904d..a84b40614 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -9,24 +9,6 @@ from ..proxy.connection import ClientConnection, ServerConnection KILL = 0 # const for killed requests -class BackreferenceMixin(object): - """ - If an attribute from the _backrefattr tuple is set, - this mixin sets a reference back on the attribute object. - Example: - e = Error() - f = Flow() - f.error = e - assert f is e.flow - """ - _backrefattr = tuple() - - def __setattr__(self, key, value): - super(BackreferenceMixin, self).__setattr__(key, value) - if key in self._backrefattr and value is not None: - setattr(value, self._backrefname, self) - - class Error(stateobject.SimpleStateObject): """ An Error. @@ -70,7 +52,7 @@ class Error(stateobject.SimpleStateObject): return c -class Flow(stateobject.SimpleStateObject, BackreferenceMixin): +class Flow(stateobject.SimpleStateObject): def __init__(self, conntype, client_conn, server_conn, live=None): self.conntype = conntype self.client_conn = client_conn @@ -84,9 +66,6 @@ class Flow(stateobject.SimpleStateObject, BackreferenceMixin): """@type: Error""" self._backup = None - _backrefattr = ("error",) - _backrefname = "flow" - _stateobject_attributes = dict( error=Error, client_conn=ClientConnection, diff --git a/libmproxy/script.py b/libmproxy/script.py index e582c4e85..706d84d5b 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -125,13 +125,8 @@ def _handle_concurrent_reply(fn, o, *args, **kwargs): def concurrent(fn): - if fn.func_name in ["request", "response", "error"]: - def _concurrent(ctx, flow): - r = getattr(flow, fn.func_name) - _handle_concurrent_reply(fn, r, ctx, flow) - return _concurrent - elif fn.func_name in ["clientconnect", "serverconnect", "clientdisconnect"]: - def _concurrent(ctx, conn): - _handle_concurrent_reply(fn, conn, ctx, conn) + if fn.func_name in ("request", "response", "error", "clientconnect", "serverconnect", "clientdisconnect"): + def _concurrent(ctx, obj): + _handle_concurrent_reply(fn, obj, ctx, obj) return _concurrent raise NotImplementedError("Concurrent decorator not supported for this method.") diff --git a/test/test_console.py b/test/test_console.py index 0c5b45914..3b6c941d3 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -51,20 +51,20 @@ class TestConsoleState: assert c.get_focus() == (None, None) def _add_request(self, state): - r = tutils.treq() - return state.add_request(r) + f = tutils.tflow() + return state.add_request(f) def _add_response(self, state): f = self._add_request(state) - r = tutils.tresp(f.request) - state.add_response(r) + f.response = tutils.tresp() + state.add_response(f) def test_add_response(self): c = console.ConsoleState() f = self._add_request(c) - r = tutils.tresp(f.request) + f.response = tutils.tresp() c.focus = None - c.add_response(r) + c.add_response(f) def test_focus_view(self): c = console.ConsoleState() diff --git a/test/test_console_common.py b/test/test_console_common.py index d798e4dc6..1949dad5c 100644 --- a/test/test_console_common.py +++ b/test/test_console_common.py @@ -9,7 +9,7 @@ import tutils def test_format_flow(): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert common.format_flow(f, True) assert common.format_flow(f, True, hostheader=True) assert common.format_flow(f, True, extended=True) diff --git a/test/test_dump.py b/test/test_dump.py index 6f70450fb..fd93cc03c 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -10,31 +10,27 @@ def test_strfuncs(): t.is_replay = True dump.str_response(t) - t = tutils.treq() - t.flow.client_conn = None - t.stickycookie = True - assert "stickycookie" in dump.str_request(t, False) - assert "stickycookie" in dump.str_request(t, True) - assert "replay" in dump.str_request(t, False) - assert "replay" in dump.str_request(t, True) + f = tutils.tflow() + f.client_conn = None + f.request.stickycookie = True + assert "stickycookie" in dump.str_request(f, False) + assert "stickycookie" in dump.str_request(f, True) + assert "replay" in dump.str_request(f, False) + assert "replay" in dump.str_request(f, True) class TestDumpMaster: def _cycle(self, m, content): - req = tutils.treq(content=content) + f = tutils.tflow(req=tutils.treq(content)) l = Log("connect") l.reply = mock.MagicMock() m.handle_log(l) - cc = req.flow.client_conn - cc.reply = mock.MagicMock() - m.handle_clientconnect(cc) - sc = proxy.connection.ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = mock.MagicMock() - m.handle_serverconnect(sc) - m.handle_request(req) - resp = tutils.tresp(req, content=content) - f = m.handle_response(resp) - m.handle_clientdisconnect(cc) + m.handle_clientconnect(f.client_conn) + m.handle_serverconnect(f.server_conn) + m.handle_request(f) + f.response = tutils.tresp(content) + f = m.handle_response(f) + m.handle_clientdisconnect(f.client_conn) return f def _dummy_cycle(self, n, filt, content, **options): @@ -49,8 +45,7 @@ class TestDumpMaster: def _flowfile(self, path): f = open(path, "wb") fw = flow.FlowWriter(f) - t = tutils.tflow_full() - t.response = tutils.tresp(t.request) + t = tutils.tflow(resp=True) fw.add(t) f.close() @@ -58,9 +53,9 @@ class TestDumpMaster: cs = StringIO() o = dump.Options(flow_detail=1) m = dump.DumpMaster(None, o, None, outfile=cs) - f = tutils.tflow_err() - m.handle_request(f.request) - assert m.handle_error(f.error) + f = tutils.tflow(err=True) + m.handle_request(f) + assert m.handle_error(f) assert "error" in cs.getvalue() def test_replay(self): diff --git a/test/test_flow.py b/test/test_flow.py index 88e7b9d70..6e9464e73 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -14,7 +14,8 @@ def test_app_registry(): ar.add("foo", "domain", 80) r = tutils.treq() - r.set_url("http://domain:80/") + r.host = "domain" + r.port = 80 assert ar.get(r) r.port = 81 @@ -32,8 +33,7 @@ def test_app_registry(): class TestStickyCookieState: def _response(self, cookie, host): s = flow.StickyCookieState(filt.parse(".*")) - f = tutils.tflow_full() - f.server_conn.address = tcp.Address((host, 80)) + f = tutils.tflow(req=tutils.treq(host=host, port=80), resp=True) f.response.headers["Set-Cookie"] = [cookie] s.handle_response(f) return s, f @@ -66,12 +66,12 @@ class TestStickyCookieState: class TestStickyAuthState: def test_handle_response(self): s = flow.StickyAuthState(filt.parse(".*")) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["authorization"] = ["foo"] s.handle_request(f) assert "address" in s.hosts - f = tutils.tflow_full() + f = tutils.tflow(resp=True) s.handle_request(f) assert f.request.headers["authorization"] == ["foo"] @@ -123,24 +123,24 @@ class TestServerPlaybackState: def test_headers(self): s = flow.ServerPlaybackState(["foo"], [], False, False) - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["foo"] = ["bar"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) assert not s._hash(r) == s._hash(r2) r2.request.headers["foo"] = ["bar"] assert s._hash(r) == s._hash(r2) r2.request.headers["oink"] = ["bar"] assert s._hash(r) == s._hash(r2) - r = tutils.tflow_full() - r2 = tutils.tflow_full() + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) assert s._hash(r) == s._hash(r2) def test_load(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, False) @@ -158,10 +158,10 @@ class TestServerPlaybackState: assert not s.next_flow(r) def test_load_with_nopop(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, True) @@ -173,7 +173,7 @@ class TestServerPlaybackState: class TestFlow: def test_copy(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) a0 = f._get_state() f2 = f.copy() a = f._get_state() @@ -188,7 +188,7 @@ class TestFlow: assert f.response == f2.response assert not f.response is f2.response - f = tutils.tflow_err() + f = tutils.tflow(err=True) f2 = f.copy() assert not f is f2 assert not f.request is f2.request @@ -198,12 +198,12 @@ class TestFlow: assert not f.error is f2.error def test_match(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert not f.match("~b test") assert f.match(None) assert not f.match("~b test") - f = tutils.tflow_err() + f = tutils.tflow(err=True) assert f.match("~e") tutils.raises(ValueError, f.match, "~") @@ -220,14 +220,14 @@ class TestFlow: assert f.request.content == "foo" def test_backup_idempotence(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.backup() f.revert() f.backup() f.revert() def test_getset_state(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) state = f._get_state() assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() @@ -248,55 +248,42 @@ class TestFlow: s = flow.State() fm = flow.FlowMaster(None, s) f = tutils.tflow() - f.request = tutils.treq() f.intercept() - assert not f.request.reply.acked + assert not f.reply.acked f.kill(fm) - assert f.request.reply.acked - f.intercept() - f.response = tutils.tresp() - f.request.reply() - assert not f.response.reply.acked - f.kill(fm) - assert f.response.reply.acked + assert f.reply.acked def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, s) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) for i in s.view: - assert not i.request.reply.acked + assert not i.reply.acked s.killall(fm) for i in s.view: - assert i.request.reply.acked + assert i.reply.acked def test_accept_intercept(self): f = tutils.tflow() - f.request = tutils.treq() + f.intercept() - assert not f.request.reply.acked + assert not f.reply.acked f.accept_intercept() - assert f.request.reply.acked - f.response = tutils.tresp() - f.intercept() - f.request.reply() - assert not f.response.reply.acked - f.accept_intercept() - assert f.response.reply.acked + assert f.reply.acked def test_replace_unicode(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.content = "\xc2foo" f.replace("foo", u"bar") def test_replace(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["foo"] = ["foo"] f.request.content = "afoob" @@ -311,7 +298,7 @@ class TestFlow: assert f.response.content == "abarb" def test_replace_encoded(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "afoob" f.request.encode("gzip") f.response.content = "afoob" @@ -332,9 +319,8 @@ class TestFlow: class TestState: def test_backup(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) - + f = tutils.tflow() + c.add_request(f) f.backup() c.revert(f) @@ -344,72 +330,66 @@ class TestState: connect -> request -> response """ - bc = tutils.tclient_conn() c = flow.State() - - req = tutils.treq(bc) - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert f assert c.flow_count() == 1 assert c.active_flow_count() == 1 - newreq = tutils.treq() - assert c.add_request(newreq) + newf = tutils.tflow() + assert c.add_request(newf) assert c.active_flow_count() == 2 - resp = tutils.tresp(req) - assert c.add_response(resp) + f.response = tutils.tresp() + assert c.add_response(f) assert c.flow_count() == 2 assert c.active_flow_count() == 1 - unseen_resp = tutils.tresp() - unseen_resp.flow = None - assert not c.add_response(unseen_resp) + _ = tutils.tresp() + assert not c.add_response(None) assert c.active_flow_count() == 1 - resp = tutils.tresp(newreq) - assert c.add_response(resp) + newf.response = tutils.tresp() + assert c.add_response(newf) assert c.active_flow_count() == 0 def test_err(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) f.error = Error("message") - assert c.add_error(f.error) - - e = Error("message") - assert not c.add_error(e) + assert c.add_error(f) c = flow.State() - req = tutils.treq() - f = c.add_request(req) - e = tutils.terr() + f = tutils.tflow() + c.add_request(f) c.set_limit("~e") assert not c.view - assert c.add_error(e) + f.error = tutils.terr() + assert c.add_error(f) assert c.view def test_set_limit(self): c = flow.State() - req = tutils.treq() + f = tutils.tflow() assert len(c.view) == 0 - c.add_request(req) + c.add_request(f) assert len(c.view) == 1 c.set_limit("~s") assert c.limit_txt == "~s" assert len(c.view) == 0 - resp = tutils.tresp(req) - c.add_response(resp) + f.response = tutils.tresp() + c.add_response(f) assert len(c.view) == 1 c.set_limit(None) assert len(c.view) == 1 - req = tutils.treq() - c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert len(c.view) == 2 c.set_limit("~q") assert len(c.view) == 1 @@ -427,20 +407,19 @@ class TestState: assert c.intercept_txt == None def _add_request(self, state): - req = tutils.treq() - f = state.add_request(req) + f = tutils.tflow() + state.add_request(f) return f def _add_response(self, state): - req = tutils.treq() - state.add_request(req) - resp = tutils.tresp(req) - state.add_response(resp) + f = tutils.tflow() + state.add_request(f) + f.response = tutils.tresp() + state.add_response(f) def _add_error(self, state): - req = tutils.treq() - f = state.add_request(req) - f.error = Error("msg") + f = tutils.tflow(err=True) + state.add_request(f) def test_clear(self): c = flow.State() @@ -479,10 +458,10 @@ class TestSerialize: sio = StringIO() w = flow.FlowWriter(sio) for i in range(3): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) w.add(f) for i in range(3): - f = tutils.tflow_err() + f = tutils.tflow(err=True) w.add(f) sio.seek(0) @@ -516,11 +495,11 @@ class TestSerialize: fl = filt.parse("~c 200") w = flow.FilteredFlowWriter(sio, fl) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 200 w.add(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 201 w.add(f) @@ -565,7 +544,7 @@ class TestFlowMaster: def test_replay(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = CONTENT_MISSING assert "missing" in fm.replay_request(f) @@ -576,48 +555,44 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - assert fm.handle_request(req) + f = tutils.tflow() + fm.handle_clientconnect(f.client_conn) + assert fm.handle_request(f) def test_script(self): s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) + f = tutils.tflow(resp=True) + + fm.handle_clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - sc = ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = controller.DummyReply() - fm.handle_serverconnect(sc) + fm.handle_serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" - f = fm.handle_request(req) + fm.handle_request(f) assert fm.scripts[0].ns["log"][-1] == "request" - resp = tutils.tresp(req) - fm.handle_response(resp) + fm.handle_response(f) assert fm.scripts[0].ns["log"][-1] == "response" #load second script assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - fm.handle_clientdisconnect(sc) + fm.handle_clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" - #unload first script fm.unload_scripts() assert len(fm.scripts) == 0 - assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = tutils.terr() - err.reply = controller.DummyReply() - fm.handle_error(err) + + f.error = tutils.terr() + fm.handle_error(f) assert fm.scripts[0].ns["log"][-1] == "error" def test_duplicate_flow(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f = fm.load_flow(f) assert s.flow_count() == 1 f2 = fm.duplicate_flow(f) @@ -630,25 +605,22 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) fm.anticache = True fm.anticomp = True - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - - f = fm.handle_request(req) + f = tutils.tflow(req=None) + fm.handle_clientconnect(f.client_conn) + f.request = tutils.treq() + fm.handle_request(f) assert s.flow_count() == 1 - resp = tutils.tresp(req) - fm.handle_response(resp) + f.response = tutils.tresp() + fm.handle_response(f) + assert not fm.handle_response(None) assert s.flow_count() == 1 - rx = tutils.tresp() - rx.flow = None - assert not fm.handle_response(rx) - - fm.handle_clientdisconnect(req.flow.client_conn) + fm.handle_clientdisconnect(f.client_conn) f.error = Error("msg") f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -656,8 +628,8 @@ class TestFlowMaster: def test_client_playback(self): s = flow.State() - f = tutils.tflow_full() - pb = [tutils.tflow_full(), f] + f = tutils.tflow(resp=True) + pb = [tutils.tflow(resp=True), f] fm = flow.FlowMaster(None, s) assert not fm.start_server_playback(pb, False, [], False, False) assert not fm.start_client_playback(pb, False) @@ -668,8 +640,7 @@ class TestFlowMaster: assert fm.state.flow_count() f.error = Error("error") - f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) def test_server_playback(self): s = flow.State() @@ -723,15 +694,15 @@ class TestFlowMaster: assert not fm.stickycookie_state fm.set_stickycookie(".*") - tf = tutils.tflow_full() - tf.response.headers["set-cookie"] = ["foo=bar"] - fm.handle_request(tf.request) - fm.handle_response(tf.response) + f = tutils.tflow(resp=True) + f.response.headers["set-cookie"] = ["foo=bar"] + fm.handle_request(f) + fm.handle_response(f) assert fm.stickycookie_state.jar - assert not "cookie" in tf.request.headers - tf = tf.copy() - fm.handle_request(tf.request) - assert tf.request.headers["cookie"] == ["foo=bar"] + assert not "cookie" in f.request.headers + f = f.copy() + fm.handle_request(f) + assert f.request.headers["cookie"] == ["foo=bar"] def test_stickyauth(self): s = flow.State() @@ -743,14 +714,14 @@ class TestFlowMaster: assert not fm.stickyauth_state fm.set_stickyauth(".*") - tf = tutils.tflow_full() - tf.request.headers["authorization"] = ["foo"] - fm.handle_request(tf.request) + f = tutils.tflow(resp=True) + f.request.headers["authorization"] = ["foo"] + fm.handle_request(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert fm.stickyauth_state.hosts assert not "authorization" in f.request.headers - fm.handle_request(f.request) + fm.handle_request(f) assert f.request.headers["authorization"] == ["foo"] def test_stream(self): @@ -762,29 +733,30 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) - tf = tutils.tflow_full() + f = tutils.tflow(resp=True) fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) - fm.handle_response(tf.response) + fm.handle_request(f) + fm.handle_response(f) fm.stop_stream() assert r()[0].response - tf = tutils.tflow() + f = tutils.tflow() fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) + fm.handle_request(f) fm.shutdown() assert not r()[1].response class TestRequest: def test_simple(self): - r = tutils.treq() - u = r.get_url() - assert r.set_url(u) - assert not r.set_url("") - assert r.get_url() == u + f = tutils.tflow() + r = f.request + u = r.get_url(False, f) + assert r.set_url(u, f) + assert not r.set_url("", f) + assert r.get_url(False, f) == u assert r._assemble() assert r.size() == len(r._assemble()) @@ -799,42 +771,45 @@ class TestRequest: tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - r = tutils.tflow().request + f = tutils.tflow() + r = f.request - assert r.get_url() == "http://address:22/path" + assert r.get_url(False, f) == "http://address:22/path" - r.flow.server_conn.ssl_established = True - assert r.get_url() == "https://address:22/path" + r.scheme = "https" + assert r.get_url(False, f) == "https://address:22/path" - r.flow.server_conn.address = tcp.Address(("host", 42)) - assert r.get_url() == "https://host:42/path" + r.host = "host" + r.port = 42 + assert r.get_url(False, f) == "https://host:42/path" r.host = "address" r.port = 22 - assert r.get_url() == "https://address:22/path" + assert r.get_url(False, f) == "https://address:22/path" - assert r.get_url(hostheader=True) == "https://address:22/path" + assert r.get_url(True, f) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url() == "https://address:22/path" - assert r.get_url(hostheader=True) == "https://foo.com:22/path" + assert r.get_url(False, f) == "https://address:22/path" + assert r.get_url(True, f) == "https://foo.com:22/path" def test_path_components(self): - r = tutils.treq() + f = tutils.tflow() + r = f.request r.path = "/" - assert r.get_path_components() == [] + assert r.get_path_components(f) == [] r.path = "/foo/bar" - assert r.get_path_components() == ["foo", "bar"] + assert r.get_path_components(f) == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] - r.set_query(q) - assert r.get_path_components() == ["foo", "bar"] + r.set_query(q, f) + assert r.get_path_components(f) == ["foo", "bar"] - r.set_path_components([]) - assert r.get_path_components() == [] - r.set_path_components(["foo"]) - assert r.get_path_components() == ["foo"] - r.set_path_components(["/oo"]) - assert r.get_path_components() == ["/oo"] + r.set_path_components([], f) + assert r.get_path_components(f) == [] + r.set_path_components(["foo"], f) + assert r.get_path_components(f) == ["foo"] + r.set_path_components(["/oo"], f) + assert r.get_path_components(f) == ["/oo"] assert "%2F" in r.path def test_getset_form_urlencoded(self): @@ -853,26 +828,26 @@ class TestRequest: def test_getset_query(self): h = flow.ODictCaseless() - r = tutils.treq() - r.path = "/foo?x=y&a=b" - q = r.get_query() + f = tutils.tflow() + f.request.path = "/foo?x=y&a=b" + q = f.request.get_query(f) assert q.lst == [("x", "y"), ("a", "b")] - r.path = "/" - q = r.get_query() + f.request.path = "/" + q = f.request.get_query(f) assert not q - r.path = "/?adsfa" - q = r.get_query() + f.request.path = "/?adsfa" + q = f.request.get_query(f) assert q.lst == [("adsfa", "")] - r.path = "/foo?x=y&a=b" - assert r.get_query() - r.set_query(flow.ODict([])) - assert not r.get_query() + f.request.path = "/foo?x=y&a=b" + assert f.request.get_query(f) + f.request.set_query(flow.ODict([]), f) + assert not f.request.get_query(f) qv = flow.ODict([("a", "b"), ("c", "d")]) - r.set_query(qv) - assert r.get_query() == qv + f.request.set_query(qv, f) + assert f.request.get_query(f) == qv def test_anticache(self): h = flow.ODictCaseless() @@ -979,8 +954,8 @@ class TestRequest: h["headername"] = ["headervalue"] r = tutils.treq() r.headers = h - result = len(r._assemble_headers()) - assert result == 62 + raw = r._assemble_headers() + assert len(raw) == 62 def test_get_content_type(self): h = flow.ODictCaseless() @@ -991,7 +966,7 @@ class TestRequest: class TestResponse: def test_simple(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) resp = f.response assert resp._assemble() assert resp.size() == len(resp._assemble()) @@ -1227,7 +1202,7 @@ def test_replacehooks(): h.run(f) assert f.request.content == "foo" - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "foo" f.response.content = "foo" h.run(f) @@ -1280,7 +1255,7 @@ def test_setheaders(): h.clear() h.add("~s", "one", "two") h.add("~s", "one", "three") - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["one"] = ["xxx"] f.response.headers["one"] = ["xxx"] h.run(f) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 3b922c063..c2ff7b440 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -26,10 +26,12 @@ def test_stripped_chunked_encoding_no_content(): class TestHTTPRequest: def test_asterisk_form(self): s = StringIO("OPTIONS * HTTP/1.1") - f = tutils.tflow_noreq() + f = tutils.tflow(req=None) f.request = HTTPRequest.from_stream(s) assert f.request.form_in == "relative" - x = f.request._assemble() + f.request.host = f.server_conn.address.host + f.request.port = f.server_conn.address.port + f.request.scheme = "http" assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_origin_form(self): @@ -41,6 +43,7 @@ class TestHTTPRequest: tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) s = StringIO("CONNECT address:22 HTTP/1.1") r = HTTPRequest.from_stream(s) + r.scheme, r.host, r.port = "http", "address", 22 assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_absolute_form(self): @@ -55,12 +58,12 @@ class TestHTTPRequest: tutils.raises("Invalid request form", r._assemble, "antiauthority") def test_set_url(self): - r = tutils.treq_absolute() - r.set_url("https://otheraddress:42/ORLY") - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" + f = tutils.tflow(req=tutils.treq_absolute()) + f.request.set_url("https://otheraddress:42/ORLY", f) + assert f.request.scheme == "https" + assert f.request.host == "otheraddress" + assert f.request.port == 42 + assert f.request.path == "/ORLY" class TestHTTPResponse: @@ -130,10 +133,10 @@ class TestProxyChainingSSL(tservers.HTTPChainProxyTest): """ https://github.com/mitmproxy/mitmproxy/issues/313 """ - def handle_request(r): - r.httpversion = (1,0) - del r.headers["Content-Length"] - r.reply() + def handle_request(f): + f.request.httpversion = (1, 0) + del f.request.headers["Content-Length"] + f.reply() _handle_request = self.chain[0].tmaster.handle_request self.chain[0].tmaster.handle_request = handle_request try: @@ -159,13 +162,13 @@ class TestProxyChainingSSLReconnect(tservers.HTTPChainProxyTest): def kill_requests(master, attr, exclude): k = [0] # variable scope workaround: put into array _func = getattr(master, attr) - def handler(r): + def handler(f): k[0] += 1 if not (k[0] in exclude): - r.flow.client_conn.finish() - r.flow.error = Error("terminated") - r.reply(KILL) - return _func(r) + f.client_conn.finish() + f.error = Error("terminated") + f.reply(KILL) + return _func(f) setattr(master, attr, handler) kill_requests(self.proxy.tmaster, "handle_request", diff --git a/test/test_proxy.py b/test/test_proxy.py index 2ff01acc6..91e4954f9 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -25,11 +25,11 @@ class TestServerConnection: def test_simple(self): sc = ServerConnection((self.d.IFACE, self.d.port), None) sc.connect() - r = tutils.treq() - r.flow.server_conn = sc - r.path = "/p/200:da" - sc.send(r._assemble()) - assert http.read_response(sc.rfile, r.method, 1000) + f = tutils.tflow() + f.server_conn = sc + f.request.path = "/p/200:da" + sc.send(f.request._assemble()) + assert http.read_response(sc.rfile, f.request.method, 1000) assert self.d.last_log() sc.finish() diff --git a/test/test_script.py b/test/test_script.py index 587c52d63..7c421fdec 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -29,8 +29,8 @@ class TestScript: s = flow.State() fm = flow.FlowMaster(None, s) fm.load_script(tutils.test_data.path("scripts/duplicate_flow.py")) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) assert fm.state.flow_count() == 2 assert not fm.state.view[0].request.is_replay assert fm.state.view[1].request.is_replay @@ -65,12 +65,12 @@ class TestScript: fm.load_script(tutils.test_data.path("scripts/concurrent_decorator.py")) with mock.patch("libmproxy.controller.DummyReply.__call__") as m: - r1, r2 = tutils.treq(), tutils.treq() + f1, f2 = tutils.tflow(), tutils.tflow() t_start = time.time() - fm.handle_request(r1) - r1.reply() - fm.handle_request(r2) - r2.reply() + fm.handle_request(f1) + f1.reply() + fm.handle_request(f2) + f2.reply() # Two instantiations assert m.call_count == 0 # No calls yet. diff --git a/test/test_server.py b/test/test_server.py index a570f10f3..48527547e 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -327,29 +327,32 @@ class TestProxySSL(tservers.HTTPProxTest): # tests that the ssl timestamp is present when ssl is used f = self.pathod("304:b@10k") assert f.status_code == 304 - first_request = self.master.state.view[0].request - assert first_request.flow.server_conn.timestamp_ssl_setup + first_flow = self.master.state.view[0] + assert first_flow.server_conn.timestamp_ssl_setup class MasterRedirectRequest(tservers.TestMaster): - def handle_request(self, request): + redirect_port = None # Set by TestRedirectRequest + + def handle_request(self, f): + request = f.request if request.path == "/p/201": - url = request.get_url() + url = request.get_url(False, f) new = "http://127.0.0.1:%s/p/201" % self.redirect_port - request.set_url(new) - request.set_url(new) - request.flow.live.change_server(("127.0.0.1", self.redirect_port), False) - request.set_url(url) - tutils.raises("SSL handshake error", request.flow.live.change_server, ("127.0.0.1", self.redirect_port), True) - request.set_url(new) - request.set_url(url) - request.set_url(new) - tservers.TestMaster.handle_request(self, request) + request.set_url(new, f) + request.set_url(new, f) + f.live.change_server(("127.0.0.1", self.redirect_port), False) + request.set_url(url, f) + tutils.raises("SSL handshake error", f.live.change_server, ("127.0.0.1", self.redirect_port), True) + request.set_url(new, f) + request.set_url(url, f) + request.set_url(new, f) + tservers.TestMaster.handle_request(self, f) - def handle_response(self, response): - response.content = str(response.flow.client_conn.address.port) - tservers.TestMaster.handle_response(self, response) + def handle_response(self, f): + f.response.content = str(f.client_conn.address.port) + tservers.TestMaster.handle_response(self, f) class TestRedirectRequest(tservers.HTTPProxTest): @@ -388,9 +391,9 @@ class MasterStreamRequest(tservers.TestMaster): """ Enables the stream flag on the flow for all requests """ - def handle_responseheaders(self, r): - r.stream = True - r.reply() + def handle_responseheaders(self, f): + f.response.stream = True + f.reply() class TestStreamRequest(tservers.HTTPProxTest): masterclass = MasterStreamRequest @@ -441,9 +444,9 @@ class TestStreamRequest(tservers.HTTPProxTest): class MasterFakeResponse(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() - m.reply(resp) + f.reply(resp) class TestFakeResponse(tservers.HTTPProxTest): @@ -454,8 +457,8 @@ class TestFakeResponse(tservers.HTTPProxTest): class MasterKillRequest(tservers.TestMaster): - def handle_request(self, m): - m.reply(KILL) + def handle_request(self, f): + f.reply(KILL) class TestKillRequest(tservers.HTTPProxTest): @@ -467,8 +470,8 @@ class TestKillRequest(tservers.HTTPProxTest): class MasterKillResponse(tservers.TestMaster): - def handle_response(self, m): - m.reply(KILL) + def handle_response(self, f): + f.reply(KILL) class TestKillResponse(tservers.HTTPProxTest): @@ -491,10 +494,10 @@ class TestTransparentResolveError(tservers.TransparentProxTest): class MasterIncomplete(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() resp.content = CONTENT_MISSING - m.reply(resp) + f.reply(resp) class TestIncompleteResponse(tservers.HTTPProxTest): diff --git a/test/tservers.py b/test/tservers.py index a12a440e3..9f2abbe1a 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -36,13 +36,13 @@ class TestMaster(flow.FlowMaster): self.apps.add(errapp, "errapp", 80) self.clear_log() - def handle_request(self, m): - flow.FlowMaster.handle_request(self, m) - m.reply() + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) + f.reply() - def handle_response(self, m): - flow.FlowMaster.handle_response(self, m) - m.reply() + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) + f.reply() def clear_log(self): self.log = [] diff --git a/test/tutils.py b/test/tutils.py index dc049adb7..84a9bba04 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -21,7 +21,38 @@ def SkipWindows(fn): return fn +def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None): + """ + @type client_conn: bool | None | libmproxy.proxy.connection.ClientConnection + @type server_conn: bool | None | libmproxy.proxy.connection.ServerConnection + @type req: bool | None | libmproxy.protocol.http.HTTPRequest + @type resp: bool | None | libmproxy.protocol.http.HTTPResponse + @type err: bool | None | libmproxy.protocol.primitives.Error + @return: bool | None | libmproxy.protocol.http.HTTPFlow + """ + if client_conn is True: + client_conn = tclient_conn() + if server_conn is True: + server_conn = tserver_conn() + if req is True: + req = treq() + if resp is True: + resp = tresp() + if err is True: + err = terr() + + f = http.HTTPFlow(client_conn, server_conn) + f.request = req + f.response = resp + f.error = err + f.reply = controller.DummyReply() + return f + + def tclient_conn(): + """ + @return: libmproxy.proxy.connection.ClientConnection + """ c = ClientConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), clientcert=None @@ -31,6 +62,9 @@ def tclient_conn(): def tserver_conn(): + """ + @return: libmproxy.proxy.connection.ServerConnection + """ c = ServerConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), state=[], @@ -41,75 +75,46 @@ def tserver_conn(): return c -def treq_absolute(conn=None, content="content"): - r = treq(conn, content) +def treq(content="content", scheme="http", host="address", port=22): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + headers = flow.ODictCaseless() + headers["header"] = ["qvalue"] + req = http.HTTPRequest("relative", "GET", scheme, host, port, "/path", (1, 1), headers, content, + None, None, None) + return req + +def treq_absolute(content="content"): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + r = treq(content) r.form_in = r.form_out = "absolute" r.host = "address" r.port = 22 r.scheme = "http" return r -def treq(conn=None, content="content"): - if not conn: - conn = tclient_conn() - server_conn = tserver_conn() - headers = flow.ODictCaseless() - headers["header"] = ["qvalue"] - f = http.HTTPFlow(conn, server_conn) - f.request = http.HTTPRequest("relative", "GET", None, None, None, "/path", (1, 1), headers, content, - None, None, None) - f.request.reply = controller.DummyReply() - return f.request - - -def tresp(req=None, content="message"): - if not req: - req = treq() - f = req.flow +def tresp(content="message"): + """ + @return: libmproxy.protocol.http.HTTPResponse + """ headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"), "rb").read()) - f.server_conn = ServerConnection._from_state(dict( - address=dict(address=("address", 22), use_ipv6=True), - state=[], - source_address=None, - cert=cert.to_pem())) - f.response = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) - f.response.reply = controller.DummyReply() - return f.response + + resp = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) + return resp -def terr(req=None): - if not req: - req = treq() - f = req.flow - f.error = Error("error") - f.error.reply = controller.DummyReply() - return f.error - -def tflow_noreq(): - f = tflow() - f.request = None - return f - -def tflow(req=None): - if not req: - req = treq() - return req.flow - - -def tflow_full(): - f = tflow() - f.response = tresp(f.request) - return f - - -def tflow_err(): - f = tflow() - f.error = terr(f.request) - return f +def terr(content="error"): + """ + @return: libmproxy.protocol.primitives.Error + """ + err = Error(content) + return err def tflowview(request_contents=None): m = Mock() @@ -117,8 +122,7 @@ def tflowview(request_contents=None): if request_contents == None: flow = tflow() else: - req = treq(None, request_contents) - flow = tflow(req) + flow = tflow(req=treq(request_contents)) fv = FlowView(m, cs, flow) return fv