[sans-io] use asyncio_utils

This commit is contained in:
Maximilian Hils 2020-11-21 20:40:56 +01:00
parent 5bf07b2176
commit 9f89c23a52
3 changed files with 45 additions and 35 deletions

View File

@ -16,6 +16,7 @@ from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layers, server
from mitmproxy.proxy2.context import Context, Server
from mitmproxy.proxy2.layer import CommandGenerator
from mitmproxy.utils import asyncio_utils
class MockServer(layers.http.HttpConnection):
@ -99,7 +100,10 @@ class ClientPlayback:
self.task = None
def running(self):
self.playback_task = asyncio.create_task(self.playback())
self.playback_task = asyncio_utils.create_task(
self.playback(),
name="client playback"
)
self.options = ctx.options
def done(self):

View File

@ -6,7 +6,7 @@ from mitmproxy import controller, ctx, eventsequence, flow, log, master, options
from mitmproxy.flow import Error
from mitmproxy.proxy2 import commands
from mitmproxy.proxy2 import server
from mitmproxy.utils import human
from mitmproxy.utils import asyncio_utils, human
class AsyncReply(controller.Reply):
@ -52,11 +52,10 @@ class ProxyConnectionHandler(server.StreamConnectionHandler):
def log(self, message: str, level: str = "info") -> None:
x = log.LogEntry(self.log_prefix + message, level)
x.reply = controller.DummyReply()
coro = self.master.addons.handle_lifecycle("log", x)
try:
asyncio.ensure_future(coro)
except RuntimeError:
coro.close() # event loop may already be closed, but we don't want a "has never been awaited error"
asyncio_utils.create_task(
self.master.addons.handle_lifecycle("log", x),
name="ProxyConnectionHandler.log"
)
class Proxyserver:
@ -95,7 +94,7 @@ class Proxyserver:
if not self.is_running:
return
if any(x in updated for x in ["server", "listen_host", "listen_port"]):
asyncio.ensure_future(self.refresh_server())
asyncio.create_task(self.refresh_server())
async def refresh_server(self):
async with self._lock:
@ -118,7 +117,11 @@ class Proxyserver:
self.server = None
async def handle_connection(self, r, w):
asyncio.current_task().set_name(f"proxy connection handler {w.get_extra_info('peername')}")
asyncio_utils.set_task_debug_info(
asyncio.current_task(),
name=f"Proxyserver.handle_connection",
client=w.get_extra_info('peername'),
)
handler = ProxyConnectionHandler(
self.master,
r,

View File

@ -8,7 +8,6 @@ The very high level overview is as follows:
"""
import abc
import asyncio
import sys
import time
import traceback
import typing
@ -21,18 +20,11 @@ from mitmproxy.proxy.protocol.http import HTTPMode
from mitmproxy.proxy2 import commands, events, layer, layers, server_hooks
from mitmproxy.proxy2.context import Client, Connection, ConnectionState, Context
from mitmproxy.proxy2.layers import tls
from mitmproxy.utils import asyncio_utils
from mitmproxy.utils import human
from mitmproxy.utils.data import pkg_data
def cancel_task(task: asyncio.Task, message: str) -> None:
"""Cancel messages are only available in Python 3.9+"""
if sys.version_info >= (3, 9):
task.cancel(message)
else:
task.cancel()
class TimeoutWatchdog:
last_activity: float
CONNECTION_TIMEOUT = 10 * 60
@ -52,10 +44,7 @@ class TimeoutWatchdog:
async def watch(self):
while True:
await self.can_timeout.wait()
try:
await asyncio.sleep(self.CONNECTION_TIMEOUT - (time.time() - self.last_activity))
except asyncio.CancelledError:
return
await asyncio.sleep(self.CONNECTION_TIMEOUT - (time.time() - self.last_activity))
if self.last_activity + self.CONNECTION_TIMEOUT < time.time():
await self.callback()
return
@ -96,7 +85,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
async def handle_client(self) -> None:
watch = asyncio.create_task(self.timeout_watchdog.watch(), name="timeout watchdog")
watch = asyncio_utils.create_task(
self.timeout_watchdog.watch(),
name="timeout watchdog",
client=self.client.peername,
)
self.log("client connect")
await self.handle_hook(server_hooks.ClientConnectedHook(self.client))
@ -104,9 +97,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log("client kill connection")
self.transports.pop(self.client).writer.close()
else:
handler = asyncio.create_task(
handler = asyncio_utils.create_task(
self.handle_connection(self.client),
name=f"handle_connection {self.client.peername}"
name=f"client connection handler",
client=self.client.peername,
)
self.transports[self.client].handler = handler
self.server_event(events.Start())
@ -121,7 +115,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if self.transports:
self.log("closing transports...", "debug")
for io in self.transports.values():
cancel_task(io.handler, "client disconnected")
asyncio_utils.cancel_task(io.handler, "client disconnected")
await asyncio.wait([x.handler for x in self.transports.values()])
self.log("transports closed!", "debug")
@ -167,9 +161,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
else:
addr = human.format_address(command.connection.address)
self.log(f"server connect {addr}")
connected_hook = asyncio.create_task(
connected_hook = asyncio_utils.create_task(
self.handle_hook(server_hooks.ServerConnectedHook(hook_data)),
name=f"serverconnected {addr}"
name=f"handle_hook(server_connected) {addr}",
client=self.client.peername,
)
self.server_event(events.OpenConnectionReply(command, None))
@ -177,8 +172,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
# during connection opening, this function is the designated handler that can be cancelled.
# once we have a connection, we do want the teardown here to happen in any case, so we
# reassign the handler to .handle_connection and then clean up here once that is done.
new_handler = asyncio.create_task(self.handle_connection(command.connection),
name=f"handle_connection {command.connection.peername}")
new_handler = asyncio_utils.create_task(
self.handle_connection(command.connection),
name=f"server connection handler for {addr}",
client=self.client.peername,
)
self.transports[command.connection].handler = new_handler
await asyncio.wait([new_handler])
@ -233,7 +231,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
cancel_task(self.transports[self.client].handler, "timeout")
asyncio_utils.cancel_task(self.transports[self.client].handler, "timeout")
async def hook_task(self, hook: commands.Hook) -> None:
await self.handle_hook(hook)
@ -255,9 +253,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
if isinstance(command, commands.OpenConnection):
assert command.connection not in self.transports
handler = asyncio.create_task(
handler = asyncio_utils.create_task(
self.open_connection(command),
name=f"open_connection {command.connection.address}"
name=f"server connection manager {command.connection.address}",
client=self.client.peername,
)
self.transports[command.connection] = ConnectionIO(handler=handler)
elif isinstance(command, commands.ConnectionCommand) and command.connection not in self.transports:
@ -270,7 +269,11 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
socket = self.transports[command.connection].writer.get_extra_info("socket")
self.server_event(events.GetSocketReply(command, socket))
elif isinstance(command, commands.Hook):
asyncio.create_task(self.hook_task(command), name=f"hook {command.name}")
asyncio_utils.create_task(
self.hook_task(command),
name=f"handle_hook({command.name})",
client=self.client.peername,
)
elif isinstance(command, commands.Log):
self.log(command.message, command.level)
else:
@ -294,7 +297,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
connection.state = ConnectionState.CLOSED
if connection.state is ConnectionState.CLOSED:
cancel_task(self.transports[connection].handler, "closed by command")
asyncio_utils.cancel_task(self.transports[connection].handler, "closed by command")
class StreamConnectionHandler(ConnectionHandler, metaclass=abc.ABCMeta):