fix vt code detection on Windows

This commit is contained in:
Maximilian Hils 2022-03-19 16:35:40 +01:00
parent 218c942808
commit 9243ba4e25
8 changed files with 72 additions and 11 deletions

View File

@ -15,6 +15,7 @@ from mitmproxy.contrib import click as miniclick
from mitmproxy.tcp import TCPFlow, TCPMessage from mitmproxy.tcp import TCPFlow, TCPMessage
from mitmproxy.utils import human from mitmproxy.utils import human
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
from mitmproxy.utils import vt_codes
from mitmproxy.websocket import WebSocketData, WebSocketMessage from mitmproxy.websocket import WebSocketData, WebSocketMessage
@ -36,7 +37,7 @@ class Dumper:
def __init__(self, outfile: Optional[IO[str]] = None): def __init__(self, outfile: Optional[IO[str]] = None):
self.filter: Optional[flowfilter.TFilter] = None self.filter: Optional[flowfilter.TFilter] = None
self.outfp: IO[str] = outfile or sys.stdout self.outfp: IO[str] = outfile or sys.stdout
self.isatty = self.outfp.isatty() self.out_has_vt_codes = vt_codes.ensure_supported(self.outfp)
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
@ -71,7 +72,7 @@ class Dumper:
self.filter = None self.filter = None
def style(self, text: str, **style) -> str: def style(self, text: str, **style) -> str:
if style and self.isatty: if style and self.out_has_vt_codes:
text = miniclick.style(text, **style) text = miniclick.style(text, **style)
return text return text

View File

@ -4,6 +4,7 @@ from typing import IO, Optional
from mitmproxy import ctx from mitmproxy import ctx
from mitmproxy import log from mitmproxy import log
from mitmproxy.contrib import click as miniclick from mitmproxy.contrib import click as miniclick
from mitmproxy.utils import vt_codes
LOG_COLORS = {'error': "red", 'warn': "yellow", 'alert': "magenta"} LOG_COLORS = {'error': "red", 'warn': "yellow", 'alert': "magenta"}
@ -15,9 +16,9 @@ class TermLog:
err: Optional[IO[str]] = None, err: Optional[IO[str]] = None,
): ):
self.out_file: IO[str] = out or sys.stdout self.out_file: IO[str] = out or sys.stdout
self.out_isatty = self.out_file.isatty() self.out_has_vt_codes = vt_codes.ensure_supported(self.out_file)
self.err_file: IO[str] = err or sys.stderr self.err_file: IO[str] = err or sys.stderr
self.err_isatty = self.err_file.isatty() self.err_has_vt_codes = vt_codes.ensure_supported(self.err_file)
def load(self, loader): def load(self, loader):
loader.add_option( loader.add_option(
@ -30,13 +31,13 @@ class TermLog:
if log.log_tier(ctx.options.termlog_verbosity) >= log.log_tier(e.level): if log.log_tier(ctx.options.termlog_verbosity) >= log.log_tier(e.level):
if e.level == "error": if e.level == "error":
f = self.err_file f = self.err_file
isatty = self.err_isatty has_vt_codes = self.err_has_vt_codes
else: else:
f = self.out_file f = self.out_file
isatty = self.out_isatty has_vt_codes = self.out_has_vt_codes
msg = e.msg msg = e.msg
if isatty: if has_vt_codes:
msg = miniclick.style( msg = miniclick.style(
e.msg, e.msg,
fg=LOG_COLORS.get(e.level), fg=LOG_COLORS.get(e.level),

View File

@ -4,6 +4,7 @@ from ctypes.wintypes import BOOL, DWORD, WCHAR, WORD, SHORT, UINT, HANDLE, LPDWO
# https://docs.microsoft.com/de-de/windows/console/getstdhandle # https://docs.microsoft.com/de-de/windows/console/getstdhandle
STD_INPUT_HANDLE = -10 STD_INPUT_HANDLE = -10
STD_OUTPUT_HANDLE = -11 STD_OUTPUT_HANDLE = -11
STD_ERROR_HANDLE = -12
# https://docs.microsoft.com/de-de/windows/console/setconsolemode # https://docs.microsoft.com/de-de/windows/console/setconsolemode
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004 ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004

View File

@ -0,0 +1,51 @@
"""
This module provides a method to detect if a given file object supports virtual terminal escape codes.
"""
import os
import sys
from typing import IO
if os.name == "nt":
from ctypes import byref, windll # type: ignore
from ctypes.wintypes import BOOL, DWORD, HANDLE, LPDWORD
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
STD_OUTPUT_HANDLE = -11
STD_ERROR_HANDLE = -12
# https://docs.microsoft.com/de-de/windows/console/getstdhandle
GetStdHandle = windll.kernel32.GetStdHandle
GetStdHandle.argtypes = [DWORD]
GetStdHandle.restype = HANDLE
# https://docs.microsoft.com/de-de/windows/console/getconsolemode
GetConsoleMode = windll.kernel32.GetConsoleMode
GetConsoleMode.argtypes = [HANDLE, LPDWORD]
GetConsoleMode.restype = BOOL
# https://docs.microsoft.com/de-de/windows/console/setconsolemode
SetConsoleMode = windll.kernel32.SetConsoleMode
SetConsoleMode.argtypes = [HANDLE, DWORD]
SetConsoleMode.restype = BOOL
def ensure_supported(f: IO[str]) -> bool:
if not f.isatty():
return False
if f == sys.stdout:
h = STD_OUTPUT_HANDLE
elif f == sys.stderr:
h = STD_ERROR_HANDLE
else:
return False
handle = GetStdHandle(h)
console_mode = DWORD()
ok = GetConsoleMode(handle, byref(console_mode))
if not ok:
return False
ok = SetConsoleMode(handle, console_mode.value | ENABLE_VIRTUAL_TERMINAL_PROCESSING)
return ok
else:
def ensure_supported(f: IO[str]) -> bool:
return f.isatty()

View File

@ -70,4 +70,5 @@ exclude =
mitmproxy/proxy/server.py mitmproxy/proxy/server.py
mitmproxy/proxy/layers/tls.py mitmproxy/proxy/layers/tls.py
mitmproxy/utils/bits.py mitmproxy/utils/bits.py
mitmproxy/utils/vt_codes.py
release/hooks release/hooks

View File

@ -243,9 +243,9 @@ def test_http2():
def test_styling(): def test_styling():
sio = io.StringIO() sio = io.StringIO()
sio.isatty = lambda: True
d = dumper.Dumper(sio) d = dumper.Dumper(sio)
d.out_has_vt_codes = True
with taddons.context(d): with taddons.context(d):
d.response(tflow.tflow(resp=True)) d.response(tflow.tflow(resp=True))
assert "\x1b[" in sio.getvalue() assert "\x1b[" in sio.getvalue()

View File

@ -19,11 +19,10 @@ def test_output(capsys):
assert err.strip().splitlines() == ["four"] assert err.strip().splitlines() == ["four"]
def test_styling() -> None: def test_styling(monkeypatch) -> None:
f = io.StringIO() f = io.StringIO()
f.isatty = lambda: True
t = termlog.TermLog(out=f) t = termlog.TermLog(out=f)
t.out_has_vt_codes = True
with taddons.context(t) as tctx: with taddons.context(t) as tctx:
tctx.configure(t) tctx.configure(t)
t.add_log(log.LogEntry("hello world", "info")) t.add_log(log.LogEntry("hello world", "info"))

View File

@ -0,0 +1,7 @@
import io
from mitmproxy.utils.vt_codes import ensure_supported
def test_simple():
assert not ensure_supported(io.StringIO())