Add a .values method to messages, simplify a lot of stuff as a consequence.

This commit is contained in:
Aldo Cortesi 2012-10-30 16:04:48 +13:00
parent a09584b9e6
commit f8df0a1e75
3 changed files with 24 additions and 34 deletions

View File

@ -88,14 +88,7 @@ def serve(msg, fp, settings, request_host=None):
msg = msg.resolve(settings, request_host)
started = time.time()
hdrs = msg.headervals(settings, request_host)
vals = msg.preamble(settings)
vals.append("\r\n")
vals.extend(hdrs)
vals.append("\r\n")
if msg.body:
vals.append(msg.body.value.get_generator(settings))
vals = msg.values(settings)
vals.reverse()
actions = msg.actions[:]
@ -222,7 +215,7 @@ class _Token(object):
"""
return None
def resolve(self, msg, settings, request_host): # pragma: no cover
def resolve(self, msg, settings): # pragma: no cover
"""
Resolves this token to ready it for transmission. This means that
the calculated offsets of actions are fixed.
@ -553,13 +546,13 @@ class _Action(_Token):
def __init__(self, offset):
self.offset = offset
def resolve(self, msg, settings, request_host):
def resolve(self, msg, settings):
"""
Resolves offset specifications to a numeric offset. Returns a copy
of the action object.
"""
c = copy.copy(self)
l = msg.length(settings, request_host)
l = msg.length(settings)
if c.offset == "r":
c.offset = random.randrange(l)
elif c.offset == "a":
@ -677,18 +670,11 @@ class _Message(object):
def headers(self):
return self._get_tokens(_Header)
def length(self, settings, request_host):
def length(self, settings):
"""
Calculate the length of the base message without any applied actions.
"""
l = sum(len(x) for x in self.preamble(settings))
l += 2
for h in self.headervals(settings, request_host):
l += len(h)
l += 2
if self.body:
l += len(self.body.value.get_generator(settings))
return l
return sum(len(x) for x in self.values(settings))
def preview_safe(self):
"""
@ -697,11 +683,11 @@ class _Message(object):
tokens = [i for i in self.tokens if not isinstance(i, PauseAt)]
return self.__class__(tokens)
def maximum_length(self, settings, request_host):
def maximum_length(self, settings):
"""
Calculate the maximum length of the base message with all applied actions.
"""
l = self.length(settings, request_host)
l = self.length(settings)
for i in self.actions:
if isinstance(i, InjectAt):
l += len(i.value.get_generator(settings))
@ -734,13 +720,7 @@ class _Message(object):
)
)
intermediate = self.__class__(tokens)
return self.__class__([i.resolve(intermediate, settings, request_host) for i in tokens])
def headervals(self, settings, request_host):
values = []
for h in self.headers:
values.extend(h.values(settings))
return values
return self.__class__([i.resolve(intermediate, settings) for i in tokens])
@abc.abstractmethod
def preamble(self, settings): # pragma: no cover
@ -768,6 +748,16 @@ class _Message(object):
ret["spec"] = self.spec()
return ret
def values(self, settings):
vals = self.preamble(settings)
vals.append("\r\n")
for h in self.headers:
vals.extend(h.values(settings))
vals.append("\r\n")
if self.body:
vals.append(self.body.value.get_generator(settings))
return vals
Sep = pp.Optional(pp.Literal(":")).suppress()

View File

@ -217,7 +217,7 @@ class Pathod(tcp.TCPServer):
A policy check that verifies the request size is withing limits.
"""
try:
l = req.maximum_length(settings, None)
l = req.maximum_length(settings)
except language.FileAccessDenied, v:
return "File access denied."
if self.sizelimit and l > self.sizelimit:

View File

@ -260,7 +260,7 @@ class Test_Action:
def test_resolve(self):
r = language.parse_request({}, 'GET:"/foo"')
e = language.DisconnectAt("r")
ret = e.resolve(r, {}, None)
ret = e.resolve(r, {})
assert isinstance(ret.offset, int)
def test_repr(self):
@ -444,7 +444,7 @@ class TestParseResponse:
def test_parse_stress(self):
r = language.parse_response({}, "400:b@100g")
assert r.length({}, None)
assert r.length({})
def test_spec(self):
def rt(s):
@ -583,7 +583,7 @@ class TestResponse:
def testlen(x):
s = cStringIO.StringIO()
language.serve(x, s, {})
assert x.length({}, None) == len(s.getvalue())
assert x.length({}) == len(s.getvalue())
testlen(language.parse_response({}, "400:m'msg':r"))
testlen(language.parse_response({}, "400:m'msg':h'foo'='bar':r"))
testlen(language.parse_response({}, "400:m'msg':h'foo'='bar':b@100b:r"))
@ -591,7 +591,7 @@ class TestResponse:
def test_maximum_length(self):
def testlen(x):
s = cStringIO.StringIO()
m = x.maximum_length({}, None)
m = x.maximum_length({})
language.serve(x, s, {})
assert m >= len(s.getvalue())