diff --git a/mitmproxy/addons/next_layer.py b/mitmproxy/addons/next_layer.py index c059b0dd1..4374efac0 100644 --- a/mitmproxy/addons/next_layer.py +++ b/mitmproxy/addons/next_layer.py @@ -1,7 +1,7 @@ from mitmproxy import ctx from mitmproxy.net import server_spec from mitmproxy.proxy.config import HostMatcher -from mitmproxy.proxy.protocol import is_tls_record_magic +from mitmproxy.net.tls import is_tls_record_magic from mitmproxy.proxy2 import layer, layers @@ -34,7 +34,7 @@ class NextLayer: nextlayer.context.server.tls = ( server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https" ) - nextlayer.layer = layers.TLSLayer(nextlayer.context) + nextlayer.layer = layers.ClientTLSLayer(nextlayer.context) return # TODO: Other top layers @@ -45,7 +45,7 @@ class NextLayer: if client_tls: nextlayer.context.client.tls = True nextlayer.context.server.tls = True - nextlayer.layer = layers.TLSLayer(nextlayer.context) + nextlayer.layer = layers.ClientTLSLayer(nextlayer.context) return # 5. Check for --tcp @@ -54,7 +54,7 @@ class NextLayer: return # 6. Check for TLS ALPN (HTTP1/HTTP2) - if isinstance(top_layer, layers.TLSLayer): + if isinstance(top_layer, layers.ServerTLSLayer): alpn = nextlayer.context.client.alpn if alpn == b'http/1.1': nextlayer.layer = layers.HTTPLayer(nextlayer.context) diff --git a/mitmproxy/addons/proxyserver.py b/mitmproxy/addons/proxyserver.py index 249d7201c..6011b0ca2 100644 --- a/mitmproxy/addons/proxyserver.py +++ b/mitmproxy/addons/proxyserver.py @@ -1,63 +1,62 @@ import asyncio -import queue -import threading -from mitmproxy import ctx, controller, log +from mitmproxy import ctx, controller, log, options, master from mitmproxy.proxy2 import commands from mitmproxy.proxy2 import events from mitmproxy.proxy2 import server class AsyncReply(controller.Reply): - # temporary glue code - let's see how many years it survives. - def __init__(self, submit, *args): - self.submit = submit + """ + controller.Reply.q.get() is blocking, which we definitely want to avoid in a coroutine. + This stub adds a .done asyncio.Event() that can be used instead. + """ + + def __init__(self, *args): + self.done = asyncio.Event() + self.loop = asyncio.get_event_loop() super().__init__(*args) def commit(self): super().commit() - self.submit(self.q.get_nowait()) + self.loop.call_soon_threadsafe(lambda: self.done.set()) class ProxyConnectionHandler(server.ConnectionHandler): - event_queue: queue.Queue - loop: asyncio.AbstractEventLoop + master: master.Master - def __init__(self, event_queue, loop, r, w, options): - self.event_queue = event_queue - self.loop = loop + def __init__(self, master, r, w, options): + self.master = master super().__init__(r, w, options) async def handle_hook(self, hook: commands.Hook) -> None: - q = asyncio.Queue() - - hook.data.reply = AsyncReply( - lambda x: self.loop.call_soon_threadsafe(lambda: q.put_nowait(x)), - hook.data - ) - self.event_queue.put((hook.name, hook.data)) - await q.get() + hook.data.reply = AsyncReply(hook.data) + await self.master.addons.handle_lifecycle(hook.name, hook.data) + await hook.data.reply.done.wait() if hook.blocking: self.server_event(events.HookReply(hook)) def log(self, message: str, level: str = "info") -> None: x = log.LogEntry(message, level) x.reply = controller.DummyReply() - self.event_queue.put(("log", x)) + asyncio.ensure_future( + self.master.addons.handle_lifecycle("log", x) + ) class Proxyserver: """ This addon runs the actual proxy server. """ + server: asyncio.AbstractServer + listen_port: int + master: master.Master + options: options.Options + is_running: bool def __init__(self): - self.server = None - self.loop = asyncio.get_event_loop() - self.listen_port = None - self.event_queue = None - self.options = None self._lock = asyncio.Lock() + self.server = None self.is_running = False def load(self, loader): @@ -68,41 +67,39 @@ class Proxyserver: ) def running(self): + self.master = ctx.master self.options = ctx.options - self.event_queue = ctx.master.event_queue - threading.Thread(target=self.loop.run_forever, daemon=True).start() self.is_running = True self.configure(["listen_port"]) - async def start(self): - async with self._lock: - if self.server: - print("Stopping server...") - self.server.close() - await self.server.wait_closed() - - print("Starting server...") - self.server = await asyncio.start_server( - self.handle_connection, - '127.0.0.1', - self.listen_port, - loop=self.loop - ) - - async def handle_connection(self, r, w): - await ProxyConnectionHandler( - self.event_queue, - self.loop, - r, - w, - self.options - ).handle_client() - def configure(self, updated): if not self.is_running: return if "listen_port" in updated: self.listen_port = ctx.options.listen_port + 1 + asyncio.ensure_future(self.start_server()) - # not sure if this actually required... - self.loop.call_soon_threadsafe(lambda: asyncio.ensure_future(self.start())) + async def start_server(self): + async with self._lock: + if self.server: + await self.shutdown_server() + print("Starting server...") + self.server = await asyncio.start_server( + self.handle_connection, + '127.0.0.1', + self.listen_port, + ) + + async def shutdown_server(self): + print("Stopping server...") + self.server.close() + await self.server.wait_closed() + self.server = None + + async def handle_connection(self, r, w): + await ProxyConnectionHandler( + self.master, + r, + w, + self.options + ).handle_client() diff --git a/mitmproxy/proxy2/layers/tls.py b/mitmproxy/proxy2/layers/tls.py index 118b1da2b..53aa22e56 100644 --- a/mitmproxy/proxy2/layers/tls.py +++ b/mitmproxy/proxy2/layers/tls.py @@ -5,7 +5,7 @@ from typing import MutableMapping, Optional, Iterator, Union, Generator, Any from OpenSSL import SSL from mitmproxy.certs import CertStore -from mitmproxy.proxy.protocol import TlsClientHello +from mitmproxy.net.tls import ClientHello from mitmproxy.proxy.protocol import tls from mitmproxy.proxy2 import context from mitmproxy.proxy2 import layer, commands, events @@ -69,7 +69,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]: return None -def parse_client_hello(data: bytes) -> Optional[TlsClientHello]: +def parse_client_hello(data: bytes) -> Optional[ClientHello]: """ Check if the supplied bytes contain a full ClientHello message, and if so, parse it. @@ -84,7 +84,7 @@ def parse_client_hello(data: bytes) -> Optional[TlsClientHello]: # Check if ClientHello is complete client_hello = get_client_hello(data) if client_hello: - return TlsClientHello(client_hello[4:]) + return ClientHello(client_hello[4:]) return None