improve protocol handling

This commit is contained in:
Maximilian Hils 2015-08-16 23:25:02 +02:00
parent c04fa1b233
commit a2b8504889
11 changed files with 114 additions and 60 deletions

View File

@ -2,20 +2,20 @@ from __future__ import (absolute_import, print_function, division)
from .. import version from .. import version
from ..exceptions import InvalidCredentials, HttpException, ProtocolException from ..exceptions import InvalidCredentials, HttpException, ProtocolException
from .layer import Layer, ServerConnectionMixin from .layer import Layer
from libmproxy import utils from libmproxy import utils
from libmproxy.proxy.connection import ServerConnection
from .messages import SetServer, Connect, Reconnect, Kill from .messages import SetServer, Connect, Reconnect, Kill
from libmproxy.protocol import KILL from libmproxy.protocol import KILL
from libmproxy.protocol.http import HTTPFlow from libmproxy.protocol.http import HTTPFlow
from libmproxy.protocol.http_wrappers import HTTPResponse, HTTPRequest from libmproxy.protocol.http_wrappers import HTTPResponse, HTTPRequest
from libmproxy.protocol2.http_protocol_mock import HTTP1 from libmproxy.protocol2.http_protocol_mock import HTTP1
from libmproxy.protocol2.tls import TlsLayer
from netlib import tcp from netlib import tcp
from netlib.http import status_codes, http1, HttpErrorConnClosed from netlib.http import status_codes, http1, HttpErrorConnClosed
from netlib.http.semantics import CONTENT_MISSING from netlib.http.semantics import CONTENT_MISSING
from netlib import odict from netlib import odict
from netlib.tcp import NetLibError from netlib.tcp import NetLibError, Address
def make_error_response(status_code, message, headers=None): def make_error_response(status_code, message, headers=None):
@ -46,6 +46,7 @@ def make_error_response(status_code, message, headers=None):
def make_connect_request(address): def make_connect_request(address):
address = Address.wrap(address)
return HTTPRequest( return HTTPRequest(
"authority", "CONNECT", None, address.host, address.port, None, (1, 1), "authority", "CONNECT", None, address.host, address.port, None, (1, 1),
odict.ODictCaseless(), "" odict.ODictCaseless(), ""
@ -66,6 +67,22 @@ def make_connect_response(httpversion):
) )
class ConnectServerConnection(object):
"""
"Fake" ServerConnection to represent state after a CONNECT request to an upstream proxy.
"""
def __init__(self, address, ctx):
self.address = tcp.Address.wrap(address)
self._ctx = ctx
@property
def via(self):
return self._ctx.server_conn
def __getattr__(self, item):
return getattr(self.via, item)
class HttpLayer(Layer): class HttpLayer(Layer):
""" """
HTTP 1 Layer HTTP 1 Layer
@ -95,12 +112,8 @@ class HttpLayer(Layer):
# Regular Proxy Mode: Handle CONNECT # Regular Proxy Mode: Handle CONNECT
if self.mode == "regular" and request.form_in == "authority": if self.mode == "regular" and request.form_in == "authority":
yield SetServer((request.host, request.port), False, None) for message in self.handle_regular_mode_connect(request):
self.send_to_client(make_connect_response(request.httpversion)) yield message
layer = self.ctx.next_layer(self)
for message in layer():
if not self._handle_server_message(message):
yield message
return return
# Make sure that the incoming request matches our expectations # Make sure that the incoming request matches our expectations
@ -122,12 +135,50 @@ class HttpLayer(Layer):
if self.check_close_connection(flow): if self.check_close_connection(flow):
return return
# Upstream Proxy Mode: Handle CONNECT
if flow.request.form_in == "authority" and flow.response.code == 200: if flow.request.form_in == "authority" and flow.response.code == 200:
raise NotImplementedError("Upstream mode CONNECT not implemented") for message in self.handle_upstream_mode_connect(flow.request.copy()):
yield message
return
except (HttpErrorConnClosed, NetLibError) as e: except (HttpErrorConnClosed, NetLibError) as e:
make_error_response(502, repr(e)) make_error_response(502, repr(e))
raise ProtocolException(repr(e), e) raise ProtocolException(repr(e), e)
def handle_regular_mode_connect(self, request):
yield SetServer((request.host, request.port), False, None)
self.send_to_client(make_connect_response(request.httpversion))
layer = self.ctx.next_layer(self)
for message in layer():
yield message
def handle_upstream_mode_connect(self, connect_request):
layer = self.ctx.next_layer(self)
self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx)
for message in layer():
if message == Connect:
if not self.server_conn:
yield message
self.send_to_server(connect_request)
else:
pass # swallow the message
elif message == Reconnect:
yield message
self.send_to_server(connect_request)
elif message == SetServer:
if message.depth == 1:
if self.ctx.server_conn:
yield Reconnect()
connect_request.host = message.address.host
connect_request.port = message.address.port
self.server_conn.address = message.address
else:
message.depth -= 1
yield message
else:
yield message
def check_close_connection(self, flow): def check_close_connection(self, flow):
""" """
Checks if the connection should be closed depending on the HTTP Checks if the connection should be closed depending on the HTTP
@ -247,11 +298,11 @@ class HttpLayer(Layer):
if flow.request.form_in == "authority": if flow.request.form_in == "authority":
flow.request.scheme = "http" # pseudo value flow.request.scheme = "http" # pseudo value
else: else:
flow.request.host = self.ctx.server_address.host flow.request.host = self.ctx.server_conn.address.host
flow.request.port = self.ctx.server_address.port flow.request.port = self.ctx.server_conn.address.port
flow.request.scheme = "https" if self.server_conn.tls_established else "http" flow.request.scheme = "https" if self.server_conn.tls_established else "http"
# TODO: Expose ChangeServer functionality to inline scripts somehow? (yield_from_callback?) # TODO: Expose SetServer functionality to inline scripts somehow? (yield_from_callback?)
request_reply = self.channel.ask("request", flow) request_reply = self.channel.ask("request", flow)
if request_reply is None or request_reply == KILL: if request_reply is None or request_reply == KILL:
yield Kill() yield Kill()
@ -265,25 +316,26 @@ class HttpLayer(Layer):
tls = (flow.request.scheme == "https") tls = (flow.request.scheme == "https")
if self.mode == "regular" or self.mode == "transparent": if self.mode == "regular" or self.mode == "transparent":
# If there's an existing connection that doesn't match our expectations, kill it. # If there's an existing connection that doesn't match our expectations, kill it.
if self.server_address != address or tls != self.server_conn.ssl_established: if address != self.server_conn.address or tls != self.server_conn.ssl_established:
yield SetServer(address, tls, address.host) yield SetServer(address, tls, address.host)
# Establish connection is neccessary. # Establish connection is neccessary.
if not self.server_conn: if not self.server_conn:
yield Connect() yield Connect()
# ChangeServer is not guaranteed to work with TLS: # SetServer is not guaranteed to work with TLS:
# If there's not TlsLayer below which could catch the exception, # If there's not TlsLayer below which could catch the exception,
# TLS will not be established. # TLS will not be established.
if tls and not self.server_conn.tls_established: if tls and not self.server_conn.tls_established:
raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.") raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
else: else:
if not self.server_conn:
yield Connect()
if tls: if tls:
raise HttpException("Cannot change scheme in upstream proxy mode.") raise HttpException("Cannot change scheme in upstream proxy mode.")
""" """
# This is a very ugly (untested) workaround to solve a very ugly problem. # This is a very ugly (untested) workaround to solve a very ugly problem.
# FIXME: Check if connected first. if self.server_conn and self.server_conn.tls_established and not ssl:
if self.server_conn.tls_established and not ssl:
yield Reconnect() yield Reconnect()
elif ssl and not hasattr(self, "connected_to") or self.connected_to != address: elif ssl and not hasattr(self, "connected_to") or self.connected_to != address:
if self.server_conn.tls_established: if self.server_conn.tls_established:

