From 806aa0f41c7816b2859a6961939ed19499b73fe7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 2 Apr 2016 14:38:33 +0200 Subject: [PATCH] improve .replace() and move it into netlib --- mitmproxy/models/http.py | 40 -------------------------------- mitmproxy/utils.py | 9 ------- netlib/http/headers.py | 31 ++++++++++++++++++++++++- netlib/http/message.py | 19 +++++++++++++++ netlib/http/request.py | 17 ++++++++++++++ netlib/odict.py | 12 +--------- netlib/utils.py | 10 ++++++++ test/mitmproxy/test_utils.py | 4 ---- test/netlib/http/test_headers.py | 19 +++++++++++++++ test/netlib/test_utils.py | 4 ++++ 10 files changed, 100 insertions(+), 65 deletions(-) diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py index 4bba35f1b..11f466117 100644 --- a/mitmproxy/models/http.py +++ b/mitmproxy/models/http.py @@ -26,30 +26,6 @@ class MessageMixin(object): return self.content return encoding.decode(ce, self.content) - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the message. Encoded body will be decoded - before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - count = 0 - if self.content: - with decoded(self): - self.content, count = utils.safe_subn( - pattern, repl, self.content, *args, **kwargs - ) - fields = [] - for name, value in self.headers.fields: - name, c = utils.safe_subn(pattern, repl, name, *args, **kwargs) - count += c - value, c = utils.safe_subn(pattern, repl, value, *args, **kwargs) - count += c - fields.append([name, value]) - self.headers.fields = fields - return count - class HTTPRequest(MessageMixin, Request): @@ -186,22 +162,6 @@ class HTTPRequest(MessageMixin, Request): def set_auth(self, auth): self.data.headers.set_all("Proxy-Authorization", (auth,)) - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in the headers, the - request path and the body of the request. Encoded content will be - decoded before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - c = MessageMixin.replace(self, pattern, repl, *args, **kwargs) - self.path, pc = utils.safe_subn( - pattern, repl, self.path, *args, **kwargs - ) - c += pc - return c - - class HTTPResponse(MessageMixin, Response): """ diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py index 4bdd036e0..5fd062ea6 100644 --- a/mitmproxy/utils.py +++ b/mitmproxy/utils.py @@ -165,12 +165,3 @@ def parse_size(s): return int(s) * mult except ValueError: raise ValueError("Invalid size specification: %s" % s) - - -def safe_subn(pattern, repl, target, *args, **kwargs): - """ - There are Unicode conversion problems with re.subn. We try to smooth - that over by casting the pattern and replacement to strings. We really - need a better solution that is aware of the actual content ecoding. - """ - return re.subn(str(pattern), str(repl), target, *args, **kwargs) diff --git a/netlib/http/headers.py b/netlib/http/headers.py index bcb828da6..72739f900 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -6,6 +6,8 @@ See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ """ from __future__ import absolute_import, print_function, division +import re + try: from collections.abc import MutableMapping except ImportError: # pragma: no cover @@ -198,4 +200,31 @@ class Headers(MutableMapping, Serializable): @classmethod def from_state(cls, state): - return cls([list(field) for field in state]) \ No newline at end of file + return cls([list(field) for field in state]) + + @_always_byte_args + def replace(self, pattern, repl, flags=0): + """ + Replaces a regular expression pattern with repl in each "name: value" + header line. + + Returns: + The number of replacements made. + """ + pattern = re.compile(pattern, flags) + replacements = 0 + + fields = [] + for name, value in self.fields: + line, n = pattern.subn(repl, name + b": " + value) + try: + name, value = line.split(b": ", 1) + except ValueError: + # We get a ValueError if the replacement removed the ": " + # There's not much we can do about this, so we just keep the header as-is. + pass + else: + replacements += n + fields.append([name, value]) + self.fields = fields + return replacements diff --git a/netlib/http/message.py b/netlib/http/message.py index b265ac4ff..da9681a0b 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -175,6 +175,25 @@ class Message(utils.Serializable): self.headers["content-encoding"] = e return True + def replace(self, pattern, repl, flags=0): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the message. Encoded body will be decoded + before replacement, and re-encoded afterwards. + + Returns: + The number of replacements made. + """ + # TODO: Proper distinction between text and bytes. + replacements = 0 + if self.content: + with decoded(self): + self.content, replacements = utils.safe_subn( + pattern, repl, self.content, flags=flags + ) + replacements += self.headers.replace(pattern, repl, flags) + return replacements + # Legacy @property diff --git a/netlib/http/request.py b/netlib/http/request.py index 5bd2547ed..07a11969d 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -54,6 +54,23 @@ class Request(Message): self.method, hostport, path ) + def replace(self, pattern, repl, flags=0): + """ + Replaces a regular expression pattern with repl in the headers, the + request path and the body of the request. Encoded content will be + decoded before replacement, and re-encoded afterwards. + + Returns: + The number of replacements made. + """ + # TODO: Proper distinction between text and bytes. + c = super(Request, self).replace(pattern, repl, flags) + self.path, pc = utils.safe_subn( + pattern, repl, self.path, flags=flags + ) + c += pc + return c + @property def first_line_format(self): """ diff --git a/netlib/odict.py b/netlib/odict.py index 1e6e381af..461192f77 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,18 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import re import copy import six -from .utils import Serializable - - -def safe_subn(pattern, repl, target, *args, **kwargs): - """ - There are Unicode conversion problems with re.subn. We try to smooth - that over by casting the pattern and replacement to strings. We really - need a better solution that is aware of the actual content ecoding. - """ - return re.subn(str(pattern), str(repl), target, *args, **kwargs) +from .utils import Serializable, safe_subn class ODict(Serializable): diff --git a/netlib/utils.py b/netlib/utils.py index 09be29d92..dda768086 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -414,8 +414,18 @@ def http2_read_raw_frame(rfile): body = rfile.safe_read(length) return [header, body] + def http2_read_frame(rfile): header, body = http2_read_raw_frame(rfile) frame, length = hyperframe.frame.Frame.parse_frame_header(header) frame.parse_body(memoryview(body)) return frame + + +def safe_subn(pattern, repl, target, *args, **kwargs): + """ + There are Unicode conversion problems with re.subn. We try to smooth + that over by casting the pattern and replacement to strings. We really + need a better solution that is aware of the actual content ecoding. + """ + return re.subn(str(pattern), str(repl), target, *args, **kwargs) diff --git a/test/mitmproxy/test_utils.py b/test/mitmproxy/test_utils.py index ae6369ae1..db7dec4ab 100644 --- a/test/mitmproxy/test_utils.py +++ b/test/mitmproxy/test_utils.py @@ -99,7 +99,3 @@ def test_parse_size(): assert utils.parse_size("1g") == 1024**3 tutils.raises(ValueError, utils.parse_size, "1f") tutils.raises(ValueError, utils.parse_size, "ak") - - -def test_safe_subn(): - assert utils.safe_subn("foo", u"bar", "\xc2foo") diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index d50fee3e4..8c1db9dc5 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -150,3 +150,22 @@ class TestHeaders(object): assert headers != headers2 headers2.set_state(headers.get_state()) assert headers == headers2 + + def test_replace_simple(self): + headers = Headers(Host="example.com", Accept="text/plain") + replacements = headers.replace("Host: ", "X-Host: ") + assert replacements == 1 + assert headers["X-Host"] == "example.com" + assert "Host" not in headers + assert headers["Accept"] == "text/plain" + + def test_replace_multi(self): + headers = self._2host() + headers.replace(r"Host: example\.com", r"Host: example.de") + assert headers.get_all("Host") == ["example.de", "example.org"] + + def test_replace_remove_spacer(self): + headers = Headers(Host="example.com") + replacements = headers.replace(r"Host: ", "X-Host ") + assert replacements == 0 + assert headers["Host"] == "example.com" diff --git a/test/netlib/test_utils.py b/test/netlib/test_utils.py index fcb63eb2f..be2a59fcf 100644 --- a/test/netlib/test_utils.py +++ b/test/netlib/test_utils.py @@ -166,3 +166,7 @@ class TestSerializable: a.set_state(1) assert a.i == 1 assert b.i == 42 + + +def test_safe_subn(): + assert utils.safe_subn("foo", u"bar", "\xc2foo")