simplify server changes for inline scripts

This commit is contained in:
Maximilian Hils 2014-09-03 20:12:30 +02:00
parent b0cfeff06d
commit cd43c5ba9c
5 changed files with 115 additions and 73 deletions

View File

@ -26,7 +26,7 @@ def get_line(fp):
return line
def send_connect_request(conn, host, port):
def send_connect_request(conn, host, port, update_state=True):
upstream_request = HTTPRequest("authority", "CONNECT", None, host, port, None,
(1, 1), ODictCaseless(), "")
conn.send(upstream_request._assemble())
@ -36,6 +36,12 @@ def send_connect_request(conn, host, port):
"Cannot establish SSL " +
"connection with upstream proxy: \r\n" +
str(resp._assemble()))
if update_state:
conn.state.append(("http", {
"state": "connect",
"host": host,
"port": port}
))
return resp
@ -545,8 +551,7 @@ class HTTPRequest(HTTPMessage):
flow.live.change_server((host, port), ssl=is_ssl)
else:
# There's not live server connection, we're just changing the attributes here.
flow.server_conn = ServerConnection((host, port),
proxy.AddressPriority.MANUALLY_CHANGED)
flow.server_conn = ServerConnection((host, port))
flow.server_conn.ssl_established = is_ssl
# If this is an absolute request, replace the attributes on the request object as well.
@ -815,7 +820,7 @@ class HTTPFlow(Flow):
s = "<HTTPFlow"
for a in ("request", "response", "error", "client_conn", "server_conn"):
if getattr(self, a, False):
s += "\r\n %s = {flow.%s}" % (a,a)
s += "\r\n %s = {flow.%s}" % (a, a)
s += ">"
return s.format(flow=self)
@ -950,8 +955,7 @@ class HTTPHandler(ProtocolHandler):
# sent through to the Master.
flow.request = req
request_reply = self.c.channel.ask("request", flow)
self.determine_server_address(flow, flow.request) # The inline script may have changed request.host
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
self.process_server_address(flow) # The inline script may have changed request.host
if request_reply is None or request_reply == KILL:
return False
@ -1048,7 +1052,7 @@ class HTTPHandler(ProtocolHandler):
def handle_server_reconnect(self, state):
if state["state"] == "connect":
send_connect_request(self.c.server_conn, state["host"], state["port"])
send_connect_request(self.c.server_conn, state["host"], state["port"], update_state=False)
else: # pragma: nocover
raise RuntimeError("Unknown State: %s" % state["state"])
@ -1114,14 +1118,30 @@ class HTTPHandler(ProtocolHandler):
if not self.skip_authentication:
self.authenticate(request)
# Determine .scheme, .host and .port attributes
# For absolute-form requests, they are directly given in the request.
# For authority-form requests, we only need to determine the request scheme.
# For relative-form requests, we need to determine host and port as well.
if not request.scheme:
request.scheme = "https" if flow.server_conn and flow.server_conn.ssl_established else "http"
if not request.host:
# Host/Port Complication: In upstream mode, use the server we CONNECTed to,
# not the upstream proxy.
for s in flow.server_conn.state:
if s[0] == "http" and s[1]["state"] == "connect":
request.host, request.port = s[1]["host"], s[1]["port"]
if not request.host:
request.host = flow.server_conn.address.host
request.port = flow.server_conn.address.port
# Now we can process the request.
if request.form_in == "authority":
if self.c.client_conn.ssl_established:
raise http.HttpError(400, "Must not CONNECT on already encrypted connection")
if self.expected_form_in == "absolute":
if not self.c.config.get_upstream_server:
self.c.set_server_address((request.host, request.port),
proxy.AddressPriority.FROM_PROTOCOL)
if not self.c.config.get_upstream_server: # Regular mode
self.c.set_server_address((request.host, request.port))
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
self.c.establish_server_connection()
self.c.client_conn.send(
@ -1140,24 +1160,63 @@ class HTTPHandler(ProtocolHandler):
self.ssl_upgrade()
self.skip_authentication = True
return True
else:
else: # upstream proxy mode
return None
else:
pass # CONNECT should never occur if we don't expect absolute-form requests
elif request.form_in == self.expected_form_in:
request.form_out = self.expected_form_out
if request.form_in == "absolute":
if request.scheme != "http":
raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme)
self.determine_server_address(flow, request)
request.form_out = self.expected_form_out
if request.form_out == "relative":
self.c.set_server_address((request.host, request.port))
flow.server_conn = self.c.server_conn
return None
raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" %
(self.expected_form_in, request.form_in))
def determine_server_address(self, flow, request):
if request.form_in == "absolute":
self.c.set_server_address((request.host, request.port),
proxy.AddressPriority.FROM_PROTOCOL)
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
def process_server_address(self, flow):
# Depending on the proxy mode, server handling is entirely different
# We provide a mostly unified API to the user, which needs to be unfiddled here
# ( See also: https://github.com/mitmproxy/mitmproxy/issues/337 )
address = netlib.tcp.Address((flow.request.host, flow.request.port))
ssl = (flow.request.scheme == "https")
if self.c.config.http_form_in == self.c.config.http_form_out == "absolute": # Upstream Proxy mode
# The connection to the upstream proxy may have a state we may need to take into account.
connected_to = None
for s in flow.server_conn.state:
if s[0] == "http" and s[1]["state"] == "connect":
connected_to = tcp.Address((s[1]["host"], s[1]["port"]))
# We need to reconnect if the current flow either requires a (possibly impossible)
# change to the connection state, e.g. the host has changed but we already CONNECTed somewhere else.
needs_server_change = (
ssl != self.c.server_conn.ssl_established
or
(connected_to and address != connected_to) # HTTP proxying is "stateless", CONNECT isn't.
)
if needs_server_change:
# force create new connection to the proxy server to reset state
self.live.change_server(self.c.server_conn.address, force=True)
if ssl:
send_connect_request(self.c.server_conn, address.host, address.port)
self.c.establish_ssl(server=True)
else:
# If we're not in upstream mode, we just want to update the host and possibly establish TLS.
self.live.change_server(address, ssl=ssl) # this is a no op if the addresses match.
flow.server_conn = self.c.server_conn
def authenticate(self, request):
if self.c.config.authenticator:

