add priorities for the destination server address

This commit is contained in:
Maximilian Hils 2014-02-04 02:56:59 +01:00
parent 2db5f9de26
commit f6253a80ff
2 changed files with 114 additions and 55 deletions

View File

@ -5,7 +5,7 @@ from netlib import http, tcp, http_status, odict
from netlib.odict import ODict, ODictCaseless from netlib.odict import ODict, ODictCaseless
from . import ProtocolHandler, ConnectionTypeChange, KILL from . import ProtocolHandler, ConnectionTypeChange, KILL
from .. import encoding, utils, version, filt, controller, stateobject from .. import encoding, utils, version, filt, controller, stateobject
from ..proxy import ProxyError from ..proxy import ProxyError, AddressPriority
from ..flow import Flow, Error from ..flow import Flow, Error
@ -816,7 +816,7 @@ class HTTPHandler(ProtocolHandler):
raise v raise v
def handle_flow(self): def handle_flow(self):
flow = HTTPFlow(self.c.client_conn, self.c.server_conn, None, None, None) flow = HTTPFlow(self.c.client_conn, self.c.server_conn)
try: try:
flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile, flow.request = HTTPRequest.from_stream(self.c.client_conn.rfile,
body_size_limit=self.c.config.body_size_limit) body_size_limit=self.c.config.body_size_limit)
@ -831,9 +831,10 @@ class HTTPHandler(ProtocolHandler):
flow.response = request_reply flow.response = request_reply
else: else:
self.process_request(flow.request) self.process_request(flow.request)
self.c.establish_server_connection()
flow.response = self.get_response_from_server(flow.request) flow.response = self.get_response_from_server(flow.request)
self.c.log("response", [flow.response._assemble_response_line()]) self.c.log("response", [flow.response._assemble_first_line()])
response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse", response_reply = self.c.channel.ask("response" if LEGACY else "httpresponse",
flow.response if LEGACY else flow) flow.response if LEGACY else flow)
if response_reply is None or response_reply == KILL: if response_reply is None or response_reply == KILL:
@ -853,16 +854,6 @@ class HTTPHandler(ProtocolHandler):
flow.server_conn = self.c.server_conn flow.server_conn = self.c.server_conn
"""
FIXME: Remove state test
d = flow._get_state()
print d
flow._load_state(d)
print flow._get_state()
copy = HTTPFlow._from_state(d)
print copy._get_state()
"""
return True return True
except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e: except (HttpAuthenticationError, http.HttpError, ProxyError, tcp.NetLibError), e:
self.handle_error(e, flow) self.handle_error(e, flow)
@ -887,8 +878,10 @@ class HTTPHandler(ProtocolHandler):
if flow: if flow:
flow.error = Error(err) flow.error = Error(err)
self.c.channel.ask("error" if LEGACY else "httperror", if not (LEGACY and not flow.request) and not (LEGACY and flow.request and flow.response):
flow.error if LEGACY else flow) # no flows without request or with both request and response in legacy mode
self.c.channel.ask("error" if LEGACY else "httperror",
flow.error if LEGACY else flow)
else: else:
pass # FIXME: Is there any use case for persisting errors that occur outside of flows? pass # FIXME: Is there any use case for persisting errors that occur outside of flows?
@ -923,6 +916,7 @@ class HTTPHandler(ProtocolHandler):
This isn't particular beautiful code, but it isolates this rare edge-case from the This isn't particular beautiful code, but it isolates this rare edge-case from the
protocol-agnostic ConnectionHandler protocol-agnostic ConnectionHandler
""" """
self.c.log("Received CONNECT request. Upgrading to SSL...")
self.c.mode = "transparent" self.c.mode = "transparent"
self.c.determine_conntype() self.c.determine_conntype()
self.c.establish_ssl(server=True, client=True) self.c.establish_ssl(server=True, client=True)
@ -933,7 +927,7 @@ class HTTPHandler(ProtocolHandler):
def reconnect_http_proxy(): def reconnect_http_proxy():
self.c.log("Hooked reconnect function") self.c.log("Hooked reconnect function")
self.c.log("Hook: Run original redirect") self.c.log("Hook: Run original reconnect")
original_reconnect_func(no_ssl=True) original_reconnect_func(no_ssl=True)
self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()]) self.c.log("Hook: Write CONNECT request to upstream proxy", [upstream_request._assemble_first_line()])
self.c.server_conn.wfile.write(upstream_request._assemble()) self.c.server_conn.wfile.write(upstream_request._assemble())
@ -948,6 +942,7 @@ class HTTPHandler(ProtocolHandler):
self.c.server_reconnect = reconnect_http_proxy self.c.server_reconnect = reconnect_http_proxy
self.c.log("Upgrade to SSL completed.")
raise ConnectionTypeChange raise ConnectionTypeChange
def process_request(self, request): def process_request(self, request):
@ -958,9 +953,10 @@ class HTTPHandler(ProtocolHandler):
# If we have a CONNECT request, we might need to intercept # If we have a CONNECT request, we might need to intercept
if request.form_in == "authority": if request.form_in == "authority":
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy directly_addressed_at_mitmproxy = (self.c.mode == "regular" and not self.c.config.forward_proxy)
if directly_addressed_at_mitmproxy: if directly_addressed_at_mitmproxy:
self.c.establish_server_connection((request.host, request.port)) self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL)
self.c.establish_server_connection()
self.c.client_conn.wfile.write( self.c.client_conn.wfile.write(
'HTTP/1.1 200 Connection established\r\n' + 'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n' % self.c.server_version) + ('Proxy-agent: %s\r\n' % self.c.server_version) +
@ -977,9 +973,7 @@ class HTTPHandler(ProtocolHandler):
raise http.HttpError(400, "Invalid Request") raise http.HttpError(400, "Invalid Request")
if not self.c.config.forward_proxy: if not self.c.config.forward_proxy:
request.form_out = "origin" request.form_out = "origin"
if ((not self.c.server_conn) or self.c.set_server_address((request.host, request.port), AddressPriority.FROM_PROTOCOL)
(self.c.server_conn.address != (request.host, request.port))):
self.c.establish_server_connection((request.host, request.port))
else: else:
raise http.HttpError(400, "Invalid Request") raise http.HttpError(400, "Invalid Request")

View File

@ -7,6 +7,23 @@ import utils, version, platform, controller, stateobject
TRANSPARENT_SSL_PORTS = [443, 8443] TRANSPARENT_SSL_PORTS = [443, 8443]
class AddressPriority(object):
"""
Enum that signifies the priority of the given address when choosing the destination host.
Higher is better (None < i)
"""
FORCE = 5
"""forward mode"""
MANUALLY_CHANGED = 4
"""user changed the target address in the ui"""
FROM_SETTINGS = 3
"""reverse proxy mode"""
FROM_CONNECTION = 2
"""derived from transparent resolver"""
FROM_PROTOCOL = 1
"""derived from protocol (e.g. absolute-form http requests)"""
class ProxyError(Exception): class ProxyError(Exception):
def __init__(self, code, msg, headers=None): def __init__(self, code, msg, headers=None):
self.code, self.msg, self.headers = code, msg, headers self.code, self.msg, self.headers = code, msg, headers
@ -189,6 +206,7 @@ class ConnectionHandler:
self.close = False self.close = False
self.conntype = None self.conntype = None
self.sni = None self.sni = None
self.server_address_priority = None
self.mode = "regular" self.mode = "regular"
if self.config.reverse_proxy: if self.config.reverse_proxy:
@ -196,14 +214,6 @@ class ConnectionHandler:
if self.config.transparent_proxy: if self.config.transparent_proxy:
self.mode = "transparent" self.mode = "transparent"
def del_server_connection(self):
if self.server_conn and self.server_conn.connection:
self.server_conn.finish()
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)])
self.channel.tell("serverdisconnect", self)
self.server_conn = None
self.sni = None
def handle(self): def handle(self):
self.log("clientconnect") self.log("clientconnect")
self.channel.ask("clientconnect", self) self.channel.ask("clientconnect", self)
@ -214,20 +224,23 @@ class ConnectionHandler:
try: try:
# Can we already identify the target server and connect to it? # Can we already identify the target server and connect to it?
server_address = None server_address = None
address_priority = None
if self.config.forward_proxy: if self.config.forward_proxy:
server_address = self.config.forward_proxy[1:] server_address = self.config.forward_proxy[1:]
else: address_priority = AddressPriority.FORCE
if self.config.reverse_proxy: elif self.config.reverse_proxy:
server_address = self.config.reverse_proxy[1:] server_address = self.config.reverse_proxy[1:]
elif self.config.transparent_proxy: address_priority = AddressPriority.FROM_SETTINGS
server_address = self.config.transparent_proxy["resolver"].original_addr( elif self.config.transparent_proxy:
self.client_conn.connection) server_address = self.config.transparent_proxy["resolver"].original_addr(
if not server_address: self.client_conn.connection)
raise ProxyError(502, "Transparent mode failure: could not resolve original destination.") if not server_address:
self.log("transparent to %s:%s" % server_address) raise ProxyError(502, "Transparent mode failure: could not resolve original destination.")
address_priority = AddressPriority.FROM_CONNECTION
self.log("transparent to %s:%s" % server_address)
if server_address: if server_address:
self.establish_server_connection(server_address) self.set_server_address(server_address, address_priority)
self._handle_ssl() self._handle_ssl()
while not self.close: while not self.close:
@ -252,53 +265,95 @@ class ConnectionHandler:
def _handle_ssl(self): def _handle_ssl(self):
""" """
Helper function of .handle()
Check if we can already identify SSL connections. Check if we can already identify SSL connections.
If so, connect to the server and establish an SSL connection
""" """
client_ssl = False
server_ssl = False
if self.config.transparent_proxy: if self.config.transparent_proxy:
client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"]) client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"])
elif self.config.reverse_proxy: elif self.config.reverse_proxy:
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
# TODO: Make protocol generic (as with transparent proxies) # TODO: Make protocol generic (as with transparent proxies)
# TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa) # TODO: Add SSL-terminating capatbility (SSL -> mitmproxy -> plain and vice versa)
self.establish_ssl(client=client_ssl, server=server_ssl) if client_ssl or server_ssl:
self.establish_server_connection()
self.establish_ssl(client=client_ssl, server=server_ssl)
def finish(self): def del_server_connection(self):
self.client_conn.finish() """
Deletes an existing server connection.
"""
if self.server_conn and self.server_conn.connection:
self.server_conn.finish()
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.address.host, self.server_conn.address.port)])
self.channel.tell("serverdisconnect", self)
self.server_conn = None
self.server_address_priority = None
self.sni = None
def determine_conntype(self): def determine_conntype(self):
#TODO: Add ruleset to select correct protocol depending on mode/target port etc. #TODO: Add ruleset to select correct protocol depending on mode/target port etc.
self.conntype = "http" self.conntype = "http"
def establish_server_connection(self, address): def set_server_address(self, address, priority):
""" """
Establishes a new server connection to the given server Sets a new server address with the given priority
If there is already an existing server connection, it will be killed. @type priority: AddressPriority
""" """
self.del_server_connection() address = tcp.Address.wrap(address)
self.server_conn = ServerConnection(address) self.log("Set server address: %s:%s" % (address.host, address.port))
if self.server_conn and (self.server_address_priority > priority):
self.log("Server address priority too low (is: %s, got: %s)" % (self.server_address_priority, priority))
return
self.address_priority = priority
if self.server_conn and (self.server_conn.address == address):
self.log("Addresses match, skip.")
return
server_conn = ServerConnection(address)
if self.server_conn and self.server_conn.connection:
self.del_server_connection()
self.server_conn = server_conn
self.establish_server_connection()
else:
self.server_conn = server_conn
def establish_server_connection(self):
"""
Establishes a new server connection.
If there is already an existing server connection, the function returns immediately.
"""
if self.server_conn.connection:
return
self.log("serverconnect", ["%s:%s" % self.server_conn.address()[:2]])
self.channel.tell("serverconnect", self)
try: try:
self.server_conn.connect() self.server_conn.connect()
except tcp.NetLibError, v: except tcp.NetLibError, v:
raise ProxyError(502, v) raise ProxyError(502, v)
self.log("serverconnect", ["%s:%s" % address[:2]])
self.channel.tell("serverconnect", self)
def establish_ssl(self, client=False, server=False): def establish_ssl(self, client=False, server=False):
""" """
Establishes SSL on the existing connection(s) to the server or the client, Establishes SSL on the existing connection(s) to the server or the client,
as specified by the parameters. If the target server is on the pass-through list, as specified by the parameters. If the target server is on the pass-through list,
the conntype attribute will be changed and no the SSL connection won't be wrapped. the conntype attribute will be changed and the SSL connection won't be wrapped.
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
""" """
# TODO: Implement SSL pass-through handling and change conntype # TODO: Implement SSL pass-through handling and change conntype
passthrough = [ passthrough = [
"echo.websocket.org", "echo.websocket.org",
"174.129.224.73" # echo.websocket.org, transparent mode "174.129.224.73" # echo.websocket.org, transparent mode
] ]
if self.server_conn.address.host in passthrough or self.sni in passthrough: if self.server_conn.address.host in passthrough or self.sni in passthrough:
self.conntype = "tcp" self.conntype = "tcp"
return return
# Logging
if client or server: if client or server:
subs = [] subs = []
if client: if client:
@ -319,13 +374,21 @@ class ConnectionHandler:
handle_sni=self.handle_sni) handle_sni=self.handle_sni)
def server_reconnect(self, no_ssl=False): def server_reconnect(self, no_ssl=False):
had_ssl, sni = self.server_conn.ssl_established, self.sni address = self.server_conn.address
had_ssl = self.server_conn.ssl_established
priority = self.server_address_priority
sni = self.sni
self.log("(server reconnect follows)") self.log("(server reconnect follows)")
self.establish_server_connection(self.server_conn.address()) self.del_server_connection()
self.set_server_address(address, priority)
self.establish_server_connection()
if had_ssl and not no_ssl: if had_ssl and not no_ssl:
self.sni = sni self.sni = sni
self.establish_ssl(server=True) self.establish_ssl(server=True)
def finish(self):
self.client_conn.finish()
def log(self, msg, subs=()): def log(self, msg, subs=()):
msg = [ msg = [
"%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg) "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg)
@ -363,6 +426,7 @@ class ConnectionHandler:
sn = connection.get_servername() sn = connection.get_servername()
if sn and sn != self.sni: if sn and sn != self.sni:
self.sni = sn.decode("utf8").encode("idna") self.sni = sn.decode("utf8").encode("idna")
self.log("SNI received: %s" % self.sni)
self.server_reconnect() # reconnect to upstream server with SNI self.server_reconnect() # reconnect to upstream server with SNI
# Now, change client context to reflect changed certificate: # Now, change client context to reflect changed certificate:
new_context = SSL.Context(SSL.TLSv1_METHOD) new_context = SSL.Context(SSL.TLSv1_METHOD)
@ -372,11 +436,12 @@ class ConnectionHandler:
connection.set_context(new_context) connection.set_context(new_context)
# An unhandled exception in this method will core dump PyOpenSSL, so # An unhandled exception in this method will core dump PyOpenSSL, so
# make dang sure it doesn't happen. # make dang sure it doesn't happen.
except Exception, e: # pragma: no cover except Exception, e: # pragma: no cover
pass pass
class ProxyServerError(Exception): pass class ProxyServerError(Exception):
pass
class ProxyServer(tcp.TCPServer): class ProxyServer(tcp.TCPServer):