mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Pull HTTP functionality out of language._Message in prep for frames
This commit is contained in:
parent
b0ab5297d1
commit
5405a4d458
@ -619,7 +619,7 @@ class WS(_Component):
|
||||
|
||||
@classmethod
|
||||
def expr(klass):
|
||||
spec = pp.Literal("ws")
|
||||
spec = pp.CaselessLiteral("ws")
|
||||
spec = spec.setParseAction(lambda x: klass(*x))
|
||||
return spec
|
||||
|
||||
@ -829,7 +829,6 @@ class InjectAt(_Action):
|
||||
|
||||
class _Message(object):
|
||||
__metaclass__ = abc.ABCMeta
|
||||
version = "HTTP/1.1"
|
||||
logattrs = []
|
||||
|
||||
def __init__(self, tokens):
|
||||
@ -917,16 +916,6 @@ class _Message(object):
|
||||
ret["spec"] = self.spec()
|
||||
return ret
|
||||
|
||||
def values(self, settings):
|
||||
vals = self.preamble(settings)
|
||||
vals.append("\r\n")
|
||||
for h in self.headers:
|
||||
vals.extend(h.values(settings))
|
||||
vals.append("\r\n")
|
||||
if self.body:
|
||||
vals.append(self.body.value.get_generator(settings))
|
||||
return vals
|
||||
|
||||
def freeze(self, settings):
|
||||
r = self.resolve(settings)
|
||||
return self.__class__([i.freeze(settings) for i in r.tokens])
|
||||
@ -938,7 +927,21 @@ class _Message(object):
|
||||
Sep = pp.Optional(pp.Literal(":")).suppress()
|
||||
|
||||
|
||||
class Response(_Message):
|
||||
class _HTTPMessage(_Message):
|
||||
version = "HTTP/1.1"
|
||||
|
||||
def values(self, settings):
|
||||
vals = self.preamble(settings)
|
||||
vals.append("\r\n")
|
||||
for h in self.headers:
|
||||
vals.extend(h.values(settings))
|
||||
vals.append("\r\n")
|
||||
if self.body:
|
||||
vals.append(self.body.value.get_generator(settings))
|
||||
return vals
|
||||
|
||||
|
||||
class Response(_HTTPMessage):
|
||||
comps = (
|
||||
Body,
|
||||
Header,
|
||||
@ -966,12 +969,8 @@ class Response(_Message):
|
||||
|
||||
def preamble(self, settings):
|
||||
l = [self.version, " "]
|
||||
if self.code:
|
||||
l.extend(self.code.values(settings))
|
||||
code = int(self.code.code)
|
||||
elif self.ws:
|
||||
l.extend(Code(101).values(settings))
|
||||
code = 101
|
||||
l.append(" ")
|
||||
if self.reason:
|
||||
l.extend(self.reason.values(settings))
|
||||
@ -1042,7 +1041,7 @@ class Response(_Message):
|
||||
return ":".join([i.spec() for i in self.tokens])
|
||||
|
||||
|
||||
class Request(_Message):
|
||||
class Request(_HTTPMessage):
|
||||
comps = (
|
||||
Body,
|
||||
Header,
|
||||
@ -1222,7 +1221,12 @@ def parse_requests(s):
|
||||
try:
|
||||
parts = pp.OneOrMore(
|
||||
pp.Group(
|
||||
Request.expr()
|
||||
pp.Or(
|
||||
[
|
||||
Request.expr(),
|
||||
WebsocketFrame.expr(),
|
||||
]
|
||||
)
|
||||
)
|
||||
).parseString(s, parseAll=True)
|
||||
return [Request(i) for i in parts]
|
||||
|
@ -14,6 +14,13 @@ def parse_request(s):
|
||||
return language.parse_requests(s)[0]
|
||||
|
||||
|
||||
class TestWS:
|
||||
def test_expr(self):
|
||||
v = language.WS("foo")
|
||||
assert v.expr()
|
||||
assert v.values(language.Settings())
|
||||
|
||||
|
||||
class TestValueNakedLiteral:
|
||||
def test_expr(self):
|
||||
v = language.ValueNakedLiteral("foo")
|
||||
@ -572,7 +579,6 @@ class TestRequest:
|
||||
language.Settings(request_host = "foo.com")
|
||||
)
|
||||
|
||||
|
||||
def test_multiline(self):
|
||||
l = """
|
||||
GET
|
||||
@ -632,7 +638,15 @@ class TestRequest:
|
||||
|
||||
|
||||
|
||||
class TestWebsocketFrame:
|
||||
|
||||
def test_spec(self):
|
||||
e = language.WebsocketFrame.expr()
|
||||
assert e.parseString("wf:foo")
|
||||
|
||||
|
||||
class TestWriteValues:
|
||||
|
||||
def test_send_chunk(self):
|
||||
v = "foobarfoobar"
|
||||
for bs in range(1, len(v) + 2):
|
||||
@ -675,14 +689,18 @@ class TestWriteValues:
|
||||
for bs in range(1, len(tst) + 2):
|
||||
for off in range(len(tst)):
|
||||
s = cStringIO.StringIO()
|
||||
language.write_values(s, [tst], [(off, "disconnect")], blocksize=bs)
|
||||
language.write_values(
|
||||
s, [tst], [(off, "disconnect")], blocksize=bs
|
||||
)
|
||||
assert s.getvalue() == tst[:off]
|
||||
|
||||
def test_write_values_pauses(self):
|
||||
tst = "".join(str(i) for i in range(10))
|
||||
for i in range(2, 10):
|
||||
s = cStringIO.StringIO()
|
||||
language.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i)
|
||||
language.write_values(
|
||||
s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i
|
||||
)
|
||||
assert s.getvalue() == tst
|
||||
|
||||
for i in range(2, 10):
|
||||
|
@ -188,6 +188,9 @@ class CommonTests(tutils.DaemonTests):
|
||||
r = self.pathoc("ws:/p/")
|
||||
assert r.status_code == 101
|
||||
|
||||
r = self.pathoc("ws:/p/ws")
|
||||
assert r.status_code == 101
|
||||
|
||||
|
||||
class TestDaemon(CommonTests):
|
||||
ssl = False
|
||||
|
Loading…
Reference in New Issue
Block a user