Merge pull request #3691 from mhils/sans-io-adjustments

Update mypy, sans-io adjustments
This commit is contained in:
Maximilian Hils 2019-11-12 05:04:05 +01:00 committed by GitHub
commit dac0bfe786
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 174 additions and 166 deletions

View File

@ -86,7 +86,7 @@ def get_cookies(flow: http.HTTPFlow) -> Cookies:
return {name: value for name, value in flow.request.cookies.fields}
def find_unclaimed_URLs(body: str, requestUrl: bytes) -> None:
def find_unclaimed_URLs(body, requestUrl):
""" Look for unclaimed URLs in script tags and log them if found"""
def getValue(attrs: List[Tuple[str, str]], attrName: str) -> Optional[str]:
for name, value in attrs:
@ -111,7 +111,7 @@ def find_unclaimed_URLs(body: str, requestUrl: bytes) -> None:
try:
socket.gethostbyname(domain)
except socket.gaierror:
ctx.log.error("XSS found in %s due to unclaimed URL \"%s\"." % (requestUrl, url))
ctx.log.error(f"XSS found in {requestUrl} due to unclaimed URL \"{url}\".")
def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData:

View File

@ -126,20 +126,18 @@ class Cut:
format is UTF-8 encoded CSV. If there is exactly one row and one
column, the data is written to file as-is, with raw bytes preserved.
"""
v: typing.Union[str, bytes]
fp = io.StringIO(newline="")
if len(cuts) == 1 and len(flows) == 1:
v = extract(cuts[0], flows[0])
if isinstance(v, bytes):
fp.write(strutils.always_str(v))
else:
fp.write(v)
fp.write(strutils.always_str(v)) # type: ignore
ctx.log.alert("Clipped single cut.")
else:
writer = csv.writer(fp)
for f in flows:
vals = [extract(c, f) for c in cuts]
writer.writerow(
[strutils.always_str(v) or "" for v in vals] # type: ignore
[strutils.always_str(v) for v in vals]
)
ctx.log.alert("Clipped %s cuts as CSV." % len(cuts))
try:

View File

@ -14,7 +14,7 @@ class EventStore:
self.sig_refresh = blinker.Signal()
@property
def size(self) -> int:
def size(self) -> typing.Optional[int]:
return self.data.maxlen
def log(self, entry: LogEntry) -> None:

View File

@ -16,7 +16,7 @@ from mitmproxy import ctx
import mitmproxy.types as mtypes
def load_script(path: str) -> types.ModuleType:
def load_script(path: str) -> typing.Optional[types.ModuleType]:
fullname = "__mitmproxy_script__.{}".format(
os.path.splitext(os.path.basename(path))[0]
)

View File

@ -215,8 +215,8 @@ class Session:
def __init__(self):
self.db_store: SessionDB = None
self._hot_store: collections.OrderedDict = collections.OrderedDict()
self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str]]] = {}
self._view: typing.List[typing.Tuple[typing.Union[int, float, str], str]] = []
self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str, None]]] = {}
self._view: typing.List[typing.Tuple[typing.Union[int, float, str, None], str]] = []
self.order: str = orders[0]
self.filter = matchall
self._flush_period: float = self._FP_DEFAULT

View File

@ -53,6 +53,7 @@ class StickyCookie:
self.flt = None
def response(self, flow: http.HTTPFlow):
assert flow.response
if self.flt:
for name, (value, attrs) in flow.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with

View File

@ -584,7 +584,7 @@ class Focus:
"""
def __init__(self, v: View) -> None:
self.view = v
self._flow: mitmproxy.flow.Flow = None
self._flow: typing.Optional[mitmproxy.flow.Flow] = None
self.sig_change = blinker.Signal()
if len(self.view):
self.flow = self.view[0]

View File

