[sans-io] adapt to asyncio changes in core

This commit is contained in:
Maximilian Hils 2018-05-06 17:26:35 +02:00
parent 8f3db90def
commit 467dc81d19
3 changed files with 58 additions and 61 deletions

View File

@ -1,7 +1,7 @@
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy.net import server_spec from mitmproxy.net import server_spec
from mitmproxy.proxy.config import HostMatcher from mitmproxy.proxy.config import HostMatcher
from mitmproxy.proxy.protocol import is_tls_record_magic from mitmproxy.net.tls import is_tls_record_magic
from mitmproxy.proxy2 import layer, layers from mitmproxy.proxy2 import layer, layers
@ -34,7 +34,7 @@ class NextLayer:
nextlayer.context.server.tls = ( nextlayer.context.server.tls = (
server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https" server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https"
) )
nextlayer.layer = layers.TLSLayer(nextlayer.context) nextlayer.layer = layers.ClientTLSLayer(nextlayer.context)
return return
# TODO: Other top layers # TODO: Other top layers
@ -45,7 +45,7 @@ class NextLayer:
if client_tls: if client_tls:
nextlayer.context.client.tls = True nextlayer.context.client.tls = True
nextlayer.context.server.tls = True nextlayer.context.server.tls = True
nextlayer.layer = layers.TLSLayer(nextlayer.context) nextlayer.layer = layers.ClientTLSLayer(nextlayer.context)
return return
# 5. Check for --tcp # 5. Check for --tcp
@ -54,7 +54,7 @@ class NextLayer:
return return
# 6. Check for TLS ALPN (HTTP1/HTTP2) # 6. Check for TLS ALPN (HTTP1/HTTP2)
if isinstance(top_layer, layers.TLSLayer): if isinstance(top_layer, layers.ServerTLSLayer):
alpn = nextlayer.context.client.alpn alpn = nextlayer.context.client.alpn
if alpn == b'http/1.1': if alpn == b'http/1.1':
nextlayer.layer = layers.HTTPLayer(nextlayer.context) nextlayer.layer = layers.HTTPLayer(nextlayer.context)

View File

