diff --git a/libmproxy/protocol2/__init__.py b/libmproxy/protocol2/__init__.py index 95f67c6c6..3f714f62d 100644 --- a/libmproxy/protocol2/__init__.py +++ b/libmproxy/protocol2/__init__.py @@ -1,6 +1,7 @@ from __future__ import (absolute_import, print_function, division) from .layer import RootContext from .socks import Socks5IncomingLayer +from .reverse_proxy import ReverseProxy from .rawtcp import TcpLayer from .auto import AutoLayer -__all__ = ["Socks5IncomingLayer", "TcpLayer", "AutoLayer", "RootContext"] +__all__ = ["Socks5IncomingLayer", "TcpLayer", "AutoLayer", "RootContext", "ReverseProxy"] diff --git a/libmproxy/protocol2/layer.py b/libmproxy/protocol2/layer.py index aaa51baf4..c18be83cf 100644 --- a/libmproxy/protocol2/layer.py +++ b/libmproxy/protocol2/layer.py @@ -176,7 +176,7 @@ def yield_from_callback(fun): """ yield_queue = Queue.Queue() - def do_yield(self, msg): + def do_yield(msg): yield_queue.put(msg) yield_queue.get() @@ -192,14 +192,14 @@ def yield_from_callback(fun): threading.Thread(target=run, name="YieldFromCallbackThread").start() while True: - e = yield_queue.get() - if e is True: + msg = yield_queue.get() + if msg is True: break - elif isinstance(e, Exception): + elif isinstance(msg, Exception): # TODO: Include func name? - raise ProxyError2("Error from callback: " + repr(e), e) + raise ProxyError2("Error in %s: %s" % (fun.__name__, repr(msg)), msg) else: - yield e + yield msg yield_queue.put(None) self.yield_from_callback = None diff --git a/libmproxy/protocol2/reverse_proxy.py b/libmproxy/protocol2/reverse_proxy.py new file mode 100644 index 000000000..dfffd2f25 --- /dev/null +++ b/libmproxy/protocol2/reverse_proxy.py @@ -0,0 +1,19 @@ +from __future__ import (absolute_import, print_function, division) + +from .layer import Layer, ServerConnectionMixin +from .ssl import SslLayer + + +class ReverseProxy(Layer, ServerConnectionMixin): + + def __init__(self, ctx, server_address, client_ssl, server_ssl): + super(ReverseProxy, self).__init__(ctx) + self.server_address = server_address + self.client_ssl = client_ssl + self.server_ssl = server_ssl + + def __call__(self): + layer = SslLayer(self, self.client_ssl, self.server_ssl) + for message in layer(): + if not self._handle_server_message(message): + yield message diff --git a/libmproxy/protocol2/ssl.py b/libmproxy/protocol2/ssl.py index 32798e720..a744a979c 100644 --- a/libmproxy/protocol2/ssl.py +++ b/libmproxy/protocol2/ssl.py @@ -14,7 +14,7 @@ class SslLayer(Layer): self._client_ssl = client_ssl self._server_ssl = server_ssl self._connected = False - self._sni_from_handshake = None + self.client_sni = None self._sni_from_server_change = None def __call__(self): @@ -74,7 +74,7 @@ class SslLayer(Layer): if self._sni_from_server_change is False: return None else: - return self._sni_from_server_change or self._sni_from_handshake + return self._sni_from_server_change or self.client_sni def _establish_ssl_with_client_and_server(self): """ @@ -97,7 +97,7 @@ class SslLayer(Layer): else: raise RuntimeError("Unexpected Message: %s" % message) - if server_err and not self._sni_from_handshake: + if server_err and not self.client_sni: raise server_err def handle_sni(self, connection): @@ -111,14 +111,14 @@ class SslLayer(Layer): sn = connection.get_servername() if not sn: return - self._sni_from_handshake = sn.decode("utf8").encode("idna") + self.client_sni = sn.decode("utf8").encode("idna") if old_upstream_sni != self.sni_for_upstream_connection: # Perform reconnect if self.server_ssl: self.yield_from_callback(Reconnect()) - if self._sni_from_handshake: + if self.client_sni: # Now, change client context to reflect possibly changed certificate: cert, key, chain_file = self.find_cert() new_context = self.client_conn.create_ssl_context( @@ -195,8 +195,8 @@ class SslLayer(Layer): sans.add(host) host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. - if self._sni_from_handshake: - sans.add(self._sni_from_handshake) + if self.client_sni: + sans.add(self.client_sni) if self._sni_from_server_change: sans.add(self._sni_from_server_change) diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index c8990a9a1..32d596ad8 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -79,7 +79,7 @@ class ConnectionHandler2: self.config, self.channel ) - root_layer = protocol2.Socks5IncomingLayer(root_context) + root_layer = protocol2.ReverseProxy(root_context, ("localhost", 5000), True, True) try: for message in root_layer():