Remove check argument to serve() methods.

Refactoring means we can now do this without a callback.

Also introduce the maximum_length method that estimates the max possible
message length.
This commit is contained in:
Aldo Cortesi 2012-10-27 17:40:22 +13:00
parent 06864e5a1b
commit ac5aacce44
6 changed files with 51 additions and 64 deletions

View File

@ -130,10 +130,15 @@ def _preview(is_request):
s = cStringIO.StringIO()
args["pauses"] = r.preview_safe()
c = app.config["pathod"].check_policy(r)
if c:
args["error"] = c
return render(template, False, **args)
if is_request:
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy, host="example.com")
r.serve(app.config["pathod"].request_settings, s, host="example.com")
else:
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy)
r.serve(app.config["pathod"].request_settings, s)
args["output"] = utils.escape_unprintables(s.getvalue())
return render(template, False, **args)

View File

@ -557,18 +557,14 @@ class Message:
self.actions = [i for i in self.actions if not isinstance(i, PauseAt)]
return pauses
def effective_length(self, settings, request_host):
def maximum_length(self, settings, request_host):
"""
Calculate the length of the base message with all applied actions.
Calculate the maximum length of the base message with all applied actions.
"""
# Order matters here, and must match the order of application in
# write_values.
l = self.length(settings, request_host)
for i in reversed(self.ready_actions(settings, request_host)):
if i[1] == "disconnect":
return i[0]
elif i[1] == "inject":
l += len(i[2])
for i in self.actions:
if isinstance(i, InjectAt):
l += len(i.value.get_generator(settings))
return l
def headervals(self, settings, request_host):
@ -609,15 +605,10 @@ class Message:
actions.reverse()
return [i.intermediate(settings) for i in actions]
def serve(self, settings, fp, check, request_host):
def serve(self, settings, fp, request_host):
"""
fp: The file pointer to write to.
check: A function called with the effective actions (after random
values have been calculated). If it returns False service proceeds,
otherwise the return is treated as an error message to be sent to
the client, and service stops.
request_host: If this a request, this is the connecting host. If
None, we assume it's a response. Used to decide what standard
modifications to make if raw is not set.
@ -636,15 +627,6 @@ class Message:
vals.append(self.body)
vals.reverse()
actions = self.ready_actions(settings, request_host)
if check:
ret = check(self, actions)
if ret:
err = PathodErrorResponse(ret)
err.serve(settings, fp)
return dict(
disconnect = True,
error = ret
)
disconnect = write_values(fp, vals, actions[:])
duration = time.time() - started
@ -751,8 +733,8 @@ class CraftedRequest(Request):
for i in tokens:
i.accept(settings, self)
def serve(self, settings, fp, check, host):
d = Request.serve(self, settings, fp, check, host)
def serve(self, settings, fp, host):
d = Request.serve(self, settings, fp, host)
d["spec"] = self.spec
return d
@ -764,8 +746,8 @@ class CraftedResponse(Response):
for i in tokens:
i.accept(settings, self)
def serve(self, settings, fp, check):
d = Response.serve(self, settings, fp, check, None)
def serve(self, settings, fp):
d = Response.serve(self, settings, fp, None)
d["spec"] = self.spec
return d
@ -780,8 +762,8 @@ class PathodErrorResponse(Response):
Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")),
]
def serve(self, settings, fp, check=None):
d = Response.serve(self, settings, fp, check, None)
def serve(self, settings, fp):
d = Response.serve(self, settings, fp, None)
d["internal"] = True
return d

View File

@ -22,7 +22,7 @@ class Pathoc(tcp.TCPClient):
language.FileAccessDenied.
"""
r = language.parse_request(self.settings, spec)
ret = r.serve(self.settings, self.wfile, None, self.host)
ret = r.serve(self.settings, self.wfile, self.host)
self.wfile.flush()
return http.read_response(self.rfile, r.method, None)
@ -68,7 +68,7 @@ class Pathoc(tcp.TCPClient):
if showresp:
self.rfile.start_log()
try:
req = r.serve(self.settings, self.wfile, None, self.host)
req = r.serve(self.settings, self.wfile, self.host)
self.wfile.flush()
resp = http.read_response(self.rfile, r.method, None)
except http.HttpError, v:

View File

