update mypy

This commit is contained in:
Maximilian Hils 2019-11-12 02:59:01 +01:00
parent f97996126f
commit bdc15cbe0c
40 changed files with 153 additions and 147 deletions

View File

@ -86,7 +86,7 @@ def get_cookies(flow: http.HTTPFlow) -> Cookies:
return {name: value for name, value in flow.request.cookies.fields} return {name: value for name, value in flow.request.cookies.fields}
def find_unclaimed_URLs(body: str, requestUrl: bytes) -> None: def find_unclaimed_URLs(body, requestUrl):
""" Look for unclaimed URLs in script tags and log them if found""" """ Look for unclaimed URLs in script tags and log them if found"""
def getValue(attrs: List[Tuple[str, str]], attrName: str) -> Optional[str]: def getValue(attrs: List[Tuple[str, str]], attrName: str) -> Optional[str]:
for name, value in attrs: for name, value in attrs:
@ -111,7 +111,7 @@ def find_unclaimed_URLs(body: str, requestUrl: bytes) -> None:
try: try:
socket.gethostbyname(domain) socket.gethostbyname(domain)
except socket.gaierror: except socket.gaierror:
ctx.log.error("XSS found in %s due to unclaimed URL \"%s\"." % (requestUrl, url)) ctx.log.error(f"XSS found in {requestUrl} due to unclaimed URL \"{url}\".")
def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData:

View File

@ -126,20 +126,18 @@ class Cut:
format is UTF-8 encoded CSV. If there is exactly one row and one format is UTF-8 encoded CSV. If there is exactly one row and one
column, the data is written to file as-is, with raw bytes preserved. column, the data is written to file as-is, with raw bytes preserved.
""" """
v: typing.Union[str, bytes]
fp = io.StringIO(newline="") fp = io.StringIO(newline="")
if len(cuts) == 1 and len(flows) == 1: if len(cuts) == 1 and len(flows) == 1:
v = extract(cuts[0], flows[0]) v = extract(cuts[0], flows[0])
if isinstance(v, bytes): fp.write(strutils.always_str(v)) # type: ignore
fp.write(strutils.always_str(v))
else:
fp.write(v)
ctx.log.alert("Clipped single cut.") ctx.log.alert("Clipped single cut.")
else: else:
writer = csv.writer(fp) writer = csv.writer(fp)
for f in flows: for f in flows:
vals = [extract(c, f) for c in cuts] vals = [extract(c, f) for c in cuts]
writer.writerow( writer.writerow(
[strutils.always_str(v) or "" for v in vals] # type: ignore [strutils.always_str(v) for v in vals]
) )
ctx.log.alert("Clipped %s cuts as CSV." % len(cuts)) ctx.log.alert("Clipped %s cuts as CSV." % len(cuts))
try: try:

View File

@ -14,7 +14,7 @@ class EventStore:
self.sig_refresh = blinker.Signal() self.sig_refresh = blinker.Signal()
@property @property
def size(self) -> int: def size(self) -> typing.Optional[int]:
return self.data.maxlen return self.data.maxlen
def log(self, entry: LogEntry) -> None: def log(self, entry: LogEntry) -> None:

View File

@ -16,7 +16,7 @@ from mitmproxy import ctx
import mitmproxy.types as mtypes import mitmproxy.types as mtypes
def load_script(path: str) -> types.ModuleType: def load_script(path: str) -> typing.Optional[types.ModuleType]:
fullname = "__mitmproxy_script__.{}".format( fullname = "__mitmproxy_script__.{}".format(
os.path.splitext(os.path.basename(path))[0] os.path.splitext(os.path.basename(path))[0]
) )

View File

@ -215,8 +215,8 @@ class Session:
def __init__(self): def __init__(self):
self.db_store: SessionDB = None self.db_store: SessionDB = None
self._hot_store: collections.OrderedDict = collections.OrderedDict() self._hot_store: collections.OrderedDict = collections.OrderedDict()
self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str]]] = {} self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str, None]]] = {}
self._view: typing.List[typing.Tuple[typing.Union[int, float, str], str]] = [] self._view: typing.List[typing.Tuple[typing.Union[int, float, str, None], str]] = []
self.order: str = orders[0] self.order: str = orders[0]
self.filter = matchall self.filter = matchall
self._flush_period: float = self._FP_DEFAULT self._flush_period: float = self._FP_DEFAULT

View File

@ -53,6 +53,7 @@ class StickyCookie:
self.flt = None self.flt = None
def response(self, flow: http.HTTPFlow): def response(self, flow: http.HTTPFlow):
assert flow.response
if self.flt: if self.flt:
for name, (value, attrs) in flow.response.cookies.items(multi=True): for name, (value, attrs) in flow.response.cookies.items(multi=True):
# FIXME: We now know that Cookie.py screws up some cookies with # FIXME: We now know that Cookie.py screws up some cookies with