View File

@ -2,7 +2,6 @@ from __future__ import absolute_import
import copy
import netlib.tcp
from .. import stateobject, utils, version
from ..proxy.primitives import AddressPriority
from ..proxy.connection import ClientConnection, ServerConnection
@ -153,44 +152,48 @@ class LiveConnection(object):
without requiring the expose the ConnectionHandler.
"""
def __init__(self, c):
self._c = c
self.c = c
self._backup_server_conn = None
"""@type: libmproxy.proxy.server.ConnectionHandler"""
def change_server(self, address, ssl, persistent_change=False):
def change_server(self, address, ssl=False, force=False, persistent_change=False):
address = netlib.tcp.Address.wrap(address)
if address != self._c.server_conn.address:
if force or address != self.c.server_conn.address or ssl != self.c.server_conn.ssl_established:
self._c.log("Change server connection: %s:%s -> %s:%s" % (
self._c.server_conn.address.host,
self._c.server_conn.address.port,
self.c.log("Change server connection: %s:%s -> %s:%s [persistent: %s]" % (
self.c.server_conn.address.host,
self.c.server_conn.address.port,
address.host,
address.port
address.port,
persistent_change
), "debug")
if not hasattr(self, "_backup_server_conn"):
self._backup_server_conn = self._c.server_conn
self._c.server_conn = None
if self._backup_server_conn:
self._backup_server_conn = self.c.server_conn
self.c.server_conn = None
else: # This is at least the second temporary change. We can kill the current connection.
self._c.del_server_connection()
self.c.del_server_connection()
self._c.set_server_address(address, AddressPriority.MANUALLY_CHANGED)
self._c.establish_server_connection(ask=False)
self.c.set_server_address(address)
self.c.establish_server_connection(ask=False)
if ssl:
self._c.establish_ssl(server=True)
if hasattr(self, "_backup_server_conn") and persistent_change:
del self._backup_server_conn
self.c.establish_ssl(server=True)
if persistent_change:
self._backup_server_conn = None
def restore_server(self):
if not hasattr(self, "_backup_server_conn"):
# TODO: Similar to _backup_server_conn, introduce _cache_server_conn, which keeps the changed connection open
# This may be beneficial if a user is rewriting all requests from http to https or similar.
if not self._backup_server_conn:
return
self._c.log("Restore original server connection: %s:%s -> %s:%s" % (
self._c.server_conn.address.host,
self._c.server_conn.address.port,
self.c.log("Restore original server connection: %s:%s -> %s:%s" % (
self.c.server_conn.address.host,
self.c.server_conn.address.port,
self._backup_server_conn.address.host,
self._backup_server_conn.address.port
), "debug")
self._c.del_server_connection()
self._c.server_conn = self._backup_server_conn
del self._backup_server_conn
self.c.del_server_connection()
self.c.server_conn = self._backup_server_conn
self._backup_server_conn = None

View File

@ -72,9 +72,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject):
class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
def __init__(self, address, priority):
def __init__(self, address):
tcp.TCPClient.__init__(self, address)
self.priority = priority
self.state = [] # a list containing (conntype, state) tuples
self.peername = None
@ -131,7 +130,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject):
@classmethod
def _from_state(cls, state):
f = cls(tuple(), None)
f = cls(tuple())
f._load_state(state)
return f

View File

@ -45,19 +45,6 @@ class TransparentUpstreamServerResolver(UpstreamServerResolver):
return [ssl, ssl] + list(dst)
class AddressPriority(object):
"""
Enum that signifies the priority of the given address when choosing the destination host.
Higher is better (None < i)
"""
MANUALLY_CHANGED = 3
"""user changed the target address in the ui"""
FROM_SETTINGS = 2
"""upstream server from arguments (reverse proxy, upstream proxy or from transparent resolver)"""
FROM_PROTOCOL = 1
"""derived from protocol (e.g. absolute-form http requests)"""
class Log:
def __init__(self, msg, level="info"):
self.msg = msg

View File

@ -5,7 +5,7 @@ import socket
from OpenSSL import SSL
from netlib import tcp
from .primitives import ProxyServerError, Log, ProxyError, AddressPriority
from .primitives import ProxyServerError, Log, ProxyError
from .connection import ClientConnection, ServerConnection
from ..protocol.handle import protocol_handler
from .. import version
@ -76,7 +76,7 @@ class ConnectionHandler:
client_ssl, server_ssl = False, False
if self.config.get_upstream_server:
upstream_info = self.config.get_upstream_server(self.client_conn.connection)
self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS)
self.set_server_address(upstream_info[2:])
client_ssl, server_ssl = upstream_info[:2]
if self.check_ignore_address(self.server_conn.address):
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
@ -129,27 +129,22 @@ class ConnectionHandler:
else:
return False
def set_server_address(self, address, priority):
def set_server_address(self, address):
"""
Sets a new server address with the given priority.
Does not re-establish either connection or SSL handshake.
"""
address = tcp.Address.wrap(address)
if self.server_conn:
if self.server_conn.priority > priority:
self.log("Attempt to change server address, "
"but priority is too low (is: %s, got: %s)" % (
self.server_conn.priority, priority), "debug")
return
if self.server_conn.address == address:
self.server_conn.priority = priority # Possibly increase priority
return
# Don't reconnect to the same destination.
if self.server_conn and self.server_conn.address == address:
return
if self.server_conn:
self.del_server_connection()
self.log("Set new server address: %s:%s" % (address.host, address.port), "debug")
self.server_conn = ServerConnection(address, priority)
self.server_conn = ServerConnection(address)
def establish_server_connection(self, ask=True):
"""
@ -212,12 +207,11 @@ class ConnectionHandler:
def server_reconnect(self):
address = self.server_conn.address
had_ssl = self.server_conn.ssl_established
priority = self.server_conn.priority
state = self.server_conn.state
sni = self.sni
self.log("(server reconnect follows)", "debug")
self.del_server_connection()
self.set_server_address(address, priority)
self.set_server_address(address)
self.establish_server_connection()
for s in state: