mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
refactor protocol handling, fix #332
This commit is contained in:
parent
82730c1c6f
commit
1f47f7b6b2
@ -6,21 +6,12 @@ protocols = {
|
||||
'tcp': dict(handler=tcp.TCPHandler)
|
||||
}
|
||||
|
||||
def protocol_handler(protocol):
|
||||
"""
|
||||
@type protocol: str
|
||||
@returns: libmproxy.protocol.primitives.ProtocolHandler
|
||||
"""
|
||||
if protocol in protocols:
|
||||
return protocols[protocol]["handler"]
|
||||
|
||||
def _handler(conntype, connection_handler):
|
||||
if conntype in protocols:
|
||||
return protocols[conntype]["handler"](connection_handler)
|
||||
|
||||
raise NotImplementedError # pragma: nocover
|
||||
|
||||
|
||||
def handle_messages(conntype, connection_handler):
|
||||
return _handler(conntype, connection_handler).handle_messages()
|
||||
|
||||
|
||||
def handle_error(conntype, connection_handler, error):
|
||||
return _handler(conntype, connection_handler).handle_error(error)
|
||||
|
||||
|
||||
def handle_server_reconnect(conntype, connection_handler, state):
|
||||
return _handler(conntype, connection_handler).handle_server_reconnect(state)
|
||||
raise NotImplementedError("Unknown Protocol: %s" % protocol) # pragma: nocover
|
@ -5,7 +5,8 @@ import threading
|
||||
from netlib import http, tcp, http_status
|
||||
import netlib.utils
|
||||
from netlib.odict import ODict, ODictCaseless
|
||||
from .primitives import KILL, ProtocolHandler, LiveConnection, Flow, Error
|
||||
from .tcp import TCPHandler
|
||||
from .primitives import KILL, ProtocolHandler, Flow, Error
|
||||
from ..proxy.connection import ServerConnection
|
||||
from .. import encoding, utils, controller, stateobject, proxy
|
||||
|
||||
@ -914,7 +915,6 @@ class HTTPHandler(ProtocolHandler):
|
||||
def handle_messages(self):
|
||||
while self.handle_flow():
|
||||
pass
|
||||
self.c.close = True
|
||||
|
||||
def get_response_from_server(self, request, include_body=True):
|
||||
self.c.establish_server_connection()
|
||||
@ -948,9 +948,9 @@ class HTTPHandler(ProtocolHandler):
|
||||
req = HTTPRequest.from_stream(self.c.client_conn.rfile,
|
||||
body_size_limit=self.c.config.body_size_limit)
|
||||
self.c.log("request", "debug", [req._assemble_first_line(req.form_in)])
|
||||
send_request_upstream = self.process_request(flow, req)
|
||||
if not send_request_upstream:
|
||||
return True
|
||||
ret = self.process_request(flow, req)
|
||||
if ret is not None:
|
||||
return ret
|
||||
|
||||
# Be careful NOT to assign the request to the flow before
|
||||
# process_request completes. This is because the call can raise an
|
||||
@ -959,7 +959,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
# sent through to the Master.
|
||||
flow.request = req
|
||||
request_reply = self.c.channel.ask("request", flow.request)
|
||||
self.determine_server_address(flow, flow.request)
|
||||
self.determine_server_address(flow, flow.request) # The inline script may have changed request.host
|
||||
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
|
||||
|
||||
if request_reply is None or request_reply == KILL:
|
||||
@ -1025,6 +1025,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
else:
|
||||
return False
|
||||
|
||||
# We sent a CONNECT request to an upstream proxy.
|
||||
if flow.request.form_in == "authority" and flow.response.code == 200:
|
||||
# TODO: Eventually add headers (space/usefulness tradeoff)
|
||||
# Make sure to add state info before the actual upgrade happens.
|
||||
@ -1034,7 +1035,15 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.c.server_conn.state.append(("http", {"state": "connect",
|
||||
"host": flow.request.host,
|
||||
"port": flow.request.port}))
|
||||
self.ssl_upgrade()
|
||||
|
||||
if self.c.check_ignore_address((flow.request.host, flow.request.port)):
|
||||
self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
|
||||
TCPHandler(self.c).handle_messages()
|
||||
return False
|
||||
else:
|
||||
if flow.request.port in self.c.config.ssl_ports:
|
||||
self.ssl_upgrade()
|
||||
self.skip_authentication = True
|
||||
|
||||
# If the user has changed the target server on this connection,
|
||||
# restore the original target server
|
||||
@ -1065,7 +1074,7 @@ class HTTPHandler(ProtocolHandler):
|
||||
|
||||
if flow:
|
||||
flow.error = Error(message)
|
||||
# FIXME: no flows without request or with both request and response at the moement.
|
||||
# FIXME: no flows without request or with both request and response at the moment.
|
||||
if flow.request and not flow.response:
|
||||
self.c.channel.ask("error", flow.error)
|
||||
else:
|
||||
@ -1103,6 +1112,13 @@ class HTTPHandler(ProtocolHandler):
|
||||
self.c.log("Upgrade to SSL completed.", "debug")
|
||||
|
||||
def process_request(self, flow, request):
|
||||
"""
|
||||
@returns:
|
||||
True, if the request should not be sent upstream
|
||||
False, if the connection should be aborted
|
||||
None, if the request should be sent upstream
|
||||
(a status code != None should be returned directly by handle_flow)
|
||||
"""
|
||||
|
||||
if not self.skip_authentication:
|
||||
self.authenticate(request)
|
||||
@ -1115,8 +1131,8 @@ class HTTPHandler(ProtocolHandler):
|
||||
if not self.c.config.get_upstream_server:
|
||||
self.c.set_server_address((request.host, request.port),
|
||||
proxy.AddressPriority.FROM_PROTOCOL)
|
||||
self.c.establish_server_connection()
|
||||
flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow
|
||||
self.c.establish_server_connection()
|
||||
self.c.client_conn.send(
|
||||
'HTTP/1.1 200 Connection established\r\n' +
|
||||
'Content-Length: 0\r\n' +
|
||||
@ -1124,18 +1140,24 @@ class HTTPHandler(ProtocolHandler):
|
||||
'\r\n'
|
||||
)
|
||||
|
||||
self.ssl_upgrade()
|
||||
self.skip_authentication = True
|
||||
return False
|
||||
if self.c.check_ignore_address(self.c.server_conn.address):
|
||||
self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info")
|
||||
TCPHandler(self.c).handle_messages()
|
||||
return False
|
||||
else:
|
||||
if self.c.server_conn.address.port in self.c.config.ssl_ports:
|
||||
self.ssl_upgrade()
|
||||
self.skip_authentication = True
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
return None
|
||||
elif request.form_in == self.expected_form_in:
|
||||
if request.form_in == "absolute":
|
||||
if request.scheme != "http":
|
||||
raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme)
|
||||
self.determine_server_address(flow, request)
|
||||
request.form_out = self.expected_form_out
|
||||
return True
|
||||
return None
|
||||
|
||||
raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" %
|
||||
(self.expected_form_in, request.form_in))
|
||||
@ -1172,10 +1194,11 @@ class RequestReplayThread(threading.Thread):
|
||||
server_address, server_ssl = False, False
|
||||
if self.config.get_upstream_server:
|
||||
try:
|
||||
# this will fail in transparent mode
|
||||
upstream_info = self.config.get_upstream_server(self.flow.client_conn)
|
||||
server_ssl = upstream_info[1]
|
||||
server_address = upstream_info[2:]
|
||||
except proxy.ProxyError: # this will fail in transparent mode
|
||||
except proxy.ProxyError:
|
||||
pass
|
||||
if not server_address:
|
||||
server_address = (r.get_host(), r.get_port())
|
||||
@ -1184,7 +1207,7 @@ class RequestReplayThread(threading.Thread):
|
||||
server.connect()
|
||||
|
||||
if server_ssl or r.get_scheme() == "https":
|
||||
if self.config.http_form_out == "absolute":
|
||||
if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT
|
||||
send_connect_request(server, r.get_host(), r.get_port())
|
||||
r.form_out = "relative"
|
||||
server.establish_ssl(self.config.clientcerts,
|
||||
|
@ -18,7 +18,7 @@ class TCPHandler(ProtocolHandler):
|
||||
buf = memoryview(bytearray(self.chunk_size))
|
||||
|
||||
conns = [self.c.client_conn.rfile, self.c.server_conn.rfile]
|
||||
while not self.c.close:
|
||||
while True:
|
||||
r, _, _ = select.select(conns, [], [], 10)
|
||||
for rfile in r:
|
||||
if self.c.client_conn.rfile == rfile:
|
||||
@ -51,7 +51,7 @@ class TCPHandler(ProtocolHandler):
|
||||
dst.connection.shutdown(socket.SHUT_WR)
|
||||
|
||||
if len(conns) == 0:
|
||||
self.c.close = True
|
||||
return
|
||||
continue
|
||||
|
||||
if src.ssl_established or dst.ssl_established:
|
||||
|
@ -15,7 +15,7 @@ class ProxyConfig:
|
||||
no_upstream_cert=False, body_size_limit=None,
|
||||
mode=None, upstream_server=None, http_form_in=None, http_form_out=None,
|
||||
authenticator=None, ignore=[],
|
||||
ciphers=None, certs=[], certforward=False):
|
||||
ciphers=None, certs=[], certforward=False, ssl_ports=TRANSPARENT_SSL_PORTS):
|
||||
self.ciphers = ciphers
|
||||
self.clientcerts = clientcerts
|
||||
self.no_upstream_cert = no_upstream_cert
|
||||
@ -49,6 +49,7 @@ class ProxyConfig:
|
||||
for spec, cert in certs:
|
||||
self.certstore.add_cert_file(spec, cert)
|
||||
self.certforward = certforward
|
||||
self.ssl_ports = ssl_ports
|
||||
|
||||
|
||||
def process_proxy_options(parser, options):
|
||||
@ -157,4 +158,10 @@ def ssl_option_group(parser):
|
||||
"--no-upstream-cert", default=False,
|
||||
action="store_true", dest="no_upstream_cert",
|
||||
help="Don't connect to upstream server to look up certificate details."
|
||||
)
|
||||
group.add_argument(
|
||||
"--ssl-port", action="append", type=int, dest="ssl_ports", default=TRANSPARENT_SSL_PORTS,
|
||||
metavar="PORT",
|
||||
help="Can be passed multiple times. Specify destination ports which are assumed to be SSL. "
|
||||
"Defaults to %s." % str(TRANSPARENT_SSL_PORTS)
|
||||
)
|
@ -6,13 +6,6 @@ class ProxyError(Exception):
|
||||
super(ProxyError, self).__init__(self, message)
|
||||
self.code, self.headers = code, headers
|
||||
|
||||
class ConnectionTypeChange(Exception):
|
||||
"""
|
||||
Gets raised if the connection type has been changed (e.g. after HTTP/1.1 101 Switching Protocols).
|
||||
It's up to the raising ProtocolHandler to specify the new conntype before raising the exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ProxyServerError(Exception):
|
||||
pass
|
||||
|
@ -5,10 +5,9 @@ import socket
|
||||
from OpenSSL import SSL
|
||||
|
||||
from netlib import tcp
|
||||
from .primitives import ProxyServerError, Log, ProxyError, ConnectionTypeChange, \
|
||||
AddressPriority
|
||||
from .primitives import ProxyServerError, Log, ProxyError, AddressPriority
|
||||
from .connection import ClientConnection, ServerConnection
|
||||
from ..protocol.handle import handle_messages, handle_error, handle_server_reconnect
|
||||
from ..protocol.handle import protocol_handler
|
||||
from .. import version
|
||||
|
||||
|
||||
@ -66,7 +65,6 @@ class ConnectionHandler:
|
||||
"""@type: libmproxy.proxy.connection.ServerConnection"""
|
||||
self.channel, self.server_version = channel, server_version
|
||||
|
||||
self.close = False
|
||||
self.conntype = "http"
|
||||
self.sni = None
|
||||
|
||||
@ -77,28 +75,28 @@ class ConnectionHandler:
|
||||
# Can we already identify the target server and connect to it?
|
||||
client_ssl, server_ssl = False, False
|
||||
if self.config.get_upstream_server:
|
||||
upstream_info = self.config.get_upstream_server(
|
||||
self.client_conn.connection)
|
||||
upstream_info = self.config.get_upstream_server(self.client_conn.connection)
|
||||
self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS)
|
||||
client_ssl, server_ssl = upstream_info[:2]
|
||||
if self.check_ignore_address(self.server_conn.address):
|
||||
self.log("Ignore host: %s:%s" % self.server_conn.address(), "info")
|
||||
self.conntype = "tcp"
|
||||
client_ssl, server_ssl = False, False
|
||||
|
||||
self.determine_conntype()
|
||||
self.channel.ask("clientconnect", self)
|
||||
|
||||
# Check for existing connection: If an inline script already established a
|
||||
# connection, do not apply client_ssl or server_ssl.
|
||||
if self.server_conn and not self.server_conn.connection:
|
||||
self.establish_server_connection()
|
||||
if client_ssl or server_ssl:
|
||||
self.establish_ssl(client=client_ssl, server=server_ssl)
|
||||
|
||||
while not self.close:
|
||||
try:
|
||||
handle_messages(self.conntype, self)
|
||||
except ConnectionTypeChange:
|
||||
self.log("Connection type changed: %s" % self.conntype, "info")
|
||||
continue
|
||||
# Delegate handling to the protocol handler
|
||||
protocol_handler(self.conntype)(self).handle_messages()
|
||||
|
||||
except (ProxyError, tcp.NetLibError), e:
|
||||
handle_error(self.conntype, self, e)
|
||||
protocol_handler(self.conntype)(self).handle_error(e)
|
||||
except Exception:
|
||||
import traceback, sys
|
||||
|
||||
@ -106,7 +104,6 @@ class ConnectionHandler:
|
||||
print >> sys.stderr, traceback.format_exc()
|
||||
print >> sys.stderr, "mitmproxy has crashed!"
|
||||
print >> sys.stderr, "Please lodge a bug report at: https://github.com/mitmproxy/mitmproxy"
|
||||
raise
|
||||
|
||||
self.del_server_connection()
|
||||
self.log("clientdisconnect", "info")
|
||||
@ -124,12 +121,13 @@ class ConnectionHandler:
|
||||
self.server_conn = None
|
||||
self.sni = None
|
||||
|
||||
def determine_conntype(self):
|
||||
if self.server_conn and any(rex.search(self.server_conn.address.host) for rex in self.config.ignore):
|
||||
self.log("Ignore host: %s" % self.server_conn.address.host, "info")
|
||||
self.conntype = "tcp"
|
||||
def check_ignore_address(self, address):
|
||||
address = tcp.Address.wrap(address)
|
||||
host = "%s:%s" % (address.host, address.port)
|
||||
if host and any(rex.search(host) for rex in self.config.ignore):
|
||||
return True
|
||||
else:
|
||||
self.conntype = "http"
|
||||
return False
|
||||
|
||||
def set_server_address(self, address, priority):
|
||||
"""
|
||||
@ -175,15 +173,8 @@ class ConnectionHandler:
|
||||
def establish_ssl(self, client=False, server=False):
|
||||
"""
|
||||
Establishes SSL on the existing connection(s) to the server or the client,
|
||||
as specified by the parameters. If the target server is on the pass-through list,
|
||||
the conntype attribute will be changed and a ConnTypeChanged exception will be raised.
|
||||
as specified by the parameters.
|
||||
"""
|
||||
# If the host is on our ignore list, change to passthrough/ignore mode.
|
||||
for host in (self.server_conn.address.host, self.sni):
|
||||
if host and any(rex.search(host) for rex in self.config.ignore):
|
||||
self.log("Ignore host: %s" % host, "info")
|
||||
self.conntype = "tcp"
|
||||
raise ConnectionTypeChange()
|
||||
|
||||
# Logging
|
||||
if client or server:
|
||||
@ -224,7 +215,7 @@ class ConnectionHandler:
|
||||
self.establish_server_connection()
|
||||
|
||||
for s in state:
|
||||
handle_server_reconnect(s[0], self, s[1])
|
||||
protocol_handler(s[0])(self).handle_server_reconnect(s[1])
|
||||
self.server_conn.state = state
|
||||
|
||||
if had_ssl:
|
||||
@ -288,4 +279,5 @@ class ConnectionHandler:
|
||||
# make dang sure it doesn't happen.
|
||||
except Exception: # pragma: no cover
|
||||
import traceback
|
||||
|
||||
self.log("Error in handle_sni:\r\n" + traceback.format_exc(), "error")
|
@ -175,6 +175,7 @@ class TestHTTPAuth(tservers.HTTPProxTest):
|
||||
class TestHTTPConnectSSLError(tservers.HTTPProxTest):
|
||||
certfile = True
|
||||
def test_go(self):
|
||||
self.config.ssl_ports.append(self.proxy.port)
|
||||
p = self.pathoc_raw()
|
||||
dst = ("localhost", self.proxy.port)
|
||||
p.connect(connect_to=dst)
|
||||
|
@ -95,6 +95,7 @@ class ProxTestBase(object):
|
||||
confdir = cls.confdir,
|
||||
authenticator = cls.authenticator,
|
||||
certforward = cls.certforward,
|
||||
ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []),
|
||||
**pconf
|
||||
)
|
||||
tmaster = cls.masterclass(cls.config)
|
||||
@ -267,17 +268,20 @@ class ChainProxTest(ProxTestBase):
|
||||
Chain n instances of mitmproxy in a row - because we can.
|
||||
"""
|
||||
n = 2
|
||||
chain_config = [lambda port: ProxyConfig(
|
||||
chain_config = [lambda port, sslports: ProxyConfig(
|
||||
upstream_server= (False, False, "127.0.0.1", port),
|
||||
http_form_in = "absolute",
|
||||
http_form_out = "absolute"
|
||||
http_form_out = "absolute",
|
||||
ssl_ports=sslports
|
||||
)] * n
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
super(ChainProxTest, cls).setupAll()
|
||||
cls.chain = []
|
||||
for i in range(cls.n):
|
||||
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port)
|
||||
sslports = [cls.server.port, cls.server2.port]
|
||||
config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port,
|
||||
sslports)
|
||||
tmaster = cls.masterclass(config)
|
||||
tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp)
|
||||
cls.chain.append(ProxyThread(tmaster))
|
||||
|
Loading…
Reference in New Issue
Block a user