mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
[sans-io] adapt to asyncio changes in core
This commit is contained in:
parent
8f3db90def
commit
467dc81d19
@ -1,7 +1,7 @@
|
||||
from mitmproxy import ctx
|
||||
from mitmproxy.net import server_spec
|
||||
from mitmproxy.proxy.config import HostMatcher
|
||||
from mitmproxy.proxy.protocol import is_tls_record_magic
|
||||
from mitmproxy.net.tls import is_tls_record_magic
|
||||
from mitmproxy.proxy2 import layer, layers
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ class NextLayer:
|
||||
nextlayer.context.server.tls = (
|
||||
server_spec.parse_with_mode(ctx.options.mode)[1].scheme == "https"
|
||||
)
|
||||
nextlayer.layer = layers.TLSLayer(nextlayer.context)
|
||||
nextlayer.layer = layers.ClientTLSLayer(nextlayer.context)
|
||||
return
|
||||
# TODO: Other top layers
|
||||
|
||||
@ -45,7 +45,7 @@ class NextLayer:
|
||||
if client_tls:
|
||||
nextlayer.context.client.tls = True
|
||||
nextlayer.context.server.tls = True
|
||||
nextlayer.layer = layers.TLSLayer(nextlayer.context)
|
||||
nextlayer.layer = layers.ClientTLSLayer(nextlayer.context)
|
||||
return
|
||||
|
||||
# 5. Check for --tcp
|
||||
@ -54,7 +54,7 @@ class NextLayer:
|
||||
return
|
||||
|
||||
# 6. Check for TLS ALPN (HTTP1/HTTP2)
|
||||
if isinstance(top_layer, layers.TLSLayer):
|
||||
if isinstance(top_layer, layers.ServerTLSLayer):
|
||||
alpn = nextlayer.context.client.alpn
|
||||
if alpn == b'http/1.1':
|
||||
nextlayer.layer = layers.HTTPLayer(nextlayer.context)
|
||||
|
@ -1,63 +1,62 @@
|
||||
import asyncio
|
||||
import queue
|
||||
import threading
|
||||
|
||||
from mitmproxy import ctx, controller, log
|
||||
from mitmproxy import ctx, controller, log, options, master
|
||||
from mitmproxy.proxy2 import commands
|
||||
from mitmproxy.proxy2 import events
|
||||
from mitmproxy.proxy2 import server
|
||||
|
||||
|
||||
class AsyncReply(controller.Reply):
|
||||
# temporary glue code - let's see how many years it survives.
|
||||
def __init__(self, submit, *args):
|
||||
self.submit = submit
|
||||
"""
|
||||
controller.Reply.q.get() is blocking, which we definitely want to avoid in a coroutine.
|
||||
This stub adds a .done asyncio.Event() that can be used instead.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
self.done = asyncio.Event()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
super().__init__(*args)
|
||||
|
||||
def commit(self):
|
||||
super().commit()
|
||||
self.submit(self.q.get_nowait())
|
||||
self.loop.call_soon_threadsafe(lambda: self.done.set())
|
||||
|
||||
|
||||
class ProxyConnectionHandler(server.ConnectionHandler):
|
||||
event_queue: queue.Queue
|
||||
loop: asyncio.AbstractEventLoop
|
||||
master: master.Master
|
||||
|
||||
def __init__(self, event_queue, loop, r, w, options):
|
||||
self.event_queue = event_queue
|
||||
self.loop = loop
|
||||
def __init__(self, master, r, w, options):
|
||||
self.master = master
|
||||
super().__init__(r, w, options)
|
||||
|
||||
async def handle_hook(self, hook: commands.Hook) -> None:
|
||||
q = asyncio.Queue()
|
||||
|
||||
hook.data.reply = AsyncReply(
|
||||
lambda x: self.loop.call_soon_threadsafe(lambda: q.put_nowait(x)),
|
||||
hook.data
|
||||
)
|
||||
self.event_queue.put((hook.name, hook.data))
|
||||
await q.get()
|
||||
hook.data.reply = AsyncReply(hook.data)
|
||||
await self.master.addons.handle_lifecycle(hook.name, hook.data)
|
||||
await hook.data.reply.done.wait()
|
||||
if hook.blocking:
|
||||
self.server_event(events.HookReply(hook))
|
||||
|
||||
def log(self, message: str, level: str = "info") -> None:
|
||||
x = log.LogEntry(message, level)
|
||||
x.reply = controller.DummyReply()
|
||||
self.event_queue.put(("log", x))
|
||||
asyncio.ensure_future(
|
||||
self.master.addons.handle_lifecycle("log", x)
|
||||
)
|
||||
|
||||
|
||||
class Proxyserver:
|
||||
"""
|
||||
This addon runs the actual proxy server.
|
||||
"""
|
||||
server: asyncio.AbstractServer
|
||||
listen_port: int
|
||||
master: master.Master
|
||||
options: options.Options
|
||||
is_running: bool
|
||||
|
||||
def __init__(self):
|
||||
self.server = None
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.listen_port = None
|
||||
self.event_queue = None
|
||||
self.options = None
|
||||
self._lock = asyncio.Lock()
|
||||
self.server = None
|
||||
self.is_running = False
|
||||
|
||||
def load(self, loader):
|
||||
@ -68,41 +67,39 @@ class Proxyserver:
|
||||
)
|
||||
|
||||
def running(self):
|
||||
self.master = ctx.master
|
||||
self.options = ctx.options
|
||||
self.event_queue = ctx.master.event_queue
|
||||
threading.Thread(target=self.loop.run_forever, daemon=True).start()
|
||||
self.is_running = True
|
||||
self.configure(["listen_port"])
|
||||
|
||||
async def start(self):
|
||||
async with self._lock:
|
||||
if self.server:
|
||||
print("Stopping server...")
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
|
||||
print("Starting server...")
|
||||
self.server = await asyncio.start_server(
|
||||
self.handle_connection,
|
||||
'127.0.0.1',
|
||||
self.listen_port,
|
||||
loop=self.loop
|
||||
)
|
||||
|
||||
async def handle_connection(self, r, w):
|
||||
await ProxyConnectionHandler(
|
||||
self.event_queue,
|
||||
self.loop,
|
||||
r,
|
||||
w,
|
||||
self.options
|
||||
).handle_client()
|
||||
|
||||
def configure(self, updated):
|
||||
if not self.is_running:
|
||||
return
|
||||
if "listen_port" in updated:
|
||||
self.listen_port = ctx.options.listen_port + 1
|
||||
asyncio.ensure_future(self.start_server())
|
||||
|
||||
# not sure if this actually required...
|
||||
self.loop.call_soon_threadsafe(lambda: asyncio.ensure_future(self.start()))
|
||||
async def start_server(self):
|
||||
async with self._lock:
|
||||
if self.server:
|
||||
await self.shutdown_server()
|
||||
print("Starting server...")
|
||||
self.server = await asyncio.start_server(
|
||||
self.handle_connection,
|
||||
'127.0.0.1',
|
||||
self.listen_port,
|
||||
)
|
||||
|
||||
async def shutdown_server(self):
|
||||
print("Stopping server...")
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
self.server = None
|
||||
|
||||
async def handle_connection(self, r, w):
|
||||
await ProxyConnectionHandler(
|
||||
self.master,
|
||||
r,
|
||||
w,
|
||||
self.options
|
||||
).handle_client()
|
||||
|
@ -5,7 +5,7 @@ from typing import MutableMapping, Optional, Iterator, Union, Generator, Any
|
||||
from OpenSSL import SSL
|
||||
|
||||
from mitmproxy.certs import CertStore
|
||||
from mitmproxy.proxy.protocol import TlsClientHello
|
||||
from mitmproxy.net.tls import ClientHello
|
||||
from mitmproxy.proxy.protocol import tls
|
||||
from mitmproxy.proxy2 import context
|
||||
from mitmproxy.proxy2 import layer, commands, events
|
||||
@ -69,7 +69,7 @@ def get_client_hello(data: bytes) -> Optional[bytes]:
|
||||
return None
|
||||
|
||||
|
||||
def parse_client_hello(data: bytes) -> Optional[TlsClientHello]:
|
||||
def parse_client_hello(data: bytes) -> Optional[ClientHello]:
|
||||
"""
|
||||
Check if the supplied bytes contain a full ClientHello message,
|
||||
and if so, parse it.
|
||||
@ -84,7 +84,7 @@ def parse_client_hello(data: bytes) -> Optional[TlsClientHello]:
|
||||
# Check if ClientHello is complete
|
||||
client_hello = get_client_hello(data)
|
||||
if client_hello:
|
||||
return TlsClientHello(client_hello[4:])
|
||||
return ClientHello(client_hello[4:])
|
||||
return None
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user