View File

@ -14,8 +14,7 @@ class HttpProxy(Layer, ServerConnectionMixin):
class HttpUpstreamProxy(Layer, ServerConnectionMixin): class HttpUpstreamProxy(Layer, ServerConnectionMixin):
def __init__(self, ctx, server_address): def __init__(self, ctx, server_address):
super(HttpUpstreamProxy, self).__init__(ctx) super(HttpUpstreamProxy, self).__init__(ctx, server_address=server_address)
self.server_address = server_address
def __call__(self): def __call__(self):
layer = HttpLayer(self, "upstream") layer = HttpLayer(self, "upstream")

View File

@ -44,8 +44,8 @@ class _LayerCodeCompletion(object):
Dummy class that provides type hinting in PyCharm, which simplifies development a lot. Dummy class that provides type hinting in PyCharm, which simplifies development a lot.
""" """
def __init__(self): def __init__(self, *args, **kwargs):
super(_LayerCodeCompletion, self).__init__() super(_LayerCodeCompletion, self).__init__(*args, **kwargs)
if True: if True:
return return
self.config = None self.config = None
@ -57,12 +57,12 @@ class _LayerCodeCompletion(object):
class Layer(_LayerCodeCompletion): class Layer(_LayerCodeCompletion):
def __init__(self, ctx): def __init__(self, ctx, *args, **kwargs):
""" """
Args: Args:
ctx: The (read-only) higher layer. ctx: The (read-only) higher layer.
""" """
super(Layer, self).__init__() super(Layer, self).__init__(*args, **kwargs)
self.ctx = ctx self.ctx = ctx
def __call__(self): def __call__(self):
@ -103,10 +103,9 @@ class ServerConnectionMixin(object):
Mixin that provides a layer with the capabilities to manage a server connection. Mixin that provides a layer with the capabilities to manage a server connection.
""" """
def __init__(self): def __init__(self, server_address=None):
super(ServerConnectionMixin, self).__init__() super(ServerConnectionMixin, self).__init__()
self._server_address = None self.server_conn = ServerConnection(server_address)
self.server_conn = None
def _handle_server_message(self, message): def _handle_server_message(self, message):
if message == Reconnect: if message == Reconnect:
@ -116,44 +115,38 @@ class ServerConnectionMixin(object):
elif message == Connect: elif message == Connect:
self._connect() self._connect()
return True return True
elif message == SetServer and message.depth == 1: elif message == SetServer:
if self.server_conn: if message.depth == 1:
self._disconnect() if self.server_conn:
self.server_address = message.address self._disconnect()
return True self.log("Set new server address: " + repr(message.address), "debug")
self.server_conn.address = message.address
return True
else:
message.depth -= 1
elif message == Kill: elif message == Kill:
self._disconnect() self._disconnect()
return False return False
@property
def server_address(self):
return self._server_address
@server_address.setter
def server_address(self, address):
self._server_address = tcp.Address.wrap(address)
self.log("Set new server address: " + repr(self.server_address), "debug")
def _disconnect(self): def _disconnect(self):
""" """
Deletes (and closes) an existing server connection. Deletes (and closes) an existing server connection.
""" """
self.log("serverdisconnect", "debug", [repr(self.server_address)]) self.log("serverdisconnect", "debug", [repr(self.server_conn.address)])
self.server_conn.finish() self.server_conn.finish()
self.server_conn.close() self.server_conn.close()
# self.channel.tell("serverdisconnect", self) # self.channel.tell("serverdisconnect", self)
self.server_conn = None self.server_conn = ServerConnection(None)
def _connect(self): def _connect(self):
if not self.server_address: if not self.server_conn.address:
raise ProtocolException("Cannot connect to server, no server address given.") raise ProtocolException("Cannot connect to server, no server address given.")
self.log("serverconnect", "debug", [repr(self.server_address)]) self.log("serverconnect", "debug", [repr(self.server_conn.address)])
self.server_conn = ServerConnection(self.server_address)
try: try:
self.server_conn.connect() self.server_conn.connect()
except tcp.NetLibError as e: except tcp.NetLibError as e:
raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_address, e), e) raise ProtocolException("Server connection to '%s' failed: %s" % (self.server_conn.address, e), e)
def yield_from_callback(fun): def yield_from_callback(fun):