@ -1,63 +1,62 @@
import asyncio import asyncio
import queue
import threading
from mitmproxy import ctx, controller, log from mitmproxy import ctx, controller, log, options, master
from mitmproxy.proxy2 import commands from mitmproxy.proxy2 import commands
from mitmproxy.proxy2 import events from mitmproxy.proxy2 import events
from mitmproxy.proxy2 import server from mitmproxy.proxy2 import server
class AsyncReply(controller.Reply): class AsyncReply(controller.Reply):
# temporary glue code - let's see how many years it survives. """
def __init__(self, submit, *args): controller.Reply.q.get() is blocking, which we definitely want to avoid in a coroutine.
self.submit = submit This stub adds a .done asyncio.Event() that can be used instead.
"""
def __init__(self, *args):
self.done = asyncio.Event()
self.loop = asyncio.get_event_loop()
super().__init__(*args) super().__init__(*args)
def commit(self): def commit(self):
super().commit() super().commit()
self.submit(self.q.get_nowait()) self.loop.call_soon_threadsafe(lambda: self.done.set())
class ProxyConnectionHandler(server.ConnectionHandler): class ProxyConnectionHandler(server.ConnectionHandler):
event_queue: queue.Queue master: master.Master
loop: asyncio.AbstractEventLoop
def __init__(self, event_queue, loop, r, w, options): def __init__(self, master, r, w, options):
self.event_queue = event_queue self.master = master
self.loop = loop
super().__init__(r, w, options) super().__init__(r, w, options)
async def handle_hook(self, hook: commands.Hook) -> None: async def handle_hook(self, hook: commands.Hook) -> None:
q = asyncio.Queue() hook.data.reply = AsyncReply(hook.data)
await self.master.addons.handle_lifecycle(hook.name, hook.data)
hook.data.reply = AsyncReply( await hook.data.reply.done.wait()
lambda x: self.loop.call_soon_threadsafe(lambda: q.put_nowait(x)),
hook.data
)
self.event_queue.put((hook.name, hook.data))
await q.get()
if hook.blocking: if hook.blocking:
self.server_event(events.HookReply(hook)) self.server_event(events.HookReply(hook))
def log(self, message: str, level: str = "info") -> None: def log(self, message: str, level: str = "info") -> None:
x = log.LogEntry(message, level) x = log.LogEntry(message, level)
x.reply = controller.DummyReply() x.reply = controller.DummyReply()
self.event_queue.put(("log", x)) asyncio.ensure_future(
self.master.addons.handle_lifecycle("log", x)
)
class Proxyserver: class Proxyserver:
""" """
This addon runs the actual proxy server. This addon runs the actual proxy server.
""" """
server: asyncio.AbstractServer
listen_port: int
master: master.Master
options: options.Options
is_running: bool
def __init__(self): def __init__(self):
self.server = None
self.loop = asyncio.get_event_loop()
self.listen_port = None
self.event_queue = None
self.options = None
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self.server = None
self.is_running = False self.is_running = False
def load(self, loader): def load(self, loader):
@ -68,41 +67,39 @@ class Proxyserver:
) )
def running(self): def running(self):
self.master = ctx.master
self.options = ctx.options self.options = ctx.options
self.event_queue = ctx.master.event_queue
threading.Thread(target=self.loop.run_forever, daemon=True).start()
self.is_running = True self.is_running = True
self.configure(["listen_port"]) self.configure(["listen_port"])
async def start(self):
async with self._lock:
if self.server:
print("Stopping server...")
self.server.close()
await self.server.wait_closed()
print("Starting server...")
self.server = await asyncio.start_server(
self.handle_connection,
'127.0.0.1',
self.listen_port,
loop=self.loop
)
async def handle_connection(self, r, w):
await ProxyConnectionHandler(
self.event_queue,
self.loop,
r,
w,
self.options
).handle_client()
def configure(self, updated): def configure(self, updated):
if not self.is_running: if not self.is_running:
return return
if "listen_port" in updated: if "listen_port" in updated:
self.listen_port = ctx.options.listen_port + 1 self.listen_port = ctx.options.listen_port + 1
asyncio.ensure_future(self.start_server())
# not sure if this actually required... async def start_server(self):
self.loop.call_soon_threadsafe(lambda: asyncio.ensure_future(self.start())) async with self._lock:
if self.server:
await self.shutdown_server()
print("Starting server...")
self.server = await asyncio.start_server(
self.handle_connection,
'127.0.0.1',
self.listen_port,
)
async def shutdown_server(self):
print("Stopping server...")
self.server.close()
await self.server.wait_closed()
self.server = None
async def handle_connection(self, r, w):
await ProxyConnectionHandler(
self.master,
r,
w,
self.options
).handle_client()

View File

@ -5,7 +5,7 @@ from typing import MutableMapping, Optional, Iterator, Union, Generator, Any
from OpenSSL import SSL from OpenSSL import SSL
from mitmproxy.certs import CertStore from mitmproxy.certs import CertStore
from mitmproxy.proxy.protocol import TlsClientHello from mitmproxy.net.tls import ClientHello
from mitmproxy.proxy.protocol import tls from mitmproxy.proxy.protocol import tls
from mitmproxy.proxy2 import context from mitmproxy.proxy2 import context
from mitmproxy.proxy2 import layer, commands, events from mitmproxy.proxy2 import layer, commands, events
@ -69,7 +69,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
return None return None
def parse_client_hello(data: bytes) -> Optional[TlsClientHello]: def parse_client_hello(data: bytes) -> Optional[ClientHello]:
""" """
Check if the supplied bytes contain a full ClientHello message, Check if the supplied bytes contain a full ClientHello message,
and if so, parse it. and if so, parse it.
@ -84,7 +84,7 @@ def parse_client_hello(data: bytes) -> Optional[TlsClientHello]:
# Check if ClientHello is complete # Check if ClientHello is complete
client_hello = get_client_hello(data) client_hello = get_client_hello(data)
if client_hello: if client_hello:
return TlsClientHello(client_hello[4:]) return ClientHello(client_hello[4:])
return None return None