View File

@ -584,7 +584,7 @@ class Focus:
""" """
def __init__(self, v: View) -> None: def __init__(self, v: View) -> None:
self.view = v self.view = v
self._flow: mitmproxy.flow.Flow = None self._flow: typing.Optional[mitmproxy.flow.Flow] = None
self.sig_change = blinker.Signal() self.sig_change = blinker.Signal()
if len(self.view): if len(self.view):
self.flow = self.view[0] self.flow = self.view[0]

View File

@ -123,7 +123,7 @@ def dummy_cert(privkey, cacert, commonname, sans, organization):
) )
]) ])
cert.set_pubkey(cacert.get_pubkey()) cert.set_pubkey(cacert.get_pubkey())
cert.sign(privkey, b"sha256") cert.sign(privkey, "sha256")
return Cert(cert) return Cert(cert)

View File

@ -44,6 +44,8 @@ def typename(t: type) -> str:
class Command: class Command:
returntype: typing.Optional[typing.Type]
def __init__(self, manager, path, func) -> None: def __init__(self, manager, path, func) -> None:
self.path = path self.path = path
self.manager = manager self.manager = manager
@ -177,7 +179,7 @@ class CommandManager(mitmproxy.types._CommandBase):
parse: typing.List[ParseResult] = [] parse: typing.List[ParseResult] = []
params: typing.List[type] = [] params: typing.List[type] = []
typ: typing.Type = None typ: typing.Type
for i in range(len(parts)): for i in range(len(parts)):
if i == 0: if i == 0:
typ = mitmproxy.types.Cmd typ = mitmproxy.types.Cmd

View File

