Merge pull request #1511 from arjun23496/count_in_replace

Fixes #1495 - Added count argument for replacing contents in body
This commit is contained in:
Thomas Kriechbaumer 2016-08-31 13:49:03 +02:00 committed by GitHub
commit b4b2e5fd34
6 changed files with 38 additions and 9 deletions

View File

@ -158,7 +158,7 @@ class Headers(multidict.MultiDict):
else: else:
return super(Headers, self).items() return super(Headers, self).items()
def replace(self, pattern, repl, flags=0): def replace(self, pattern, repl, flags=0, count=0):
""" """
Replaces a regular expression pattern with repl in each "name: value" Replaces a regular expression pattern with repl in each "name: value"
header line. header line.
@ -172,10 +172,10 @@ class Headers(multidict.MultiDict):
repl = strutils.escaped_str_to_bytes(repl) repl = strutils.escaped_str_to_bytes(repl)
pattern = re.compile(pattern, flags) pattern = re.compile(pattern, flags)
replacements = 0 replacements = 0
flag_count = count > 0
fields = [] fields = []
for name, value in self.fields: for name, value in self.fields:
line, n = pattern.subn(repl, name + b": " + value) line, n = pattern.subn(repl, name + b": " + value, count=count)
try: try:
name, value = line.split(b": ", 1) name, value = line.split(b": ", 1)
except ValueError: except ValueError:
@ -184,6 +184,10 @@ class Headers(multidict.MultiDict):
pass pass
else: else:
replacements += n replacements += n
if flag_count:
count -= n
if count == 0:
break
fields.append((name, value)) fields.append((name, value))
self.fields = tuple(fields) self.fields = tuple(fields)
return replacements return replacements

View File

@ -260,7 +260,7 @@ class Message(basetypes.Serializable):
if "content-encoding" not in self.headers: if "content-encoding" not in self.headers:
raise ValueError("Invalid content encoding {}".format(repr(e))) raise ValueError("Invalid content encoding {}".format(repr(e)))
def replace(self, pattern, repl, flags=0): def replace(self, pattern, repl, flags=0, count=0):
""" """
Replaces a regular expression pattern with repl in both the headers Replaces a regular expression pattern with repl in both the headers
and the body of the message. Encoded body will be decoded and the body of the message. Encoded body will be decoded
@ -276,9 +276,9 @@ class Message(basetypes.Serializable):
replacements = 0 replacements = 0
if self.content: if self.content:
self.content, replacements = re.subn( self.content, replacements = re.subn(
pattern, repl, self.content, flags=flags pattern, repl, self.content, flags=flags, count=count
) )
replacements += self.headers.replace(pattern, repl, flags) replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
return replacements return replacements
# Legacy # Legacy

View File

@ -80,7 +80,7 @@ class Request(message.Message):
self.method, hostport, path self.method, hostport, path
) )
def replace(self, pattern, repl, flags=0): def replace(self, pattern, repl, flags=0, count=0):
""" """
Replaces a regular expression pattern with repl in the headers, the Replaces a regular expression pattern with repl in the headers, the
request path and the body of the request. Encoded content will be request path and the body of the request. Encoded content will be
@ -94,9 +94,9 @@ class Request(message.Message):
if isinstance(repl, six.text_type): if isinstance(repl, six.text_type):
repl = strutils.escaped_str_to_bytes(repl) repl = strutils.escaped_str_to_bytes(repl)
c = super(Request, self).replace(pattern, repl, flags) c = super(Request, self).replace(pattern, repl, flags, count)
self.path, pc = re.subn( self.path, pc = re.subn(
pattern, repl, self.data.path, flags=flags pattern, repl, self.data.path, flags=flags, count=count
) )
c += pc c += pc
return c return c

View File

@ -75,6 +75,11 @@ class TestHeaders(object):
assert replacements == 0 assert replacements == 0
assert headers["Host"] == "example.com" assert headers["Host"] == "example.com"
def test_replace_with_count(self):
headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
replacements = headers.replace("foo", "bar", count=1)
assert replacements == 1
def test_parse_content_type(): def test_parse_content_type():
p = parse_content_type p = parse_content_type

View File

@ -99,6 +99,16 @@ class TestMessage(object):
def test_http_version(self): def test_http_version(self):
_test_decoded_attr(tresp(), "http_version") _test_decoded_attr(tresp(), "http_version")
def test_replace(self):
r = tresp()
r.content = b"foofootoo"
r.replace(b"foo", "gg")
assert r.content == b"ggggtoo"
r.content = b"foofootoo"
r.replace(b"foo", "gg", count=1)
assert r.content == b"ggfootoo"
class TestMessageContentEncoding(object): class TestMessageContentEncoding(object):
def test_simple(self): def test_simple(self):

View File

@ -26,6 +26,16 @@ class TestRequestCore(object):
request.host = None request.host = None
assert repr(request) == "Request(GET /path)" assert repr(request) == "Request(GET /path)"
def replace(self):
r = treq()
r.path = b"foobarfoo"
r.replace(b"foo", "bar")
assert r.path == b"barbarbar"
r.path = b"foobarfoo"
r.replace(b"foo", "bar", count=1)
assert r.path == b"barbarfoo"
def test_first_line_format(self): def test_first_line_format(self):
_test_passthrough_attr(treq(), "first_line_format") _test_passthrough_attr(treq(), "first_line_format")