From a2b85048892626e6834df06e9022498814724636 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 16 Aug 2015 23:25:02 +0200 Subject: [PATCH] improve protocol handling --- libmproxy/protocol2/http.py | 86 +++++++++++++++++++----- libmproxy/protocol2/http_proxy.py | 3 +- libmproxy/protocol2/layer.py | 47 ++++++------- libmproxy/protocol2/messages.py | 3 +- libmproxy/protocol2/reverse_proxy.py | 3 +- libmproxy/protocol2/root_context.py | 5 +- libmproxy/protocol2/socks_proxy.py | 4 +- libmproxy/protocol2/tls.py | 3 +- libmproxy/protocol2/transparent_proxy.py | 2 +- libmproxy/proxy/connection.py | 7 +- libmproxy/proxy/server.py | 11 ++- 11 files changed, 114 insertions(+), 60 deletions(-) diff --git a/libmproxy/protocol2/http.py b/libmproxy/protocol2/http.py index 7cc276527..907846667 100644 --- a/libmproxy/protocol2/http.py +++ b/libmproxy/protocol2/http.py @@ -2,20 +2,20 @@ from __future__ import (absolute_import, print_function, division) from .. import version from ..exceptions import InvalidCredentials, HttpException, ProtocolException -from .layer import Layer, ServerConnectionMixin +from .layer import Layer from libmproxy import utils +from libmproxy.proxy.connection import ServerConnection from .messages import SetServer, Connect, Reconnect, Kill from libmproxy.protocol import KILL from libmproxy.protocol.http import HTTPFlow from libmproxy.protocol.http_wrappers import HTTPResponse, HTTPRequest from libmproxy.protocol2.http_protocol_mock import HTTP1 -from libmproxy.protocol2.tls import TlsLayer from netlib import tcp from netlib.http import status_codes, http1, HttpErrorConnClosed from netlib.http.semantics import CONTENT_MISSING from netlib import odict -from netlib.tcp import NetLibError +from netlib.tcp import NetLibError, Address def make_error_response(status_code, message, headers=None): @@ -46,6 +46,7 @@ def make_error_response(status_code, message, headers=None): def make_connect_request(address): + address = Address.wrap(address) return HTTPRequest( "authority", "CONNECT", None, address.host, address.port, None, (1, 1), odict.ODictCaseless(), "" @@ -66,6 +67,22 @@ def make_connect_response(httpversion): ) +class ConnectServerConnection(object): + """ + "Fake" ServerConnection to represent state after a CONNECT request to an upstream proxy. + """ + def __init__(self, address, ctx): + self.address = tcp.Address.wrap(address) + self._ctx = ctx + + @property + def via(self): + return self._ctx.server_conn + + def __getattr__(self, item): + return getattr(self.via, item) + + class HttpLayer(Layer): """ HTTP 1 Layer @@ -95,12 +112,8 @@ class HttpLayer(Layer): # Regular Proxy Mode: Handle CONNECT if self.mode == "regular" and request.form_in == "authority": - yield SetServer((request.host, request.port), False, None) - self.send_to_client(make_connect_response(request.httpversion)) - layer = self.ctx.next_layer(self) - for message in layer(): - if not self._handle_server_message(message): - yield message + for message in self.handle_regular_mode_connect(request): + yield message return # Make sure that the incoming request matches our expectations @@ -122,12 +135,50 @@ class HttpLayer(Layer): if self.check_close_connection(flow): return + # Upstream Proxy Mode: Handle CONNECT if flow.request.form_in == "authority" and flow.response.code == 200: - raise NotImplementedError("Upstream mode CONNECT not implemented") + for message in self.handle_upstream_mode_connect(flow.request.copy()): + yield message + return + except (HttpErrorConnClosed, NetLibError) as e: make_error_response(502, repr(e)) raise ProtocolException(repr(e), e) + def handle_regular_mode_connect(self, request): + yield SetServer((request.host, request.port), False, None) + self.send_to_client(make_connect_response(request.httpversion)) + layer = self.ctx.next_layer(self) + for message in layer(): + yield message + + def handle_upstream_mode_connect(self, connect_request): + layer = self.ctx.next_layer(self) + self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx) + + for message in layer(): + if message == Connect: + if not self.server_conn: + yield message + self.send_to_server(connect_request) + else: + pass # swallow the message + elif message == Reconnect: + yield message + self.send_to_server(connect_request) + elif message == SetServer: + if message.depth == 1: + if self.ctx.server_conn: + yield Reconnect() + connect_request.host = message.address.host + connect_request.port = message.address.port + self.server_conn.address = message.address + else: + message.depth -= 1 + yield message + else: + yield message + def check_close_connection(self, flow): """ Checks if the connection should be closed depending on the HTTP @@ -247,11 +298,11 @@ class HttpLayer(Layer): if flow.request.form_in == "authority": flow.request.scheme = "http" # pseudo value else: - flow.request.host = self.ctx.server_address.host - flow.request.port = self.ctx.server_address.port + flow.request.host = self.ctx.server_conn.address.host + flow.request.port = self.ctx.server_conn.address.port flow.request.scheme = "https" if self.server_conn.tls_established else "http" - # TODO: Expose ChangeServer functionality to inline scripts somehow? (yield_from_callback?) + # TODO: Expose SetServer functionality to inline scripts somehow? (yield_from_callback?) request_reply = self.channel.ask("request", flow) if request_reply is None or request_reply == KILL: yield Kill() @@ -265,25 +316,26 @@ class HttpLayer(Layer): tls = (flow.request.scheme == "https") if self.mode == "regular" or self.mode == "transparent": # If there's an existing connection that doesn't match our expectations, kill it. - if self.server_address != address or tls != self.server_conn.ssl_established: + if address != self.server_conn.address or tls != self.server_conn.ssl_established: yield SetServer(address, tls, address.host) # Establish connection is neccessary. if not self.server_conn: yield Connect() - # ChangeServer is not guaranteed to work with TLS: + # SetServer is not guaranteed to work with TLS: # If there's not TlsLayer below which could catch the exception, # TLS will not be established. if tls and not self.server_conn.tls_established: raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.") else: + if not self.server_conn: + yield Connect() if tls: raise HttpException("Cannot change scheme in upstream proxy mode.") """ # This is a very ugly (untested) workaround to solve a very ugly problem. - # FIXME: Check if connected first. - if self.server_conn.tls_established and not ssl: + if self.server_conn and self.server_conn.tls_established and not ssl: yield Reconnect() elif ssl and not hasattr(self, "connected_to") or self.connected_to != address: if self.server_conn.tls_established: diff --git a/libmproxy/protocol2/http_proxy.py b/libmproxy/protocol2/http_proxy.py index ca70b0126..3cc7fee2c 100644 --- a/libmproxy/protocol2/http_proxy.py +++ b/libmproxy/protocol2/http_proxy.py @@ -14,8 +14,7 @@ class HttpProxy(Layer, ServerConnectionMixin): class HttpUpstreamProxy(Layer, ServerConnectionMixin): def __init__(self, ctx, server_address): - super(HttpUpstreamProxy, self).__init__(ctx) - self.server_address = server_address + super(HttpUpstreamProxy, self).__init__(ctx, server_address=server_address) def __call__(self): layer = HttpLayer(self, "upstream") diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py index c1648a62d..67f3d549c 100644 --- a/libmproxy/protocol2/layer.py +++ b/libmproxy/protocol2/layer.py @@ -44,8 +44,8 @@ class _LayerCodeCompletion(object): Dummy class that provides type hinting in PyCharm, which simplifies development a lot. """ - def __init__(self): - super(_LayerCodeCompletion, self).__init__() + def __init__(self, *args, **kwargs): + super(_LayerCodeCompletion, self).__init__(*args, **kwargs) if True: return self.config = None @@ -57,12 +57,12 @@ class _LayerCodeCompletion(object): class Layer(_LayerCodeCompletion): - def __init__(self, ctx): + def __init__(self, ctx, *args, **kwargs): """ Args: ctx: The (read-only) higher layer. """ - super(Layer, self).__init__() + super(Layer, self).__init__(*args, **kwargs) self.ctx = ctx def __call__(self): @@ -103,10 +103,9 @@ class ServerConnectionMixin(object): Mixin that provides a layer with the capabilities to manage a server connection. """ - def __init__(self): + def __init__(self, server_address=None): super(ServerConnectionMixin, self).__init__() - self._server_address = None - self.server_conn = None + self.server_conn = ServerConnection(server_address) def _handle_server_message(self, message): if message == Reconnect: @@ -116,44 +115,38 @@ class ServerConnectionMixin(object): elif message == Connect: self._connect() return True - elif message == SetServer and message.depth == 1: - if self.server_conn: - self._disconnect() - self.server_address = message.address - return True + elif message == SetServer: + if message.depth == 1: + if self.server_conn: + self._disconnect() + self.log("Set new server address: " + repr(message.address), "debug") + self.server_conn.address = message.address + return True + else: + message.depth -= 1 elif message == Kill: self._disconnect() return False - @property - def server_address(self): - return self._server_address - - @server_address.setter - def server_address(self, address): - self._server_address = tcp.Address.wrap(address) - self.log("Set new server address: " + repr(self.server_address), "debug") - def _disconnect(self): """ Deletes (and closes) an existing server connection. """ - self.log("serverdisconnect", "debug", [repr(self.server_address)]) + self.log("serverdisconnect", "debug", [repr(self.server_conn.address)]) self.server_conn.finish() self.server_conn.close() # self.channel.tell("serverdisconnect", self) - self.server_conn = None + self.server_conn = ServerConnection(None) def _connect(self): - if not self.server_address: + if not self.server_conn.address: raise ProtocolException("Cannot connect to server, no server address given.") - self.log("serverconnect", "debug", [repr(self.server_address)]) - self.server_conn = ServerConnection(self.server_address) + self.log("serverconnect", "debug", [repr(self.server_conn.address)]) try: self.server_conn.connect() except tcp.NetLibError as e: - raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_address, e), e) + raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_conn.address, e), e) def yield_from_callback(fun): diff --git a/libmproxy/protocol2/messages.py b/libmproxy/protocol2/messages.py index 17e12f112..f5907537a 100644 --- a/libmproxy/protocol2/messages.py +++ b/libmproxy/protocol2/messages.py @@ -2,6 +2,7 @@ This module contains all valid messages layers can send to the underlying layers. """ from __future__ import (absolute_import, print_function, division) +from netlib.tcp import Address class _Message(object): @@ -33,7 +34,7 @@ class SetServer(_Message): """ def __init__(self, address, server_tls, sni, depth=1): - self.address = address + self.address = Address.wrap(address) self.server_tls = server_tls self.sni = sni diff --git a/libmproxy/protocol2/reverse_proxy.py b/libmproxy/protocol2/reverse_proxy.py index bb414ec37..2ee3d9d8d 100644 --- a/libmproxy/protocol2/reverse_proxy.py +++ b/libmproxy/protocol2/reverse_proxy.py @@ -7,8 +7,7 @@ from .tls import TlsLayer class ReverseProxy(Layer, ServerConnectionMixin): def __init__(self, ctx, server_address, client_tls, server_tls): - super(ReverseProxy, self).__init__(ctx) - self.server_address = server_address + super(ReverseProxy, self).__init__(ctx, server_address=server_address) self._client_tls = client_tls self._server_tls = server_tls diff --git a/libmproxy/protocol2/root_context.py b/libmproxy/protocol2/root_context.py index bda8b12b2..f369fd488 100644 --- a/libmproxy/protocol2/root_context.py +++ b/libmproxy/protocol2/root_context.py @@ -1,9 +1,11 @@ from __future__ import (absolute_import, print_function, division) +from .messages import Kill from .rawtcp import RawTcpLayer from .tls import TlsLayer from .http import HttpLayer + class RootContext(object): """ The outmost context provided to the root layer. @@ -35,7 +37,7 @@ class RootContext(object): ) if not d: - return + return iter([]) if is_tls_client_hello: return TlsLayer(top_layer, True, True) @@ -44,7 +46,6 @@ class RootContext(object): else: return RawTcpLayer(top_layer) - @property def layers(self): return [] diff --git a/libmproxy/protocol2/socks_proxy.py b/libmproxy/protocol2/socks_proxy.py index c89477cae..c6126a424 100644 --- a/libmproxy/protocol2/socks_proxy.py +++ b/libmproxy/protocol2/socks_proxy.py @@ -5,7 +5,7 @@ from ..proxy import ProxyError, Socks5ProxyMode from .layer import Layer, ServerConnectionMixin -class Socks5Proxy(Layer, ServerConnectionMixin): +class Socks5Proxy(ServerConnectionMixin, Layer): def __call__(self): try: s5mode = Socks5ProxyMode(self.config.ssl_ports) @@ -14,7 +14,7 @@ class Socks5Proxy(Layer, ServerConnectionMixin): # TODO: Unmonkeypatch raise ProtocolException(str(e), e) - self.server_address = address + self.server_conn.address = address layer = self.ctx.next_layer(self) for message in layer(): diff --git a/libmproxy/protocol2/tls.py b/libmproxy/protocol2/tls.py index fcc12f18c..12c67f4e4 100644 --- a/libmproxy/protocol2/tls.py +++ b/libmproxy/protocol2/tls.py @@ -220,7 +220,7 @@ class TlsLayer(Layer): host = self.server_conn.address.host sans = set() # Incorporate upstream certificate - if self.server_conn.tls_established and (not self.config.no_upstream_cert): + if self.server_conn and self.server_conn.tls_established and (not self.config.no_upstream_cert): upstream_cert = self.server_conn.cert sans.update(upstream_cert.altnames) if upstream_cert.cn: @@ -232,4 +232,5 @@ class TlsLayer(Layer): if self._sni_from_server_change: sans.add(self._sni_from_server_change) + sans.discard(host) return self.config.certstore.get_cert(host, list(sans)) diff --git a/libmproxy/protocol2/transparent_proxy.py b/libmproxy/protocol2/transparent_proxy.py index f073e2f86..4ed4c14bd 100644 --- a/libmproxy/protocol2/transparent_proxy.py +++ b/libmproxy/protocol2/transparent_proxy.py @@ -13,7 +13,7 @@ class TransparentProxy(Layer, ServerConnectionMixin): def __call__(self): try: - self.server_address = self.resolver.original_addr(self.client_conn.connection) + self.server_conn.address = self.resolver.original_addr(self.client_conn.connection) except Exception as e: raise ProtocolException("Transparent mode failure: %s" % repr(e), e) diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index f33e84cdf..f92b53aac 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -96,6 +96,9 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.timestamp_ssl_setup = None self.protocol = None + def __nonzero__(self): + return bool(self.connection) + def __repr__(self): if self.ssl_established and self.sni: ssl = "[ssl: {0}] ".format(self.sni) @@ -132,8 +135,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): d.update( address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, - source_address= ({"address": self.source_address(), - "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), + source_address=({"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), cert=self.cert.to_pem() if self.cert else None ) return d diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index ffca55eea..e23a7d72f 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -80,7 +80,12 @@ class ConnectionHandler2: self.config, self.channel ) - root_layer = protocol2.HttpProxy(root_context) + + # FIXME: properly parse config + if self.config.mode == "upstream": + root_layer = protocol2.HttpUpstreamProxy(root_context, ("localhost", 8081)) + else: + root_layer = protocol2.HttpProxy(root_context) try: for message in root_layer(): @@ -302,7 +307,7 @@ class ConnectionHandler: if ssl_cert_err is not None: self.log( "SSL verification failed for upstream server at depth %s with error: %s" % - (ssl_cert_err['depth'], ssl_cert_err['errno']), + (ssl_cert_err['depth'], ssl_cert_err['errno']), "error") self.log("Ignoring server verification error, continuing with connection", "error") except tcp.NetLibError as v: @@ -318,7 +323,7 @@ class ConnectionHandler: if ssl_cert_err is not None: self.log( "SSL verification failed for upstream server at depth %s with error: %s" % - (ssl_cert_err['depth'], ssl_cert_err['errno']), + (ssl_cert_err['depth'], ssl_cert_err['errno']), "error") self.log("Aborting connection attempt", "error") raise e