Merge pull request #1511 from arjun23496/count_in_replace

Fixes #1495 - Added count argument for replacing contents in body
This commit is contained in:
Thomas Kriechbaumer 2016-08-31 13:49:03 +02:00 committed by GitHub
commit b4b2e5fd34
6 changed files with 38 additions and 9 deletions

View File

@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
else:
return super(Headers, self).items()
def replace(self, pattern, repl, flags=0):
def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
header line.
@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
repl = strutils.escaped_str_to_bytes(repl)
pattern = re.compile(pattern, flags)
replacements = 0
flag_count = count > 0
fields = []
for name, value in self.fields:
line, n = pattern.subn(repl, name + b": " + value)
line, n = pattern.subn(repl, name + b": " + value, count=count)
try:
name, value = line.split(b": ", 1)
except ValueError:
@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
pass
else:
replacements += n
if flag_count:
count -= n
if count == 0:
break
fields.append((name, value))
self.fields = tuple(fields)
return replacements

View File

@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
if "content-encoding" not in self.headers:
raise ValueError("Invalid content encoding {}".format(repr(e)))
def replace(self, pattern, repl, flags=0):
def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded
@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
replacements = 0
if self.content:
self.content, replacements = re.subn(
pattern, repl, self.content, flags=flags
pattern, repl, self.content, flags=flags, count=count
)
replacements += self.headers.replace(pattern, repl, flags)
replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
return replacements
# Legacy

View File

@ -80,7 +80,7 @@ class Request(message.Message):
self.method, hostport, path
)
def replace(self, pattern, repl, flags=0):
def replace(self, pattern, repl, flags=0, count=0):
"""
Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be
@ -94,9 +94,9 @@ class Request(message.Message):
if isinstance(repl, six.text_type):
repl = strutils.escaped_str_to_bytes(repl)
c = super(Request, self).replace(pattern, repl, flags)
c = super(Request, self).replace(pattern, repl, flags, count)
self.path, pc = re.subn(
pattern, repl, self.data.path, flags=flags
pattern, repl, self.data.path, flags=flags, count=count
)
c += pc
return c

View File

@ -75,6 +75,11 @@ class TestHeaders(object):
assert replacements == 0
assert headers["Host"] == "example.com"
def test_replace_with_count(self):
headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
replacements = headers.replace("foo", "bar", count=1)
assert replacements == 1
def test_parse_content_type():
p = parse_content_type

View File

@ -99,6 +99,16 @@ class TestMessage(object):
def test_http_version(self):
_test_decoded_attr(tresp(), "http_version")
def test_replace(self):
r = tresp()
r.content = b"foofootoo"
r.replace(b"foo", "gg")
assert r.content == b"ggggtoo"
r.content = b"foofootoo"
r.replace(b"foo", "gg", count=1)
assert r.content == b"ggfootoo"
class TestMessageContentEncoding(object):
def test_simple(self):

View File

@ -26,6 +26,16 @@ class TestRequestCore(object):
request.host = None
assert repr(request) == "Request(GET /path)"
def replace(self):
r = treq()
r.path = b"foobarfoo"
r.replace(b"foo", "bar")
assert r.path == b"barbarbar"
r.path = b"foobarfoo"
r.replace(b"foo", "bar", count=1)
assert r.path == b"barbarfoo"
def test_first_line_format(self):
_test_passthrough_attr(treq(), "first_line_format")