mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Remove check argument to serve() methods.
Refactoring means we can now do this without a callback. Also introduce the maximum_length method that estimates the max possible message length.
This commit is contained in:
parent
06864e5a1b
commit
ac5aacce44
@ -129,11 +129,16 @@ def _preview(is_request):
|
|||||||
|
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
args["pauses"] = r.preview_safe()
|
args["pauses"] = r.preview_safe()
|
||||||
|
|
||||||
|
c = app.config["pathod"].check_policy(r)
|
||||||
|
if c:
|
||||||
|
args["error"] = c
|
||||||
|
return render(template, False, **args)
|
||||||
|
|
||||||
if is_request:
|
if is_request:
|
||||||
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy, host="example.com")
|
r.serve(app.config["pathod"].request_settings, s, host="example.com")
|
||||||
else:
|
else:
|
||||||
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy)
|
r.serve(app.config["pathod"].request_settings, s)
|
||||||
|
|
||||||
args["output"] = utils.escape_unprintables(s.getvalue())
|
args["output"] = utils.escape_unprintables(s.getvalue())
|
||||||
return render(template, False, **args)
|
return render(template, False, **args)
|
||||||
|
@ -557,18 +557,14 @@ class Message:
|
|||||||
self.actions = [i for i in self.actions if not isinstance(i, PauseAt)]
|
self.actions = [i for i in self.actions if not isinstance(i, PauseAt)]
|
||||||
return pauses
|
return pauses
|
||||||
|
|
||||||
def effective_length(self, settings, request_host):
|
def maximum_length(self, settings, request_host):
|
||||||
"""
|
"""
|
||||||
Calculate the length of the base message with all applied actions.
|
Calculate the maximum length of the base message with all applied actions.
|
||||||
"""
|
"""
|
||||||
# Order matters here, and must match the order of application in
|
|
||||||
# write_values.
|
|
||||||
l = self.length(settings, request_host)
|
l = self.length(settings, request_host)
|
||||||
for i in reversed(self.ready_actions(settings, request_host)):
|
for i in self.actions:
|
||||||
if i[1] == "disconnect":
|
if isinstance(i, InjectAt):
|
||||||
return i[0]
|
l += len(i.value.get_generator(settings))
|
||||||
elif i[1] == "inject":
|
|
||||||
l += len(i[2])
|
|
||||||
return l
|
return l
|
||||||
|
|
||||||
def headervals(self, settings, request_host):
|
def headervals(self, settings, request_host):
|
||||||
@ -609,15 +605,10 @@ class Message:
|
|||||||
actions.reverse()
|
actions.reverse()
|
||||||
return [i.intermediate(settings) for i in actions]
|
return [i.intermediate(settings) for i in actions]
|
||||||
|
|
||||||
def serve(self, settings, fp, check, request_host):
|
def serve(self, settings, fp, request_host):
|
||||||
"""
|
"""
|
||||||
fp: The file pointer to write to.
|
fp: The file pointer to write to.
|
||||||
|
|
||||||
check: A function called with the effective actions (after random
|
|
||||||
values have been calculated). If it returns False service proceeds,
|
|
||||||
otherwise the return is treated as an error message to be sent to
|
|
||||||
the client, and service stops.
|
|
||||||
|
|
||||||
request_host: If this a request, this is the connecting host. If
|
request_host: If this a request, this is the connecting host. If
|
||||||
None, we assume it's a response. Used to decide what standard
|
None, we assume it's a response. Used to decide what standard
|
||||||
modifications to make if raw is not set.
|
modifications to make if raw is not set.
|
||||||
@ -636,15 +627,6 @@ class Message:
|
|||||||
vals.append(self.body)
|
vals.append(self.body)
|
||||||
vals.reverse()
|
vals.reverse()
|
||||||
actions = self.ready_actions(settings, request_host)
|
actions = self.ready_actions(settings, request_host)
|
||||||
if check:
|
|
||||||
ret = check(self, actions)
|
|
||||||
if ret:
|
|
||||||
err = PathodErrorResponse(ret)
|
|
||||||
err.serve(settings, fp)
|
|
||||||
return dict(
|
|
||||||
disconnect = True,
|
|
||||||
error = ret
|
|
||||||
)
|
|
||||||
|
|
||||||
disconnect = write_values(fp, vals, actions[:])
|
disconnect = write_values(fp, vals, actions[:])
|
||||||
duration = time.time() - started
|
duration = time.time() - started
|
||||||
@ -751,8 +733,8 @@ class CraftedRequest(Request):
|
|||||||
for i in tokens:
|
for i in tokens:
|
||||||
i.accept(settings, self)
|
i.accept(settings, self)
|
||||||
|
|
||||||
def serve(self, settings, fp, check, host):
|
def serve(self, settings, fp, host):
|
||||||
d = Request.serve(self, settings, fp, check, host)
|
d = Request.serve(self, settings, fp, host)
|
||||||
d["spec"] = self.spec
|
d["spec"] = self.spec
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -764,8 +746,8 @@ class CraftedResponse(Response):
|
|||||||
for i in tokens:
|
for i in tokens:
|
||||||
i.accept(settings, self)
|
i.accept(settings, self)
|
||||||
|
|
||||||
def serve(self, settings, fp, check):
|
def serve(self, settings, fp):
|
||||||
d = Response.serve(self, settings, fp, check, None)
|
d = Response.serve(self, settings, fp, None)
|
||||||
d["spec"] = self.spec
|
d["spec"] = self.spec
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@ -780,8 +762,8 @@ class PathodErrorResponse(Response):
|
|||||||
Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")),
|
Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")),
|
||||||
]
|
]
|
||||||
|
|
||||||
def serve(self, settings, fp, check=None):
|
def serve(self, settings, fp):
|
||||||
d = Response.serve(self, settings, fp, check, None)
|
d = Response.serve(self, settings, fp, None)
|
||||||
d["internal"] = True
|
d["internal"] = True
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class Pathoc(tcp.TCPClient):
|
|||||||
language.FileAccessDenied.
|
language.FileAccessDenied.
|
||||||
"""
|
"""
|
||||||
r = language.parse_request(self.settings, spec)
|
r = language.parse_request(self.settings, spec)
|
||||||
ret = r.serve(self.settings, self.wfile, None, self.host)
|
ret = r.serve(self.settings, self.wfile, self.host)
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
return http.read_response(self.rfile, r.method, None)
|
return http.read_response(self.rfile, r.method, None)
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class Pathoc(tcp.TCPClient):
|
|||||||
if showresp:
|
if showresp:
|
||||||
self.rfile.start_log()
|
self.rfile.start_log()
|
||||||
try:
|
try:
|
||||||
req = r.serve(self.settings, self.wfile, None, self.host)
|
req = r.serve(self.settings, self.wfile, self.host)
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
resp = http.read_response(self.rfile, r.method, None)
|
resp = http.read_response(self.rfile, r.method, None)
|
||||||
except http.HttpError, v:
|
except http.HttpError, v:
|
||||||
|
@ -18,7 +18,17 @@ class PathodHandler(tcp.BaseHandler):
|
|||||||
self.sni = connection.get_servername()
|
self.sni = connection.get_servername()
|
||||||
|
|
||||||
def serve_crafted(self, crafted, request_log):
|
def serve_crafted(self, crafted, request_log):
|
||||||
response_log = crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
|
c = self.server.check_policy(crafted)
|
||||||
|
if c:
|
||||||
|
err = language.PathodErrorResponse(c)
|
||||||
|
err.serve(self.server.request_settings, self.wfile)
|
||||||
|
log = dict(
|
||||||
|
type = "error",
|
||||||
|
msg = c
|
||||||
|
)
|
||||||
|
return False, log
|
||||||
|
|
||||||
|
response_log = crafted.serve(self.server.request_settings, self.wfile)
|
||||||
log = dict(
|
log = dict(
|
||||||
type = "crafted",
|
type = "crafted",
|
||||||
request=request_log,
|
request=request_log,
|
||||||
@ -96,7 +106,7 @@ class PathodHandler(tcp.BaseHandler):
|
|||||||
return self.serve_crafted(crafted, request_log)
|
return self.serve_crafted(crafted, request_log)
|
||||||
elif self.server.noweb:
|
elif self.server.noweb:
|
||||||
crafted = language.PathodErrorResponse("Access Denied")
|
crafted = language.PathodErrorResponse("Access Denied")
|
||||||
crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
|
crafted.serve(self.server.request_settings, self.wfile)
|
||||||
return False, dict(type = "error", msg="Access denied: web interface disabled")
|
return False, dict(type = "error", msg="Access denied: web interface disabled")
|
||||||
else:
|
else:
|
||||||
self.info("app: %s %s"%(method, path))
|
self.info("app: %s %s"%(method, path))
|
||||||
@ -205,13 +215,13 @@ class Pathod(tcp.TCPServer):
|
|||||||
raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v)))
|
raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v)))
|
||||||
self.anchors.append((arex, i[1]))
|
self.anchors.append((arex, i[1]))
|
||||||
|
|
||||||
def check_policy(self, req, actions):
|
def check_policy(self, req):
|
||||||
"""
|
"""
|
||||||
A policy check that verifies the request size is withing limits.
|
A policy check that verifies the request size is withing limits.
|
||||||
"""
|
"""
|
||||||
if self.sizelimit and req.effective_length({}, None) > self.sizelimit:
|
if self.sizelimit and req.maximum_length({}, None) > self.sizelimit:
|
||||||
return "Response too large."
|
return "Response too large."
|
||||||
if self.nohang and any([i[1] == "pause" for i in actions]):
|
if self.nohang and any([isinstance(i, language.PauseAt) for i in req.actions]):
|
||||||
return "Pauses have been disabled."
|
return "Pauses have been disabled."
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -279,7 +279,7 @@ class TestInject:
|
|||||||
def test_serve(self):
|
def test_serve(self):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:i0,'foo'")
|
r = language.parse_response({}, "400:i0,'foo'")
|
||||||
assert r.serve({}, s, None)
|
assert r.serve({}, s)
|
||||||
|
|
||||||
def test_spec(self):
|
def test_spec(self):
|
||||||
e = language.InjectAt.expr()
|
e = language.InjectAt.expr()
|
||||||
@ -344,7 +344,7 @@ class TestParseRequest:
|
|||||||
def test_render(self):
|
def test_render(self):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_request({}, "GET:'/foo'")
|
r = language.parse_request({}, "GET:'/foo'")
|
||||||
assert r.serve({}, s, None, "foo.com")
|
assert r.serve({}, s, "foo.com")
|
||||||
|
|
||||||
def test_str(self):
|
def test_str(self):
|
||||||
r = language.parse_request({}, 'GET:"/foo"')
|
r = language.parse_request({}, 'GET:"/foo"')
|
||||||
@ -479,15 +479,15 @@ class TestWriteValues:
|
|||||||
def test_write_values_after(self):
|
def test_write_values_after(self):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:da")
|
r = language.parse_response({}, "400:da")
|
||||||
r.serve({}, s, None)
|
r.serve({}, s)
|
||||||
|
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:pa,0")
|
r = language.parse_response({}, "400:pa,0")
|
||||||
r.serve({}, s, None)
|
r.serve({}, s)
|
||||||
|
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:ia,'xx'")
|
r = language.parse_response({}, "400:ia,'xx'")
|
||||||
r.serve({}, s, None)
|
r.serve({}, s)
|
||||||
assert s.getvalue().endswith('xx')
|
assert s.getvalue().endswith('xx')
|
||||||
|
|
||||||
|
|
||||||
@ -511,29 +511,22 @@ class TestResponse:
|
|||||||
assert r.body[:]
|
assert r.body[:]
|
||||||
assert str(r)
|
assert str(r)
|
||||||
|
|
||||||
def test_checkfunc(self):
|
|
||||||
s = cStringIO.StringIO()
|
|
||||||
r = language.parse_response({}, "400:b@100k")
|
|
||||||
def check(req, acts):
|
|
||||||
return "errmsg"
|
|
||||||
assert r.serve({}, s, check=check)["error"] == "errmsg"
|
|
||||||
|
|
||||||
def test_render(self):
|
def test_render(self):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400'msg'")
|
r = language.parse_response({}, "400'msg'")
|
||||||
assert r.serve({}, s, None)
|
assert r.serve({}, s)
|
||||||
|
|
||||||
def test_raw(self):
|
def test_raw(self):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:b'foo'")
|
r = language.parse_response({}, "400:b'foo'")
|
||||||
r.serve({}, s, None)
|
r.serve({}, s)
|
||||||
v = s.getvalue()
|
v = s.getvalue()
|
||||||
assert "Content-Length" in v
|
assert "Content-Length" in v
|
||||||
assert "Date" in v
|
assert "Date" in v
|
||||||
|
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
r = language.parse_response({}, "400:b'foo':r")
|
r = language.parse_response({}, "400:b'foo':r")
|
||||||
r.serve({}, s, None)
|
r.serve({}, s)
|
||||||
v = s.getvalue()
|
v = s.getvalue()
|
||||||
assert not "Content-Length" in v
|
assert not "Content-Length" in v
|
||||||
assert not "Date" in v
|
assert not "Date" in v
|
||||||
@ -541,21 +534,18 @@ class TestResponse:
|
|||||||
def test_length(self):
|
def test_length(self):
|
||||||
def testlen(x):
|
def testlen(x):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
x.serve({}, s, None)
|
x.serve({}, s)
|
||||||
assert x.length({}, None) == len(s.getvalue())
|
assert x.length({}, None) == len(s.getvalue())
|
||||||
testlen(language.parse_response({}, "400'msg'"))
|
testlen(language.parse_response({}, "400'msg'"))
|
||||||
testlen(language.parse_response({}, "400'msg':h'foo'='bar'"))
|
testlen(language.parse_response({}, "400'msg':h'foo'='bar'"))
|
||||||
testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b"))
|
testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b"))
|
||||||
|
|
||||||
def test_effective_length(self):
|
def test_maximum_length(self):
|
||||||
l = [None]
|
|
||||||
def check(req, actions):
|
|
||||||
l[0] = req.effective_length({}, None)
|
|
||||||
|
|
||||||
def testlen(x, actions):
|
def testlen(x, actions):
|
||||||
s = cStringIO.StringIO()
|
s = cStringIO.StringIO()
|
||||||
x.serve({}, s, check)
|
m = x.maximum_length({}, None)
|
||||||
assert l[0] == len(s.getvalue())
|
x.serve({}, s)
|
||||||
|
assert m >= len(s.getvalue())
|
||||||
|
|
||||||
r = language.parse_response({}, "400'msg':b@100")
|
r = language.parse_response({}, "400'msg':b@100")
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class TestNohang(tutils.DaemonTests):
|
|||||||
r = self.get("200:p0,0")
|
r = self.get("200:p0,0")
|
||||||
assert r.status_code == 800
|
assert r.status_code == 800
|
||||||
l = self.d.last_log()
|
l = self.d.last_log()
|
||||||
assert "Pauses have been disabled" in l["response"]["error"]
|
assert "Pauses have been disabled" in l["msg"]
|
||||||
|
|
||||||
|
|
||||||
class TestHexdump(tutils.DaemonTests):
|
class TestHexdump(tutils.DaemonTests):
|
||||||
@ -77,7 +77,7 @@ class CommonTests(tutils.DaemonTests):
|
|||||||
r = self.get("200:b@1g")
|
r = self.get("200:b@1g")
|
||||||
assert r.status_code == 800
|
assert r.status_code == 800
|
||||||
l = self.d.last_log()
|
l = self.d.last_log()
|
||||||
assert "too large" in l["response"]["error"]
|
assert "too large" in l["msg"]
|
||||||
|
|
||||||
def test_preline(self):
|
def test_preline(self):
|
||||||
v = self.pathoc(r"get:'/p/200':i0,'\r\n'")
|
v = self.pathoc(r"get:'/p/200':i0,'\r\n'")
|
||||||
|
Loading…
Reference in New Issue
Block a user