remove proxy mode abstraction: always be clear which mode we are in

This commit is contained in:
Maximilian Hils 2014-09-08 14:32:42 +02:00
parent 6dbe431c5e
commit d06b4bfa4e
8 changed files with 94 additions and 71 deletions

View File

@ -314,11 +314,6 @@ def common_options(parser):
action="store", choices=("relative", "absolute"), action="store", choices=("relative", "absolute"),
help="Override the HTTP request form sent upstream by the proxy" help="Override the HTTP request form sent upstream by the proxy"
) )
group.add_argument(
"--destination-server", dest="manual_destination_server", default=None,
action="store", type=parse_server_spec,
help="Override the destination server all requests are sent to: http[s][2http[s]]://host[:port]"
)
group = parser.add_argument_group("Web App") group = parser.add_argument_group("Web App")
group.add_argument( group.add_argument(

View File

@ -175,13 +175,12 @@ class StatusBar(common.WWrap):
if opts: if opts:
r.append("[%s]"%(":".join(opts))) r.append("[%s]"%(":".join(opts)))
if self.master.server.config.get_upstream_server and \ if self.master.server.config.mode in ["reverse", "upstream"]:
isinstance(self.master.server.config.get_upstream_server, proxy.ConstUpstreamServerResolver): dst = self.master.server.config.mode.dst
dst = self.master.server.config.get_upstream_server.dst
scheme = "https" if dst[0] else "http" scheme = "https" if dst[0] else "http"
if dst[1] != dst[0]: if dst[1] != dst[0]:
scheme += "2https" if dst[1] else "http" scheme += "2https" if dst[1] else "http"
r.append("[dest:%s]"%utils.unparse_url(scheme, *self.master.server.config.get_upstream_server.dst[2:])) r.append("[dest:%s]"%utils.unparse_url(scheme, *dst[2:]))
if self.master.scripts: if self.master.scripts:
r.append("[") r.append("[")
r.append(("heading_key", "s")) r.append(("heading_key", "s"))

View File

@ -865,8 +865,8 @@ class HTTPHandler(ProtocolHandler):
""" """
def __init__(self, c): def __init__(self, c):
super(HTTPHandler, self).__init__(c) super(HTTPHandler, self).__init__(c)
self.expected_form_in = c.config.http_form_in self.expected_form_in = c.config.mode.http_form_in
self.expected_form_out = c.config.http_form_out self.expected_form_out = c.config.mode.http_form_out
self.skip_authentication = False self.skip_authentication = False
def handle_messages(self): def handle_messages(self):
@ -1072,20 +1072,19 @@ class HTTPHandler(ProtocolHandler):
if self.c.client_conn.ssl_established: if self.c.client_conn.ssl_established:
raise http.HttpError(400, "Must not CONNECT on already encrypted connection") raise http.HttpError(400, "Must not CONNECT on already encrypted connection")
if self.expected_form_in == "absolute": if self.c.config.mode == "regular":
if not self.c.config.get_upstream_server: # Regular mode self.c.set_server_address((request.host, request.port))
self.c.set_server_address((request.host, request.port)) flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow self.c.establish_server_connection()
self.c.establish_server_connection() self.c.client_conn.send(
self.c.client_conn.send( 'HTTP/1.1 200 Connection established\r\n' +
'HTTP/1.1 200 Connection established\r\n' + 'Content-Length: 0\r\n' +
'Content-Length: 0\r\n' + ('Proxy-agent: %s\r\n' % self.c.server_version) +
('Proxy-agent: %s\r\n' % self.c.server_version) + '\r\n'
'\r\n' )
) return self.process_connect_request(self.c.server_conn.address)
return self.process_connect_request(self.c.server_conn.address) elif self.c.config.mode == "upstream":
else: # upstream proxy mode return None
return None
else: else:
pass # CONNECT should never occur if we don't expect absolute-form requests pass # CONNECT should never occur if we don't expect absolute-form requests
@ -1113,7 +1112,7 @@ class HTTPHandler(ProtocolHandler):
ssl = (flow.request.scheme == "https") ssl = (flow.request.scheme == "https")
if self.c.config.http_form_in == self.c.config.http_form_out == "absolute": # Upstream Proxy mode if self.c.config.mode == "upstream":
# The connection to the upstream proxy may have a state we may need to take into account. # The connection to the upstream proxy may have a state we may need to take into account.
connected_to = None connected_to = None
@ -1223,8 +1222,8 @@ class RequestReplayThread(threading.Thread):
form_out_backup = r.form_out form_out_backup = r.form_out
try: try:
# In all modes, we directly connect to the server displayed # In all modes, we directly connect to the server displayed
if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode if self.config.mode == "upstream":
server_address = self.config.get_upstream_server(self.flow.client_conn)[2:] server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
server = ServerConnection(server_address) server = ServerConnection(server_address)
server.connect() server.connect()
if r.scheme == "https": if r.scheme == "https":

View File

@ -3,7 +3,7 @@ import os
import re import re
from netlib import http_auth, certutils from netlib import http_auth, certutils
from .. import utils, platform from .. import utils, platform
from .primitives import ConstUpstreamServerResolver, TransparentUpstreamServerResolver from .primitives import RegularProxyMode, TransparentProxyMode, UpstreamProxyMode, ReverseProxyMode
TRANSPARENT_SSL_PORTS = [443, 8443] TRANSPARENT_SSL_PORTS = [443, 8443]
CONF_BASENAME = "mitmproxy" CONF_BASENAME = "mitmproxy"
@ -26,25 +26,17 @@ class ProxyConfig:
self.body_size_limit = body_size_limit self.body_size_limit = body_size_limit
if mode == "transparent": if mode == "transparent":
get_upstream_server = TransparentUpstreamServerResolver(platform.resolver(), TRANSPARENT_SSL_PORTS) self.mode = TransparentProxyMode(platform.resolver(), TRANSPARENT_SSL_PORTS)
http_form_in_default, http_form_out_default = "relative", "relative"
elif mode == "reverse": elif mode == "reverse":
get_upstream_server = ConstUpstreamServerResolver(upstream_server) self.mode = ReverseProxyMode(upstream_server)
http_form_in_default, http_form_out_default = "relative", "relative"
elif mode == "upstream": elif mode == "upstream":
get_upstream_server = ConstUpstreamServerResolver(upstream_server) self.mode = UpstreamProxyMode(upstream_server)
http_form_in_default, http_form_out_default = "absolute", "absolute"
elif upstream_server:
get_upstream_server = ConstUpstreamServerResolver(upstream_server)
http_form_in_default, http_form_out_default = "absolute", "relative"
else: else:
get_upstream_server, http_form_in_default, http_form_out_default = None, "absolute", "relative" self.mode = RegularProxyMode()
http_form_in = http_form_in or http_form_in_default
http_form_out = http_form_out or http_form_out_default self.mode.http_form_in = http_form_in or self.mode.http_form_in
self.mode.http_form_out = http_form_out or self.mode.http_form_out
self.get_upstream_server = get_upstream_server
self.http_form_in = http_form_in
self.http_form_out = http_form_out
self.ignore = parse_host_pattern(ignore) self.ignore = parse_host_pattern(ignore)
self.authenticator = authenticator self.authenticator = authenticator
self.confdir = os.path.expanduser(confdir) self.confdir = os.path.expanduser(confdir)
@ -74,13 +66,9 @@ def process_proxy_options(parser, options):
c += 1 c += 1
mode = "upstream" mode = "upstream"
upstream_server = options.upstream_proxy upstream_server = options.upstream_proxy
if options.manual_destination_server:
c += 1
mode = "manual"
upstream_server = options.manual_destination_server
if c > 1: if c > 1:
return parser.error("Transparent mode, reverse mode, upstream proxy mode and " return parser.error("Transparent mode, reverse mode and upstream proxy mode "
"specification of an upstream server are mutually exclusive.") "are mutually exclusive.")
if options.clientcerts: if options.clientcerts:
options.clientcerts = os.path.expanduser(options.clientcerts) options.clientcerts = os.path.expanduser(options.clientcerts)

View File

@ -11,28 +11,54 @@ class ProxyServerError(Exception):
pass pass
class UpstreamServerResolver(object): class ProxyMode(object):
def __call__(self, conn): http_form_in = None
http_form_out = None
def get_upstream_server(self, conn):
""" """
Returns the address of the server to connect to. Returns the address of the server to connect to.
Returns None if the address needs to be determined on the protocol level (regular proxy mode)
""" """
raise NotImplementedError # pragma: nocover raise NotImplementedError() # pragma: nocover
@property
def name(self):
return self.__class__.__name__.replace("ProxyMode", "").lower()
def __str__(self):
return self.name
def __eq__(self, other):
"""
Allow comparisions with "regular" etc.
"""
if isinstance(other, ProxyMode):
return self is other
else:
return self.name == other
def __ne__(self, other):
return not self.__eq__(other)
class ConstUpstreamServerResolver(UpstreamServerResolver): class RegularProxyMode(ProxyMode):
def __init__(self, dst): http_form_in = "absolute"
self.dst = dst http_form_out = "relative"
def __call__(self, conn): def get_upstream_server(self, conn):
return self.dst return None
class TransparentUpstreamServerResolver(UpstreamServerResolver): class TransparentProxyMode(ProxyMode):
http_form_in = "relative"
http_form_out = "relative"
def __init__(self, resolver, sslports): def __init__(self, resolver, sslports):
self.resolver = resolver self.resolver = resolver
self.sslports = sslports self.sslports = sslports
def __call__(self, conn): def get_upstream_server(self, conn):
try: try:
dst = self.resolver.original_addr(conn) dst = self.resolver.original_addr(conn)
except Exception, e: except Exception, e:
@ -45,6 +71,24 @@ class TransparentUpstreamServerResolver(UpstreamServerResolver):
return [ssl, ssl] + list(dst) return [ssl, ssl] + list(dst)
class _ConstDestinationProxyMode(ProxyMode):
def __init__(self, dst):
self.dst = dst
def get_upstream_server(self, conn):
return self.dst
class ReverseProxyMode(_ConstDestinationProxyMode):
http_form_in = "relative"
http_form_out = "relative"
class UpstreamProxyMode(_ConstDestinationProxyMode):
http_form_in = "absolute"
http_form_out = "absolute"
class Log: class Log:
def __init__(self, msg, level="info"): def __init__(self, msg, level="info"):
self.msg = msg self.msg = msg

View File

@ -73,14 +73,16 @@ class ConnectionHandler:
# Can we already identify the target server and connect to it? # Can we already identify the target server and connect to it?
client_ssl, server_ssl = False, False client_ssl, server_ssl = False, False
if self.config.get_upstream_server: upstream_info = self.config.mode.get_upstream_server(self.client_conn.connection)
upstream_info = self.config.get_upstream_server(self.client_conn.connection) if upstream_info:
self.set_server_address(upstream_info[2:]) self.set_server_address(upstream_info[2:])
client_ssl, server_ssl = upstream_info[:2] client_ssl, server_ssl = upstream_info[:2]
if self.check_ignore_address(self.server_conn.address): if self.check_ignore_address(self.server_conn.address):
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info") self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
self.conntype = "tcp" self.conntype = "tcp"
client_ssl, server_ssl = False, False client_ssl, server_ssl = False, False
else:
pass # No upstream info from the metadata: upstream info in the protocol (e.g. HTTP absolute-form)
self.channel.ask("clientconnect", self) self.channel.ask("clientconnect", self)

View File

@ -91,10 +91,6 @@ class TestProcessProxyOptions:
self.assert_err("expected one argument", "-U") self.assert_err("expected one argument", "-U")
self.assert_err("Invalid server specification", "-U", "upstream") self.assert_err("Invalid server specification", "-U", "upstream")
self.assert_noerr("--destination-server", "http://localhost")
self.assert_err("expected one argument", "--destination-server")
self.assert_err("Invalid server specification", "--destination-server", "manual")
self.assert_err("mutually exclusive", "-R", "http://localhost", "-T") self.assert_err("mutually exclusive", "-R", "http://localhost", "-T")
def test_client_certs(self): def test_client_certs(self):
@ -144,7 +140,8 @@ class TestDummyServer:
class TestConnectionHandler: class TestConnectionHandler:
def test_fatal_error(self): def test_fatal_error(self):
config = dict(get_upstream_server=mock.Mock(side_effect=RuntimeError)) config = mock.Mock()
config.mode.get_upstream_server.side_effect = RuntimeError
c = ConnectionHandler(config, mock.MagicMock(), ("127.0.0.1", 8080), None, mock.MagicMock(), None) c = ConnectionHandler(config, mock.MagicMock(), ("127.0.0.1", 8080), None, mock.MagicMock(), None)
with tutils.capture_stderr(c.handle) as output: with tutils.capture_stderr(c.handle) as output:
assert "mitmproxy has crashed" in output assert "mitmproxy has crashed" in output

View File

@ -6,7 +6,7 @@ import mock
from libmproxy.proxy.config import ProxyConfig from libmproxy.proxy.config import ProxyConfig
from libmproxy.proxy.server import ProxyServer from libmproxy.proxy.server import ProxyServer
from libmproxy.proxy.primitives import TransparentUpstreamServerResolver from libmproxy.proxy.primitives import TransparentProxyMode
import libpathod.test, libpathod.pathoc import libpathod.test, libpathod.pathoc
from libmproxy import flow, controller from libmproxy import flow, controller
from libmproxy.cmdline import APP_HOST, APP_PORT from libmproxy.cmdline import APP_HOST, APP_PORT
@ -184,7 +184,7 @@ class TransparentProxTest(ProxTestBase):
ports = [cls.server.port, cls.server2.port] ports = [cls.server.port, cls.server2.port]
else: else:
ports = [] ports = []
cls.config.get_upstream_server = TransparentUpstreamServerResolver(cls.resolver(cls.server.port), ports) cls.config.mode = TransparentProxyMode(cls.resolver(cls.server.port), ports)
@classmethod @classmethod
def get_proxy_config(cls): def get_proxy_config(cls):
@ -224,8 +224,7 @@ class ReverseProxTest(ProxTestBase):
"127.0.0.1", "127.0.0.1",
cls.server.port cls.server.port
) )
d["http_form_in"] = "relative" d["mode"] = "reverse"
d["http_form_out"] = "relative"
return d return d
def pathoc(self, sni=None): def pathoc(self, sni=None):