@ -315,7 +315,12 @@ class CertStore:
ret.append(b"*." + b".".join(parts[i:]))
return ret
def get_cert(self, commonname: typing.Optional[bytes], sans: typing.List[bytes], organization: typing.Optional[bytes] = None):
def get_cert(
self,
commonname: typing.Optional[bytes],
sans: typing.List[bytes],
organization: typing.Optional[bytes] = None
) -> typing.Tuple["Cert", OpenSSL.SSL.PKey, str]:
"""
Returns an (cert, privkey, cert_chain) tuple.

View File

@ -44,6 +44,8 @@ def typename(t: type) -> str:
class Command:
returntype: typing.Optional[typing.Type]
def __init__(self, manager, path, func) -> None:
self.path = path
self.manager = manager
@ -177,7 +179,7 @@ class CommandManager(mitmproxy.types._CommandBase):
parse: typing.List[ParseResult] = []
params: typing.List[type] = []
typ: typing.Type = None
typ: typing.Type
for i in range(len(parts)):
if i == 0:
typ = mitmproxy.types.Cmd

View File

@ -135,7 +135,9 @@ def get_content_view(viewmode: View, data: bytes, **metadata):
# Third-party viewers can fail in unexpected ways...
except Exception:
desc = "Couldn't parse: falling back to Raw"
_, content = get("Raw")(data, **metadata)
raw = get("Raw")
assert raw
content = raw(data, **metadata)[1]
error = "{} Content viewer failed: \n{}".format(
getattr(viewmode, "name"),
traceback.format_exc()

View File

@ -9,8 +9,8 @@ TViewResult = typing.Tuple[str, typing.Iterator[TViewLine]]
class View:
name: str = None
content_types: typing.List[str] = []
name: typing.ClassVar[str]
content_types: typing.ClassVar[typing.List[str]] = []
def __call__(self, data: bytes, **metadata) -> TViewResult:
"""

View File

@ -1,7 +1,7 @@
import io
import re
import textwrap
from typing import Iterable
from typing import Iterable, Optional
from mitmproxy.contentviews import base
from mitmproxy.utils import sliding_window
@ -124,14 +124,14 @@ def indent_text(data: str, prefix: str) -> str:
return textwrap.indent(dedented, prefix[:32])
def is_inline_text(a: Token, b: Token, c: Token) -> bool:
def is_inline_text(a: Optional[Token], b: Optional[Token], c: Optional[Token]) -> bool:
if isinstance(a, Tag) and isinstance(b, Text) and isinstance(c, Tag):
if a.is_opening and "\n" not in b.data and c.is_closing and a.tag == c.tag:
return True
return False
def is_inline(prev2: Token, prev1: Token, t: Token, next1: Token, next2: Token) -> bool:
def is_inline(prev2: Optional[Token], prev1: Optional[Token], t: Optional[Token], next1: Optional[Token], next2: Optional[Token]) -> bool:
if isinstance(t, Text):
return is_inline_text(prev1, t, next1)
elif isinstance(t, Tag):

View File

@ -1,7 +1,7 @@
import mitmproxy.master # noqa
import mitmproxy.log # noqa
import mitmproxy.options # noqa
import mitmproxy.log
import mitmproxy.master
import mitmproxy.options
master = None # type: mitmproxy.master.Master
log: mitmproxy.log.Log = None
options: mitmproxy.options.Options = None
log: "mitmproxy.log.Log"
master: "mitmproxy.master.Master"
options: "mitmproxy.options.Options"

View File

@ -44,7 +44,7 @@ from mitmproxy import flow
from mitmproxy.utils import strutils
import pyparsing as pp
from typing import Callable, Sequence, Type # noqa
from typing import Callable, Sequence, Type, Optional, ClassVar
def only(*types):
@ -69,8 +69,8 @@ class _Token:
class _Action(_Token):
code: str = None
help: str = None
code: ClassVar[str]
help: ClassVar[str]
@classmethod
def make(klass, s, loc, toks):
@ -539,7 +539,7 @@ bnf = _make()
TFilter = Callable[[flow.Flow], bool]
def parse(s: str) -> TFilter:
def parse(s: str) -> Optional[TFilter]:
try:
flt = bnf.parseString(s, parseAll=True)[0]
flt.pattern = s

