[sans-io] add timeout watchdog to close lingering tcp streams

This commit is contained in:
Maximilian Hils 2019-11-07 18:18:23 +01:00
parent 549e41ee40
commit fbe1d73eab

View File

@ -11,8 +11,10 @@ import asyncio
import logging import logging
import socket import socket
import sys import sys
import time
import traceback import traceback
import typing import typing
from contextlib import contextmanager
from mitmproxy import http, options as moptions from mitmproxy import http, options as moptions
from mitmproxy.proxy.protocol.http import HTTPMode from mitmproxy.proxy.protocol.http import HTTPMode
@ -27,8 +29,49 @@ class StreamIO(typing.NamedTuple):
w: asyncio.StreamWriter w: asyncio.StreamWriter
class TimeoutWatchdog:
last_activity: float
CONNECTION_TIMEOUT = 120
can_timeout: asyncio.Event
blocker: int
def __init__(self, callback: typing.Callable[[], typing.Any]):
self.callback = callback
self.last_activity = time.time()
self.can_timeout = asyncio.Event()
self.can_timeout.set()
self.blocker = 0
def register_activity(self):
self.last_activity = time.time()
async def watch(self):
while True:
await self.can_timeout.wait()
try:
await asyncio.sleep(self.CONNECTION_TIMEOUT - (time.time() - self.last_activity))
except asyncio.CancelledError:
return
if self.last_activity + self.CONNECTION_TIMEOUT < time.time():
await self.callback()
return
@contextmanager
def disarm(self):
self.can_timeout.clear()
self.blocker += 1
try:
yield
finally:
self.blocker -= 1
if self.blocker == 0:
self.register_activity()
self.can_timeout.set()
class ConnectionHandler(metaclass=abc.ABCMeta): class ConnectionHandler(metaclass=abc.ABCMeta):
transports: typing.MutableMapping[Connection, StreamIO] transports: typing.MutableMapping[Connection, StreamIO]
timeout_watchdog: TimeoutWatchdog
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None: def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, options: moptions.Options) -> None:
addr = writer.get_extra_info('peername') addr = writer.get_extra_info('peername')
@ -36,6 +79,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.client = Client(addr) self.client = Client(addr)
self.context = Context(self.client, options) self.context = Context(self.client, options)
self.layer = layer.NextLayer(self.context) self.layer = layer.NextLayer(self.context)
self.timeout_watchdog = TimeoutWatchdog(self.on_timeout)
# Ask for the first layer right away. # Ask for the first layer right away.
# In a reverse proxy scenario, this is necessary as we would otherwise hang # In a reverse proxy scenario, this is necessary as we would otherwise hang
@ -50,12 +94,14 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
# FIXME: Work around log suppression in core. # FIXME: Work around log suppression in core.
logging.getLogger('asyncio').setLevel(logging.DEBUG) logging.getLogger('asyncio').setLevel(logging.DEBUG)
watch = asyncio.ensure_future(self.timeout_watchdog.watch())
self.log("[sans-io] clientconnect") self.log("[sans-io] clientconnect")
self.server_event(events.Start()) self.server_event(events.Start())
await self.handle_connection(self.client) await self.handle_connection(self.client)
self.log("[sans-io] clientdisconnected") self.log("[sans-io] clientdisconnected")
watch.cancel()
if self.transports: if self.transports:
self.log("[sans-io] closing transports...") self.log("[sans-io] closing transports...")
@ -65,6 +111,10 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
]) ])
self.log("[sans-io] transports closed!") self.log("[sans-io] transports closed!")
async def on_timeout(self) -> None:
self.log(f"Closing connection due to inactivity: {self.client}")
await self.close_connection(self.client)
async def close_connection(self, connection: Connection) -> None: async def close_connection(self, connection: Connection) -> None:
self.log(f"closing {connection}", "debug") self.log(f"closing {connection}", "debug")
connection.state = ConnectionState.CLOSED connection.state = ConnectionState.CLOSED
@ -123,12 +173,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
print(message) print(message)
def server_event(self, event: events.Event) -> None: def server_event(self, event: events.Event) -> None:
self.timeout_watchdog.register_activity()
try: try:
self._server_event(event)
except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
def _server_event(self, event: events.Event) -> None:
layer_commands = self.layer.handle_event(event) layer_commands = self.layer.handle_event(event)
for command in layer_commands: for command in layer_commands:
if isinstance(command, commands.OpenConnection): if isinstance(command, commands.OpenConnection):
@ -161,6 +207,8 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
self.log(command.message, command.level) self.log(command.message, command.level)
else: else:
raise RuntimeError(f"Unexpected command: {command}") raise RuntimeError(f"Unexpected command: {command}")
except Exception:
self.log(f"mitmproxy has crashed!\n{traceback.format_exc()}", level="error")
class SimpleConnectionHandler(ConnectionHandler): class SimpleConnectionHandler(ConnectionHandler):
@ -220,7 +268,7 @@ if __name__ == "__main__":
else: else:
flow.request.url = flow.request.url.replace("http://", "https://") flow.request.url = flow.request.url.replace("http://", "https://")
if "redirect" in flow.request.path: if "redirect" in flow.request.path:
flow.request.url = "https://httpbin.org/robots.txt" flow.request.host = "httpbin.org"
await SimpleConnectionHandler(reader, writer, opts, { await SimpleConnectionHandler(reader, writer, opts, {
"next_layer": next_layer, "next_layer": next_layer,