mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 00:01:36 +00:00
Remove check argument to serve() methods.
Refactoring means we can now do this without a callback. Also introduce the maximum_length method that estimates the max possible message length.
This commit is contained in:
parent
06864e5a1b
commit
ac5aacce44
@ -129,11 +129,16 @@ def _preview(is_request):
|
||||
|
||||
s = cStringIO.StringIO()
|
||||
args["pauses"] = r.preview_safe()
|
||||
|
||||
c = app.config["pathod"].check_policy(r)
|
||||
if c:
|
||||
args["error"] = c
|
||||
return render(template, False, **args)
|
||||
|
||||
if is_request:
|
||||
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy, host="example.com")
|
||||
r.serve(app.config["pathod"].request_settings, s, host="example.com")
|
||||
else:
|
||||
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy)
|
||||
r.serve(app.config["pathod"].request_settings, s)
|
||||
|
||||
args["output"] = utils.escape_unprintables(s.getvalue())
|
||||
return render(template, False, **args)
|
||||
|
@ -557,18 +557,14 @@ class Message:
|
||||
self.actions = [i for i in self.actions if not isinstance(i, PauseAt)]
|
||||
return pauses
|
||||
|
||||
def effective_length(self, settings, request_host):
|
||||
def maximum_length(self, settings, request_host):
|
||||
"""
|
||||
Calculate the length of the base message with all applied actions.
|
||||
Calculate the maximum length of the base message with all applied actions.
|
||||
"""
|
||||
# Order matters here, and must match the order of application in
|
||||
# write_values.
|
||||
l = self.length(settings, request_host)
|
||||
for i in reversed(self.ready_actions(settings, request_host)):
|
||||
if i[1] == "disconnect":
|
||||
return i[0]
|
||||
elif i[1] == "inject":
|
||||
l += len(i[2])
|
||||
for i in self.actions:
|
||||
if isinstance(i, InjectAt):
|
||||
l += len(i.value.get_generator(settings))
|
||||
return l
|
||||
|
||||
def headervals(self, settings, request_host):
|
||||
@ -609,15 +605,10 @@ class Message:
|
||||
actions.reverse()
|
||||
return [i.intermediate(settings) for i in actions]
|
||||
|
||||
def serve(self, settings, fp, check, request_host):
|
||||
def serve(self, settings, fp, request_host):
|
||||
"""
|
||||
fp: The file pointer to write to.
|
||||
|
||||
check: A function called with the effective actions (after random
|
||||
values have been calculated). If it returns False service proceeds,
|
||||
otherwise the return is treated as an error message to be sent to
|
||||
the client, and service stops.
|
||||
|
||||
request_host: If this a request, this is the connecting host. If
|
||||
None, we assume it's a response. Used to decide what standard
|
||||
modifications to make if raw is not set.
|
||||
@ -636,15 +627,6 @@ class Message:
|
||||
vals.append(self.body)
|
||||
vals.reverse()
|
||||
actions = self.ready_actions(settings, request_host)
|
||||
if check:
|
||||
ret = check(self, actions)
|
||||
if ret:
|
||||
err = PathodErrorResponse(ret)
|
||||
err.serve(settings, fp)
|
||||
return dict(
|
||||
disconnect = True,
|
||||
error = ret
|
||||
)
|
||||
|
||||
disconnect = write_values(fp, vals, actions[:])
|
||||
duration = time.time() - started
|
||||
@ -751,8 +733,8 @@ class CraftedRequest(Request):
|
||||
for i in tokens:
|
||||
i.accept(settings, self)
|
||||
|
||||
def serve(self, settings, fp, check, host):
|
||||
d = Request.serve(self, settings, fp, check, host)
|
||||
def serve(self, settings, fp, host):
|
||||
d = Request.serve(self, settings, fp, host)
|
||||
d["spec"] = self.spec
|
||||
return d
|
||||
|
||||
@ -764,8 +746,8 @@ class CraftedResponse(Response):
|
||||
for i in tokens:
|
||||
i.accept(settings, self)
|
||||
|
||||
def serve(self, settings, fp, check):
|
||||
d = Response.serve(self, settings, fp, check, None)
|
||||
def serve(self, settings, fp):
|
||||
d = Response.serve(self, settings, fp, None)
|
||||
d["spec"] = self.spec
|
||||
return d
|
||||
|
||||
@ -780,8 +762,8 @@ class PathodErrorResponse(Response):
|
||||
Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")),
|
||||
]
|
||||
|
||||
def serve(self, settings, fp, check=None):
|
||||
d = Response.serve(self, settings, fp, check, None)
|
||||
def serve(self, settings, fp):
|
||||
d = Response.serve(self, settings, fp, None)
|
||||
d["internal"] = True
|
||||
return d
|
||||
|
||||
|
@ -22,7 +22,7 @@ class Pathoc(tcp.TCPClient):
|
||||
language.FileAccessDenied.
|
||||
"""
|
||||
r = language.parse_request(self.settings, spec)
|
||||
ret = r.serve(self.settings, self.wfile, None, self.host)
|
||||
ret = r.serve(self.settings, self.wfile, self.host)
|
||||
self.wfile.flush()
|
||||
return http.read_response(self.rfile, r.method, None)
|
||||
|
||||
@ -68,7 +68,7 @@ class Pathoc(tcp.TCPClient):
|
||||
if showresp:
|
||||
self.rfile.start_log()
|
||||
try:
|
||||
req = r.serve(self.settings, self.wfile, None, self.host)
|
||||
req = r.serve(self.settings, self.wfile, self.host)
|
||||
self.wfile.flush()
|
||||
resp = http.read_response(self.rfile, r.method, None)
|
||||
except http.HttpError, v:
|
||||
|
@ -18,7 +18,17 @@ class PathodHandler(tcp.BaseHandler):
|
||||
self.sni = connection.get_servername()
|
||||
|
||||
def serve_crafted(self, crafted, request_log):
|
||||
response_log = crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
|
||||
c = self.server.check_policy(crafted)
|
||||
if c:
|
||||
err = language.PathodErrorResponse(c)
|
||||
err.serve(self.server.request_settings, self.wfile)
|
||||
log = dict(
|
||||
type = "error",
|
||||
msg = c
|
||||
)
|
||||
return False, log
|
||||
|
||||
response_log = crafted.serve(self.server.request_settings, self.wfile)
|
||||
log = dict(
|
||||
type = "crafted",
|
||||
request=request_log,
|
||||
@ -96,7 +106,7 @@ class PathodHandler(tcp.BaseHandler):
|
||||
return self.serve_crafted(crafted, request_log)
|
||||
elif self.server.noweb:
|
||||
crafted = language.PathodErrorResponse("Access Denied")
|
||||
crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
|
||||
crafted.serve(self.server.request_settings, self.wfile)
|
||||
return False, dict(type = "error", msg="Access denied: web interface disabled")
|
||||
else:
|
||||
self.info("app: %s %s"%(method, path))
|
||||
@ -205,13 +215,13 @@ class Pathod(tcp.TCPServer):
|
||||
raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v)))
|
||||
self.anchors.append((arex, i[1]))
|
||||
|
||||
def check_policy(self, req, actions):
|
||||
def check_policy(self, req):
|
||||
"""
|
||||
A policy check that verifies the request size is withing limits.
|
||||
"""
|
||||
if self.sizelimit and req.effective_length({}, None) > self.sizelimit:
|
||||
if self.sizelimit and req.maximum_length({}, None) > self.sizelimit:
|
||||
return "Response too large."
|
||||
if self.nohang and any([i[1] == "pause" for i in actions]):
|
||||
if self.nohang and any([isinstance(i, language.PauseAt) for i in req.actions]):
|
||||
return "Pauses have been disabled."
|
||||
return False
|
||||
|
||||
|
@ -279,7 +279,7 @@ class TestInject:
|
||||
def test_serve(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:i0,'foo'")
|
||||
assert r.serve({}, s, None)
|
||||
assert r.serve({}, s)
|
||||
|
||||
def test_spec(self):
|
||||
e = language.InjectAt.expr()
|
||||
@ -344,7 +344,7 @@ class TestParseRequest:
|
||||
def test_render(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_request({}, "GET:'/foo'")
|
||||
assert r.serve({}, s, None, "foo.com")
|
||||
assert r.serve({}, s, "foo.com")
|
||||
|
||||
def test_str(self):
|
||||
r = language.parse_request({}, 'GET:"/foo"')
|
||||
@ -479,15 +479,15 @@ class TestWriteValues:
|
||||
def test_write_values_after(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:da")
|
||||
r.serve({}, s, None)
|
||||
r.serve({}, s)
|
||||
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:pa,0")
|
||||
r.serve({}, s, None)
|
||||
r.serve({}, s)
|
||||
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:ia,'xx'")
|
||||
r.serve({}, s, None)
|
||||
r.serve({}, s)
|
||||
assert s.getvalue().endswith('xx')
|
||||
|
||||
|
||||
@ -511,29 +511,22 @@ class TestResponse:
|
||||
assert r.body[:]
|
||||
assert str(r)
|
||||
|
||||
def test_checkfunc(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:b@100k")
|
||||
def check(req, acts):
|
||||
return "errmsg"
|
||||
assert r.serve({}, s, check=check)["error"] == "errmsg"
|
||||
|
||||
def test_render(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400'msg'")
|
||||
assert r.serve({}, s, None)
|
||||
assert r.serve({}, s)
|
||||
|
||||
def test_raw(self):
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:b'foo'")
|
||||
r.serve({}, s, None)
|
||||
r.serve({}, s)
|
||||
v = s.getvalue()
|
||||
assert "Content-Length" in v
|
||||
assert "Date" in v
|
||||
|
||||
s = cStringIO.StringIO()
|
||||
r = language.parse_response({}, "400:b'foo':r")
|
||||
r.serve({}, s, None)
|
||||
r.serve({}, s)
|
||||
v = s.getvalue()
|
||||
assert not "Content-Length" in v
|
||||
assert not "Date" in v
|
||||
@ -541,21 +534,18 @@ class TestResponse:
|
||||
def test_length(self):
|
||||
def testlen(x):
|
||||
s = cStringIO.StringIO()
|
||||
x.serve({}, s, None)
|
||||
x.serve({}, s)
|
||||
assert x.length({}, None) == len(s.getvalue())
|
||||
testlen(language.parse_response({}, "400'msg'"))
|
||||
testlen(language.parse_response({}, "400'msg':h'foo'='bar'"))
|
||||
testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b"))
|
||||
|
||||
def test_effective_length(self):
|
||||
l = [None]
|
||||
def check(req, actions):
|
||||
l[0] = req.effective_length({}, None)
|
||||
|
||||
def test_maximum_length(self):
|
||||
def testlen(x, actions):
|
||||
s = cStringIO.StringIO()
|
||||
x.serve({}, s, check)
|
||||
assert l[0] == len(s.getvalue())
|
||||
m = x.maximum_length({}, None)
|
||||
x.serve({}, s)
|
||||
assert m >= len(s.getvalue())
|
||||
|
||||
r = language.parse_response({}, "400'msg':b@100")
|
||||
|
||||
|
@ -58,7 +58,7 @@ class TestNohang(tutils.DaemonTests):
|
||||
r = self.get("200:p0,0")
|
||||
assert r.status_code == 800
|
||||
l = self.d.last_log()
|
||||
assert "Pauses have been disabled" in l["response"]["error"]
|
||||
assert "Pauses have been disabled" in l["msg"]
|
||||
|
||||
|
||||
class TestHexdump(tutils.DaemonTests):
|
||||
@ -77,7 +77,7 @@ class CommonTests(tutils.DaemonTests):
|
||||
r = self.get("200:b@1g")
|
||||
assert r.status_code == 800
|
||||
l = self.d.last_log()
|
||||
assert "too large" in l["response"]["error"]
|
||||
assert "too large" in l["msg"]
|
||||
|
||||
def test_preline(self):
|
||||
v = self.pathoc(r"get:'/p/200':i0,'\r\n'")
|
||||
|
Loading…
Reference in New Issue
Block a user