mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
[sans-io] add timeout watchdog to close lingering tcp streams
This commit is contained in:
parent
549e41ee40
commit
fbe1d73eab
@ -11,8 +11,10 @@ import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing
|
||||
from contextlib import contextmanager
|
||||
|
||||
from mitmproxy import http, options as moptions
|
||||
from mitmproxy.proxy.protocol.http import HTTPMode
|
||||
@ -27,8 +29,49 @@ class StreamIO(typing.NamedTuple):
|
||||
w: asyncio.StreamWriter
|
||||
|
||||
|
||||
class TimeoutWatchdog:
|
||||
last_activity: float
|
||||
CONNECTION_TIMEOUT = 120
|
||||
can_timeout: asyncio.Event
|
||||
blocker: int
|
||||
|
||||
def __init__(self, callback: typing.Callable[[], typing.Any]):
|
||||
self.callback = callback
|
||||
self.last_activity = time.time()
|
||||
self.can_timeout = asyncio.Event()
|
||||
self.can_timeout.set()
|
||||
self.blocker = 0
|
||||
|
||||
def register_activity(self):
|
||||
self.last_activity = time.time()
|
||||
|
||||
async def watch(self):
|
||||
while True:
|
||||
await self.can_timeout.wait()
|
||||
try:
|
||||
await asyncio.sleep(self.CONNECTION_TIMEOUT - (time.time() - self.last_activity))
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
if self.last_activity + self.CONNECTION_TIMEOUT < time.time():
|
||||
await self.callback()
|
||||
return
|
||||
|
||||
@contextmanager
|
||||
def disarm(self):
|
||||
self.can_timeout.clear()
|
||||
self.blocker += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.blocker -= 1
|
||||
if self.blocker == 0:
|
||||
self.register_activity()
|
||||
self.can_timeout.set()
|
||||
|
||||
|
||||
class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
transports: typing.MutableMapping[Connection, StreamIO]
|
||||
timeout_watchdog: TimeoutWatchdog
|
||||
|
||||
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
|
||||
addr = writer.get_extra_info('peername')
|
||||
@ -36,6 +79,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
self.client = Client(addr)
|
||||
self.context = Context(self.client, options)
|
||||
self.layer = layer.NextLayer(self.context)
|
||||
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
|
||||
|
||||
# Ask for the first layer right away.
|
||||
# In a reverse proxy scenario, this is necessary as we would otherwise hang
|
||||
@ -50,12 +94,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
# FIXME: Work around log suppression in core.
|
||||
logging.getLogger('asyncio').setLevel(logging.DEBUG)
|
||||
|
||||
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
|
||||
self.log("[sans-io] clientconnect")
|
||||
|
||||
self.server_event(events.Start())
|
||||
await self.handle_connection(self.client)
|
||||
|
||||
self.log("[sans-io] clientdisconnected")
|
||||
watch.cancel()
|
||||
|
||||
if self.transports:
|
||||
self.log("[sans-io] closing transports...")
|
||||
@ -65,6 +111,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
])
|
||||
self.log("[sans-io] transports closed!")
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
self.log(f"Closing connection due to inactivity: {self.client}")
|
||||
await self.close_connection(self.client)
|
||||
|
||||
async def close_connection(self, connection: Connection) -> None:
|
||||
self.log(f"closing {connection}", "debug")
|
||||
connection.state = ConnectionState.CLOSED
|
||||
@ -123,45 +173,43 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
|
||||
print(message)
|
||||
|
||||
def server_event(self, event: events.Event) -> None:
|
||||
self.timeout_watchdog.register_activity()
|
||||
try:
|
||||
self._server_event(event)
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
asyncio.ensure_future(
|
||||
self.open_connection(command)
|
||||
)
|
||||
elif isinstance(command, commands.SendData):
|
||||
try:
|
||||
io = self.transports[command.connection]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
|
||||
else:
|
||||
io.w.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if command.connection == self.client:
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
else:
|
||||
asyncio.ensure_future(
|
||||
self.shutdown_connection(command.connection)
|
||||
)
|
||||
elif isinstance(command, glue.GlueGetConnectionHandler):
|
||||
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio.ensure_future(
|
||||
self.handle_hook(command)
|
||||
)
|
||||
elif isinstance(command, commands.Log):
|
||||
self.log(command.message, command.level)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected command: {command}")
|
||||
except Exception:
|
||||
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
|
||||
|
||||
def _server_event(self, event: events.Event) -> None:
|
||||
layer_commands = self.layer.handle_event(event)
|
||||
for command in layer_commands:
|
||||
if isinstance(command, commands.OpenConnection):
|
||||
asyncio.ensure_future(
|
||||
self.open_connection(command)
|
||||
)
|
||||
elif isinstance(command, commands.SendData):
|
||||
try:
|
||||
io = self.transports[command.connection]
|
||||
except KeyError:
|
||||
raise RuntimeError(f"Cannot write to closed connection: {command.connection}")
|
||||
else:
|
||||
io.w.write(command.data)
|
||||
elif isinstance(command, commands.CloseConnection):
|
||||
if command.connection == self.client:
|
||||
asyncio.ensure_future(
|
||||
self.close_connection(command.connection)
|
||||
)
|
||||
else:
|
||||
asyncio.ensure_future(
|
||||
self.shutdown_connection(command.connection)
|
||||
)
|
||||
elif isinstance(command, glue.GlueGetConnectionHandler):
|
||||
self.server_event(glue.GlueGetConnectionHandlerReply(command, self))
|
||||
elif isinstance(command, commands.Hook):
|
||||
asyncio.ensure_future(
|
||||
self.handle_hook(command)
|
||||
)
|
||||
elif isinstance(command, commands.Log):
|
||||
self.log(command.message, command.level)
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected command: {command}")
|
||||
|
||||
|
||||
class SimpleConnectionHandler(ConnectionHandler):
|
||||
"""Simple handler that does not really process any hooks."""
|
||||
@ -220,7 +268,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
flow.request.url = flow.request.url.replace("http://", "https://")
|
||||
if "redirect" in flow.request.path:
|
||||
flow.request.url = "https://httpbin.org/robots.txt"
|
||||
flow.request.host = "httpbin.org"
|
||||
|
||||
await SimpleConnectionHandler(reader, writer, opts, {
|
||||
"next_layer": next_layer,
|
||||
|
Loading…
Reference in New Issue
Block a user