mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
websockets: handshake checks only take headers
This commit is contained in:
parent
4fb49c8e55
commit
42a87a1d8b
@ -33,7 +33,7 @@ def _is_valid_host(host):
|
||||
return True
|
||||
|
||||
|
||||
def get_line(fp):
|
||||
def get_request_line(fp):
|
||||
"""
|
||||
Get a line, possibly preceded by a blank.
|
||||
"""
|
||||
@ -41,8 +41,6 @@ def get_line(fp):
|
||||
if line == "\r\n" or line == "\n":
|
||||
# Possible leftover from previous message
|
||||
line = fp.readline()
|
||||
if line == "":
|
||||
raise tcp.NetLibDisconnect()
|
||||
return line
|
||||
|
||||
|
||||
@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
|
||||
httpversion, host, port, scheme, method, path, headers, content = (
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
request_line = get_line(rfile)
|
||||
request_line = get_request_line(rfile)
|
||||
if not request_line:
|
||||
raise tcp.NetLibDisconnect()
|
||||
|
||||
request_line_parts = parse_init(request_line)
|
||||
if not request_line_parts:
|
||||
|
@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring):
|
||||
return (length_code, actual_length)
|
||||
|
||||
|
||||
def check_client_handshake(req):
|
||||
if req.headers.get_first("upgrade", None) != "websocket":
|
||||
def check_client_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return req.headers.get_first('sec-websocket-key')
|
||||
return headers.get_first('sec-websocket-key')
|
||||
|
||||
|
||||
def check_server_handshake(resp):
|
||||
if resp.headers.get_first("upgrade", None) != "websocket":
|
||||
def check_server_handshake(headers):
|
||||
if headers.get_first("upgrade", None) != "websocket":
|
||||
return
|
||||
return resp.headers.get_first('sec-websocket-accept')
|
||||
return headers.get_first('sec-websocket-accept')
|
||||
|
||||
|
||||
def create_server_nonce(client_nonce):
|
||||
|
@ -412,10 +412,10 @@ def test_parse_http_basic_auth():
|
||||
assert not http.parse_http_basic_auth(v)
|
||||
|
||||
|
||||
def test_get_line():
|
||||
def test_get_request_line():
|
||||
r = cStringIO.StringIO("\nfoo")
|
||||
assert http.get_line(r) == "foo"
|
||||
tutils.raises(tcp.NetLibDisconnect, http.get_line, r)
|
||||
assert http.get_request_line(r) == "foo"
|
||||
assert not http.get_request_line(r)
|
||||
|
||||
|
||||
class TestReadRequest():
|
||||
|
@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
|
||||
def handshake(self):
|
||||
req = http.read_request(self.rfile)
|
||||
key = websockets.check_client_handshake(req)
|
||||
key = websockets.check_client_handshake(req.headers)
|
||||
|
||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||
headers = websockets.server_handshake_headers(key)
|
||||
@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
self.wfile.flush()
|
||||
|
||||
resp = http.read_response(self.rfile, "get", None)
|
||||
server_nonce = websockets.check_server_handshake(resp)
|
||||
server_nonce = websockets.check_server_handshake(resp.headers)
|
||||
|
||||
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
||||
self.close()
|
||||
@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase):
|
||||
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
|
||||
|
||||
def test_check_server_handshake(self):
|
||||
resp = http.Response(
|
||||
(1, 1),
|
||||
101,
|
||||
"Switching Protocols",
|
||||
websockets.server_handshake_headers("key"),
|
||||
""
|
||||
)
|
||||
assert websockets.check_server_handshake(resp)
|
||||
resp.headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_server_handshake(resp)
|
||||
headers = websockets.server_handshake_headers("key")
|
||||
assert websockets.check_server_handshake(headers)
|
||||
headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_server_handshake(headers)
|
||||
|
||||
def test_check_client_handshake(self):
|
||||
resp = http.Request(
|
||||
"relative",
|
||||
"get",
|
||||
"http",
|
||||
"host",
|
||||
22,
|
||||
"/",
|
||||
(1, 1),
|
||||
websockets.client_handshake_headers("key"),
|
||||
""
|
||||
)
|
||||
assert websockets.check_client_handshake(resp) == "key"
|
||||
resp.headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_client_handshake(resp)
|
||||
headers = websockets.client_handshake_headers("key")
|
||||
assert websockets.check_client_handshake(headers) == "key"
|
||||
headers["Upgrade"] = ["not_websocket"]
|
||||
assert not websockets.check_client_handshake(headers)
|
||||
|
||||
|
||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
def handshake(self):
|
||||
client_hs = http.read_request(self.rfile)
|
||||
websockets.check_client_handshake(client_hs)
|
||||
websockets.check_client_handshake(client_hs.headers)
|
||||
|
||||
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||
headers = websockets.server_handshake_headers("malformed key")
|
||||
|
Loading…
Reference in New Issue
Block a user