@ -18,7 +18,17 @@ class PathodHandler(tcp.BaseHandler):
self.sni = connection.get_servername()
def serve_crafted(self, crafted, request_log):
response_log = crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
c = self.server.check_policy(crafted)
if c:
err = language.PathodErrorResponse(c)
err.serve(self.server.request_settings, self.wfile)
log = dict(
type = "error",
msg = c
)
return False, log
response_log = crafted.serve(self.server.request_settings, self.wfile)
log = dict(
type = "crafted",
request=request_log,
@ -96,7 +106,7 @@ class PathodHandler(tcp.BaseHandler):
return self.serve_crafted(crafted, request_log)
elif self.server.noweb:
crafted = language.PathodErrorResponse("Access Denied")
crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy)
crafted.serve(self.server.request_settings, self.wfile)
return False, dict(type = "error", msg="Access denied: web interface disabled")
else:
self.info("app: %s %s"%(method, path))
@ -205,13 +215,13 @@ class Pathod(tcp.TCPServer):
raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v)))
self.anchors.append((arex, i[1]))
def check_policy(self, req, actions):
def check_policy(self, req):
"""
A policy check that verifies the request size is withing limits.
"""
if self.sizelimit and req.effective_length({}, None) > self.sizelimit:
if self.sizelimit and req.maximum_length({}, None) > self.sizelimit:
return "Response too large."
if self.nohang and any([i[1] == "pause" for i in actions]):
if self.nohang and any([isinstance(i, language.PauseAt) for i in req.actions]):
return "Pauses have been disabled."
return False

View File

@ -279,7 +279,7 @@ class TestInject:
def test_serve(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400:i0,'foo'")
assert r.serve({}, s, None)
assert r.serve({}, s)
def test_spec(self):
e = language.InjectAt.expr()
@ -344,7 +344,7 @@ class TestParseRequest:
def test_render(self):
s = cStringIO.StringIO()
r = language.parse_request({}, "GET:'/foo'")
assert r.serve({}, s, None, "foo.com")
assert r.serve({}, s, "foo.com")
def test_str(self):
r = language.parse_request({}, 'GET:"/foo"')
@ -479,15 +479,15 @@ class TestWriteValues:
def test_write_values_after(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400:da")
r.serve({}, s, None)
r.serve({}, s)
s = cStringIO.StringIO()
r = language.parse_response({}, "400:pa,0")
r.serve({}, s, None)
r.serve({}, s)
s = cStringIO.StringIO()
r = language.parse_response({}, "400:ia,'xx'")
r.serve({}, s, None)
r.serve({}, s)
assert s.getvalue().endswith('xx')
@ -511,29 +511,22 @@ class TestResponse:
assert r.body[:]
assert str(r)
def test_checkfunc(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400:b@100k")
def check(req, acts):
return "errmsg"
assert r.serve({}, s, check=check)["error"] == "errmsg"
def test_render(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400'msg'")
assert r.serve({}, s, None)
assert r.serve({}, s)
def test_raw(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400:b'foo'")
r.serve({}, s, None)
r.serve({}, s)
v = s.getvalue()
assert "Content-Length" in v
assert "Date" in v
s = cStringIO.StringIO()
r = language.parse_response({}, "400:b'foo':r")
r.serve({}, s, None)
r.serve({}, s)
v = s.getvalue()
assert not "Content-Length" in v
assert not "Date" in v
@ -541,21 +534,18 @@ class TestResponse:
def test_length(self):
def testlen(x):
s = cStringIO.StringIO()
x.serve({}, s, None)
x.serve({}, s)
assert x.length({}, None) == len(s.getvalue())
testlen(language.parse_response({}, "400'msg'"))
testlen(language.parse_response({}, "400'msg':h'foo'='bar'"))
testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b"))
def test_effective_length(self):
l = [None]
def check(req, actions):
l[0] = req.effective_length({}, None)
def test_maximum_length(self):
def testlen(x, actions):
s = cStringIO.StringIO()
x.serve({}, s, check)
assert l[0] == len(s.getvalue())
m = x.maximum_length({}, None)
x.serve({}, s)
assert m >= len(s.getvalue())
r = language.parse_response({}, "400'msg':b@100")

View File

@ -58,7 +58,7 @@ class TestNohang(tutils.DaemonTests):
r = self.get("200:p0,0")
assert r.status_code == 800
l = self.d.last_log()
assert "Pauses have been disabled" in l["response"]["error"]
assert "Pauses have been disabled" in l["msg"]
class TestHexdump(tutils.DaemonTests):
@ -77,7 +77,7 @@ class CommonTests(tutils.DaemonTests):
r = self.get("200:b@1g")
assert r.status_code == 800
l = self.d.last_log()
assert "too large" in l["response"]["error"]
assert "too large" in l["msg"]
def test_preline(self):
v = self.pathoc(r"get:'/p/200':i0,'\r\n'")