View File

@ -2,6 +2,7 @@
This module contains all valid messages layers can send to the underlying layers. This module contains all valid messages layers can send to the underlying layers.
""" """
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
from netlib.tcp import Address
class _Message(object): class _Message(object):
@ -33,7 +34,7 @@ class SetServer(_Message):
""" """
def __init__(self, address, server_tls, sni, depth=1): def __init__(self, address, server_tls, sni, depth=1):
self.address = address self.address = Address.wrap(address)
self.server_tls = server_tls self.server_tls = server_tls
self.sni = sni self.sni = sni

View File

@ -7,8 +7,7 @@ from .tls import TlsLayer
class ReverseProxy(Layer, ServerConnectionMixin): class ReverseProxy(Layer, ServerConnectionMixin):
def __init__(self, ctx, server_address, client_tls, server_tls): def __init__(self, ctx, server_address, client_tls, server_tls):
super(ReverseProxy, self).__init__(ctx) super(ReverseProxy, self).__init__(ctx, server_address=server_address)
self.server_address = server_address
self._client_tls = client_tls self._client_tls = client_tls
self._server_tls = server_tls self._server_tls = server_tls

View File

@ -1,9 +1,11 @@
from __future__ import (absolute_import, print_function, division) from __future__ import (absolute_import, print_function, division)
from .messages import Kill
from .rawtcp import RawTcpLayer from .rawtcp import RawTcpLayer
from .tls import TlsLayer from .tls import TlsLayer
from .http import HttpLayer from .http import HttpLayer
class RootContext(object): class RootContext(object):
""" """
The outmost context provided to the root layer. The outmost context provided to the root layer.
@ -35,7 +37,7 @@ class RootContext(object):
) )
if not d: if not d:
return return iter([])
if is_tls_client_hello: if is_tls_client_hello:
return TlsLayer(top_layer, True, True) return TlsLayer(top_layer, True, True)
@ -44,7 +46,6 @@ class RootContext(object):
else: else:
return RawTcpLayer(top_layer) return RawTcpLayer(top_layer)
@property @property
def layers(self): def layers(self):
return [] return []