View File

@ -1,15 +1,13 @@
import html
from typing import Optional
from mitmproxy import connections
from mitmproxy import flow
from mitmproxy.net import http
from mitmproxy import version
from mitmproxy import connections # noqa
from mitmproxy.net import http
class HTTPRequest(http.Request):
"""
A mitmproxy HTTP request.
"""
@ -85,10 +83,10 @@ class HTTPRequest(http.Request):
class HTTPResponse(http.Response):
"""
A mitmproxy HTTP response.
"""
# This is a very thin wrapper on top of :py:class:`mitmproxy.net.http.Response` and
# may be removed in the future.
@ -136,34 +134,28 @@ class HTTPResponse(http.Response):
class HTTPFlow(flow.Flow):
"""
An HTTPFlow is a collection of objects representing a single HTTP
transaction.
"""
request: HTTPRequest
response: Optional[HTTPResponse] = None
error: Optional[flow.Error] = None
"""
Note that it's possible for a Flow to have both a response and an error
object. This might happen, for instance, when a response was received
from the server, but there was an error sending it back to the client.
"""
server_conn: connections.ServerConnection
client_conn: connections.ClientConnection
intercepted: bool = False
""" Is this flow currently being intercepted? """
mode: str
""" What mode was the proxy layer in when receiving this request? """
def __init__(self, client_conn, server_conn, live=None, mode="regular"):
super().__init__("http", client_conn, server_conn, live)
self.request: HTTPRequest = None
""" :py:class:`HTTPRequest` object """
self.response: HTTPResponse = None
""" :py:class:`HTTPResponse` object """
self.error: flow.Error = None
""" :py:class:`Error` object
Note that it's possible for a Flow to have both a response and an error
object. This might happen, for instance, when a response was received
from the server, but there was an error sending it back to the client.
"""
self.server_conn: connections.ServerConnection = server_conn
""" :py:class:`ServerConnection` object """
self.client_conn: connections.ClientConnection = client_conn
""":py:class:`ClientConnection` object """
self.intercepted: bool = False
""" Is this flow currently being intercepted? """
self.mode = mode
""" What mode was the proxy layer in when receiving this request? """
_stateobject_attributes = flow.Flow._stateobject_attributes.copy()
# mypy doesn't support update with kwargs
@ -205,8 +197,8 @@ class HTTPFlow(flow.Flow):
def make_error_response(
status_code: int,
message: str="",
headers: Optional[http.Headers]=None,
message: str = "",
headers: Optional[http.Headers] = None,
) -> HTTPResponse:
reason = http.status_codes.RESPONSES.get(status_code, "Unknown")
body = """

View File

@ -192,22 +192,22 @@ def parse(data_type: int, data: bytes) -> TSerializable:
try:
return int(data)
except ValueError:
raise ValueError("not a tnetstring: invalid integer literal: {}".format(data))
raise ValueError(f"not a tnetstring: invalid integer literal: {data!r}")
if data_type == ord(b'^'):
try:
return float(data)
except ValueError:
raise ValueError("not a tnetstring: invalid float literal: {}".format(data))
raise ValueError(f"not a tnetstring: invalid float literal: {data!r}")
if data_type == ord(b'!'):
if data == b'true':
return True
elif data == b'false':
return False
else:
raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data))
raise ValueError(f"not a tnetstring: invalid boolean literal: {data!r}")
if data_type == ord(b'~'):
if data:
raise ValueError("not a tnetstring: invalid null literal")
raise ValueError(f"not a tnetstring: invalid null literal: {data!r}")
return None
if data_type == ord(b']'):
l = []
@ -236,7 +236,7 @@ def pop(data: bytes) -> typing.Tuple[TSerializable, bytes]:
blength, data = data.split(b':', 1)
length = int(blength)
except ValueError:
raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data))
raise ValueError(f"not a tnetstring: missing or invalid length prefix: {data!r}")
try:
data, data_type, remain = data[:length], data[length], data[length + 1:]
except IndexError:

View File

@ -82,7 +82,7 @@ class Message(serializable.Serializable):
def raw_content(self, content):
self.data.content = content
def get_content(self, strict: bool=True) -> bytes:
def get_content(self, strict: bool=True) -> Optional[bytes]:
"""
The uncompressed HTTP message body as bytes.
@ -195,10 +195,9 @@ class Message(serializable.Serializable):
See also: :py:attr:`content`, :py:class:`raw_content`
"""
if self.raw_content is None:
return None
content = self.get_content(strict)
if content is None:
return None
enc = self._guess_encoding(content)
try:
return encoding.decode(content, enc)

View File

@ -295,6 +295,17 @@ def create_client_context(
return context
def accept_all(
conn_: SSL.Connection,
x509: SSL.X509,
errno: int,
err_depth: int,
is_cert_verified: bool,
) -> bool:
# Return true to prevent cert verification error
return True
def create_server_context(
cert: typing.Union[certs.Cert, str],
key: SSL.PKey,
@ -324,16 +335,6 @@ def create_server_context(
until then we're conservative.
"""
def accept_all(
conn_: SSL.Connection,
x509: SSL.X509,
errno: int,
err_depth: int,
is_cert_verified: bool,
) -> bool:
# Return true to prevent cert verification error
return True
if request_client_cert:
verify = SSL.VERIFY_PEER
else:
@ -425,7 +426,7 @@ class ClientHello:
return self._client_hello.cipher_suites.cipher_suites
@property
def sni(self):
def sni(self) -> typing.Optional[bytes]:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
is_valid_sni_extension = (
@ -435,7 +436,7 @@ class ClientHello:
check.is_valid_host(extension.body.server_names[0].host_name)
)
if is_valid_sni_extension:
return extension.body.server_names[0].host_name.decode("idna")
return extension.body.server_names[0].host_name
return None
@property
@ -473,10 +474,8 @@ class ClientHello:
return cls(raw_client_hello)
except EOFError as e:
raise exceptions.TlsProtocolException(
'Cannot parse Client Hello: %s, Raw Client Hello: %s' %
(repr(e), binascii.hexlify(raw_client_hello))
f"Cannot parse Client Hello: {e!r}, Raw Client Hello: {binascii.hexlify(raw_client_hello)!r}"
)
def __repr__(self):
return "ClientHello(sni: %s, alpn_protocols: %s, cipher_suites: %s)" % \
(self.sni, self.alpn_protocols, self.cipher_suites)
return f"ClientHello(sni: {self.sni}, alpn_protocols: {self.alpn_protocols})"

