mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-30 03:14:22 +00:00
websockets: handshake checks only take headers
This commit is contained in:
parent
4fb49c8e55
commit
42a87a1d8b
@ -33,7 +33,7 @@ def _is_valid_host(host):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_line(fp):
|
def get_request_line(fp):
|
||||||
"""
|
"""
|
||||||
Get a line, possibly preceded by a blank.
|
Get a line, possibly preceded by a blank.
|
||||||
"""
|
"""
|
||||||
@ -41,8 +41,6 @@ def get_line(fp):
|
|||||||
if line == "\r\n" or line == "\n":
|
if line == "\r\n" or line == "\n":
|
||||||
# Possible leftover from previous message
|
# Possible leftover from previous message
|
||||||
line = fp.readline()
|
line = fp.readline()
|
||||||
if line == "":
|
|
||||||
raise tcp.NetLibDisconnect()
|
|
||||||
return line
|
return line
|
||||||
|
|
||||||
|
|
||||||
@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
|
|||||||
httpversion, host, port, scheme, method, path, headers, content = (
|
httpversion, host, port, scheme, method, path, headers, content = (
|
||||||
None, None, None, None, None, None, None, None)
|
None, None, None, None, None, None, None, None)
|
||||||
|
|
||||||
request_line = get_line(rfile)
|
request_line = get_request_line(rfile)
|
||||||
|
if not request_line:
|
||||||
|
raise tcp.NetLibDisconnect()
|
||||||
|
|
||||||
request_line_parts = parse_init(request_line)
|
request_line_parts = parse_init(request_line)
|
||||||
if not request_line_parts:
|
if not request_line_parts:
|
||||||
|
@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring):
|
|||||||
return (length_code, actual_length)
|
return (length_code, actual_length)
|
||||||
|
|
||||||
|
|
||||||
def check_client_handshake(req):
|
def check_client_handshake(headers):
|
||||||
if req.headers.get_first("upgrade", None) != "websocket":
|
if headers.get_first("upgrade", None) != "websocket":
|
||||||
return
|
return
|
||||||
return req.headers.get_first('sec-websocket-key')
|
return headers.get_first('sec-websocket-key')
|
||||||
|
|
||||||
|
|
||||||
def check_server_handshake(resp):
|
def check_server_handshake(headers):
|
||||||
if resp.headers.get_first("upgrade", None) != "websocket":
|
if headers.get_first("upgrade", None) != "websocket":
|
||||||
return
|
return
|
||||||
return resp.headers.get_first('sec-websocket-accept')
|
return headers.get_first('sec-websocket-accept')
|
||||||
|
|
||||||
|
|
||||||
def create_server_nonce(client_nonce):
|
def create_server_nonce(client_nonce):
|
||||||
|
@ -412,10 +412,10 @@ def test_parse_http_basic_auth():
|
|||||||
assert not http.parse_http_basic_auth(v)
|
assert not http.parse_http_basic_auth(v)
|
||||||
|
|
||||||
|
|
||||||
def test_get_line():
|
def test_get_request_line():
|
||||||
r = cStringIO.StringIO("\nfoo")
|
r = cStringIO.StringIO("\nfoo")
|
||||||
assert http.get_line(r) == "foo"
|
assert http.get_request_line(r) == "foo"
|
||||||
tutils.raises(tcp.NetLibDisconnect, http.get_line, r)
|
assert not http.get_request_line(r)
|
||||||
|
|
||||||
|
|
||||||
class TestReadRequest():
|
class TestReadRequest():
|
||||||
|
@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
|||||||
|
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
req = http.read_request(self.rfile)
|
req = http.read_request(self.rfile)
|
||||||
key = websockets.check_client_handshake(req)
|
key = websockets.check_client_handshake(req.headers)
|
||||||
|
|
||||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||||
headers = websockets.server_handshake_headers(key)
|
headers = websockets.server_handshake_headers(key)
|
||||||
@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient):
|
|||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
resp = http.read_response(self.rfile, "get", None)
|
resp = http.read_response(self.rfile, "get", None)
|
||||||
server_nonce = websockets.check_server_handshake(resp)
|
server_nonce = websockets.check_server_handshake(resp.headers)
|
||||||
|
|
||||||
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
||||||
self.close()
|
self.close()
|
||||||
@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase):
|
|||||||
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
||||||
|
|
||||||
def test_check_server_handshake(self):
|
def test_check_server_handshake(self):
|
||||||
resp = http.Response(
|
headers = websockets.server_handshake_headers("key")
|
||||||
(1, 1),
|
assert websockets.check_server_handshake(headers)
|
||||||
101,
|
headers["Upgrade"] = ["not_websocket"]
|
||||||
"Switching Protocols",
|
assert not websockets.check_server_handshake(headers)
|
||||||
websockets.server_handshake_headers("key"),
|
|
||||||
""
|
|
||||||
)
|
|
||||||
assert websockets.check_server_handshake(resp)
|
|
||||||
resp.headers["Upgrade"] = ["not_websocket"]
|
|
||||||
assert not websockets.check_server_handshake(resp)
|
|
||||||
|
|
||||||
def test_check_client_handshake(self):
|
def test_check_client_handshake(self):
|
||||||
resp = http.Request(
|
headers = websockets.client_handshake_headers("key")
|
||||||
"relative",
|
assert websockets.check_client_handshake(headers) == "key"
|
||||||
"get",
|
headers["Upgrade"] = ["not_websocket"]
|
||||||
"http",
|
assert not websockets.check_client_handshake(headers)
|
||||||
"host",
|
|
||||||
22,
|
|
||||||
"/",
|
|
||||||
(1, 1),
|
|
||||||
websockets.client_handshake_headers("key"),
|
|
||||||
""
|
|
||||||
)
|
|
||||||
assert websockets.check_client_handshake(resp) == "key"
|
|
||||||
resp.headers["Upgrade"] = ["not_websocket"]
|
|
||||||
assert not websockets.check_client_handshake(resp)
|
|
||||||
|
|
||||||
|
|
||||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
client_hs = http.read_request(self.rfile)
|
client_hs = http.read_request(self.rfile)
|
||||||
websockets.check_client_handshake(client_hs)
|
websockets.check_client_handshake(client_hs.headers)
|
||||||
|
|
||||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||||
headers = websockets.server_handshake_headers("malformed key")
|
headers = websockets.server_handshake_headers("malformed key")
|
||||||
|
Loading…
Reference in New Issue
Block a user