@ -135,7 +135,9 @@ def get_content_view(viewmode: View, data: bytes, **metadata):
# Third-party viewers can fail in unexpected ways... # Third-party viewers can fail in unexpected ways...
except Exception: except Exception:
desc = "Couldn't parse: falling back to Raw" desc = "Couldn't parse: falling back to Raw"
_, content = get("Raw")(data, **metadata) raw = get("Raw")
assert raw
content = raw(data, **metadata)[1]
error = "{} Content viewer failed: \n{}".format( error = "{} Content viewer failed: \n{}".format(
getattr(viewmode, "name"), getattr(viewmode, "name"),
traceback.format_exc() traceback.format_exc()

View File

@ -9,8 +9,8 @@ TViewResult = typing.Tuple[str, typing.Iterator[TViewLine]]
class View: class View:
name: str = None name: typing.ClassVar[str]
content_types: typing.List[str] = [] content_types: typing.ClassVar[typing.List[str]] = []
def __call__(self, data: bytes, **metadata) -> TViewResult: def __call__(self, data: bytes, **metadata) -> TViewResult:
""" """

View File

@ -1,7 +1,7 @@
import io import io
import re import re
import textwrap import textwrap
from typing import Iterable from typing import Iterable, Optional
from mitmproxy.contentviews import base from mitmproxy.contentviews import base
from mitmproxy.utils import sliding_window from mitmproxy.utils import sliding_window
@ -124,14 +124,14 @@ def indent_text(data: str, prefix: str) -> str:
return textwrap.indent(dedented, prefix[:32]) return textwrap.indent(dedented, prefix[:32])
def is_inline_text(a: Token, b: Token, c: Token) -> bool: def is_inline_text(a: Optional[Token], b: Optional[Token], c: Optional[Token]) -> bool:
if isinstance(a, Tag) and isinstance(b, Text) and isinstance(c, Tag): if isinstance(a, Tag) and isinstance(b, Text) and isinstance(c, Tag):
if a.is_opening and "\n" not in b.data and c.is_closing and a.tag == c.tag: if a.is_opening and "\n" not in b.data and c.is_closing and a.tag == c.tag:
return True return True
return False return False
def is_inline(prev2: Token, prev1: Token, t: Token, next1: Token, next2: Token) -> bool: def is_inline(prev2: Optional[Token], prev1: Optional[Token], t: Optional[Token], next1: Optional[Token], next2: Optional[Token]) -> bool:
if isinstance(t, Text): if isinstance(t, Text):
return is_inline_text(prev1, t, next1) return is_inline_text(prev1, t, next1)
elif isinstance(t, Tag): elif isinstance(t, Tag):

View File

@ -1,7 +1,7 @@
import mitmproxy.master # noqa import mitmproxy.log
import mitmproxy.log # noqa import mitmproxy.master
import mitmproxy.options # noqa import mitmproxy.options
master = None # type: mitmproxy.master.Master log: "mitmproxy.log.Log"
log: mitmproxy.log.Log = None master: "mitmproxy.master.Master"
options: mitmproxy.options.Options = None options: "mitmproxy.options.Options"

View File

@ -44,7 +44,7 @@ from mitmproxy import flow
from mitmproxy.utils import strutils from mitmproxy.utils import strutils
import pyparsing as pp import pyparsing as pp
from typing import Callable, Sequence, Type # noqa from typing import Callable, Sequence, Type, Optional, ClassVar
def only(*types): def only(*types):
@ -69,8 +69,8 @@ class _Token:
class _Action(_Token): class _Action(_Token):
code: str = None code: ClassVar[str]
help: str = None help: ClassVar[str]
@classmethod @classmethod
def make(klass, s, loc, toks): def make(klass, s, loc, toks):
@ -539,7 +539,7 @@ bnf = _make()
TFilter = Callable[[flow.Flow], bool] TFilter = Callable[[flow.Flow], bool]
def parse(s: str) -> TFilter: def parse(s: str) -> Optional[TFilter]:
try: try:
flt = bnf.parseString(s, parseAll=True)[0] flt = bnf.parseString(s, parseAll=True)[0]
flt.pattern = s flt.pattern = s

View File

@ -1,15 +1,13 @@
import html import html
from typing import Optional from typing import Optional
from mitmproxy import connections
from mitmproxy import flow from mitmproxy import flow
from mitmproxy.net import http
from mitmproxy import version from mitmproxy import version
from mitmproxy import connections # noqa from mitmproxy.net import http
class HTTPRequest(http.Request): class HTTPRequest(http.Request):
""" """
A mitmproxy HTTP request. A mitmproxy HTTP request.
""" """
@ -85,10 +83,10 @@ class HTTPRequest(http.Request):
class HTTPResponse(http.Response): class HTTPResponse(http.Response):
""" """
A mitmproxy HTTP response. A mitmproxy HTTP response.
""" """
# This is a very thin wrapper on top of :py:class:`mitmproxy.net.http.Response` and # This is a very thin wrapper on top of :py:class:`mitmproxy.net.http.Response` and
# may be removed in the future. # may be removed in the future.
@ -136,35 +134,29 @@ class HTTPResponse(http.Response):
class HTTPFlow(flow.Flow): class HTTPFlow(flow.Flow):
""" """
An HTTPFlow is a collection of objects representing a single HTTP An HTTPFlow is a collection of objects representing a single HTTP
transaction. transaction.
""" """
request: HTTPRequest
def __init__(self, client_conn, server_conn, live=None, mode="regular"): response: Optional[HTTPResponse] = None
super().__init__("http", client_conn, server_conn, live) error: Optional[flow.Error] = None
"""
self.request: HTTPRequest = None
""" :py:class:`HTTPRequest` object """
self.response: HTTPResponse = None
""" :py:class:`HTTPResponse` object """
self.error: flow.Error = None
""" :py:class:`Error` object
Note that it's possible for a Flow to have both a response and an error Note that it's possible for a Flow to have both a response and an error
object. This might happen, for instance, when a response was received object. This might happen, for instance, when a response was received
from the server, but there was an error sending it back to the client. from the server, but there was an error sending it back to the client.
""" """
self.server_conn: connections.ServerConnection = server_conn server_conn: connections.ServerConnection
""" :py:class:`ServerConnection` object """ client_conn: connections.ClientConnection
self.client_conn: connections.ClientConnection = client_conn intercepted: bool = False
""":py:class:`ClientConnection` object """
self.intercepted: bool = False
""" Is this flow currently being intercepted? """ """ Is this flow currently being intercepted? """
self.mode = mode mode: str
""" What mode was the proxy layer in when receiving this request? """ """ What mode was the proxy layer in when receiving this request? """
def __init__(self, client_conn, server_conn, live=None, mode="regular"):
super().__init__("http", client_conn, server_conn, live)
self.mode = mode
_stateobject_attributes = flow.Flow._stateobject_attributes.copy() _stateobject_attributes = flow.Flow._stateobject_attributes.copy()
# mypy doesn't support update with kwargs # mypy doesn't support update with kwargs
_stateobject_attributes.update(dict( _stateobject_attributes.update(dict(
@ -205,8 +197,8 @@ class HTTPFlow(flow.Flow):
def make_error_response( def make_error_response(
status_code: int, status_code: int,
message: str="", message: str = "",
headers: Optional[http.Headers]=None, headers: Optional[http.Headers] = None,
) -> HTTPResponse: ) -> HTTPResponse:
reason = http.status_codes.RESPONSES.get(status_code, "Unknown") reason = http.status_codes.RESPONSES.get(status_code, "Unknown")
body = """ body = """

View File

@ -192,22 +192,22 @@ def parse(data_type: int, data: bytes) -> TSerializable:
try: try:
return int(data) return int(data)
except ValueError: except ValueError:
raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) raise ValueError(f"not a tnetstring: invalid integer literal: {data!r}")
if data_type == ord(b'^'): if data_type == ord(b'^'):
try: try:
return float(data) return float(data)
except ValueError: except ValueError:
raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) raise ValueError(f"not a tnetstring: invalid float literal: {data!r}")
if data_type == ord(b'!'): if data_type == ord(b'!'):
if data == b'true': if data == b'true':
return True return True
elif data == b'false': elif data == b'false':
return False return False
else: else:
raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) raise ValueError(f"not a tnetstring: invalid boolean literal: {data!r}")
if data_type == ord(b'~'): if data_type == ord(b'~'):
if data: if data:
raise ValueError("not a tnetstring: invalid null literal") raise ValueError(f"not a tnetstring: invalid null literal: {data!r}")
return None return None
if data_type == ord(b']'): if data_type == ord(b']'):
l = [] l = []
@ -236,7 +236,7 @@ def pop(data: bytes) -> typing.Tuple[TSerializable, bytes]:
blength, data = data.split(b':', 1) blength, data = data.split(b':', 1)
length = int(blength) length = int(blength)
except ValueError: except ValueError:
raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) raise ValueError(f"not a tnetstring: missing or invalid length prefix: {data!r}")
try: try:
data, data_type, remain = data[:length], data[length], data[length + 1:] data, data_type, remain = data[:length], data[length], data[length + 1:]
except IndexError: except IndexError:

View File

@ -82,7 +82,7 @@ class Message(serializable.Serializable):
def raw_content(self, content): def raw_content(self, content):
self.data.content = content self.data.content = content
def get_content(self, strict: bool=True) -> bytes: def get_content(self, strict: bool=True) -> Optional[bytes]:
""" """
The uncompressed HTTP message body as bytes. The uncompressed HTTP message body as bytes.
@ -195,10 +195,9 @@ class Message(serializable.Serializable):
See also: :py:attr:`content`, :py:class:`raw_content` See also: :py:attr:`content`, :py:class:`raw_content`
""" """
if self.raw_content is None:
return None
content = self.get_content(strict) content = self.get_content(strict)
if content is None:
return None
enc = self._guess_encoding(content) enc = self._guess_encoding(content)
try: try:
return encoding.decode(content, enc) return encoding.decode(content, enc)

View File

@ -474,8 +474,7 @@ class ClientHello:
return cls(raw_client_hello) return cls(raw_client_hello)
except EOFError as e: except EOFError as e:
raise exceptions.TlsProtocolException( raise exceptions.TlsProtocolException(
'Cannot parse Client Hello: %s, Raw Client Hello: %s' % f"Cannot parse Client Hello: {e!r}, Raw Client Hello: {binascii.hexlify(raw_client_hello)!r}"
(repr(e), binascii.hexlify(raw_client_hello))
) )
def __repr__(self): def __repr__(self):

View File

@ -551,7 +551,9 @@ def serialize(opts: OptManager, text: str, defaults: bool = False) -> str:
for k in list(data.keys()): for k in list(data.keys()):
if k not in opts._options: if k not in opts._options:
del data[k] del data[k]
return ruamel.yaml.round_trip_dump(data) ret = ruamel.yaml.round_trip_dump(data)
assert ret
return ret
def save(opts: OptManager, path: str, defaults: bool =False) -> None: def save(opts: OptManager, path: str, defaults: bool =False) -> None:

View File

@ -1,7 +1,7 @@
import re import re
import socket import socket
import sys import sys
from typing import Tuple from typing import Callable, Optional, Tuple
def init_transparent_mode() -> None: def init_transparent_mode() -> None:
@ -10,30 +10,34 @@ def init_transparent_mode() -> None:
""" """
def original_addr(csock: socket.socket) -> Tuple[str, int]: original_addr: Optional[Callable[[socket.socket], Tuple[str, int]]]
""" """
Get the original destination for the given socket. Get the original destination for the given socket.
This function will be None if transparent mode is not supported. This function will be None if transparent mode is not supported.
""" """
if re.match(r"linux(?:2)?", sys.platform): if re.match(r"linux(?:2)?", sys.platform):
from . import linux from . import linux
original_addr = linux.original_addr # noqa original_addr = linux.original_addr
elif sys.platform == "darwin" or sys.platform.startswith("freebsd"): elif sys.platform == "darwin" or sys.platform.startswith("freebsd"):
from . import osx from . import osx
original_addr = osx.original_addr # noqa original_addr = osx.original_addr
elif sys.platform.startswith("openbsd"): elif sys.platform.startswith("openbsd"):
from . import openbsd from . import openbsd
original_addr = openbsd.original_addr # noqa original_addr = openbsd.original_addr
elif sys.platform == "win32": elif sys.platform == "win32":
from . import windows from . import windows
resolver = windows.Resolver() resolver = windows.Resolver()
init_transparent_mode = resolver.setup # noqa init_transparent_mode = resolver.setup # noqa
original_addr = resolver.original_addr # noqa original_addr = resolver.original_addr
else: else:
original_addr = None # noqa original_addr = None
__all__ = [
"original_addr",
"init_transparent_mode"
]

View File

@ -34,9 +34,9 @@ class ProxyConfig:
def __init__(self, options: moptions.Options) -> None: def __init__(self, options: moptions.Options) -> None:
self.options = options self.options = options
self.check_filter: HostMatcher = None self.certstore: certs.CertStore
self.check_tcp: HostMatcher = None self.check_filter: typing.Optional[HostMatcher] = None
self.certstore: certs.CertStore = None self.check_tcp: typing.Optional[HostMatcher] = None
self.upstream_server: typing.Optional[server_spec.ServerSpec] = None self.upstream_server: typing.Optional[server_spec.ServerSpec] = None
self.configure(options, set(options.keys())) self.configure(options, set(options.keys()))
options.changed.connect(self.configure) options.changed.connect(self.configure)

View File

@ -1,7 +1,7 @@
import threading import threading
import time import time
import functools import functools
from typing import Dict, Callable, Any, List # noqa from typing import Dict, Callable, Any, List, Optional # noqa
import h2.exceptions import h2.exceptions
from h2 import connection from h2 import connection
@ -382,15 +382,15 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
ctx, name="Http2SingleStreamLayer-{}".format(stream_id) ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
) )
self.h2_connection = h2_connection self.h2_connection = h2_connection
self.zombie: float = None self.zombie: Optional[float] = None
self.client_stream_id: int = stream_id self.client_stream_id: int = stream_id
self.server_stream_id: int = None self.server_stream_id: Optional[int] = None
self.request_headers = request_headers self.request_headers = request_headers
self.response_headers: mitmproxy.net.http.Headers = None self.response_headers: Optional[mitmproxy.net.http.Headers] = None
self.pushed = False self.pushed = False
self.timestamp_start: float = None self.timestamp_start: Optional[float] = None
self.timestamp_end: float = None self.timestamp_end: Optional[float] = None
self.request_arrived = threading.Event() self.request_arrived = threading.Event()
self.request_data_queue: queue.Queue[bytes] = queue.Queue() self.request_data_queue: queue.Queue[bytes] = queue.Queue()
@ -404,9 +404,9 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.no_body = False self.no_body = False
self.priority_exclusive: bool = None self.priority_exclusive: bool
self.priority_depends_on: int = None self.priority_depends_on: Optional[int] = None
self.priority_weight: int = None self.priority_weight: Optional[int] = None
self.handled_priority_event: Any = None self.handled_priority_event: Any = None
def kill(self): def kill(self):

