Remove check argument to serve() methods.

Refactoring means we can now do this without a callback.

Also introduce the maximum_length method that estimates the max possible
message length.
This commit is contained in:
Aldo Cortesi 2012-10-27 17:40:22 +13:00
parent 06864e5a1b
commit ac5aacce44
6 changed files with 51 additions and 64 deletions

View File

@ -129,11 +129,16 @@ def _preview(is_request):
s = cStringIO.StringIO() s = cStringIO.StringIO()
args["pauses"] = r.preview_safe() args["pauses"] = r.preview_safe()
c = app.config["pathod"].check_policy(r)
if c:
args["error"] = c
return render(template, False, **args)
if is_request: if is_request:
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy, host="example.com") r.serve(app.config["pathod"].request_settings, s, host="example.com")
else: else:
r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy) r.serve(app.config["pathod"].request_settings, s)
args["output"] = utils.escape_unprintables(s.getvalue()) args["output"] = utils.escape_unprintables(s.getvalue())
return render(template, False, **args) return render(template, False, **args)

View File

@ -557,18 +557,14 @@ class Message:
self.actions = [i for i in self.actions if not isinstance(i, PauseAt)] self.actions = [i for i in self.actions if not isinstance(i, PauseAt)]
return pauses return pauses
def effective_length(self, settings, request_host): def maximum_length(self, settings, request_host):
""" """
Calculate the length of the base message with all applied actions. Calculate the maximum length of the base message with all applied actions.
""" """
# Order matters here, and must match the order of application in
# write_values.
l = self.length(settings, request_host) l = self.length(settings, request_host)
for i in reversed(self.ready_actions(settings, request_host)): for i in self.actions:
if i[1] == "disconnect": if isinstance(i, InjectAt):
return i[0] l += len(i.value.get_generator(settings))
elif i[1] == "inject":
l += len(i[2])
return l return l
def headervals(self, settings, request_host): def headervals(self, settings, request_host):
@ -609,15 +605,10 @@ class Message:
actions.reverse() actions.reverse()
return [i.intermediate(settings) for i in actions] return [i.intermediate(settings) for i in actions]
def serve(self, settings, fp, check, request_host): def serve(self, settings, fp, request_host):
""" """
fp: The file pointer to write to. fp: The file pointer to write to.
check: A function called with the effective actions (after random
values have been calculated). If it returns False service proceeds,
otherwise the return is treated as an error message to be sent to
the client, and service stops.
request_host: If this a request, this is the connecting host. If request_host: If this a request, this is the connecting host. If
None, we assume it's a response. Used to decide what standard None, we assume it's a response. Used to decide what standard
modifications to make if raw is not set. modifications to make if raw is not set.
@ -636,15 +627,6 @@ class Message:
vals.append(self.body) vals.append(self.body)
vals.reverse() vals.reverse()
actions = self.ready_actions(settings, request_host) actions = self.ready_actions(settings, request_host)
if check:
ret = check(self, actions)
if ret:
err = PathodErrorResponse(ret)
err.serve(settings, fp)
return dict(
disconnect = True,
error = ret
)
disconnect = write_values(fp, vals, actions[:]) disconnect = write_values(fp, vals, actions[:])
duration = time.time() - started duration = time.time() - started
@ -751,8 +733,8 @@ class CraftedRequest(Request):
for i in tokens: for i in tokens:
i.accept(settings, self) i.accept(settings, self)
def serve(self, settings, fp, check, host): def serve(self, settings, fp, host):
d = Request.serve(self, settings, fp, check, host) d = Request.serve(self, settings, fp, host)
d["spec"] = self.spec d["spec"] = self.spec
return d return d
@ -764,8 +746,8 @@ class CraftedResponse(Response):
for i in tokens: for i in tokens:
i.accept(settings, self) i.accept(settings, self)
def serve(self, settings, fp, check): def serve(self, settings, fp):
d = Response.serve(self, settings, fp, check, None) d = Response.serve(self, settings, fp, None)
d["spec"] = self.spec d["spec"] = self.spec
return d return d
@ -780,8 +762,8 @@ class PathodErrorResponse(Response):
Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")), Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")),
] ]
def serve(self, settings, fp, check=None): def serve(self, settings, fp):
d = Response.serve(self, settings, fp, check, None) d = Response.serve(self, settings, fp, None)
d["internal"] = True d["internal"] = True
return d return d

