pyupgrade --keep-runtime-typing --py38-plus

This commit is contained in:
Maximilian Hils 2022-03-29 14:00:41 +02:00
parent beb49ab121
commit 9d1e3107e8
32 changed files with 74 additions and 76 deletions

View File

@ -21,7 +21,7 @@ class MyAddon:
totals[f.request.host] = totals.setdefault(f.request.host, 0) + 1
with open(path, "w+") as fp:
for cnt, dom in sorted([(v, k) for (k, v) in totals.items()]):
for cnt, dom in sorted((v, k) for (k, v) in totals.items()):
fp.write(f"{cnt}: {dom}\n")
ctx.log.alert("done")

View File

@ -15,7 +15,7 @@ def websocket_message(flow: http.HTTPFlow):
last_message = flow.websocket.messages[-1]
if last_message.is_text and "secret" in last_message.text:
last_message.drop()
ctx.master.commands.call("inject.websocket", flow, last_message.from_client, "ssssssh".encode())
ctx.master.commands.call("inject.websocket", flow, last_message.from_client, b"ssssssh")
# Complex example: Schedule a periodic timer

View File

@ -3,14 +3,15 @@ import binascii
import socket
import typing
from ntlm_auth import gss_channel_bindings, ntlm
from mitmproxy import addonmanager, http
from mitmproxy import ctx
from mitmproxy import http, addonmanager
from mitmproxy.net.http import http1
from mitmproxy.proxy import layer, commands
from mitmproxy.proxy import commands, layer
from mitmproxy.proxy.context import Context
from mitmproxy.proxy.layers.http import HttpConnectUpstreamHook, HttpLayer, HttpStream
from mitmproxy.proxy.layers.http._upstream_proxy import HttpUpstreamProxy
from ntlm_auth import gss_channel_bindings, ntlm
class NTLMUpstreamAuth:
@ -148,7 +149,7 @@ class CustomNTLMContext:
ntlm_compatibility: int = ctx.options.upstream_ntlm_compatibility
username, password = tuple(auth.split(":"))
workstation = socket.gethostname().upper()
ctx.log.debug('\nntlm context with the details: "{}\\{}", *****'.format(domain, username))
ctx.log.debug(f'\nntlm context with the details: "{domain}\\{username}", *****')
self.ctx_log = ctx.log
self.preferred_type = preferred_type
self.ntlm_context = ntlm.NtlmContext(
@ -163,7 +164,7 @@ class CustomNTLMContext:
negotiate_message = self.ntlm_context.step()
negotiate_message_base_64_in_bytes = base64.b64encode(negotiate_message)
negotiate_message_base_64_ascii = negotiate_message_base_64_in_bytes.decode("ascii")
negotiate_message_base_64_final = u'%s %s' % (self.preferred_type, negotiate_message_base_64_ascii)
negotiate_message_base_64_final = f'{self.preferred_type} {negotiate_message_base_64_ascii}'
self.ctx_log.debug(
f'{self.preferred_type} Authentication, negotiate message: {negotiate_message_base_64_final}'
)
@ -174,11 +175,11 @@ class CustomNTLMContext:
try:
challenge_message_ascii_bytes = base64.b64decode(challenge_message, validate=True)
except binascii.Error as err:
self.ctx_log.debug('{} Authentication fail with error {}'.format(self.preferred_type, err.__str__()))
self.ctx_log.debug(f'{self.preferred_type} Authentication fail with error {err.__str__()}')
return False
authenticate_message = self.ntlm_context.step(challenge_message_ascii_bytes)
negotiate_message_base_64 = u'%s %s' % (self.preferred_type,
base64.b64encode(authenticate_message).decode('ascii'))
negotiate_message_base_64 = '{} {}'.format(self.preferred_type,
base64.b64encode(authenticate_message).decode('ascii'))
self.ctx_log.debug(
f'{self.preferred_type} Authentication, response to challenge message: {negotiate_message_base_64}'
)

View File

@ -178,7 +178,7 @@ class Core:
req.url = val
except ValueError as e:
raise exceptions.CommandError(
"URL {} is invalid: {}".format(repr(val), e)
f"URL {repr(val)} is invalid: {e}"
) from e
else:
self.rupdate = False
@ -199,7 +199,7 @@ class Core:
updated.append(f)
ctx.master.addons.trigger(hooks.UpdateHook(updated))
ctx.log.alert("Set {} on {} flows.".format(attr, len(updated)))
ctx.log.alert(f"Set {attr} on {len(updated)} flows.")
@command.command("flow.decode")
def decode(self, flows: typing.Sequence[flow.Flow], part: str) -> None:

View File

@ -63,7 +63,7 @@ def curl_command(f: flow.Flow) -> str:
server_addr = f.server_conn.peername[0] if f.server_conn.peername else None
if ctx.options.export_preserve_original_ip and server_addr and request.pretty_host != server_addr:
resolve = "{}:{}:[{}]".format(request.pretty_host, request.port, server_addr)
resolve = f"{request.pretty_host}:{request.port}:[{server_addr}]"
args.append("--resolve")
args.append(resolve)

View File

@ -112,7 +112,7 @@ class ServerPlayback:
@command.command("replay.server.count")
def count(self) -> int:
return sum([len(i) for i in self.flowmap.values()])
return sum(len(i) for i in self.flowmap.values())
def _hash(self, flow: http.HTTPFlow) -> typing.Hashable:
"""

View File

@ -265,11 +265,11 @@ class CommandManager:
parts, _ = self.parse_partial(cmdstr)
if not parts:
raise exceptions.CommandError(f"Invalid command: {cmdstr!r}")
command_name, *args = [
command_name, *args = (
unquote(part.value)
for part in parts
if part.type != mitmproxy.types.Space
]
)
return self.call_strings(command_name, args)
def dump(self, out=sys.stdout) -> None:

View File

@ -67,4 +67,4 @@ if __name__ == "__main__": # pragma: no cover
t = time.time()
x = beautify(data)
print("Beautifying vendor.css took {:.2}s".format(time.time() - t))
print(f"Beautifying vendor.css took {time.time() - t:.2}s")

View File

@ -20,7 +20,7 @@ def format_query_list(data: typing.List[typing.Any]):
num_queries = len(data) - 1
result = ""
for i, op in enumerate(data):
result += "--- {i}/{num_queries}\n".format(i=i, num_queries=num_queries)
result += f"--- {i}/{num_queries}\n"
result += format_graphql(op)
return result

View File

@ -200,7 +200,7 @@ class ProtoParser:
# read field key (tag and wire_type)
offset, key = ProtoParser._read_base128le(wire_data[pos:])
# casting raises exception for invalid WireTypes
wt = ProtoParser.WireTypes((key & 7))
wt = ProtoParser.WireTypes(key & 7)
tag = (key >> 3)
pos += offset
@ -232,7 +232,7 @@ class ProtoParser:
wt == ProtoParser.WireTypes.group_start or
wt == ProtoParser.WireTypes.group_end
):
raise ValueError("deprecated field: {}".format(wt))
raise ValueError(f"deprecated field: {wt}")
elif wt == ProtoParser.WireTypes.bit_32:
offset, val = ProtoParser._read_u32(wire_data[pos:])
pos += offset
@ -735,8 +735,7 @@ class ProtoParser:
yield field_desc_dict
# add sub-fields of messages or packed fields
for f in decoded_val:
for field_dict in f.gen_flat_decoded_field_dicts():
yield field_dict
yield from f.gen_flat_decoded_field_dicts()
else:
field_desc_dict["val"] = decoded_val
yield field_desc_dict
@ -767,8 +766,7 @@ class ProtoParser:
def gen_flat_decoded_field_dicts(self) -> Generator[Dict, None, None]:
for f in self.root_fields:
for field_dict in f.gen_flat_decoded_field_dicts():
yield field_dict
yield from f.gen_flat_decoded_field_dicts()
def gen_str_rows(self) -> Generator[Tuple[str, ...], None, None]:
for field_dict in self.gen_flat_decoded_field_dicts():
@ -879,8 +877,7 @@ def hack_generator_to_list(generator_func):
def format_pbuf(message: bytes, parser_options: ProtoParser.ParserOptions, rules: List[ProtoParser.ParserRule]):
for l in format_table(ProtoParser(data=message, parser_options=parser_options, rules=rules).gen_str_rows()):
yield l
yield from format_table(ProtoParser(data=message, parser_options=parser_options, rules=rules).gen_str_rows())
def format_grpc(
@ -895,12 +892,11 @@ def format_grpc(
compression_scheme if compressed else compressed) + ')'
yield [("text", headline)]
for l in format_pbuf(
yield from format_pbuf(
message=pb_message,
parser_options=parser_options,
rules=rules
):
yield l
)
@dataclass

View File

@ -3,10 +3,10 @@ import typing
from kaitaistruct import KaitaiStream
from mitmproxy.contrib.kaitaistruct import png
from mitmproxy.contrib.kaitaistruct import gif
from mitmproxy.contrib.kaitaistruct import jpeg
from mitmproxy.contrib.kaitaistruct import ico
from mitmproxy.contrib.kaitaistruct import jpeg
from mitmproxy.contrib.kaitaistruct import png
Metadata = typing.List[typing.Tuple[str, str]]
@ -91,12 +91,13 @@ def parse_ico(data: bytes) -> Metadata:
for i, image in enumerate(img.images):
parts.append(
(
'Image {}'.format(i + 1), "Size: {} x {}\n"
"{: >18}Bits per pixel: {}\n"
"{: >18}PNG: {}".format(256 if not image.width else image.width,
256 if not image.height else image.height,
'', image.bpp,
'', image.is_png)
f'Image {i + 1}',
"Size: {} x {}\n"
"{: >18}Bits per pixel: {}\n"
"{: >18}PNG: {}".format(256 if not image.width else image.width,
256 if not image.height else image.height,
'', image.bpp,
'', image.is_png)
)
)

View File

@ -376,7 +376,7 @@ class FSrc(_Rex):
def __call__(self, f):
if not f.client_conn or not f.client_conn.peername:
return False
r = "{}:{}".format(f.client_conn.peername[0], f.client_conn.peername[1])
r = f"{f.client_conn.peername[0]}:{f.client_conn.peername[1]}"
return f.client_conn.peername and self.re.search(r)
@ -388,7 +388,7 @@ class FDst(_Rex):
def __call__(self, f):
if not f.server_conn or not f.server_conn.address:
return False
r = "{}:{}".format(f.server_conn.address[0], f.server_conn.address[1])
r = f"{f.server_conn.address[0]}:{f.server_conn.address[1]}"
return f.server_conn.address and self.re.search(r)

View File

@ -510,7 +510,7 @@ class Message(serializable.Serializable):
self.headers["content-encoding"] = encoding
self.content = self.raw_content
if "content-encoding" not in self.headers:
raise ValueError("Invalid content encoding {}".format(repr(encoding)))
raise ValueError(f"Invalid content encoding {repr(encoding)}")
def json(self, **kwargs: Any) -> Any:
"""
@ -1033,7 +1033,7 @@ class Response(Message):
reason = reason.encode("ascii", "strict")
if isinstance(content, str):
raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
raise ValueError(f"Content must be bytes, not {type(content).__name__}")
if not isinstance(headers, Headers):
headers = Headers(headers)
if trailers is not None and not isinstance(trailers, Headers):

View File

@ -146,7 +146,7 @@ def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int:
write(span)
return size + 1 + len(span)
else:
raise ValueError("unserializable object: {} ({})".format(value, type(value)))
raise ValueError(f"unserializable object: {value} ({type(value)})")
def loads(string: bytes) -> TSerializable:

View File

@ -63,7 +63,7 @@ class WebsocketConnection(wsproto.Connection):
frame_buf: List[bytes]
def __init__(self, *args, conn: connection.Connection, **kwargs):
super(WebsocketConnection, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.conn = conn
self.frame_buf = [b""]

View File

@ -153,7 +153,7 @@ class ConnectionHandler(metaclass=abc.ABCMeta):
try:
command.connection.timestamp_start = time.time()
reader, writer = await asyncio.open_connection(*command.connection.address)
except (IOError, asyncio.CancelledError) as e:
except (OSError, asyncio.CancelledError) as e:
err = str(e)
if not err: # str(CancelledError()) returns empty string.
err = "connection cancelled"

View File

@ -55,7 +55,7 @@ class TCPFlow(flow.Flow):
_stateobject_attributes["messages"] = List[TCPMessage]
def __repr__(self):
return "<TCPFlow ({} messages)>".format(len(self.messages))
return f"<TCPFlow ({len(self.messages)} messages)>"
__all__ = [

View File

@ -46,7 +46,7 @@ class EventLog(urwid.ListBox, layoutwidget.LayoutWidget):
def add_event(self, event_store, entry: log.LogEntry):
if log.log_tier(self.master.options.console_eventlog_verbosity) < log.log_tier(entry.level):
return
txt = "{}: {}".format(entry.level, str(entry.msg))
txt = f"{entry.level}: {str(entry.msg)}"
if entry.level in ("error", "warn", "alert"):
e = urwid.Text((entry.level, txt))
else:

View File

@ -89,7 +89,7 @@ class ConsoleMaster(master.Master):
signals.status_message.send(
message = (
entry.level,
"{}: {}".format(entry.level.title(), str(entry.msg).lstrip())
f"{entry.level.title()}: {str(entry.msg).lstrip()}"
),
expire=5
)

View File

@ -108,7 +108,7 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget):
self.master = master
self.choices = choices
self.callback = callback
choicewidth = max([len(i) for i in choices])
choicewidth = max(len(i) for i in choices)
self.width = max(choicewidth, len(title)) + 7
self.walker = ChooserListWalker(choices, current)

View File

@ -100,7 +100,7 @@ def run(
opts.update(**extra(args))
except exceptions.OptionsError as e:
print("{}: {}".format(sys.argv[0], e), file=sys.stderr)
print(f"{sys.argv[0]}: {e}", file=sys.stderr)
sys.exit(1)
loop = asyncio.get_running_loop()

View File

@ -206,7 +206,7 @@ class RequestHandler(tornado.web.RequestHandler):
try:
return json.loads(self.request.body.decode())
except Exception as e:
raise APIError(400, "Malformed JSON: {}".format(str(e)))
raise APIError(400, f"Malformed JSON: {str(e)}")
@property
def filecontents(self):

View File

@ -20,7 +20,7 @@ def dump_system_info():
data = [
f"Mitmproxy: {mitmproxy_version}",
f"Python: {platform.python_version()}",
"OpenSSL: {}".format(SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode()),
f"OpenSSL: {SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode()}",
f"Platform: {platform.platform()}",
]
return "\n".join(data)

View File

@ -31,7 +31,7 @@ def pretty_size(size: int) -> str:
raise AssertionError
@functools.lru_cache()
@functools.lru_cache
def parse_size(s: typing.Optional[str]) -> typing.Optional[int]:
"""
Parse a size with an optional k/m/... suffix.
@ -65,7 +65,7 @@ def pretty_duration(secs: typing.Optional[float]) -> str:
if secs >= limit:
return formatter.format(secs)
# less than 1 sec
return "{:.0f}ms".format(secs * 1000)
return f"{secs * 1000:.0f}ms"
def format_timestamp(s):
@ -79,7 +79,7 @@ def format_timestamp_with_milli(s):
return d.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
@functools.lru_cache()
@functools.lru_cache
def format_address(address: typing.Optional[tuple]) -> str:
"""
This function accepts IPv4/IPv6 tuples and
@ -90,12 +90,12 @@ def format_address(address: typing.Optional[tuple]) -> str:
try:
host = ipaddress.ip_address(address[0])
if host.is_unspecified:
return "*:{}".format(address[1])
return f"*:{address[1]}"
if isinstance(host, ipaddress.IPv4Address):
return "{}:{}".format(str(host), address[1])
return f"{str(host)}:{address[1]}"
# If IPv6 is mapped to IPv4
elif host.ipv4_mapped:
return "{}:{}".format(str(host.ipv4_mapped), address[1])
return "[{}]:{}".format(str(host), address[1])
return f"{str(host.ipv4_mapped)}:{address[1]}"
return f"[{str(host)}]:{address[1]}"
except ValueError:
return "{}:{}".format(address[0], address[1])
return f"{address[0]}:{address[1]}"

View File

@ -22,7 +22,7 @@ def always_bytes(str_or_bytes: Union[None, str, bytes], *encode_args) -> Union[N
elif isinstance(str_or_bytes, str):
return str_or_bytes.encode(*encode_args)
else:
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
raise TypeError(f"Expected str or bytes, but got {type(str_or_bytes).__name__}.")
@overload
@ -45,7 +45,7 @@ def always_str(str_or_bytes: Union[None, str, bytes], *decode_args) -> Union[Non
elif isinstance(str_or_bytes, bytes):
return str_or_bytes.decode(*decode_args)
else:
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
raise TypeError(f"Expected str or bytes, but got {type(str_or_bytes).__name__}.")
# Translate control characters to "safe" characters. This implementation
@ -73,7 +73,7 @@ def escape_control_characters(text: str, keep_spacing=True) -> str:
keep_spacing: If True, tabs and newlines will not be replaced.
"""
if not isinstance(text, str):
raise ValueError("text type must be unicode but is {}".format(type(text).__name__))
raise ValueError(f"text type must be unicode but is {type(text).__name__}")
trans = _control_char_trans_newline if keep_spacing else _control_char_trans
return text.translate(trans)

View File

@ -388,7 +388,7 @@ def build_pyinstaller(be: BuildEnviron) -> None: # pragma: no cover
click.echo(subprocess.check_output([executable, "--version"]).decode())
archive.add(str(executable), str(executable.name))
click.echo("Packed {}.".format(be.archive_path.name))
click.echo(f"Packed {be.archive_path.name}.")
def build_wininstaller(be: BuildEnviron) -> None: # pragma: no cover

View File

@ -108,7 +108,7 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config):
terminalreporter.write(msg, **markup)
for name in sorted(coverage_values.keys()):
msg = 'Coverage for {}: {:.2f}%\n'.format(name, coverage_values[name][0])
msg = f'Coverage for {name}: {coverage_values[name][0]:.2f}%\n'
if coverage_values[name][0] < 100:
markup = {'red': True, 'bold': True}
for s, v in sorted(coverage_values[name][1]):

View File

@ -62,7 +62,7 @@ def get_random_object(random=random, depth=0):
else:
return -1 * random.randint(0, MAXINT)
n = random.randint(0, 100)
return bytes([random.randint(32, 126) for _ in range(n)])
return bytes(random.randint(32, 126) for _ in range(n))
class Test_Format(unittest.TestCase):

View File

@ -22,7 +22,7 @@ SAMPLE_SETTINGS = {
}
class FrameFactory(object):
class FrameFactory:
"""
A class containing lots of helper methods and state to build frames. This
allows test cases to easily build correct HTTP/2 frames to feed to

View File

@ -1312,9 +1312,9 @@ def test_invalid_content_length(tctx):
flow = Placeholder(HTTPFlow)
assert (
Playbook(http.HttpLayer(tctx, HTTPMode.regular))
>> DataReceived(tctx.client, ("GET http://example.com/ HTTP/1.1\r\n"
"Host: example.com\r\n"
"Content-Length: NaN\r\n\r\n").encode())
>> DataReceived(tctx.client, (b"GET http://example.com/ HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: NaN\r\n\r\n"))
<< SendData(tctx.client, err)
<< CloseConnection(tctx.client)
<< http.HttpRequestHeadersHook(flow)

View File

@ -27,7 +27,7 @@ class _Masked:
other[1] &= 0b0111_1111 # remove mask bit
assert other[1] < 126 # (we don't support extended payload length here)
mask = other[2:6]
payload = bytes([x ^ mask[i % 4] for i, x in enumerate(other[6:])])
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(other[6:]))
return self.unmasked == other[:2] + payload
@ -41,7 +41,7 @@ def masked_bytes(unmasked: bytes) -> bytes:
assert header[1] < 126 # assert that this is neither masked nor extended payload
header[1] |= 0b1000_0000
mask = secrets.token_bytes(4)
masked = bytes([x ^ mask[i % 4] for i, x in enumerate(unmasked[2:])])
masked = bytes(x ^ mask[i % 4] for i, x in enumerate(unmasked[2:]))
return bytes(header + mask + masked)

View File

@ -264,13 +264,13 @@ class TestCert:
def test_multi_valued_rdns(self, tdata):
subject = x509.Name([
x509.RelativeDistinguishedName([
x509.NameAttribute(NameOID.TITLE, u'Test'),
x509.NameAttribute(NameOID.COMMON_NAME, u'Multivalue'),
x509.NameAttribute(NameOID.SURNAME, u'RDNs'),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'TSLA'),
x509.NameAttribute(NameOID.TITLE, 'Test'),
x509.NameAttribute(NameOID.COMMON_NAME, 'Multivalue'),
x509.NameAttribute(NameOID.SURNAME, 'RDNs'),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, 'TSLA'),
]),
x509.RelativeDistinguishedName([
x509.NameAttribute(NameOID.ORGANIZATION_NAME, u'PyCA')
x509.NameAttribute(NameOID.ORGANIZATION_NAME, 'PyCA')
]),
])
expected = [('2.5.4.12', 'Test'), ('CN', 'Multivalue'), ('2.5.4.4', 'RDNs'), ('O', 'TSLA'), ('O', 'PyCA')]