diff --git a/libpathod/app.py b/libpathod/app.py index 1ebe99017..6cde801a0 100644 --- a/libpathod/app.py +++ b/libpathod/app.py @@ -129,11 +129,16 @@ def _preview(is_request): s = cStringIO.StringIO() args["pauses"] = r.preview_safe() + + c = app.config["pathod"].check_policy(r) + if c: + args["error"] = c + return render(template, False, **args) if is_request: - r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy, host="example.com") + r.serve(app.config["pathod"].request_settings, s, host="example.com") else: - r.serve(app.config["pathod"].request_settings, s, check=app.config["pathod"].check_policy) + r.serve(app.config["pathod"].request_settings, s) args["output"] = utils.escape_unprintables(s.getvalue()) return render(template, False, **args) diff --git a/libpathod/language.py b/libpathod/language.py index a9670c327..311d51e5d 100644 --- a/libpathod/language.py +++ b/libpathod/language.py @@ -557,18 +557,14 @@ class Message: self.actions = [i for i in self.actions if not isinstance(i, PauseAt)] return pauses - def effective_length(self, settings, request_host): + def maximum_length(self, settings, request_host): """ - Calculate the length of the base message with all applied actions. + Calculate the maximum length of the base message with all applied actions. """ - # Order matters here, and must match the order of application in - # write_values. l = self.length(settings, request_host) - for i in reversed(self.ready_actions(settings, request_host)): - if i[1] == "disconnect": - return i[0] - elif i[1] == "inject": - l += len(i[2]) + for i in self.actions: + if isinstance(i, InjectAt): + l += len(i.value.get_generator(settings)) return l def headervals(self, settings, request_host): @@ -609,15 +605,10 @@ class Message: actions.reverse() return [i.intermediate(settings) for i in actions] - def serve(self, settings, fp, check, request_host): + def serve(self, settings, fp, request_host): """ fp: The file pointer to write to. - check: A function called with the effective actions (after random - values have been calculated). If it returns False service proceeds, - otherwise the return is treated as an error message to be sent to - the client, and service stops. - request_host: If this a request, this is the connecting host. If None, we assume it's a response. Used to decide what standard modifications to make if raw is not set. @@ -636,15 +627,6 @@ class Message: vals.append(self.body) vals.reverse() actions = self.ready_actions(settings, request_host) - if check: - ret = check(self, actions) - if ret: - err = PathodErrorResponse(ret) - err.serve(settings, fp) - return dict( - disconnect = True, - error = ret - ) disconnect = write_values(fp, vals, actions[:]) duration = time.time() - started @@ -751,8 +733,8 @@ class CraftedRequest(Request): for i in tokens: i.accept(settings, self) - def serve(self, settings, fp, check, host): - d = Request.serve(self, settings, fp, check, host) + def serve(self, settings, fp, host): + d = Request.serve(self, settings, fp, host) d["spec"] = self.spec return d @@ -764,8 +746,8 @@ class CraftedResponse(Response): for i in tokens: i.accept(settings, self) - def serve(self, settings, fp, check): - d = Response.serve(self, settings, fp, check, None) + def serve(self, settings, fp): + d = Response.serve(self, settings, fp, None) d["spec"] = self.spec return d @@ -780,8 +762,8 @@ class PathodErrorResponse(Response): Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")), ] - def serve(self, settings, fp, check=None): - d = Response.serve(self, settings, fp, check, None) + def serve(self, settings, fp): + d = Response.serve(self, settings, fp, None) d["internal"] = True return d diff --git a/libpathod/pathoc.py b/libpathod/pathoc.py index 873a989cc..df291c59a 100644 --- a/libpathod/pathoc.py +++ b/libpathod/pathoc.py @@ -22,7 +22,7 @@ class Pathoc(tcp.TCPClient): language.FileAccessDenied. """ r = language.parse_request(self.settings, spec) - ret = r.serve(self.settings, self.wfile, None, self.host) + ret = r.serve(self.settings, self.wfile, self.host) self.wfile.flush() return http.read_response(self.rfile, r.method, None) @@ -68,7 +68,7 @@ class Pathoc(tcp.TCPClient): if showresp: self.rfile.start_log() try: - req = r.serve(self.settings, self.wfile, None, self.host) + req = r.serve(self.settings, self.wfile, self.host) self.wfile.flush() resp = http.read_response(self.rfile, r.method, None) except http.HttpError, v: diff --git a/libpathod/pathod.py b/libpathod/pathod.py index e0e30d17d..5d787c554 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -18,7 +18,17 @@ class PathodHandler(tcp.BaseHandler): self.sni = connection.get_servername() def serve_crafted(self, crafted, request_log): - response_log = crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy) + c = self.server.check_policy(crafted) + if c: + err = language.PathodErrorResponse(c) + err.serve(self.server.request_settings, self.wfile) + log = dict( + type = "error", + msg = c + ) + return False, log + + response_log = crafted.serve(self.server.request_settings, self.wfile) log = dict( type = "crafted", request=request_log, @@ -96,7 +106,7 @@ class PathodHandler(tcp.BaseHandler): return self.serve_crafted(crafted, request_log) elif self.server.noweb: crafted = language.PathodErrorResponse("Access Denied") - crafted.serve(self.server.request_settings, self.wfile, self.server.check_policy) + crafted.serve(self.server.request_settings, self.wfile) return False, dict(type = "error", msg="Access denied: web interface disabled") else: self.info("app: %s %s"%(method, path)) @@ -205,13 +215,13 @@ class Pathod(tcp.TCPServer): raise PathodError("Invalid page spec in anchor: '%s', %s"%(i[1], str(v))) self.anchors.append((arex, i[1])) - def check_policy(self, req, actions): + def check_policy(self, req): """ A policy check that verifies the request size is withing limits. """ - if self.sizelimit and req.effective_length({}, None) > self.sizelimit: + if self.sizelimit and req.maximum_length({}, None) > self.sizelimit: return "Response too large." - if self.nohang and any([i[1] == "pause" for i in actions]): + if self.nohang and any([isinstance(i, language.PauseAt) for i in req.actions]): return "Pauses have been disabled." return False diff --git a/test/test_language.py b/test/test_language.py index d3124c5a9..289f180c7 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -279,7 +279,7 @@ class TestInject: def test_serve(self): s = cStringIO.StringIO() r = language.parse_response({}, "400:i0,'foo'") - assert r.serve({}, s, None) + assert r.serve({}, s) def test_spec(self): e = language.InjectAt.expr() @@ -344,7 +344,7 @@ class TestParseRequest: def test_render(self): s = cStringIO.StringIO() r = language.parse_request({}, "GET:'/foo'") - assert r.serve({}, s, None, "foo.com") + assert r.serve({}, s, "foo.com") def test_str(self): r = language.parse_request({}, 'GET:"/foo"') @@ -479,15 +479,15 @@ class TestWriteValues: def test_write_values_after(self): s = cStringIO.StringIO() r = language.parse_response({}, "400:da") - r.serve({}, s, None) + r.serve({}, s) s = cStringIO.StringIO() r = language.parse_response({}, "400:pa,0") - r.serve({}, s, None) + r.serve({}, s) s = cStringIO.StringIO() r = language.parse_response({}, "400:ia,'xx'") - r.serve({}, s, None) + r.serve({}, s) assert s.getvalue().endswith('xx') @@ -511,29 +511,22 @@ class TestResponse: assert r.body[:] assert str(r) - def test_checkfunc(self): - s = cStringIO.StringIO() - r = language.parse_response({}, "400:b@100k") - def check(req, acts): - return "errmsg" - assert r.serve({}, s, check=check)["error"] == "errmsg" - def test_render(self): s = cStringIO.StringIO() r = language.parse_response({}, "400'msg'") - assert r.serve({}, s, None) + assert r.serve({}, s) def test_raw(self): s = cStringIO.StringIO() r = language.parse_response({}, "400:b'foo'") - r.serve({}, s, None) + r.serve({}, s) v = s.getvalue() assert "Content-Length" in v assert "Date" in v s = cStringIO.StringIO() r = language.parse_response({}, "400:b'foo':r") - r.serve({}, s, None) + r.serve({}, s) v = s.getvalue() assert not "Content-Length" in v assert not "Date" in v @@ -541,21 +534,18 @@ class TestResponse: def test_length(self): def testlen(x): s = cStringIO.StringIO() - x.serve({}, s, None) + x.serve({}, s) assert x.length({}, None) == len(s.getvalue()) testlen(language.parse_response({}, "400'msg'")) testlen(language.parse_response({}, "400'msg':h'foo'='bar'")) testlen(language.parse_response({}, "400'msg':h'foo'='bar':b@100b")) - def test_effective_length(self): - l = [None] - def check(req, actions): - l[0] = req.effective_length({}, None) - + def test_maximum_length(self): def testlen(x, actions): s = cStringIO.StringIO() - x.serve({}, s, check) - assert l[0] == len(s.getvalue()) + m = x.maximum_length({}, None) + x.serve({}, s) + assert m >= len(s.getvalue()) r = language.parse_response({}, "400'msg':b@100") diff --git a/test/test_pathod.py b/test/test_pathod.py index 7bbb5545c..195c73337 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -58,7 +58,7 @@ class TestNohang(tutils.DaemonTests): r = self.get("200:p0,0") assert r.status_code == 800 l = self.d.last_log() - assert "Pauses have been disabled" in l["response"]["error"] + assert "Pauses have been disabled" in l["msg"] class TestHexdump(tutils.DaemonTests): @@ -77,7 +77,7 @@ class CommonTests(tutils.DaemonTests): r = self.get("200:b@1g") assert r.status_code == 800 l = self.d.last_log() - assert "too large" in l["response"]["error"] + assert "too large" in l["msg"] def test_preline(self): v = self.pathoc(r"get:'/p/200':i0,'\r\n'")