remove proxy mode abstraction: always be clear which mode we are in

This commit is contained in:
Maximilian Hils 2014-09-08 14:32:42 +02:00
parent 6dbe431c5e
commit d06b4bfa4e
8 changed files with 94 additions and 71 deletions

View File

@ -314,11 +314,6 @@ def common_options(parser):
action="store", choices=("relative", "absolute"),
help="Override the HTTP request form sent upstream by the proxy"
)
group.add_argument(
"--destination-server", dest="manual_destination_server", default=None,
action="store", type=parse_server_spec,
help="Override the destination server all requests are sent to: http[s][2http[s]]://host[:port]"
)
group = parser.add_argument_group("Web App")
group.add_argument(

View File

@ -175,13 +175,12 @@ class StatusBar(common.WWrap):
if opts:
r.append("[%s]"%(":".join(opts)))
if self.master.server.config.get_upstream_server and \
isinstance(self.master.server.config.get_upstream_server, proxy.ConstUpstreamServerResolver):
dst = self.master.server.config.get_upstream_server.dst
if self.master.server.config.mode in ["reverse", "upstream"]:
dst = self.master.server.config.mode.dst
scheme = "https" if dst[0] else "http"
if dst[1] != dst[0]:
scheme += "2https" if dst[1] else "http"
r.append("[dest:%s]"%utils.unparse_url(scheme, *self.master.server.config.get_upstream_server.dst[2:]))
r.append("[dest:%s]"%utils.unparse_url(scheme, *dst[2:]))
if self.master.scripts:
r.append("[")
r.append(("heading_key", "s"))

View File

@ -865,8 +865,8 @@ class HTTPHandler(ProtocolHandler):
"""
def __init__(self, c):
super(HTTPHandler, self).__init__(c)
self.expected_form_in = c.config.http_form_in
self.expected_form_out = c.config.http_form_out
self.expected_form_in = c.config.mode.http_form_in
self.expected_form_out = c.config.mode.http_form_out
self.skip_authentication = False
def handle_messages(self):
@ -1072,8 +1072,7 @@ class HTTPHandler(ProtocolHandler):
if self.c.client_conn.ssl_established:
raise http.HttpError(400, "Must not CONNECT on already encrypted connection")
if self.expected_form_in == "absolute":
if not self.c.config.get_upstream_server: # Regular mode
if self.c.config.mode == "regular":
self.c.set_server_address((request.host, request.port))
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
self.c.establish_server_connection()
@ -1084,7 +1083,7 @@ class HTTPHandler(ProtocolHandler):
'\r\n'
)
return self.process_connect_request(self.c.server_conn.address)
else: # upstream proxy mode
elif self.c.config.mode == "upstream":
return None
else:
pass # CONNECT should never occur if we don't expect absolute-form requests
@ -1113,7 +1112,7 @@ class HTTPHandler(ProtocolHandler):
ssl = (flow.request.scheme == "https")
if self.c.config.http_form_in == self.c.config.http_form_out == "absolute": # Upstream Proxy mode
if self.c.config.mode == "upstream":
# The connection to the upstream proxy may have a state we may need to take into account.
connected_to = None
@ -1223,8 +1222,8 @@ class RequestReplayThread(threading.Thread):
form_out_backup = r.form_out
try:
# In all modes, we directly connect to the server displayed
if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode
server_address = self.config.get_upstream_server(self.flow.client_conn)[2:]
if self.config.mode == "upstream":
server_address = self.config.mode.get_upstream_server(self.flow.client_conn)[2:]
server = ServerConnection(server_address)
server.connect()
if r.scheme == "https":

View File

@ -3,7 +3,7 @@ import os
import re
from netlib import http_auth, certutils
from .. import utils, platform
from .primitives import ConstUpstreamServerResolver, TransparentUpstreamServerResolver
from .primitives import RegularProxyMode, TransparentProxyMode, UpstreamProxyMode, ReverseProxyMode
TRANSPARENT_SSL_PORTS = [443, 8443]
CONF_BASENAME = "mitmproxy"
@ -26,25 +26,17 @@ class ProxyConfig:
self.body_size_limit = body_size_limit
if mode == "transparent":
get_upstream_server = TransparentUpstreamServerResolver(platform.resolver(), TRANSPARENT_SSL_PORTS)
http_form_in_default, http_form_out_default = "relative", "relative"
self.mode = TransparentProxyMode(platform.resolver(), TRANSPARENT_SSL_PORTS)
elif mode == "reverse":
get_upstream_server = ConstUpstreamServerResolver(upstream_server)
http_form_in_default, http_form_out_default = "relative", "relative"
self.mode = ReverseProxyMode(upstream_server)
elif mode == "upstream":
get_upstream_server = ConstUpstreamServerResolver(upstream_server)
http_form_in_default, http_form_out_default = "absolute", "absolute"
elif upstream_server:
get_upstream_server = ConstUpstreamServerResolver(upstream_server)
http_form_in_default, http_form_out_default = "absolute", "relative"
self.mode = UpstreamProxyMode(upstream_server)
else:
get_upstream_server, http_form_in_default, http_form_out_default = None, "absolute", "relative"
http_form_in = http_form_in or http_form_in_default
http_form_out = http_form_out or http_form_out_default
self.mode = RegularProxyMode()
self.mode.http_form_in = http_form_in or self.mode.http_form_in
self.mode.http_form_out = http_form_out or self.mode.http_form_out
self.get_upstream_server = get_upstream_server
self.http_form_in = http_form_in
self.http_form_out = http_form_out
self.ignore = parse_host_pattern(ignore)
self.authenticator = authenticator
self.confdir = os.path.expanduser(confdir)
@ -74,13 +66,9 @@ def process_proxy_options(parser, options):
c += 1
mode = "upstream"
upstream_server = options.upstream_proxy
if options.manual_destination_server:
c += 1
mode = "manual"
upstream_server = options.manual_destination_server
if c > 1:
return parser.error("Transparent mode, reverse mode, upstream proxy mode and "
"specification of an upstream server are mutually exclusive.")
return parser.error("Transparent mode, reverse mode and upstream proxy mode "
"are mutually exclusive.")
if options.clientcerts:
options.clientcerts = os.path.expanduser(options.clientcerts)

View File

@ -11,28 +11,54 @@ class ProxyServerError(Exception):
pass
class UpstreamServerResolver(object):
def __call__(self, conn):
class ProxyMode(object):
http_form_in = None
http_form_out = None
def get_upstream_server(self, conn):
"""
Returns the address of the server to connect to.
Returns None if the address needs to be determined on the protocol level (regular proxy mode)
"""
raise NotImplementedError # pragma: nocover
raise NotImplementedError() # pragma: nocover
@property
def name(self):
return self.__class__.__name__.replace("ProxyMode", "").lower()
def __str__(self):
return self.name
def __eq__(self, other):
"""
Allow comparisions with "regular" etc.
"""
if isinstance(other, ProxyMode):
return self is other
else:
return self.name == other
def __ne__(self, other):
return not self.__eq__(other)
class ConstUpstreamServerResolver(UpstreamServerResolver):
def __init__(self, dst):
self.dst = dst
class RegularProxyMode(ProxyMode):
http_form_in = "absolute"
http_form_out = "relative"
def __call__(self, conn):
return self.dst
def get_upstream_server(self, conn):
return None
class TransparentUpstreamServerResolver(UpstreamServerResolver):
class TransparentProxyMode(ProxyMode):
http_form_in = "relative"
http_form_out = "relative"
def __init__(self, resolver, sslports):
self.resolver = resolver
self.sslports = sslports
def __call__(self, conn):
def get_upstream_server(self, conn):
try:
dst = self.resolver.original_addr(conn)
except Exception, e:
@ -45,6 +71,24 @@ class TransparentUpstreamServerResolver(UpstreamServerResolver):
return [ssl, ssl] + list(dst)
class _ConstDestinationProxyMode(ProxyMode):
def __init__(self, dst):
self.dst = dst
def get_upstream_server(self, conn):
return self.dst
class ReverseProxyMode(_ConstDestinationProxyMode):
http_form_in = "relative"
http_form_out = "relative"
class UpstreamProxyMode(_ConstDestinationProxyMode):
http_form_in = "absolute"
http_form_out = "absolute"
class Log:
def __init__(self, msg, level="info"):
self.msg = msg

View File

@ -73,14 +73,16 @@ class ConnectionHandler:
# Can we already identify the target server and connect to it?
client_ssl, server_ssl = False, False
if self.config.get_upstream_server:
upstream_info = self.config.get_upstream_server(self.client_conn.connection)
upstream_info = self.config.mode.get_upstream_server(self.client_conn.connection)
if upstream_info:
self.set_server_address(upstream_info[2:])
client_ssl, server_ssl = upstream_info[:2]
if self.check_ignore_address(self.server_conn.address):
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
self.conntype = "tcp"
client_ssl, server_ssl = False, False
else:
pass # No upstream info from the metadata: upstream info in the protocol (e.g. HTTP absolute-form)
self.channel.ask("clientconnect", self)

View File

@ -91,10 +91,6 @@ class TestProcessProxyOptions:
self.assert_err("expected one argument", "-U")
self.assert_err("Invalid server specification", "-U", "upstream")
self.assert_noerr("--destination-server", "http://localhost")
self.assert_err("expected one argument", "--destination-server")
self.assert_err("Invalid server specification", "--destination-server", "manual")
self.assert_err("mutually exclusive", "-R", "http://localhost", "-T")
def test_client_certs(self):
@ -144,7 +140,8 @@ class TestDummyServer:
class TestConnectionHandler:
def test_fatal_error(self):
config = dict(get_upstream_server=mock.Mock(side_effect=RuntimeError))
config = mock.Mock()
config.mode.get_upstream_server.side_effect = RuntimeError
c = ConnectionHandler(config, mock.MagicMock(), ("127.0.0.1", 8080), None, mock.MagicMock(), None)
with tutils.capture_stderr(c.handle) as output:
assert "mitmproxy has crashed" in output

View File

@ -6,7 +6,7 @@ import mock
from libmproxy.proxy.config import ProxyConfig
from libmproxy.proxy.server import ProxyServer
from libmproxy.proxy.primitives import TransparentUpstreamServerResolver
from libmproxy.proxy.primitives import TransparentProxyMode
import libpathod.test, libpathod.pathoc
from libmproxy import flow, controller
from libmproxy.cmdline import APP_HOST, APP_PORT
@ -184,7 +184,7 @@ class TransparentProxTest(ProxTestBase):
ports = [cls.server.port, cls.server2.port]
else:
ports = []
cls.config.get_upstream_server = TransparentUpstreamServerResolver(cls.resolver(cls.server.port), ports)
cls.config.mode = TransparentProxyMode(cls.resolver(cls.server.port), ports)
@classmethod
def get_proxy_config(cls):
@ -224,8 +224,7 @@ class ReverseProxTest(ProxTestBase):
"127.0.0.1",
cls.server.port
)
d["http_form_in"] = "relative"
d["http_form_out"] = "relative"
d["mode"] = "reverse"
return d
def pathoc(self, sni=None):