View File

@ -551,7 +551,9 @@ def serialize(opts: OptManager, text: str, defaults: bool = False) -> str:
for k in list(data.keys()):
if k not in opts._options:
del data[k]
return ruamel.yaml.round_trip_dump(data)
ret = ruamel.yaml.round_trip_dump(data)
assert ret
return ret
def save(opts: OptManager, path: str, defaults: bool =False) -> None:

View File

@ -1,7 +1,7 @@
import re
import socket
import sys
from typing import Tuple
from typing import Callable, Optional, Tuple
def init_transparent_mode() -> None:
@ -10,30 +10,34 @@ def init_transparent_mode() -> None:
"""
def original_addr(csock: socket.socket) -> Tuple[str, int]:
"""
Get the original destination for the given socket.
This function will be None if transparent mode is not supported.
"""
original_addr: Optional[Callable[[socket.socket], Tuple[str, int]]]
"""
Get the original destination for the given socket.
This function will be None if transparent mode is not supported.
"""
if re.match(r"linux(?:2)?", sys.platform):
from . import linux
original_addr = linux.original_addr # noqa
original_addr = linux.original_addr
elif sys.platform == "darwin" or sys.platform.startswith("freebsd"):
from . import osx
original_addr = osx.original_addr # noqa
original_addr = osx.original_addr
elif sys.platform.startswith("openbsd"):
from . import openbsd
original_addr = openbsd.original_addr # noqa
original_addr = openbsd.original_addr
elif sys.platform == "win32":
from . import windows
resolver = windows.Resolver()
init_transparent_mode = resolver.setup # noqa
original_addr = resolver.original_addr # noqa
original_addr = resolver.original_addr
else:
original_addr = None # noqa
original_addr = None
__all__ = [
"original_addr",
"init_transparent_mode"
]

View File

@ -34,9 +34,9 @@ class ProxyConfig:
def __init__(self, options: moptions.Options) -> None:
self.options = options
self.check_filter: HostMatcher = None
self.check_tcp: HostMatcher = None
self.certstore: certs.CertStore = None
self.certstore: certs.CertStore
self.check_filter: typing.Optional[HostMatcher] = None
self.check_tcp: typing.Optional[HostMatcher] = None
self.upstream_server: typing.Optional[server_spec.ServerSpec] = None
self.configure(options, set(options.keys()))
options.changed.connect(self.configure)

View File

@ -1,7 +1,7 @@
import threading
import time
import functools
from typing import Dict, Callable, Any, List # noqa
from typing import Dict, Callable, Any, List, Optional # noqa
import h2.exceptions
from h2 import connection
@ -382,15 +382,15 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
)
self.h2_connection = h2_connection
self.zombie: float = None
self.zombie: Optional[float] = None
self.client_stream_id: int = stream_id
self.server_stream_id: int = None
self.server_stream_id: Optional[int] = None
self.request_headers = request_headers
self.response_headers: mitmproxy.net.http.Headers = None
self.response_headers: Optional[mitmproxy.net.http.Headers] = None
self.pushed = False
self.timestamp_start: float = None
self.timestamp_end: float = None
self.timestamp_start: Optional[float] = None
self.timestamp_end: Optional[float] = None
self.request_arrived = threading.Event()
self.request_data_queue: queue.Queue[bytes] = queue.Queue()
@ -404,9 +404,9 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.no_body = False
self.priority_exclusive: bool = None
self.priority_depends_on: int = None
self.priority_weight: int = None
self.priority_exclusive: bool
self.priority_depends_on: Optional[int] = None
self.priority_weight: Optional[int] = None
self.handled_priority_event: Any = None
def kill(self):

View File

@ -196,17 +196,14 @@ CIPHER_ID_NAME_MAP = {
}
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
# https://mozilla.github.io/server-side-tls/ssl-config-generator/?server=apache-2.2.15&openssl=1.0.2&hsts=yes&profile=old
# https://ssl-config.mozilla.org/#config=old
DEFAULT_CLIENT_CIPHERS = (
"ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES256-GCM-SHA384:"
"ECDHE-ECDSA-AES256-GCM-SHA384:DHE-RSA-AES128-GCM-SHA256:DHE-DSS-AES128-GCM-SHA256:kEDH+AESGCM:"
"ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES128-SHA:"
"ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA:ECDHE-ECDSA-AES256-SHA:"
"DHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA:DHE-DSS-AES128-SHA256:DHE-RSA-AES256-SHA256:DHE-DSS-AES256-SHA:"
"DHE-RSA-AES256-SHA:ECDHE-RSA-DES-CBC3-SHA:ECDHE-ECDSA-DES-CBC3-SHA:AES128-GCM-SHA256:AES256-GCM-SHA384:"
"AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:AES:DES-CBC3-SHA:"
"HIGH:!aNULL:!eNULL:!EXPORT:!DES:!RC4:!MD5:!PSK:!aECDH:"
"!EDH-DSS-DES-CBC3-SHA:!EDH-RSA-DES-CBC3-SHA:!KRB5-DES-CBC3-SHA"
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
"ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:"
"DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:"
"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:"
"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA256:AES128-GCM-SHA256:"
"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA"
)
@ -323,14 +320,18 @@ class TlsLayer(base.Layer):
return self._server_tls
@property
def server_sni(self):
def server_sni(self) -> Optional[str]:
"""
The Server Name Indication we want to send with the next server TLS handshake.
"""
if self._custom_server_sni is False:
return None
elif self._custom_server_sni:
return self._custom_server_sni
elif self._client_hello and self._client_hello.sni:
return self._client_hello.sni.decode("idna")
else:
return self._custom_server_sni or self._client_hello and self._client_hello.sni
return None
@property
def alpn_for_client_connection(self):
@ -391,11 +392,12 @@ class TlsLayer(base.Layer):
# raises ann error.
self.client_conn.rfile.peek(1)
except exceptions.TlsException as e:
sni_str = self._client_hello.sni and self._client_hello.sni.decode("idna")
raise exceptions.ClientHandshakeException(
"Cannot establish TLS with client (sni: {sni}): {e}".format(
sni=self._client_hello.sni, e=repr(e)
sni=sni_str, e=repr(e)
),
self._client_hello.sni or repr(self.server_conn.address)
sni_str or repr(self.server_conn.address)
)
def _establish_tls_with_server(self):
@ -493,7 +495,7 @@ class TlsLayer(base.Layer):
organization = upstream_cert.organization
# Also add SNI values.
if self._client_hello.sni:
sans.add(self._client_hello.sni.encode("idna"))
sans.add(self._client_hello.sni)
if self._custom_server_sni:
sans.add(self._custom_server_sni.encode("idna"))

View File

@ -57,7 +57,8 @@ class RootContext:
except exceptions.TlsProtocolException as e:
self.log("Cannot parse Client Hello: %s" % repr(e), "error")
else:
is_filtered = self.config.check_filter((client_hello.sni, 443))
sni_str = client_hello.sni and client_hello.sni.decode("idna")
is_filtered = self.config.check_filter((sni_str, 443))
if is_filtered:
return protocol.RawTCPLayer(top_layer, ignore=True)

View File

@ -35,6 +35,7 @@ class DummyServer:
class ProxyServer(tcp.TCPServer):
allow_reuse_address = True
bound = True
channel: controller.Channel
def __init__(self, config: config.ProxyConfig) -> None:
"""
@ -53,7 +54,6 @@ class ProxyServer(tcp.TCPServer):
raise exceptions.ServerException(
'Error starting proxy server: ' + repr(e)
) from e
self.channel: controller.Channel = None
def set_channel(self, channel):
self.channel = channel