View File

@ -5,7 +5,7 @@ from ..proxy import ProxyError, Socks5ProxyMode
from .layer import Layer, ServerConnectionMixin from .layer import Layer, ServerConnectionMixin
class Socks5Proxy(Layer, ServerConnectionMixin): class Socks5Proxy(ServerConnectionMixin, Layer):
def __call__(self): def __call__(self):
try: try:
s5mode = Socks5ProxyMode(self.config.ssl_ports) s5mode = Socks5ProxyMode(self.config.ssl_ports)
@ -14,7 +14,7 @@ class Socks5Proxy(Layer, ServerConnectionMixin):
# TODO: Unmonkeypatch # TODO: Unmonkeypatch
raise ProtocolException(str(e), e) raise ProtocolException(str(e), e)
self.server_address = address self.server_conn.address = address
layer = self.ctx.next_layer(self) layer = self.ctx.next_layer(self)
for message in layer(): for message in layer():

View File

@ -220,7 +220,7 @@ class TlsLayer(Layer):
host = self.server_conn.address.host host = self.server_conn.address.host
sans = set() sans = set()
# Incorporate upstream certificate # Incorporate upstream certificate
if self.server_conn.tls_established and (not self.config.no_upstream_cert): if self.server_conn and self.server_conn.tls_established and (not self.config.no_upstream_cert):
upstream_cert = self.server_conn.cert upstream_cert = self.server_conn.cert
sans.update(upstream_cert.altnames) sans.update(upstream_cert.altnames)
if upstream_cert.cn: if upstream_cert.cn:
@ -232,4 +232,5 @@ class TlsLayer(Layer):
if self._sni_from_server_change: if self._sni_from_server_change:
sans.add(self._sni_from_server_change) sans.add(self._sni_from_server_change)
sans.discard(host)
return self.config.certstore.get_cert(host, list(sans)) return self.config.certstore.get_cert(host, list(sans))