View File

@ -22,7 +22,7 @@ class Pathoc(tcp.TCPClient):
language.FileAccessDenied. language.FileAccessDenied.
""" """
r = language.parse_request(self.settings, spec) r = language.parse_request(self.settings, spec)
ret = r.serve(self.settings, self.wfile, None, self.host) ret = r.serve(self.settings, self.wfile, self.host)
self.wfile.flush() self.wfile.flush()
return http.read_response(self.rfile, r.method, None) return http.read_response(self.rfile, r.method, None)
@ -68,7 +68,7 @@ class Pathoc(tcp.TCPClient):
if showresp: if showresp:
self.rfile.start_log() self.rfile.start_log()
try: try:
req = r.serve(self.settings, self.wfile, None, self.host) req = r.serve(self.settings, self.wfile, self.host)
self.wfile.flush() self.wfile.flush()
resp = http.read_response(self.rfile, r.method, None) resp = http.read_response(self.rfile, r.method, None)
except http.HttpError, v: except http.HttpError, v:

View File

@ -18,7 +18,17 @@ class PathodHandler(tcp.BaseHandler):
self.sni = connection.get_servername() self.sni = connection.get_servername()
def serve_crafted(self, crafted, request_log): def serve_crafted(self, crafted, request_log):
response_log = crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy) c = self.server.check_policy(crafted)
if c:
err = language.PathodErrorResponse(c)
err.serve(self.server.request_settings, self.wfile)
log = dict(
type = "error",
msg = c
)
return False, log
response_log = crafted.serve(self.server.request_settings, self.wfile)
log = dict( log = dict(
type = "crafted", type = "crafted",
request=request_log, request=request_log,
@ -96,7 +106,7 @@ class PathodHandler(tcp.BaseHandler):
return self.serve_crafted(crafted, request_log) return self.serve_crafted(crafted, request_log)
elif self.server.noweb: elif self.server.noweb:
crafted = language.PathodErrorResponse("Access Denied") crafted = language.PathodErrorResponse("Access Denied")
crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy) crafted.serve(self.server.request_settings, self.wfile)
return False, dict(type = "error", msg="Access denied: web interface disabled") return False, dict(type = "error", msg="Access denied: web interface disabled")
else: else:
self.info("app: %s %s"%(method, path)) self.info("app: %s %s"%(method, path))
@ -205,13 +215,13 @@ class Pathod(tcp.TCPServer):
raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v))) raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v)))
self.anchors.append((arex, i[1])) self.anchors.append((arex, i[1]))
def check_policy(self, req, actions): def check_policy(self, req):
""" """
A policy check that verifies the request size is withing limits. A policy check that verifies the request size is withing limits.
""" """
if self.sizelimit and req.effective_length({}, None) > self.sizelimit: if self.sizelimit and req.maximum_length({}, None) > self.sizelimit:
return "Response too large." return "Response too large."
if self.nohang and any([i[1] == "pause" for i in actions]): if self.nohang and any([isinstance(i, language.PauseAt) for i in req.actions]):
return "Pauses have been disabled." return "Pauses have been disabled."
return False return False

View File

