encoding: add support for zstd (zstandard)

Handles zstandard-compressed bodies labeled as zstd.
This commit is contained in:
Tero Saaristo 2019-09-05 17:52:04 +03:00
parent 0b0b4ccba6
commit dd3589ce34
6 changed files with 27 additions and 5 deletions

View File

@ -289,7 +289,7 @@ class Core:
""" """
The possible values for an encoding specification. The possible values for an encoding specification.
""" """
return ["gzip", "deflate", "br"] return ["gzip", "deflate", "br", "zstd"]
@command.command("options.load") @command.command("options.load")
def options_load(self, path: mitmproxy.types.Path) -> None: def options_load(self, path: mitmproxy.types.Path) -> None:

View File

@ -9,6 +9,7 @@ from io import BytesIO
import gzip import gzip
import zlib import zlib
import brotli import brotli
import zstandard as zstd
from typing import Union, Optional, AnyStr # noqa from typing import Union, Optional, AnyStr # noqa
@ -52,7 +53,7 @@ def decode(
decoded = custom_decode[encoding](encoded) decoded = custom_decode[encoding](encoded)
except KeyError: except KeyError:
decoded = codecs.decode(encoded, encoding, errors) decoded = codecs.decode(encoded, encoding, errors)
if encoding in ("gzip", "deflate", "br"): if encoding in ("gzip", "deflate", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded) _cache = CachedDecode(encoded, encoding, errors, decoded)
return decoded return decoded
except TypeError: except TypeError:
@ -93,7 +94,7 @@ def encode(decoded: Optional[str], encoding: str, errors: str='strict') -> Optio
encoded = custom_encode[encoding](decoded) encoded = custom_encode[encoding](decoded)
except KeyError: except KeyError:
encoded = codecs.encode(decoded, encoding, errors) encoded = codecs.encode(decoded, encoding, errors)
if encoding in ("gzip", "deflate", "br"): if encoding in ("gzip", "deflate", "br", "zstd"):
_cache = CachedDecode(encoded, encoding, errors, decoded) _cache = CachedDecode(encoded, encoding, errors, decoded)
return encoded return encoded
except TypeError: except TypeError:
@ -140,6 +141,23 @@ def encode_brotli(content: bytes) -> bytes:
return brotli.compress(content) return brotli.compress(content)
def decode_zstd(content: bytes) -> bytes:
if not content:
return b""
zstd_ctx = zstd.ZstdDecompressor()
try:
return zstd_ctx.decompress(content)
except zstd.ZstdError:
# If the zstd stream is streamed without a size header,
# try decoding with a 10MiB output buffer
return zstd_ctx.decompress(content, max_output_size=10 * 2**20)
def encode_zstd(content: bytes) -> bytes:
zstd_ctx = zstd.ZstdCompressor()
return zstd_ctx.compress(content)
def decode_deflate(content: bytes) -> bytes: def decode_deflate(content: bytes) -> bytes:
""" """
Returns decompressed data for DEFLATE. Some servers may respond with Returns decompressed data for DEFLATE. Some servers may respond with
@ -170,6 +188,7 @@ custom_decode = {
"gzip": decode_gzip, "gzip": decode_gzip,
"deflate": decode_deflate, "deflate": decode_deflate,
"br": decode_brotli, "br": decode_brotli,
"zstd": decode_zstd,
} }
custom_encode = { custom_encode = {
"none": identity, "none": identity,
@ -177,6 +196,7 @@ custom_encode = {
"gzip": encode_gzip, "gzip": encode_gzip,
"deflate": encode_deflate, "deflate": encode_deflate,
"br": encode_brotli, "br": encode_brotli,
"zstd": encode_zstd,
} }
__all__ = ["encode", "decode"] __all__ = ["encode", "decode"]

View File

@ -236,7 +236,7 @@ class Message(serializable.Serializable):
def encode(self, e): def encode(self, e):
""" """
Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br". Encodes body with the encoding e, where e is "gzip", "deflate", "identity", "br", or "zstd".
Any existing content-encodings are overwritten, Any existing content-encodings are overwritten,
the content is not decoded beforehand. the content is not decoded beforehand.

View File

@ -421,7 +421,7 @@ class Request(message.Message):
self.headers["accept-encoding"] = ( self.headers["accept-encoding"] = (
', '.join( ', '.join(
e e
for e in {"gzip", "identity", "deflate", "br"} for e in {"gzip", "identity", "deflate", "br", "zstd"}
if e in accept_encoding if e in accept_encoding
) )
) )

View File

@ -81,6 +81,7 @@ setup(
"tornado>=4.3,<5.2", "tornado>=4.3,<5.2",
"urwid>=2.0.1,<2.1", "urwid>=2.0.1,<2.1",
"wsproto>=0.13.0,<0.14.0", "wsproto>=0.13.0,<0.14.0",
"zstandard>=0.11.0,<0.13.0",
], ],
extras_require={ extras_require={
':sys_platform == "win32"': [ ':sys_platform == "win32"': [

View File

@ -19,6 +19,7 @@ def test_identity(encoder):
'gzip', 'gzip',
'br', 'br',
'deflate', 'deflate',
'zstd',
]) ])
def test_encoders(encoder): def test_encoders(encoder):
""" """