mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
Merge pull request #1511 from arjun23496/count_in_replace
Fixes #1495 - Added count argument for replacing contents in body
This commit is contained in:
commit
b4b2e5fd34
@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
|
|||||||
else:
|
else:
|
||||||
return super(Headers, self).items()
|
return super(Headers, self).items()
|
||||||
|
|
||||||
def replace(self, pattern, repl, flags=0):
|
def replace(self, pattern, repl, flags=0, count=0):
|
||||||
"""
|
"""
|
||||||
Replaces a regular expression pattern with repl in each "name: value"
|
Replaces a regular expression pattern with repl in each "name: value"
|
||||||
header line.
|
header line.
|
||||||
@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
|
|||||||
repl = strutils.escaped_str_to_bytes(repl)
|
repl = strutils.escaped_str_to_bytes(repl)
|
||||||
pattern = re.compile(pattern, flags)
|
pattern = re.compile(pattern, flags)
|
||||||
replacements = 0
|
replacements = 0
|
||||||
|
flag_count = count > 0
|
||||||
fields = []
|
fields = []
|
||||||
for name, value in self.fields:
|
for name, value in self.fields:
|
||||||
line, n = pattern.subn(repl, name + b": " + value)
|
line, n = pattern.subn(repl, name + b": " + value, count=count)
|
||||||
try:
|
try:
|
||||||
name, value = line.split(b": ", 1)
|
name, value = line.split(b": ", 1)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
replacements += n
|
replacements += n
|
||||||
|
if flag_count:
|
||||||
|
count -= n
|
||||||
|
if count == 0:
|
||||||
|
break
|
||||||
fields.append((name, value))
|
fields.append((name, value))
|
||||||
self.fields = tuple(fields)
|
self.fields = tuple(fields)
|
||||||
return replacements
|
return replacements
|
||||||
|
@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
|
|||||||
if "content-encoding" not in self.headers:
|
if "content-encoding" not in self.headers:
|
||||||
raise ValueError("Invalid content encoding {}".format(repr(e)))
|
raise ValueError("Invalid content encoding {}".format(repr(e)))
|
||||||
|
|
||||||
def replace(self, pattern, repl, flags=0):
|
def replace(self, pattern, repl, flags=0, count=0):
|
||||||
"""
|
"""
|
||||||
Replaces a regular expression pattern with repl in both the headers
|
Replaces a regular expression pattern with repl in both the headers
|
||||||
and the body of the message. Encoded body will be decoded
|
and the body of the message. Encoded body will be decoded
|
||||||
@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
|
|||||||
replacements = 0
|
replacements = 0
|
||||||
if self.content:
|
if self.content:
|
||||||
self.content, replacements = re.subn(
|
self.content, replacements = re.subn(
|
||||||
pattern, repl, self.content, flags=flags
|
pattern, repl, self.content, flags=flags, count=count
|
||||||
)
|
)
|
||||||
replacements += self.headers.replace(pattern, repl, flags)
|
replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
|
||||||
return replacements
|
return replacements
|
||||||
|
|
||||||
# Legacy
|
# Legacy
|
||||||
|
@ -80,7 +80,7 @@ class Request(message.Message):
|
|||||||
self.method, hostport, path
|
self.method, hostport, path
|
||||||
)
|
)
|
||||||
|
|
||||||
def replace(self, pattern, repl, flags=0):
|
def replace(self, pattern, repl, flags=0, count=0):
|
||||||
"""
|
"""
|
||||||
Replaces a regular expression pattern with repl in the headers, the
|
Replaces a regular expression pattern with repl in the headers, the
|
||||||
request path and the body of the request. Encoded content will be
|
request path and the body of the request. Encoded content will be
|
||||||
@ -94,9 +94,9 @@ class Request(message.Message):
|
|||||||
if isinstance(repl, six.text_type):
|
if isinstance(repl, six.text_type):
|
||||||
repl = strutils.escaped_str_to_bytes(repl)
|
repl = strutils.escaped_str_to_bytes(repl)
|
||||||
|
|
||||||
c = super(Request, self).replace(pattern, repl, flags)
|
c = super(Request, self).replace(pattern, repl, flags, count)
|
||||||
self.path, pc = re.subn(
|
self.path, pc = re.subn(
|
||||||
pattern, repl, self.data.path, flags=flags
|
pattern, repl, self.data.path, flags=flags, count=count
|
||||||
)
|
)
|
||||||
c += pc
|
c += pc
|
||||||
return c
|
return c
|
||||||
|
@ -75,6 +75,11 @@ class TestHeaders(object):
|
|||||||
assert replacements == 0
|
assert replacements == 0
|
||||||
assert headers["Host"] == "example.com"
|
assert headers["Host"] == "example.com"
|
||||||
|
|
||||||
|
def test_replace_with_count(self):
|
||||||
|
headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
|
||||||
|
replacements = headers.replace("foo", "bar", count=1)
|
||||||
|
assert replacements == 1
|
||||||
|
|
||||||
|
|
||||||
def test_parse_content_type():
|
def test_parse_content_type():
|
||||||
p = parse_content_type
|
p = parse_content_type
|
||||||
|
@ -99,6 +99,16 @@ class TestMessage(object):
|
|||||||
def test_http_version(self):
|
def test_http_version(self):
|
||||||
_test_decoded_attr(tresp(), "http_version")
|
_test_decoded_attr(tresp(), "http_version")
|
||||||
|
|
||||||
|
def test_replace(self):
|
||||||
|
r = tresp()
|
||||||
|
r.content = b"foofootoo"
|
||||||
|
r.replace(b"foo", "gg")
|
||||||
|
assert r.content == b"ggggtoo"
|
||||||
|
|
||||||
|
r.content = b"foofootoo"
|
||||||
|
r.replace(b"foo", "gg", count=1)
|
||||||
|
assert r.content == b"ggfootoo"
|
||||||
|
|
||||||
|
|
||||||
class TestMessageContentEncoding(object):
|
class TestMessageContentEncoding(object):
|
||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
|
@ -26,6 +26,16 @@ class TestRequestCore(object):
|
|||||||
request.host = None
|
request.host = None
|
||||||
assert repr(request) == "Request(GET /path)"
|
assert repr(request) == "Request(GET /path)"
|
||||||
|
|
||||||
|
def replace(self):
|
||||||
|
r = treq()
|
||||||
|
r.path = b"foobarfoo"
|
||||||
|
r.replace(b"foo", "bar")
|
||||||
|
assert r.path == b"barbarbar"
|
||||||
|
|
||||||
|
r.path = b"foobarfoo"
|
||||||
|
r.replace(b"foo", "bar", count=1)
|
||||||
|
assert r.path == b"barbarfoo"
|
||||||
|
|
||||||
def test_first_line_format(self):
|
def test_first_line_format(self):
|
||||||
_test_passthrough_attr(treq(), "first_line_format")
|
_test_passthrough_attr(treq(), "first_line_format")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user