mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
improve .replace() and move it into netlib
This commit is contained in:
parent
4ee8808b44
commit
806aa0f41c
@ -26,30 +26,6 @@ class MessageMixin(object):
|
||||
return self.content
|
||||
return encoding.decode(ce, self.content)
|
||||
|
||||
def replace(self, pattern, repl, *args, **kwargs):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in both the headers
|
||||
and the body of the message. Encoded body will be decoded
|
||||
before replacement, and re-encoded afterwards.
|
||||
|
||||
Returns the number of replacements made.
|
||||
"""
|
||||
count = 0
|
||||
if self.content:
|
||||
with decoded(self):
|
||||
self.content, count = utils.safe_subn(
|
||||
pattern, repl, self.content, *args, **kwargs
|
||||
)
|
||||
fields = []
|
||||
for name, value in self.headers.fields:
|
||||
name, c = utils.safe_subn(pattern, repl, name, *args, **kwargs)
|
||||
count += c
|
||||
value, c = utils.safe_subn(pattern, repl, value, *args, **kwargs)
|
||||
count += c
|
||||
fields.append([name, value])
|
||||
self.headers.fields = fields
|
||||
return count
|
||||
|
||||
|
||||
class HTTPRequest(MessageMixin, Request):
|
||||
|
||||
@ -186,22 +162,6 @@ class HTTPRequest(MessageMixin, Request):
|
||||
def set_auth(self, auth):
|
||||
self.data.headers.set_all("Proxy-Authorization", (auth,))
|
||||
|
||||
def replace(self, pattern, repl, *args, **kwargs):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in the headers, the
|
||||
request path and the body of the request. Encoded content will be
|
||||
decoded before replacement, and re-encoded afterwards.
|
||||
|
||||
Returns the number of replacements made.
|
||||
"""
|
||||
c = MessageMixin.replace(self, pattern, repl, *args, **kwargs)
|
||||
self.path, pc = utils.safe_subn(
|
||||
pattern, repl, self.path, *args, **kwargs
|
||||
)
|
||||
c += pc
|
||||
return c
|
||||
|
||||
|
||||
class HTTPResponse(MessageMixin, Response):
|
||||
|
||||
"""
|
||||
|
@ -165,12 +165,3 @@ def parse_size(s):
|
||||
return int(s) * mult
|
||||
except ValueError:
|
||||
raise ValueError("Invalid size specification: %s" % s)
|
||||
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
"""
|
||||
There are Unicode conversion problems with re.subn. We try to smooth
|
||||
that over by casting the pattern and replacement to strings. We really
|
||||
need a better solution that is aware of the actual content ecoding.
|
||||
"""
|
||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
|
@ -6,6 +6,8 @@ See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
|
||||
"""
|
||||
from __future__ import absolute_import, print_function, division
|
||||
|
||||
import re
|
||||
|
||||
try:
|
||||
from collections.abc import MutableMapping
|
||||
except ImportError: # pragma: no cover
|
||||
@ -198,4 +200,31 @@ class Headers(MutableMapping, Serializable):
|
||||
|
||||
@classmethod
|
||||
def from_state(cls, state):
|
||||
return cls([list(field) for field in state])
|
||||
return cls([list(field) for field in state])
|
||||
|
||||
@_always_byte_args
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in each "name: value"
|
||||
header line.
|
||||
|
||||
Returns:
|
||||
The number of replacements made.
|
||||
"""
|
||||
pattern = re.compile(pattern, flags)
|
||||
replacements = 0
|
||||
|
||||
fields = []
|
||||
for name, value in self.fields:
|
||||
line, n = pattern.subn(repl, name + b": " + value)
|
||||
try:
|
||||
name, value = line.split(b": ", 1)
|
||||
except ValueError:
|
||||
# We get a ValueError if the replacement removed the ": "
|
||||
# There's not much we can do about this, so we just keep the header as-is.
|
||||
pass
|
||||
else:
|
||||
replacements += n
|
||||
fields.append([name, value])
|
||||
self.fields = fields
|
||||
return replacements
|
||||
|
@ -175,6 +175,25 @@ class Message(utils.Serializable):
|
||||
self.headers["content-encoding"] = e
|
||||
return True
|
||||
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in both the headers
|
||||
and the body of the message. Encoded body will be decoded
|
||||
before replacement, and re-encoded afterwards.
|
||||
|
||||
Returns:
|
||||
The number of replacements made.
|
||||
"""
|
||||
# TODO: Proper distinction between text and bytes.
|
||||
replacements = 0
|
||||
if self.content:
|
||||
with decoded(self):
|
||||
self.content, replacements = utils.safe_subn(
|
||||
pattern, repl, self.content, flags=flags
|
||||
)
|
||||
replacements += self.headers.replace(pattern, repl, flags)
|
||||
return replacements
|
||||
|
||||
# Legacy
|
||||
|
||||
@property
|
||||
|
@ -54,6 +54,23 @@ class Request(Message):
|
||||
self.method, hostport, path
|
||||
)
|
||||
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in the headers, the
|
||||
request path and the body of the request. Encoded content will be
|
||||
decoded before replacement, and re-encoded afterwards.
|
||||
|
||||
Returns:
|
||||
The number of replacements made.
|
||||
"""
|
||||
# TODO: Proper distinction between text and bytes.
|
||||
c = super(Request, self).replace(pattern, repl, flags)
|
||||
self.path, pc = utils.safe_subn(
|
||||
pattern, repl, self.path, flags=flags
|
||||
)
|
||||
c += pc
|
||||
return c
|
||||
|
||||
@property
|
||||
def first_line_format(self):
|
||||
"""
|
||||
|
@ -1,18 +1,8 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import re
|
||||
import copy
|
||||
import six
|
||||
|
||||
from .utils import Serializable
|
||||
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
"""
|
||||
There are Unicode conversion problems with re.subn. We try to smooth
|
||||
that over by casting the pattern and replacement to strings. We really
|
||||
need a better solution that is aware of the actual content ecoding.
|
||||
"""
|
||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
from .utils import Serializable, safe_subn
|
||||
|
||||
|
||||
class ODict(Serializable):
|
||||
|
@ -414,8 +414,18 @@ def http2_read_raw_frame(rfile):
|
||||
body = rfile.safe_read(length)
|
||||
return [header, body]
|
||||
|
||||
|
||||
def http2_read_frame(rfile):
|
||||
header, body = http2_read_raw_frame(rfile)
|
||||
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
|
||||
frame.parse_body(memoryview(body))
|
||||
return frame
|
||||
|
||||
|
||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||
"""
|
||||
There are Unicode conversion problems with re.subn. We try to smooth
|
||||
that over by casting the pattern and replacement to strings. We really
|
||||
need a better solution that is aware of the actual content ecoding.
|
||||
"""
|
||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||
|
@ -99,7 +99,3 @@ def test_parse_size():
|
||||
assert utils.parse_size("1g") == 1024**3
|
||||
tutils.raises(ValueError, utils.parse_size, "1f")
|
||||
tutils.raises(ValueError, utils.parse_size, "ak")
|
||||
|
||||
|
||||
def test_safe_subn():
|
||||
assert utils.safe_subn("foo", u"bar", "\xc2foo")
|
||||
|
@ -150,3 +150,22 @@ class TestHeaders(object):
|
||||
assert headers != headers2
|
||||
headers2.set_state(headers.get_state())
|
||||
assert headers == headers2
|
||||
|
||||
def test_replace_simple(self):
|
||||
headers = Headers(Host="example.com", Accept="text/plain")
|
||||
replacements = headers.replace("Host: ", "X-Host: ")
|
||||
assert replacements == 1
|
||||
assert headers["X-Host"] == "example.com"
|
||||
assert "Host" not in headers
|
||||
assert headers["Accept"] == "text/plain"
|
||||
|
||||
def test_replace_multi(self):
|
||||
headers = self._2host()
|
||||
headers.replace(r"Host: example\.com", r"Host: example.de")
|
||||
assert headers.get_all("Host") == ["example.de", "example.org"]
|
||||
|
||||
def test_replace_remove_spacer(self):
|
||||
headers = Headers(Host="example.com")
|
||||
replacements = headers.replace(r"Host: ", "X-Host ")
|
||||
assert replacements == 0
|
||||
assert headers["Host"] == "example.com"
|
||||
|
@ -166,3 +166,7 @@ class TestSerializable:
|
||||
a.set_state(1)
|
||||
assert a.i == 1
|
||||
assert b.i == 42
|
||||
|
||||
|
||||
def test_safe_subn():
|
||||
assert utils.safe_subn("foo", u"bar", "\xc2foo")
|
||||
|
Loading…
Reference in New Issue
Block a user