diff --git a/mitmproxy/addons/core.py b/mitmproxy/addons/core.py index a908dbb31..5c9bbcd09 100644 --- a/mitmproxy/addons/core.py +++ b/mitmproxy/addons/core.py @@ -289,7 +289,7 @@ class Core: """ The possible values for an encoding specification. """ - return ["gzip", "deflate", "br"] + return ["gzip", "deflate", "br", "zstd"] @command.command("options.load") def options_load(self, path: mitmproxy.types.Path) -> None: diff --git a/mitmproxy/net/http/encoding.py b/mitmproxy/net/http/encoding.py index 8cb96e5c3..16d399ca6 100644 --- a/mitmproxy/net/http/encoding.py +++ b/mitmproxy/net/http/encoding.py @@ -9,6 +9,7 @@ from io import BytesIO import gzip import zlib import brotli +import zstandard as zstd from typing import Union, Optional, AnyStr # noqa @@ -52,7 +53,7 @@ def decode( decoded = custom_decode[encoding](encoded) except KeyError: decoded = codecs.decode(encoded, encoding, errors) - if encoding in ("gzip", "deflate", "br"): + if encoding in ("gzip", "deflate", "br", "zstd"): _cache = CachedDecode(encoded, encoding, errors, decoded) return decoded except TypeError: @@ -93,7 +94,7 @@ def encode(decoded: Optional[str], encoding: str, errors: str='strict') -> Optio encoded = custom_encode[encoding](decoded) except KeyError: encoded = codecs.encode(decoded, encoding, errors) - if encoding in ("gzip", "deflate", "br"): + if encoding in ("gzip", "deflate", "br", "zstd"): _cache = CachedDecode(encoded, encoding, errors, decoded) return encoded except TypeError: @@ -140,6 +141,23 @@ def encode_brotli(content: bytes) -> bytes: return brotli.compress(content) +def decode_zstd(content: bytes) -> bytes: + if not content: + return b"" + zstd_ctx = zstd.ZstdDecompressor() + try: + return zstd_ctx.decompress(content) + except zstd.ZstdError: + # If the zstd stream is streamed without a size header, + # try decoding with a 10MiB output buffer + return zstd_ctx.decompress(content, max_output_size=10 * 2**20) + + +def encode_zstd(content: bytes) -> bytes: + zstd_ctx = zstd.ZstdCompressor() + return zstd_ctx.compress(content) + + def decode_deflate(content: bytes) -> bytes: """ Returns decompressed data for DEFLATE. Some servers may respond with @@ -170,6 +188,7 @@ custom_decode = { "gzip": decode_gzip, "deflate": decode_deflate, "br": decode_brotli, + "zstd": decode_zstd, } custom_encode = { "none": identity, @@ -177,6 +196,7 @@ custom_encode = { "gzip": encode_gzip, "deflate": encode_deflate, "br": encode_brotli, + "zstd": encode_zstd, } __all__ = ["encode", "decode"] diff --git a/mitmproxy/net/http/message.py b/mitmproxy/net/http/message.py index 86782e8ac..6830c6cd3 100644 --- a/mitmproxy/net/http/message.py +++ b/mitmproxy/net/http/message.py @@ -236,7 +236,7 @@ class Message(serializable.Serializable): def encode(self, e): """ - Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br". + Encodes body with the encoding e, where e is "gzip", "deflate", "identity", "br", or "zstd". Any existing content-encodings are overwritten, the content is not decoded beforehand. diff --git a/mitmproxy/net/http/request.py b/mitmproxy/net/http/request.py index 959fdd339..3516e9496 100644 --- a/mitmproxy/net/http/request.py +++ b/mitmproxy/net/http/request.py @@ -421,7 +421,7 @@ class Request(message.Message): self.headers["accept-encoding"] = ( ', '.join( e - for e in {"gzip", "identity", "deflate", "br"} + for e in {"gzip", "identity", "deflate", "br", "zstd"} if e in accept_encoding ) ) diff --git a/setup.py b/setup.py index 7f83de631..12439f4eb 100644 --- a/setup.py +++ b/setup.py @@ -81,6 +81,7 @@ setup( "tornado>=4.3,<5.2", "urwid>=2.0.1,<2.1", "wsproto>=0.13.0,<0.14.0", + "zstandard>=0.11.0,<0.13.0", ], extras_require={ ':sys_platform == "win32"': [ diff --git a/test/mitmproxy/net/http/test_encoding.py b/test/mitmproxy/net/http/test_encoding.py index 8dac12cbc..7f768f39d 100644 --- a/test/mitmproxy/net/http/test_encoding.py +++ b/test/mitmproxy/net/http/test_encoding.py @@ -19,6 +19,7 @@ def test_identity(encoder): 'gzip', 'br', 'deflate', + 'zstd', ]) def test_encoders(encoder): """