View File

@ -198,12 +198,12 @@ CIPHER_ID_NAME_MAP = {
# We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default. # We manually need to specify this, otherwise OpenSSL may select a non-HTTP2 cipher by default.
# https://ssl-config.mozilla.org/#config=old # https://ssl-config.mozilla.org/#config=old
DEFAULT_CLIENT_CIPHERS = ( DEFAULT_CLIENT_CIPHERS = (
b"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:" "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:"
b"ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:" "ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:"
b"DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:" "DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:"
b"ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:" "ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:"
b"ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA256:AES128-GCM-SHA256:" "ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA256:AES128-GCM-SHA256:"
b"AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA" "AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA"
) )
@ -320,14 +320,18 @@ class TlsLayer(base.Layer):
return self._server_tls return self._server_tls
@property @property
def server_sni(self): def server_sni(self) -> Optional[str]:
""" """
The Server Name Indication we want to send with the next server TLS handshake. The Server Name Indication we want to send with the next server TLS handshake.
""" """
if self._custom_server_sni is False: if self._custom_server_sni is False:
return None return None
elif self._custom_server_sni:
return self._custom_server_sni
elif self._client_hello and self._client_hello.sni:
return self._client_hello.sni.decode("idna")
else: else:
return self._custom_server_sni or self._client_hello and self._client_hello.sni.decode("idna") return None
@property @property
def alpn_for_client_connection(self): def alpn_for_client_connection(self):
@ -388,11 +392,12 @@ class TlsLayer(base.Layer):
# raises ann error. # raises ann error.
self.client_conn.rfile.peek(1) self.client_conn.rfile.peek(1)
except exceptions.TlsException as e: except exceptions.TlsException as e:
sni_str = self._client_hello.sni and self._client_hello.sni.decode("idna")
raise exceptions.ClientHandshakeException( raise exceptions.ClientHandshakeException(
"Cannot establish TLS with client (sni: {sni}): {e}".format( "Cannot establish TLS with client (sni: {sni}): {e}".format(
sni=self._client_hello.sni.decode("idna"), e=repr(e) sni=sni_str, e=repr(e)
), ),
self._client_hello.sni.decode("idna") or repr(self.server_conn.address) sni_str or repr(self.server_conn.address)
) )
def _establish_tls_with_server(self): def _establish_tls_with_server(self):

View File

@ -57,7 +57,8 @@ class RootContext:
except exceptions.TlsProtocolException as e: except exceptions.TlsProtocolException as e:
self.log("Cannot parse Client Hello: %s" % repr(e), "error") self.log("Cannot parse Client Hello: %s" % repr(e), "error")
else: else:
is_filtered = self.config.check_filter((client_hello.sni.decode("idna"), 443)) sni_str = client_hello.sni and client_hello.sni.decode("idna")
is_filtered = self.config.check_filter((sni_str, 443))
if is_filtered: if is_filtered:
return protocol.RawTCPLayer(top_layer, ignore=True) return protocol.RawTCPLayer(top_layer, ignore=True)

View File

@ -35,6 +35,7 @@ class DummyServer:
class ProxyServer(tcp.TCPServer): class ProxyServer(tcp.TCPServer):
allow_reuse_address = True allow_reuse_address = True
bound = True bound = True
channel: controller.Channel
def __init__(self, config: config.ProxyConfig) -> None: def __init__(self, config: config.ProxyConfig) -> None:
""" """
@ -53,7 +54,6 @@ class ProxyServer(tcp.TCPServer):
raise exceptions.ServerException( raise exceptions.ServerException(
'Error starting proxy server: ' + repr(e) 'Error starting proxy server: ' + repr(e)
) from e ) from e
self.channel: controller.Channel = None
def set_channel(self, channel): def set_channel(self, channel):
self.channel = channel self.channel = channel

View File

@ -1,7 +1,5 @@
import typing
from typing import Any # noqa
from typing import MutableMapping # noqa
import json import json
import typing
from mitmproxy.coretypes import serializable from mitmproxy.coretypes import serializable
from mitmproxy.utils import typecheck from mitmproxy.utils import typecheck
@ -15,7 +13,7 @@ class StateObject(serializable.Serializable):
or StateObject instances themselves. or StateObject instances themselves.
""" """
_stateobject_attributes: MutableMapping[str, Any] = None _stateobject_attributes: typing.ClassVar[typing.MutableMapping[str, typing.Any]]
""" """
An attribute-name -> class-or-type dict containing all attributes that An attribute-name -> class-or-type dict containing all attributes that
should be serialized. If the attribute is a class, it must implement the should be serialized. If the attribute is a class, it must implement the
@ -42,7 +40,7 @@ class StateObject(serializable.Serializable):
if val is None: if val is None:
setattr(self, attr, val) setattr(self, attr, val)
else: else:
curr = getattr(self, attr) curr = getattr(self, attr, None)
if hasattr(curr, "set_state"): if hasattr(curr, "set_state"):
curr.set_state(val) curr.set_state(val)
else: else:

View File

@ -6,17 +6,16 @@ Feel free to import and use whatever new package you deem necessary.
import os import os
import sys import sys
import asyncio import asyncio
import argparse # noqa import argparse
import signal # noqa import signal
import typing # noqa import typing
from mitmproxy.tools import cmdline # noqa from mitmproxy.tools import cmdline
from mitmproxy import exceptions, master # noqa from mitmproxy import exceptions, master
from mitmproxy import options # noqa from mitmproxy import options
from mitmproxy import optmanager # noqa from mitmproxy import optmanager
from mitmproxy import proxy # noqa from mitmproxy import proxy
from mitmproxy import log # noqa from mitmproxy.utils import debug, arg_check
from mitmproxy.utils import debug, arg_check # noqa
OPTIONS_FILE_NAME = "config.yaml" OPTIONS_FILE_NAME = "config.yaml"

View File

@ -55,7 +55,7 @@ class CommandBuffer:
self.text = self.flatten(start) self.text = self.flatten(start)
# Cursor is always within the range [0:len(buffer)]. # Cursor is always within the range [0:len(buffer)].
self._cursor = len(self.text) self._cursor = len(self.text)
self.completion: CompletionState = None self.completion: typing.Optional[CompletionState] = None
@property @property
def cursor(self) -> int: def cursor(self) -> int:

View File

@ -38,7 +38,7 @@ KEY_MAX = 30
def format_keyvals( def format_keyvals(
entries: typing.List[typing.Tuple[str, typing.Union[None, str, urwid.Widget]]], entries: typing.Iterable[typing.Tuple[str, typing.Union[None, str, urwid.Widget]]],
key_format: str = "key", key_format: str = "key",
value_format: str = "text", value_format: str = "text",
indent: int = 0 indent: int = 0

View File

@ -254,7 +254,7 @@ FIRST_WIDTH_MAX = 40
class BaseGridEditor(urwid.WidgetWrap): class BaseGridEditor(urwid.WidgetWrap):
title = "" title: str = ""
keyctx = "grideditor" keyctx = "grideditor"
def __init__( def __init__(
@ -402,8 +402,8 @@ class BaseGridEditor(urwid.WidgetWrap):
class GridEditor(BaseGridEditor): class GridEditor(BaseGridEditor):
title: str = None title = ""
columns: typing.Sequence[Column] = None columns: typing.Sequence[Column] = ()
keyctx = "grideditor" keyctx = "grideditor"
def __init__( def __init__(

View File

@ -107,7 +107,7 @@ class CookieAttributeEditor(base.FocusEditor):
col_text.Column("Name"), col_text.Column("Name"),
col_text.Column("Value"), col_text.Column("Value"),
] ]
grideditor: base.BaseGridEditor = None grideditor: base.BaseGridEditor
def data_in(self, data): def data_in(self, data):
return [(k, v or "") for k, v in data] return [(k, v or "") for k, v in data]
@ -169,7 +169,7 @@ class SetCookieEditor(base.FocusEditor):
class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget): class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget):
title: str = None title = ""
columns = [ columns = [
col_text.Column("") col_text.Column("")
] ]
@ -189,7 +189,7 @@ class OptionsEditor(base.GridEditor, layoutwidget.LayoutWidget):
class DataViewer(base.GridEditor, layoutwidget.LayoutWidget): class DataViewer(base.GridEditor, layoutwidget.LayoutWidget):
title: str = None title = ""
def __init__( def __init__(
self, self,

View File

@ -42,7 +42,7 @@ class Palette:
'commander_command', 'commander_invalid', 'commander_hint' 'commander_command', 'commander_invalid', 'commander_hint'
] ]
_fields.extend(['gradient_%02d' % i for i in range(100)]) _fields.extend(['gradient_%02d' % i for i in range(100)])
high: typing.Mapping[str, typing.Sequence[str]] = None high: typing.Optional[typing.Mapping[str, typing.Sequence[str]]] = None
def palette(self, transparent): def palette(self, transparent):
l = [] l = []

View File

@ -5,6 +5,7 @@ import logging
import os.path import os.path
import re import re
from io import BytesIO from io import BytesIO
from typing import ClassVar, Optional
import tornado.escape import tornado.escape
import tornado.web import tornado.web
@ -50,6 +51,8 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:
f["error"] = flow.error.get_state() f["error"] = flow.error.get_state()
if isinstance(flow, http.HTTPFlow): if isinstance(flow, http.HTTPFlow):
content_length: Optional[int]
content_hash: Optional[str]
if flow.request: if flow.request:
if flow.request.raw_content: if flow.request.raw_content:
content_length = len(flow.request.raw_content) content_length = len(flow.request.raw_content)
@ -193,7 +196,7 @@ class FilterHelp(RequestHandler):
class WebSocketEventBroadcaster(tornado.websocket.WebSocketHandler): class WebSocketEventBroadcaster(tornado.websocket.WebSocketHandler):
# raise an error if inherited class doesn't specify its own instance. # raise an error if inherited class doesn't specify its own instance.
connections: set = None connections: ClassVar[set]
def open(self): def open(self):
self.connections.add(self) self.connections.add(self)
@ -213,7 +216,7 @@ class WebSocketEventBroadcaster(tornado.websocket.WebSocketHandler):
class ClientConnection(WebSocketEventBroadcaster): class ClientConnection(WebSocketEventBroadcaster):
connections: set = set() connections: ClassVar[set] = set()
class Flows(RequestHandler): class Flows(RequestHandler):

View File

@ -423,7 +423,7 @@ class TypeManager:
for t in types: for t in types:
self.typemap[t.typ] = t() self.typemap[t.typ] = t()
def get(self, t: type, default=None) -> _BaseType: def get(self, t: typing.Optional[typing.Type], default=None) -> _BaseType:
if type(t) in self.typemap: if type(t) in self.typemap:
return self.typemap[type(t)] return self.typemap[type(t)]
return self.typemap.get(t, default) return self.typemap.get(t, default)

View File

@ -1,5 +1,5 @@
import itertools import itertools
from typing import TypeVar, Iterable, Iterator, Tuple, Optional from typing import TypeVar, Iterable, Iterator, Tuple, Optional, List
T = TypeVar('T') T = TypeVar('T')
@ -18,7 +18,7 @@ def window(iterator: Iterable[T], behind: int = 0, ahead: int = 0) -> Iterator[T
2 3 None 2 3 None
""" """
# TODO: move into utils # TODO: move into utils
iters = list(itertools.tee(iterator, behind + 1 + ahead)) iters: List[Iterator[Optional[T]]] = list(itertools.tee(iterator, behind + 1 + ahead))
for i in range(behind): for i in range(behind):
iters[i] = itertools.chain((behind - i) * [None], iters[i]) iters[i] = itertools.chain((behind - i) * [None], iters[i])
for i in range(ahead): for i in range(ahead):

View File

@ -1,10 +1,10 @@
import codecs
import io import io
import re import re
import codecs from typing import Iterable, Optional, Union, cast
from typing import AnyStr, Optional, cast, Iterable
def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes]: def always_bytes(str_or_bytes: Union[str, bytes, None], *encode_args) -> Optional[bytes]:
if isinstance(str_or_bytes, bytes) or str_or_bytes is None: if isinstance(str_or_bytes, bytes) or str_or_bytes is None:
return cast(Optional[bytes], str_or_bytes) return cast(Optional[bytes], str_or_bytes)
elif isinstance(str_or_bytes, str): elif isinstance(str_or_bytes, str):
@ -13,13 +13,15 @@ def always_bytes(str_or_bytes: Optional[AnyStr], *encode_args) -> Optional[bytes
raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__)) raise TypeError("Expected str or bytes, but got {}.".format(type(str_or_bytes).__name__))
def always_str(str_or_bytes: Optional[AnyStr], *decode_args) -> Optional[str]: def always_str(str_or_bytes: Union[str, bytes, None], *decode_args) -> Optional[str]:
""" """
Returns, Returns,
str_or_bytes unmodified, if str_or_bytes unmodified, if
""" """
if isinstance(str_or_bytes, str) or str_or_bytes is None: if str_or_bytes is None:
return cast(Optional[str], str_or_bytes) return None
if isinstance(str_or_bytes, str):
return cast(str, str_or_bytes)
elif isinstance(str_or_bytes, bytes): elif isinstance(str_or_bytes, bytes):
return str_or_bytes.decode(*decode_args) return str_or_bytes.decode(*decode_args)
else: else:
@ -39,7 +41,6 @@ _control_char_trans_newline = _control_char_trans.copy()
for x in ("\r", "\n", "\t"): for x in ("\r", "\n", "\t"):
del _control_char_trans_newline[ord(x)] del _control_char_trans_newline[ord(x)]
_control_char_trans = str.maketrans(_control_char_trans) _control_char_trans = str.maketrans(_control_char_trans)
_control_char_trans_newline = str.maketrans(_control_char_trans_newline) _control_char_trans_newline = str.maketrans(_control_char_trans_newline)

View File

@ -25,9 +25,9 @@ def get_dev_version() -> str:
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
cwd=here, cwd=here,
) )
last_tag, tag_dist, commit = git_describe.decode().strip().rsplit("-", 2) last_tag, tag_dist_str, commit = git_describe.decode().strip().rsplit("-", 2)
commit = commit.lstrip("g")[:7] commit = commit.lstrip("g")[:7]
tag_dist = int(tag_dist) tag_dist = int(tag_dist_str)
except Exception: except Exception:
pass pass
else: else:

View File

@ -25,7 +25,7 @@ class Daemon:
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, type, value, traceback) -> bool: def __exit__(self, type, value, traceback):
self.logfp.truncate(0) self.logfp.truncate(0)
self.shutdown() self.shutdown()
return False return False

View File

@ -72,7 +72,7 @@ setup(
"kaitaistruct>=0.7,<0.9", "kaitaistruct>=0.7,<0.9",
"ldap3>=2.6.1,<2.7", "ldap3>=2.6.1,<2.7",
"passlib>=1.6.5, <1.8", "passlib>=1.6.5, <1.8",
"protobuf>=3.6.0, <3.10", "protobuf>=3.6.0, <3.11",
"pyasn1>=0.3.1,<0.5", "pyasn1>=0.3.1,<0.5",
"pyOpenSSL>=19.0.0,<20", "pyOpenSSL>=19.0.0,<20",
"pyparsing>=2.4.2,<2.5", "pyparsing>=2.4.2,<2.5",
@ -93,7 +93,7 @@ setup(
"asynctest>=0.12.0", "asynctest>=0.12.0",
"flake8>=3.7.8,<3.8", "flake8>=3.7.8,<3.8",
"Flask>=1.0,<1.2", "Flask>=1.0,<1.2",
"mypy>=0.590,<0.591", "mypy>=0.740,<0.741",
"parver>=0.1,<2.0", "parver>=0.1,<2.0",
"pytest-asyncio>=0.10.0,<0.11", "pytest-asyncio>=0.10.0,<0.11",
"pytest-cov>=2.7.1,<3", "pytest-cov>=2.7.1,<3",

View File

@ -95,7 +95,7 @@ class TestServerConnection:
def test_repr(self): def test_repr(self):
c = tflow.tserver_conn() c = tflow.tserver_conn()
c.sni = b'foobar' c.sni = 'foobar'
c.tls_established = True c.tls_established = True
c.alpn_proto_negotiated = b'h2' c.alpn_proto_negotiated = b'h2'
assert 'address:22' in repr(c) assert 'address:22' in repr(c)