View File

@ -1,7 +1,5 @@
import typing
from typing import Any # noqa
from typing import MutableMapping # noqa
import json
import typing
from mitmproxy.coretypes import serializable
from mitmproxy.utils import typecheck
@ -15,7 +13,7 @@ class StateObject(serializable.Serializable):
or StateObject instances themselves.
"""
_stateobject_attributes: MutableMapping[str, Any] = None
_stateobject_attributes: typing.ClassVar[typing.MutableMapping[str, typing.Any]]
"""
An attribute-name -> class-or-type dict containing all attributes that
should be serialized. If the attribute is a class, it must implement the
@ -42,7 +40,7 @@ class StateObject(serializable.Serializable):
if val is None:
setattr(self, attr, val)
else:
curr = getattr(self, attr)
curr = getattr(self, attr, None)
if hasattr(curr, "set_state"):
curr.set_state(val)
else:

View File

@ -6,17 +6,16 @@ Feel free to import and use whatever new package you deem necessary.
import os
import sys
import asyncio
import argparse # noqa
import signal # noqa
import typing # noqa
import argparse
import signal
import typing
from mitmproxy.tools import cmdline # noqa
from mitmproxy import exceptions, master # noqa
from mitmproxy import options # noqa
from mitmproxy import optmanager # noqa
from mitmproxy import proxy # noqa
from mitmproxy import log # noqa
from mitmproxy.utils import debug, arg_check # noqa
from mitmproxy.tools import cmdline
from mitmproxy import exceptions, master
from mitmproxy import options
from mitmproxy import optmanager
from mitmproxy import proxy
from mitmproxy.utils import debug, arg_check
OPTIONS_FILE_NAME = "config.yaml"

