Pull HTTP functionality out of language._Message in prep for frames

This commit is contained in:
Aldo Cortesi 2015-04-24 07:35:17 +12:00
parent b0ab5297d1
commit 5405a4d458
3 changed files with 52 additions and 27 deletions

View File

@ -619,7 +619,7 @@ class WS(_Component):
@classmethod @classmethod
def expr(klass): def expr(klass):
spec = pp.Literal("ws") spec = pp.CaselessLiteral("ws")
spec = spec.setParseAction(lambda x: klass(*x)) spec = spec.setParseAction(lambda x: klass(*x))
return spec return spec
@ -829,7 +829,6 @@ class InjectAt(_Action):
class _Message(object): class _Message(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
version = "HTTP/1.1"
logattrs = [] logattrs = []
def __init__(self, tokens): def __init__(self, tokens):
@ -917,16 +916,6 @@ class _Message(object):
ret["spec"] = self.spec() ret["spec"] = self.spec()
return ret return ret
def values(self, settings):
vals = self.preamble(settings)
vals.append("\r\n")
for h in self.headers:
vals.extend(h.values(settings))
vals.append("\r\n")
if self.body:
vals.append(self.body.value.get_generator(settings))
return vals
def freeze(self, settings): def freeze(self, settings):
r = self.resolve(settings) r = self.resolve(settings)
return self.__class__([i.freeze(settings) for i in r.tokens]) return self.__class__([i.freeze(settings) for i in r.tokens])
@ -938,7 +927,21 @@ class _Message(object):
Sep = pp.Optional(pp.Literal(":")).suppress() Sep = pp.Optional(pp.Literal(":")).suppress()
class Response(_Message): class _HTTPMessage(_Message):
version = "HTTP/1.1"
def values(self, settings):
vals = self.preamble(settings)
vals.append("\r\n")
for h in self.headers:
vals.extend(h.values(settings))
vals.append("\r\n")
if self.body:
vals.append(self.body.value.get_generator(settings))
return vals
class Response(_HTTPMessage):
comps = ( comps = (
Body, Body,
Header, Header,
@ -966,12 +969,8 @@ class Response(_Message):
def preamble(self, settings): def preamble(self, settings):
l = [self.version, " "] l = [self.version, " "]
if self.code: l.extend(self.code.values(settings))
l.extend(self.code.values(settings)) code = int(self.code.code)
code = int(self.code.code)
elif self.ws:
l.extend(Code(101).values(settings))
code = 101
l.append(" ") l.append(" ")
if self.reason: if self.reason:
l.extend(self.reason.values(settings)) l.extend(self.reason.values(settings))
@ -1042,7 +1041,7 @@ class Response(_Message):
return ":".join([i.spec() for i in self.tokens]) return ":".join([i.spec() for i in self.tokens])
class Request(_Message): class Request(_HTTPMessage):
comps = ( comps = (
Body, Body,
Header, Header,
@ -1222,7 +1221,12 @@ def parse_requests(s):
try: try:
parts = pp.OneOrMore( parts = pp.OneOrMore(
pp.Group( pp.Group(
Request.expr() pp.Or(
[
Request.expr(),
WebsocketFrame.expr(),
]
)
) )
).parseString(s, parseAll=True) ).parseString(s, parseAll=True)
return [Request(i) for i in parts] return [Request(i) for i in parts]

View File

@ -14,6 +14,13 @@ def parse_request(s):
return language.parse_requests(s)[0] return language.parse_requests(s)[0]
class TestWS:
def test_expr(self):
v = language.WS("foo")
assert v.expr()
assert v.values(language.Settings())
class TestValueNakedLiteral: class TestValueNakedLiteral:
def test_expr(self): def test_expr(self):
v = language.ValueNakedLiteral("foo") v = language.ValueNakedLiteral("foo")
@ -572,7 +579,6 @@ class TestRequest:
language.Settings(request_host = "foo.com") language.Settings(request_host = "foo.com")
) )
def test_multiline(self): def test_multiline(self):
l = """ l = """
GET GET
@ -632,10 +638,18 @@ class TestRequest:
class TestWebsocketFrame:
def test_spec(self):
e = language.WebsocketFrame.expr()
assert e.parseString("wf:foo")
class TestWriteValues: class TestWriteValues:
def test_send_chunk(self): def test_send_chunk(self):
v = "foobarfoobar" v = "foobarfoobar"
for bs in range(1, len(v)+2): for bs in range(1, len(v) + 2):
s = cStringIO.StringIO() s = cStringIO.StringIO()
language.send_chunk(s, v, bs, 0, len(v)) language.send_chunk(s, v, bs, 0, len(v))
assert s.getvalue() == v assert s.getvalue() == v
@ -662,7 +676,7 @@ class TestWriteValues:
def test_write_values_disconnects(self): def test_write_values_disconnects(self):
s = cStringIO.StringIO() s = cStringIO.StringIO()
tst = "foo"*100 tst = "foo" * 100
language.write_values(s, [tst], [(0, "disconnect")], blocksize=5) language.write_values(s, [tst], [(0, "disconnect")], blocksize=5)
assert not s.getvalue() assert not s.getvalue()
@ -675,14 +689,18 @@ class TestWriteValues:
for bs in range(1, len(tst) + 2): for bs in range(1, len(tst) + 2):
for off in range(len(tst)): for off in range(len(tst)):
s = cStringIO.StringIO() s = cStringIO.StringIO()
language.write_values(s, [tst], [(off, "disconnect")], blocksize=bs) language.write_values(
s, [tst], [(off, "disconnect")], blocksize=bs
)
assert s.getvalue() == tst[:off] assert s.getvalue() == tst[:off]
def test_write_values_pauses(self): def test_write_values_pauses(self):
tst = "".join(str(i) for i in range(10)) tst = "".join(str(i) for i in range(10))
for i in range(2, 10): for i in range(2, 10):
s = cStringIO.StringIO() s = cStringIO.StringIO()
language.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i) language.write_values(
s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i
)
assert s.getvalue() == tst assert s.getvalue() == tst
for i in range(2, 10): for i in range(2, 10):
@ -690,7 +708,7 @@ class TestWriteValues:
language.write_values(s, [tst], [(1, "pause", 0)], blocksize=i) language.write_values(s, [tst], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst assert s.getvalue() == tst
tst = ["".join(str(i) for i in range(10))]*5 tst = ["".join(str(i) for i in range(10))] * 5
for i in range(2, 10): for i in range(2, 10):
s = cStringIO.StringIO() s = cStringIO.StringIO()
language.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i) language.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i)

View File

@ -188,6 +188,9 @@ class CommonTests(tutils.DaemonTests):
r = self.pathoc("ws:/p/") r = self.pathoc("ws:/p/")
assert r.status_code == 101 assert r.status_code == 101
r = self.pathoc("ws:/p/ws")
assert r.status_code == 101
class TestDaemon(CommonTests): class TestDaemon(CommonTests):
ssl = False ssl = False