@ -279,7 +279,7 @@ class TestInject:
def test_serve(self): def test_serve(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:i0,'foo'") r = language.parse_response({}, "400:i0,'foo'")
assert r.serve({}, s, None) assert r.serve({}, s)
def test_spec(self): def test_spec(self):
e = language.InjectAt.expr() e = language.InjectAt.expr()
@ -344,7 +344,7 @@ class TestParseRequest:
def test_render(self): def test_render(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_request({}, "GET:'/foo'") r = language.parse_request({}, "GET:'/foo'")
assert r.serve({}, s, None, "foo.com") assert r.serve({}, s, "foo.com")
def test_str(self): def test_str(self):
r = language.parse_request({}, 'GET:"/foo"') r = language.parse_request({}, 'GET:"/foo"')
@ -479,15 +479,15 @@ class TestWriteValues:
def test_write_values_after(self): def test_write_values_after(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:da") r = language.parse_response({}, "400:da")
r.serve({}, s, None) r.serve({}, s)
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:pa,0") r = language.parse_response({}, "400:pa,0")
r.serve({}, s, None) r.serve({}, s)
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:ia,'xx'") r = language.parse_response({}, "400:ia,'xx'")
r.serve({}, s, None) r.serve({}, s)
assert s.getvalue().endswith('xx') assert s.getvalue().endswith('xx')
@ -511,29 +511,22 @@ class TestResponse:
assert r.body[:] assert r.body[:]
assert str(r) assert str(r)
def test_checkfunc(self):
s = cStringIO.StringIO()
r = language.parse_response({}, "400:b@100k")
def check(req, acts):
return "errmsg"
assert r.serve({}, s, check=check)["error"] == "errmsg"
def test_render(self): def test_render(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400'msg'") r = language.parse_response({}, "400'msg'")
assert r.serve({}, s, None) assert r.serve({}, s)
def test_raw(self): def test_raw(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:b'foo'") r = language.parse_response({}, "400:b'foo'")
r.serve({}, s, None) r.serve({}, s)
v = s.getvalue() v = s.getvalue()
assert "Content-Length" in v assert "Content-Length" in v
assert "Date" in v assert "Date" in v
s = cStringIO.StringIO() s = cStringIO.StringIO()
r = language.parse_response({}, "400:b'foo':r") r = language.parse_response({}, "400:b'foo':r")
r.serve({}, s, None) r.serve({}, s)
v = s.getvalue() v = s.getvalue()
assert not "Content-Length" in v assert not "Content-Length" in v
assert not "Date" in v assert not "Date" in v
@ -541,21 +534,18 @@ class TestResponse:
def test_length(self): def test_length(self):
def testlen(x): def testlen(x):
s = cStringIO.StringIO() s = cStringIO.StringIO()
x.serve({}, s, None) x.serve({}, s)
assert x.length({}, None) == len(s.getvalue()) assert x.length({}, None) == len(s.getvalue())
testlen(language.parse_response({}, "400'msg'")) testlen(language.parse_response({}, "400'msg'"))
testlen(language.parse_response({}, "400'msg':h'foo'='bar'")) testlen(language.parse_response({}, "400'msg':h'foo'='bar'"))
testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b")) testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b"))
def test_effective_length(self): def test_maximum_length(self):
l = [None]
def check(req, actions):
l[0] = req.effective_length({}, None)
def testlen(x, actions): def testlen(x, actions):
s = cStringIO.StringIO() s = cStringIO.StringIO()
x.serve({}, s, check) m = x.maximum_length({}, None)
assert l[0] == len(s.getvalue()) x.serve({}, s)
assert m >= len(s.getvalue())
r = language.parse_response({}, "400'msg':b@100") r = language.parse_response({}, "400'msg':b@100")

View File

@ -58,7 +58,7 @@ class TestNohang(tutils.DaemonTests):
r = self.get("200:p0,0") r = self.get("200:p0,0")
assert r.status_code == 800 assert r.status_code == 800
l = self.d.last_log() l = self.d.last_log()
assert "Pauses have been disabled" in l["response"]["error"] assert "Pauses have been disabled" in l["msg"]
class TestHexdump(tutils.DaemonTests): class TestHexdump(tutils.DaemonTests):
@ -77,7 +77,7 @@ class CommonTests(tutils.DaemonTests):
r = self.get("200:b@1g") r = self.get("200:b@1g")
assert r.status_code == 800 assert r.status_code == 800
l = self.d.last_log() l = self.d.last_log()
assert "too large" in l["response"]["error"] assert "too large" in l["msg"]
def test_preline(self): def test_preline(self):
v = self.pathoc(r"get:'/p/200':i0,'\r\n'") v = self.pathoc(r"get:'/p/200':i0,'\r\n'")