diff --git a/libpathod/language.py b/libpathod/language.py index a74412b6e..6fa39f6ab 100644 --- a/libpathod/language.py +++ b/libpathod/language.py @@ -309,18 +309,8 @@ class _Spec(object): """ return None - @abc.abstractmethod - def accept(self, r): # pragma: no cover - """ - Notifies the component to register itself with message r. - """ - return None - class Raw(_Spec): - def accept(self, r): - r.raw = True - @classmethod def expr(klass): e = pp.Literal("r").suppress() @@ -357,9 +347,6 @@ class _Header(_Component): "\r\n", ] - def accept(self, r): - r.headers.append(self) - class Header(_Header): @classmethod @@ -397,9 +384,6 @@ class Body(_Component): def __init__(self, value): self.value = value - def accept(self, r): - r.body = self - @classmethod def expr(klass): e = pp.Literal("b").suppress() @@ -418,9 +402,6 @@ class Path(_Component): value = ValueLiteral(value) self.value = value - def accept(self, r): - r.path = self - @classmethod def expr(klass): e = NakedValue.copy() @@ -451,9 +432,6 @@ class Method(_Component): value = ValueLiteral(value.upper()) self.value = value - def accept(self, r): - r.method = self - @classmethod def expr(klass): parts = [pp.CaselessLiteral(i) for i in klass.methods] @@ -472,9 +450,6 @@ class Code(_Component): def __init__(self, code): self.code = str(code) - def accept(self, r): - r.code = self - @classmethod def expr(klass): e = v_integer.copy() @@ -488,8 +463,6 @@ class Reason(_Component): def __init__(self, value): self.value = value - def accept(self, r): - r.reason = self @classmethod def expr(klass): @@ -529,9 +502,6 @@ class _Action(_Spec): def __repr__(self): return self.spec() - def accept(self, r): - r.actions.append(self) - @abc.abstractmethod def spec(self): # pragma: no cover pass @@ -610,11 +580,32 @@ class InjectAt(_Action): class _Message(object): __metaclass__ = abc.ABCMeta version = "HTTP/1.1" - def __init__(self): - self.body = None - self.headers = [] - self.actions = [] - self.raw = False + def __init__(self, tokens): + self.tokens = tokens + + def _get_tokens(self, klass): + return [i for i in self.tokens if isinstance(i, klass)] + + def _get_token(self, klass): + l = self._get_tokens(klass) + if l: + return l[0] + + @property + def raw(self): + return bool(self._get_token(Raw)) + + @property + def actions(self): + return self._get_tokens(_Action) + + @property + def body(self): + return self._get_token(Body) + + @property + def headers(self): + return self._get_tokens(_Header) def length(self, settings, request_host): """ @@ -634,7 +625,7 @@ class _Message(object): Modify this message to be safe for previews. Returns a list of elided actions. """ pauses = [i for i in self.actions if isinstance(i, PauseAt)] - self.actions = [i for i in self.actions if not isinstance(i, PauseAt)] + #self.actions = [i for i in self.actions if not isinstance(i, PauseAt)] return pauses def maximum_length(self, settings, request_host): @@ -738,7 +729,7 @@ class _Message(object): Sep = pp.Optional(pp.Literal(":")).suppress() -class Response(_Message): +class _Response(_Message): comps = ( Body, Header, @@ -751,10 +742,13 @@ class Response(_Message): Reason ) logattrs = ["code", "reason", "version", "body"] - def __init__(self): - _Message.__init__(self) - self.code = None - self.reason = None + @property + def code(self): + return self._get_token(Code) + + @property + def reason(self): + return self._get_token(Reason) def preamble(self, settings): l = [self.version, " "] @@ -779,7 +773,7 @@ class Response(_Message): return resp -class Request(_Message): +class _Request(_Message): comps = ( Body, Header, @@ -790,10 +784,13 @@ class Request(_Message): Raw ) logattrs = ["method", "path", "body"] - def __init__(self): - _Message.__init__(self) - self.method = None - self.path = None + @property + def method(self): + return self._get_token(Method) + + @property + def path(self): + return self._get_token(Path) def preamble(self, settings): v = self.method.values(settings) @@ -818,44 +815,40 @@ class Request(_Message): return resp -class CraftedRequest(Request): - def __init__(self, settings, spec, tokens): - Request.__init__(self) +class CraftedRequest(_Request): + def __init__(self, spec, tokens): + _Request.__init__(self, tokens) self.spec, self.tokens = spec, tokens - for i in tokens: - i.accept(self) def serve(self, fp, settings, host): - d = Request.serve(self, fp, settings, host) + d = _Request.serve(self, fp, settings, host) d["spec"] = self.spec return d -class CraftedResponse(Response): - def __init__(self, settings, spec, tokens): - Response.__init__(self) +class CraftedResponse(_Response): + def __init__(self, spec, tokens): + _Response.__init__(self, tokens) self.spec, self.tokens = spec, tokens - for i in tokens: - i.accept(self) def serve(self, fp, settings): - d = Response.serve(self, fp, settings, None) + d = _Response.serve(self, fp, settings, None) d["spec"] = self.spec return d -class PathodErrorResponse(Response): - def __init__(self, msg, body=None): - Response.__init__(self) - self.code = Code("800") - self.msg = LiteralGenerator(msg) - self.body = Body(ValueLiteral("pathod error: " + (body or msg))) - self.headers = [ +class PathodErrorResponse(_Response): + def __init__(self, reason, body=None): + tokens = [ + Code("800"), Header(ValueLiteral("Content-Type"), ValueLiteral("text/plain")), + Reason(ValueLiteral(reason)), + Body(ValueLiteral("pathod error: " + (body or reason))), ] + _Response.__init__(self, tokens) def serve(self, fp, settings): - d = Response.serve(self, fp, settings, None) + d = _Response.serve(self, fp, settings, None) d["internal"] = True return d @@ -888,7 +881,7 @@ def parse_response(settings, s): if s.startswith(FILESTART): s = read_file(settings, s) try: - return CraftedResponse(settings, s, Response.expr().parseString(s, parseAll=True)) + return CraftedResponse(s, _Response.expr().parseString(s, parseAll=True)) except pp.ParseException, v: raise ParseException(v.msg, v.line, v.col) @@ -904,6 +897,6 @@ def parse_request(settings, s): if s.startswith(FILESTART): s = read_file(settings, s) try: - return CraftedRequest(settings, s, Request.expr().parseString(s, parseAll=True)) + return CraftedRequest(s, _Request.expr().parseString(s, parseAll=True)) except pp.ParseException, v: raise ParseException(v.msg, v.line, v.col) diff --git a/test/test_language.py b/test/test_language.py index cb7d7d1b1..009f4ddd1 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -542,32 +542,20 @@ class TestResponse: testlen(language.parse_response({}, "400:m'msg':h'foo'='bar':b@100b")) def test_maximum_length(self): - def testlen(x, actions): + def testlen(x): s = cStringIO.StringIO() m = x.maximum_length({}, None) x.serve(s, {}) assert m >= len(s.getvalue()) - r = language.parse_response({}, "400:m'msg':b@100") + r = language.parse_response({}, "400:m'msg':b@100:d0") + testlen(r) - actions = [ - language.DisconnectAt(0) - ] - r.actions = actions - testlen(r, actions) + r = language.parse_response({}, "400:m'msg':b@100:d0:i0,'foo'") + testlen(r) - actions = [ - language.DisconnectAt(0), - language.InjectAt(0, language.ValueLiteral("foo")) - ] - r.actions = actions - testlen(r, actions) - - actions = [ - language.InjectAt(0, language.ValueLiteral("foo")) - ] - r.actions = actions - testlen(r, actions) + r = language.parse_response({}, "400:m'msg':b@100:d0:i0,'foo'") + testlen(r) def test_render(self): r = language.parse_response({}, "400:p0,100:dr")