View File

@ -55,7 +55,7 @@ class CommandBuffer:
self.text = self.flatten(start)
# Cursor is always within the range [0:len(buffer)].
self._cursor = len(self.text)
self.completion: CompletionState = None
self.completion: typing.Optional[CompletionState] = None
@property
def cursor(self) -> int:

View File

@ -38,7 +38,7 @@ KEY_MAX = 30
def format_keyvals(
entries: typing.List[typing.Tuple[str, typing.Union[None, str, urwid.Widget]]],
entries: typing.Iterable[typing.Tuple[str, typing.Union[None, str, urwid.Widget]]],
key_format: str = "key",
value_format: str = "text",
indent: int = 0

View File

@ -254,7 +254,7 @@ FIRST_WIDTH_MAX = 40
class BaseGridEditor(urwid.WidgetWrap):
title = ""
title: str = ""
keyctx = "grideditor"
def __init__(
@ -402,8 +402,8 @@ class BaseGridEditor(urwid.WidgetWrap):
class GridEditor(BaseGridEditor):
title: str = None
columns: typing.Sequence[Column] = None
title = ""
columns: typing.Sequence[Column] = ()
keyctx = "grideditor"
def __init__(

View File

@ -107,7 +107,7 @@ class CookieAttributeEditor(base.FocusEditor):
col_text.Column("Name"),
col_text.Column("Value"),
]
grideditor: base.BaseGridEditor = None
grideditor: base.BaseGridEditor
def data_in(self, data):
return [(k, v or "") for k, v in data]
@ -169,7 +169,7 @@ class SetCookieEditor(base.FocusEditor):
class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget):
title: str = None
title = ""
columns = [
col_text.Column("")
]
@ -189,7 +189,7 @@ class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget):
class DataViewer(base.GridEditor, layoutwidget.LayoutWidget):
title: str = None
title = ""
def __init__(
self,

View File

@ -42,7 +42,7 @@ class Palette:
'commander_command', 'commander_invalid', 'commander_hint'
]
_fields.extend(['gradient_%02d' % i for i in range(100)])
high: typing.Mapping[str, typing.Sequence[str]] = None
high: typing.Optional[typing.Mapping[str, typing.Sequence[str]]] = None
def palette(self, transparent):
l = []

