mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
Merge pull request #1511 from arjun23496/count_in_replace
Fixes #1495 - Added count argument for replacing contents in body
This commit is contained in:
commit
b4b2e5fd34
@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
|
||||
else:
|
||||
return super(Headers, self).items()
|
||||
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
def replace(self, pattern, repl, flags=0, count=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in each "name: value"
|
||||
header line.
|
||||
@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
|
||||
repl = strutils.escaped_str_to_bytes(repl)
|
||||
pattern = re.compile(pattern, flags)
|
||||
replacements = 0
|
||||
|
||||
flag_count = count > 0
|
||||
fields = []
|
||||
for name, value in self.fields:
|
||||
line, n = pattern.subn(repl, name + b": " + value)
|
||||
line, n = pattern.subn(repl, name + b": " + value, count=count)
|
||||
try:
|
||||
name, value = line.split(b": ", 1)
|
||||
except ValueError:
|
||||
@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
|
||||
pass
|
||||
else:
|
||||
replacements += n
|
||||
if flag_count:
|
||||
count -= n
|
||||
if count == 0:
|
||||
break
|
||||
fields.append((name, value))
|
||||
self.fields = tuple(fields)
|
||||
return replacements
|
||||
|
@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
|
||||
if "content-encoding" not in self.headers:
|
||||
raise ValueError("Invalid content encoding {}".format(repr(e)))
|
||||
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
def replace(self, pattern, repl, flags=0, count=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in both the headers
|
||||
and the body of the message. Encoded body will be decoded
|
||||
@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
|
||||
replacements = 0
|
||||
if self.content:
|
||||
self.content, replacements = re.subn(
|
||||
pattern, repl, self.content, flags=flags
|
||||
pattern, repl, self.content, flags=flags, count=count
|
||||
)
|
||||
replacements += self.headers.replace(pattern, repl, flags)
|
||||
replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
|
||||
return replacements
|
||||
|
||||
# Legacy
|
||||
|
@ -80,7 +80,7 @@ class Request(message.Message):
|
||||
self.method, hostport, path
|
||||
)
|
||||
|
||||
def replace(self, pattern, repl, flags=0):
|
||||
def replace(self, pattern, repl, flags=0, count=0):
|
||||
"""
|
||||
Replaces a regular expression pattern with repl in the headers, the
|
||||
request path and the body of the request. Encoded content will be
|
||||
@ -94,9 +94,9 @@ class Request(message.Message):
|
||||
if isinstance(repl, six.text_type):
|
||||
repl = strutils.escaped_str_to_bytes(repl)
|
||||
|
||||
c = super(Request, self).replace(pattern, repl, flags)
|
||||
c = super(Request, self).replace(pattern, repl, flags, count)
|
||||
self.path, pc = re.subn(
|
||||
pattern, repl, self.data.path, flags=flags
|
||||
pattern, repl, self.data.path, flags=flags, count=count
|
||||
)
|
||||
c += pc
|
||||
return c
|
||||
|
@ -75,6 +75,11 @@ class TestHeaders(object):
|
||||
assert replacements == 0
|
||||
assert headers["Host"] == "example.com"
|
||||
|
||||
def test_replace_with_count(self):
|
||||
headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
|
||||
replacements = headers.replace("foo", "bar", count=1)
|
||||
assert replacements == 1
|
||||
|
||||
|
||||
def test_parse_content_type():
|
||||
p = parse_content_type
|
||||
|
@ -99,6 +99,16 @@ class TestMessage(object):
|
||||
def test_http_version(self):
|
||||
_test_decoded_attr(tresp(), "http_version")
|
||||
|
||||
def test_replace(self):
|
||||
r = tresp()
|
||||
r.content = b"foofootoo"
|
||||
r.replace(b"foo", "gg")
|
||||
assert r.content == b"ggggtoo"
|
||||
|
||||
r.content = b"foofootoo"
|
||||
r.replace(b"foo", "gg", count=1)
|
||||
assert r.content == b"ggfootoo"
|
||||
|
||||
|
||||
class TestMessageContentEncoding(object):
|
||||
def test_simple(self):
|
||||
|
@ -26,6 +26,16 @@ class TestRequestCore(object):
|
||||
request.host = None
|
||||
assert repr(request) == "Request(GET /path)"
|
||||
|
||||
def replace(self):
|
||||
r = treq()
|
||||
r.path = b"foobarfoo"
|
||||
r.replace(b"foo", "bar")
|
||||
assert r.path == b"barbarbar"
|
||||
|
||||
r.path = b"foobarfoo"
|
||||
r.replace(b"foo", "bar", count=1)
|
||||
assert r.path == b"barbarfoo"
|
||||
|
||||
def test_first_line_format(self):
|
||||
_test_passthrough_attr(treq(), "first_line_format")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user