View File

@ -13,7 +13,7 @@ class TransparentProxy(Layer, ServerConnectionMixin):
def __call__(self): def __call__(self):
try: try:
self.server_address = self.resolver.original_addr(self.client_conn.connection) self.server_conn.address = self.resolver.original_addr(self.client_conn.connection)
except Exception as e: except Exception as e:
raise ProtocolException("Transparent mode failure: %s" % repr(e), e) raise ProtocolException("Transparent mode failure: %s" % repr(e), e)

View File

@ -96,6 +96,9 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.timestamp_ssl_setup = None self.timestamp_ssl_setup = None
self.protocol = None self.protocol = None
def __nonzero__(self):
return bool(self.connection)
def __repr__(self): def __repr__(self):
if self.ssl_established and self.sni: if self.ssl_established and self.sni:
ssl = "[ssl: {0}] ".format(self.sni) ssl = "[ssl: {0}] ".format(self.sni)
@ -132,8 +135,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
d.update( d.update(
address={"address": self.address(), address={"address": self.address(),
"use_ipv6": self.address.use_ipv6}, "use_ipv6": self.address.use_ipv6},
source_address= ({"address": self.source_address(), source_address=({"address": self.source_address(),
"use_ipv6": self.source_address.use_ipv6} if self.source_address else None), "use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
cert=self.cert.to_pem() if self.cert else None cert=self.cert.to_pem() if self.cert else None
) )
return d return d

View File

@ -80,7 +80,12 @@ class ConnectionHandler2:
self.config, self.config,
self.channel self.channel
) )
root_layer = protocol2.HttpProxy(root_context)
# FIXME: properly parse config
if self.config.mode == "upstream":
root_layer = protocol2.HttpUpstreamProxy(root_context, ("localhost", 8081))
else:
root_layer = protocol2.HttpProxy(root_context)
try: try:
for message in root_layer(): for message in root_layer():
@ -302,7 +307,7 @@ class ConnectionHandler:
if ssl_cert_err is not None: if ssl_cert_err is not None:
self.log( self.log(
"SSL verification failed for upstream server at depth %s with error: %s" % "SSL verification failed for upstream server at depth %s with error: %s" %
(ssl_cert_err['depth'], ssl_cert_err['errno']), (ssl_cert_err['depth'], ssl_cert_err['errno']),
"error") "error")
self.log("Ignoring server verification error, continuing with connection", "error") self.log("Ignoring server verification error, continuing with connection", "error")
except tcp.NetLibError as v: except tcp.NetLibError as v:
@ -318,7 +323,7 @@ class ConnectionHandler:
if ssl_cert_err is not None: if ssl_cert_err is not None:
self.log( self.log(
"SSL verification failed for upstream server at depth %s with error: %s" % "SSL verification failed for upstream server at depth %s with error: %s" %
(ssl_cert_err['depth'], ssl_cert_err['errno']), (ssl_cert_err['depth'], ssl_cert_err['errno']),
"error") "error")
self.log("Aborting connection attempt", "error") self.log("Aborting connection attempt", "error")
raise e raise e