View File

@ -5,6 +5,7 @@ import logging
import os.path
import re
from io import BytesIO
from typing import ClassVar, Optional
import tornado.escape
import tornado.web
@ -50,6 +51,8 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
f["error"] = flow.error.get_state()
if isinstance(flow, http.HTTPFlow):
content_length: Optional[int]
content_hash: Optional[str]
if flow.request:
if flow.request.raw_content:
content_length = len(flow.request.raw_content)
@ -193,7 +196,7 @@ class FilterHelp(RequestHandler):
class WebSocketEventBroadcaster(tornado.websocket.WebSocketHandler):
# raise an error if inherited class doesn't specify its own instance.
connections: set = None
connections: ClassVar[set]
def open(self):
self.connections.add(self)
@ -213,7 +216,7 @@ class WebSocketEventBroadcaster(tornado.websocket.WebSocketHandler):
class ClientConnection(WebSocketEventBroadcaster):
connections: set = set()
connections: ClassVar[set] = set()
class Flows(RequestHandler):

View File

@ -423,7 +423,7 @@ class TypeManager:
for t in types:
self.typemap[t.typ] = t()
def get(self, t: type, default=None) -> _BaseType:
def get(self, t: typing.Optional[typing.Type], default=None) -> _BaseType:
if type(t) in self.typemap:
return self.typemap[type(t)]
return self.typemap.get(t, default)

View File

@ -1,5 +1,5 @@
import itertools
from typing import TypeVar, Iterable, Iterator, Tuple, Optional
from typing import TypeVar, Iterable, Iterator, Tuple, Optional, List
T = TypeVar('T')
@ -18,7 +18,7 @@ def window(iterator: Iterable[T], behind: int = 0, ahead: int = 0) -> Iterator[T
2 3 None
"""
# TODO: move into utils
iters = list(itertools.tee(iterator, behind + 1 + ahead))
iters: List[Iterator[Optional[T]]] = list(itertools.tee(iterator, behind + 1 + ahead))
for i in range(behind):
iters[i] = itertools.chain((behind - i) * [None], iters[i])
for i in range(ahead):

View File

@ -1,10 +1,10 @@
import codecs
import io
import re
import codecs
from typing import AnyStr, Optional, cast, Iterable
from typing import Iterable, Optional, Union, cast
def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes]:
def always_bytes(str_or_bytes: Union[str, bytes, None], *encode_args) -> Optional[bytes]:
if isinstance(str_or_bytes, bytes) or str_or_bytes is None:
return cast(Optional[bytes], str_or_bytes)
elif isinstance(str_or_bytes, str):
@ -13,13 +13,15 @@ def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
def always_str(str_or_bytes: Optional[AnyStr], *decode_args) -> Optional[str]:
def always_str(str_or_bytes: Union[str, bytes, None], *decode_args) -> Optional[str]:
"""
Returns,
str_or_bytes unmodified, if
"""
if isinstance(str_or_bytes, str) or str_or_bytes is None:
return cast(Optional[str], str_or_bytes)
if str_or_bytes is None:
return None
if isinstance(str_or_bytes, str):
return cast(str, str_or_bytes)
elif isinstance(str_or_bytes, bytes):
return str_or_bytes.decode(*decode_args)
else:
@ -39,7 +41,6 @@ _control_char_trans_newline = _control_char_trans.copy()
for x in ("\r", "\n", "\t"):
del _control_char_trans_newline[ord(x)]
_control_char_trans = str.maketrans(_control_char_trans)
_control_char_trans_newline = str.maketrans(_control_char_trans_newline)

View File

@ -25,9 +25,9 @@ def get_dev_version() -> str:
stderr=subprocess.STDOUT,
cwd=here,
)
last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2)
last_tag, tag_dist_str, commit = git_describe.decode().strip().rsplit("-", 2)
commit = commit.lstrip("g")[:7]
tag_dist = int(tag_dist)
tag_dist = int(tag_dist_str)
except Exception:
pass
else:

View File

@ -25,7 +25,7 @@ class Daemon:
def __enter__(self):
return self
def __exit__(self, type, value, traceback) -> bool:
def __exit__(self, type, value, traceback):
self.logfp.truncate(0)
self.shutdown()
return False

View File

@ -72,7 +72,7 @@ setup(
"kaitaistruct>=0.7,<0.9",
"ldap3>=2.6.1,<2.7",
"passlib>=1.6.5, <1.8",
"protobuf>=3.6.0, <3.10",
"protobuf>=3.6.0, <3.11",
"pyasn1>=0.3.1,<0.5",
"pyOpenSSL>=19.0.0,<20",
"pyparsing>=2.4.2,<2.5",
@ -93,7 +93,7 @@ setup(
"asynctest>=0.12.0",
"flake8>=3.7.8,<3.8",
"Flask>=1.0,<1.2",
"mypy>=0.590,<0.591",
"mypy>=0.740,<0.741",
"parver>=0.1,<2.0",
"pytest-asyncio>=0.10.0,<0.11",
"pytest-cov>=2.7.1,<3",

View File

@ -116,7 +116,7 @@ class TestClientHello:
)
c = tls.ClientHello(data)
assert repr(c)
assert c.sni == 'example.com'
assert c.sni == b'example.com'
assert c.cipher_suites == [
49195, 49199, 49196, 49200, 52393, 52392, 52244, 52243, 49161,
49171, 49162, 49172, 156, 157, 47, 53, 10