mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
websockets: refactor to use http and header functions in http.py
This commit is contained in:
parent
e5f1264838
commit
3e0a71ea34
126
netlib/http.py
126
netlib/http.py
@ -4,7 +4,7 @@ import string
|
|||||||
import urlparse
|
import urlparse
|
||||||
import binascii
|
import binascii
|
||||||
import sys
|
import sys
|
||||||
from . import odict, utils, tcp
|
from . import odict, utils, tcp, http_status
|
||||||
|
|
||||||
|
|
||||||
class HttpError(Exception):
|
class HttpError(Exception):
|
||||||
@ -314,62 +314,6 @@ def parse_response_line(line):
|
|||||||
return (proto, code, msg)
|
return (proto, code, msg)
|
||||||
|
|
||||||
|
|
||||||
Response = collections.namedtuple(
|
|
||||||
"Response",
|
|
||||||
[
|
|
||||||
"httpversion",
|
|
||||||
"code",
|
|
||||||
"msg",
|
|
||||||
"headers",
|
|
||||||
"content"
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def read_response(rfile, request_method, body_size_limit, include_body=True):
|
|
||||||
"""
|
|
||||||
Return an (httpversion, code, msg, headers, content) tuple.
|
|
||||||
|
|
||||||
By default, both response header and body are read.
|
|
||||||
If include_body=False is specified, content may be one of the
|
|
||||||
following:
|
|
||||||
- None, if the response is technically allowed to have a response body
|
|
||||||
- "", if the response must not have a response body (e.g. it's a
|
|
||||||
response to a HEAD request)
|
|
||||||
"""
|
|
||||||
line = rfile.readline()
|
|
||||||
# Possible leftover from previous message
|
|
||||||
if line == "\r\n" or line == "\n":
|
|
||||||
line = rfile.readline()
|
|
||||||
if not line:
|
|
||||||
raise HttpErrorConnClosed(502, "Server disconnect.")
|
|
||||||
parts = parse_response_line(line)
|
|
||||||
if not parts:
|
|
||||||
raise HttpError(502, "Invalid server response: %s" % repr(line))
|
|
||||||
proto, code, msg = parts
|
|
||||||
httpversion = parse_http_protocol(proto)
|
|
||||||
if httpversion is None:
|
|
||||||
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
|
|
||||||
headers = read_headers(rfile)
|
|
||||||
if headers is None:
|
|
||||||
raise HttpError(502, "Invalid headers.")
|
|
||||||
|
|
||||||
if include_body:
|
|
||||||
content = read_http_body(
|
|
||||||
rfile,
|
|
||||||
headers,
|
|
||||||
body_size_limit,
|
|
||||||
request_method,
|
|
||||||
code,
|
|
||||||
False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if include_body==False then a None content means the body should be
|
|
||||||
# read separately
|
|
||||||
content = None
|
|
||||||
return Response(httpversion, code, msg, headers, content)
|
|
||||||
|
|
||||||
|
|
||||||
def read_http_body(*args, **kwargs):
|
def read_http_body(*args, **kwargs):
|
||||||
return "".join(
|
return "".join(
|
||||||
content for _, content, _ in read_http_body_chunked(*args, **kwargs)
|
content for _, content, _ in read_http_body_chunked(*args, **kwargs)
|
||||||
@ -579,3 +523,71 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
|
|||||||
headers,
|
headers,
|
||||||
content
|
content
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Response = collections.namedtuple(
|
||||||
|
"Response",
|
||||||
|
[
|
||||||
|
"httpversion",
|
||||||
|
"code",
|
||||||
|
"msg",
|
||||||
|
"headers",
|
||||||
|
"content"
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_response(rfile, request_method, body_size_limit, include_body=True):
|
||||||
|
"""
|
||||||
|
Return an (httpversion, code, msg, headers, content) tuple.
|
||||||
|
|
||||||
|
By default, both response header and body are read.
|
||||||
|
If include_body=False is specified, content may be one of the
|
||||||
|
following:
|
||||||
|
- None, if the response is technically allowed to have a response body
|
||||||
|
- "", if the response must not have a response body (e.g. it's a
|
||||||
|
response to a HEAD request)
|
||||||
|
"""
|
||||||
|
line = rfile.readline()
|
||||||
|
# Possible leftover from previous message
|
||||||
|
if line == "\r\n" or line == "\n":
|
||||||
|
line = rfile.readline()
|
||||||
|
if not line:
|
||||||
|
raise HttpErrorConnClosed(502, "Server disconnect.")
|
||||||
|
parts = parse_response_line(line)
|
||||||
|
if not parts:
|
||||||
|
raise HttpError(502, "Invalid server response: %s" % repr(line))
|
||||||
|
proto, code, msg = parts
|
||||||
|
httpversion = parse_http_protocol(proto)
|
||||||
|
if httpversion is None:
|
||||||
|
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
|
||||||
|
headers = read_headers(rfile)
|
||||||
|
if headers is None:
|
||||||
|
raise HttpError(502, "Invalid headers.")
|
||||||
|
|
||||||
|
if include_body:
|
||||||
|
content = read_http_body(
|
||||||
|
rfile,
|
||||||
|
headers,
|
||||||
|
body_size_limit,
|
||||||
|
request_method,
|
||||||
|
code,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if include_body==False then a None content means the body should be
|
||||||
|
# read separately
|
||||||
|
content = None
|
||||||
|
return Response(httpversion, code, msg, headers, content)
|
||||||
|
|
||||||
|
|
||||||
|
def request_preamble(method, resource, http_major="1", http_minor="1"):
|
||||||
|
return '%s %s HTTP/%s.%s' % (
|
||||||
|
method, resource, http_major, http_minor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def response_preamble(code, message=None, http_major="1", http_minor="1"):
|
||||||
|
if message is None:
|
||||||
|
message = http_status.RESPONSES.get(code)
|
||||||
|
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
|
||||||
|
@ -2,13 +2,11 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import mimetools
|
|
||||||
import StringIO
|
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from . import utils
|
from . import utils, odict
|
||||||
|
|
||||||
# Colleciton of utility functions that implement small portions of the RFC6455
|
# Colleciton of utility functions that implement small portions of the RFC6455
|
||||||
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
# WebSockets Protocol Useful for building WebSocket clients and servers.
|
||||||
@ -23,6 +21,7 @@ from . import utils
|
|||||||
# The magic sha that websocket servers must know to prove they understand
|
# The magic sha that websocket servers must know to prove they understand
|
||||||
# RFC6455
|
# RFC6455
|
||||||
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
|
||||||
|
VERSION = "13"
|
||||||
|
|
||||||
|
|
||||||
class CONST(object):
|
class CONST(object):
|
||||||
@ -151,9 +150,9 @@ class Frame(object):
|
|||||||
("opcode - " + str(self.opcode)),
|
("opcode - " + str(self.opcode)),
|
||||||
("mask_bit - " + str(self.mask_bit)),
|
("mask_bit - " + str(self.mask_bit)),
|
||||||
("payload_length_code - " + str(self.payload_length_code)),
|
("payload_length_code - " + str(self.payload_length_code)),
|
||||||
("masking_key - " + str(self.masking_key)),
|
("masking_key - " + repr(str(self.masking_key))),
|
||||||
("payload - " + str(self.payload)),
|
("payload - " + repr(str(self.payload))),
|
||||||
("decoded_payload - " + str(self.decoded_payload)),
|
("decoded_payload - " + repr(str(self.decoded_payload))),
|
||||||
("actual_payload_length - " + str(self.actual_payload_length))
|
("actual_payload_length - " + str(self.actual_payload_length))
|
||||||
])
|
])
|
||||||
|
|
||||||
@ -198,24 +197,24 @@ class Frame(object):
|
|||||||
|
|
||||||
second_byte = (self.mask_bit << 7) | self.payload_length_code
|
second_byte = (self.mask_bit << 7) | self.payload_length_code
|
||||||
|
|
||||||
bytes = chr(first_byte) + chr(second_byte)
|
b = chr(first_byte) + chr(second_byte)
|
||||||
|
|
||||||
if self.actual_payload_length < 126:
|
if self.actual_payload_length < 126:
|
||||||
pass
|
pass
|
||||||
elif self.actual_payload_length < CONST.MAX_16_BIT_INT:
|
elif self.actual_payload_length < CONST.MAX_16_BIT_INT:
|
||||||
# '!H' pack as 16 bit unsigned short
|
# '!H' pack as 16 bit unsigned short
|
||||||
# add 2 byte extended payload length
|
# add 2 byte extended payload length
|
||||||
bytes += struct.pack('!H', self.actual_payload_length)
|
b += struct.pack('!H', self.actual_payload_length)
|
||||||
elif self.actual_payload_length < CONST.MAX_64_BIT_INT:
|
elif self.actual_payload_length < CONST.MAX_64_BIT_INT:
|
||||||
# '!Q' = pack as 64 bit unsigned long long
|
# '!Q' = pack as 64 bit unsigned long long
|
||||||
# add 8 bytes extended payload length
|
# add 8 bytes extended payload length
|
||||||
bytes += struct.pack('!Q', self.actual_payload_length)
|
b += struct.pack('!Q', self.actual_payload_length)
|
||||||
|
|
||||||
if self.masking_key is not None:
|
if self.masking_key is not None:
|
||||||
bytes += self.masking_key
|
b += self.masking_key
|
||||||
|
|
||||||
bytes += self.payload # already will be encoded if neccessary
|
b += self.payload # already will be encoded if neccessary
|
||||||
return bytes
|
return b
|
||||||
|
|
||||||
def to_file(self, writer):
|
def to_file(self, writer):
|
||||||
writer.write(self.to_bytes())
|
writer.write(self.to_bytes())
|
||||||
@ -313,58 +312,35 @@ def random_masking_key():
|
|||||||
return os.urandom(4)
|
return os.urandom(4)
|
||||||
|
|
||||||
|
|
||||||
def create_client_handshake(host, port, key, version, resource):
|
def client_handshake_headers(key=None, version=VERSION):
|
||||||
"""
|
"""
|
||||||
WebSockets connections are intiated by the client with a valid HTTP
|
Create the headers for a valid HTTP upgrade request. If Key is not
|
||||||
upgrade request
|
specified, it is generated, and can be found in sec-websocket-key in
|
||||||
|
the returned header set.
|
||||||
|
|
||||||
|
Returns an instance of ODictCaseless
|
||||||
"""
|
"""
|
||||||
headers = [
|
if not key:
|
||||||
('Host', '%s:%s' % (host, port)),
|
key = base64.b64encode(os.urandom(16)).decode('utf-8')
|
||||||
|
return odict.ODictCaseless([
|
||||||
('Connection', 'Upgrade'),
|
('Connection', 'Upgrade'),
|
||||||
('Upgrade', 'websocket'),
|
('Upgrade', 'websocket'),
|
||||||
('Sec-WebSocket-Key', key),
|
('Sec-WebSocket-Key', key),
|
||||||
('Sec-WebSocket-Version', version)
|
('Sec-WebSocket-Version', version)
|
||||||
]
|
])
|
||||||
request = "GET %s HTTP/1.1" % resource
|
|
||||||
return build_handshake(headers, request)
|
|
||||||
|
|
||||||
|
|
||||||
def create_server_handshake(key):
|
def server_handshake_headers(key):
|
||||||
"""
|
"""
|
||||||
The server response is a valid HTTP 101 response.
|
The server response is a valid HTTP 101 response.
|
||||||
"""
|
"""
|
||||||
headers = [
|
return odict.ODictCaseless(
|
||||||
|
[
|
||||||
('Connection', 'Upgrade'),
|
('Connection', 'Upgrade'),
|
||||||
('Upgrade', 'websocket'),
|
('Upgrade', 'websocket'),
|
||||||
('Sec-WebSocket-Accept', create_server_nonce(key))
|
('Sec-WebSocket-Accept', create_server_nonce(key))
|
||||||
]
|
]
|
||||||
request = "HTTP/1.1 101 Switching Protocols"
|
)
|
||||||
return build_handshake(headers, request)
|
|
||||||
|
|
||||||
|
|
||||||
def build_handshake(headers, request):
|
|
||||||
handshake = [request.encode('utf-8')]
|
|
||||||
for header, value in headers:
|
|
||||||
handshake.append(("%s: %s" % (header, value)).encode('utf-8'))
|
|
||||||
handshake.append(b'\r\n')
|
|
||||||
return b'\r\n'.join(handshake)
|
|
||||||
|
|
||||||
|
|
||||||
def read_handshake(reader, num_bytes_per_read):
|
|
||||||
"""
|
|
||||||
From provided function that reads bytes, read in a
|
|
||||||
complete HTTP request, which terminates with a CLRF
|
|
||||||
"""
|
|
||||||
response = b''
|
|
||||||
doubleCLRF = b'\r\n\r\n'
|
|
||||||
while True:
|
|
||||||
bytes = reader.read(num_bytes_per_read)
|
|
||||||
if not bytes:
|
|
||||||
break
|
|
||||||
response += bytes
|
|
||||||
if doubleCLRF in response:
|
|
||||||
break
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_payload_length_pair(payload_bytestring):
|
def get_payload_length_pair(payload_bytestring):
|
||||||
@ -384,33 +360,19 @@ def get_payload_length_pair(payload_bytestring):
|
|||||||
return (length_code, actual_length)
|
return (length_code, actual_length)
|
||||||
|
|
||||||
|
|
||||||
def process_handshake_from_client(handshake):
|
def check_client_handshake(req):
|
||||||
headers = headers_from_http_message(handshake)
|
if req.headers.get_first("upgrade", None) != "websocket":
|
||||||
if headers.get("Upgrade", None) != "websocket":
|
|
||||||
return
|
return
|
||||||
key = headers['Sec-WebSocket-Key']
|
return req.headers.get_first('sec-websocket-key')
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def process_handshake_from_server(handshake):
|
def check_server_handshake(resp):
|
||||||
headers = headers_from_http_message(handshake)
|
if resp.headers.get_first("upgrade", None) != "websocket":
|
||||||
if headers.get("Upgrade", None) != "websocket":
|
|
||||||
return
|
return
|
||||||
key = headers['Sec-WebSocket-Accept']
|
return resp.headers.get_first('sec-websocket-accept')
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def headers_from_http_message(http_message):
|
|
||||||
return mimetools.Message(
|
|
||||||
StringIO.StringIO(http_message.split('\r\n', 1)[1])
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_server_nonce(client_nonce):
|
def create_server_nonce(client_nonce):
|
||||||
return base64.b64encode(
|
return base64.b64encode(
|
||||||
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_client_nonce():
|
|
||||||
return base64.b64encode(os.urandom(16)).decode('utf-8')
|
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
from netlib import tcp
|
from netlib import tcp, test, websockets, http, odict
|
||||||
from netlib import test
|
|
||||||
from netlib import websockets
|
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
from nose.tools import raises
|
from nose.tools import raises
|
||||||
@ -21,18 +19,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
|||||||
self.read_next_message()
|
self.read_next_message()
|
||||||
|
|
||||||
def read_next_message(self):
|
def read_next_message(self):
|
||||||
decoded = websockets.Frame.from_file(self.rfile).decoded_payload
|
frame = websockets.Frame.from_file(self.rfile)
|
||||||
self.on_message(decoded)
|
self.on_message(frame.decoded_payload)
|
||||||
|
|
||||||
def send_message(self, message):
|
def send_message(self, message):
|
||||||
frame = websockets.Frame.default(message, from_client = False)
|
frame = websockets.Frame.default(message, from_client = False)
|
||||||
frame.to_file(self.wfile)
|
frame.to_file(self.wfile)
|
||||||
|
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
client_hs = websockets.read_handshake(self.rfile, 1)
|
req = http.read_request(self.rfile)
|
||||||
key = websockets.process_handshake_from_client(client_hs)
|
key = websockets.check_client_handshake(req)
|
||||||
response = websockets.create_server_handshake(key)
|
|
||||||
self.wfile.write(response)
|
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||||
|
headers = websockets.server_handshake_headers(key)
|
||||||
|
self.wfile.write(headers.format() + "\r\n")
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
self.handshake_done = True
|
self.handshake_done = True
|
||||||
|
|
||||||
@ -44,28 +44,20 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
|||||||
class WebSocketsClient(tcp.TCPClient):
|
class WebSocketsClient(tcp.TCPClient):
|
||||||
def __init__(self, address, source_address=None):
|
def __init__(self, address, source_address=None):
|
||||||
super(WebSocketsClient, self).__init__(address, source_address)
|
super(WebSocketsClient, self).__init__(address, source_address)
|
||||||
self.version = "13"
|
self.client_nonce = None
|
||||||
self.client_nonce = websockets.create_client_nonce()
|
|
||||||
self.resource = "/"
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
super(WebSocketsClient, self).connect()
|
super(WebSocketsClient, self).connect()
|
||||||
|
|
||||||
handshake = websockets.create_client_handshake(
|
preamble = http.request_preamble("GET", "/")
|
||||||
self.address.host,
|
self.wfile.write(preamble + "\r\n")
|
||||||
self.address.port,
|
headers = websockets.client_handshake_headers()
|
||||||
self.client_nonce,
|
self.client_nonce = headers.get_first("sec-websocket-key")
|
||||||
self.version,
|
self.wfile.write(headers.format() + "\r\n")
|
||||||
self.resource
|
|
||||||
)
|
|
||||||
|
|
||||||
self.wfile.write(handshake)
|
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
server_handshake = websockets.read_handshake(self.rfile, 1)
|
resp = http.read_response(self.rfile, "get", None)
|
||||||
server_nonce = websockets.process_handshake_from_server(
|
server_nonce = websockets.check_server_handshake(resp)
|
||||||
server_handshake
|
|
||||||
)
|
|
||||||
|
|
||||||
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
|
||||||
self.close()
|
self.close()
|
||||||
@ -140,51 +132,43 @@ class TestWebSockets(test.ServerTestBase):
|
|||||||
frame.actual_payload_length = 1 # corrupt the frame
|
frame.actual_payload_length = 1 # corrupt the frame
|
||||||
frame.safe_to_bytes()
|
frame.safe_to_bytes()
|
||||||
|
|
||||||
def test_handshake(self):
|
def test_check_server_handshake(self):
|
||||||
bad_upgrade = "not_websockets"
|
resp = http.Response(
|
||||||
bad_header_handshake = websockets.build_handshake([
|
(1, 1),
|
||||||
('Host', '%s:%s' % ("a", "b")),
|
101,
|
||||||
('Connection', "c"),
|
"Switching Protocols",
|
||||||
('Upgrade', bad_upgrade),
|
websockets.server_handshake_headers("key"),
|
||||||
('Sec-WebSocket-Key', "d"),
|
""
|
||||||
('Sec-WebSocket-Version', "e")
|
|
||||||
], "f")
|
|
||||||
|
|
||||||
# check behavior when required header values are missing
|
|
||||||
assert None is websockets.process_handshake_from_server(
|
|
||||||
bad_header_handshake
|
|
||||||
)
|
)
|
||||||
assert None is websockets.process_handshake_from_client(
|
assert websockets.check_server_handshake(resp)
|
||||||
bad_header_handshake
|
resp.headers["Upgrade"] = ["not_websocket"]
|
||||||
|
assert not websockets.check_server_handshake(resp)
|
||||||
|
|
||||||
|
def test_check_client_handshake(self):
|
||||||
|
resp = http.Request(
|
||||||
|
"relative",
|
||||||
|
"get",
|
||||||
|
"http",
|
||||||
|
"host",
|
||||||
|
22,
|
||||||
|
"/",
|
||||||
|
(1, 1),
|
||||||
|
websockets.client_handshake_headers("key"),
|
||||||
|
""
|
||||||
)
|
)
|
||||||
|
assert websockets.check_client_handshake(resp) == "key"
|
||||||
key = "test_key"
|
resp.headers["Upgrade"] = ["not_websocket"]
|
||||||
|
assert not websockets.check_client_handshake(resp)
|
||||||
client_handshake = websockets.create_client_handshake(
|
|
||||||
"a", "b", key, "d", "e"
|
|
||||||
)
|
|
||||||
assert key == websockets.process_handshake_from_client(
|
|
||||||
client_handshake
|
|
||||||
)
|
|
||||||
|
|
||||||
server_handshake = websockets.create_server_handshake(key)
|
|
||||||
assert websockets.create_server_nonce(key) == websockets.process_handshake_from_server(server_handshake)
|
|
||||||
|
|
||||||
handshake = websockets.create_client_handshake("a", "b", "c", "d", "e")
|
|
||||||
stream = io.BytesIO(handshake)
|
|
||||||
assert handshake == websockets.read_handshake(stream, 1)
|
|
||||||
|
|
||||||
# ensure readhandshake doesn't loop forever on empty stream
|
|
||||||
empty_stream = io.BytesIO("")
|
|
||||||
assert "" == websockets.read_handshake(empty_stream, 1)
|
|
||||||
|
|
||||||
|
|
||||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||||
def handshake(self):
|
def handshake(self):
|
||||||
client_hs = websockets.read_handshake(self.rfile, 1)
|
client_hs = http.read_request(self.rfile)
|
||||||
websockets.process_handshake_from_client(client_hs)
|
websockets.check_client_handshake(client_hs)
|
||||||
response = websockets.create_server_handshake("malformed_key")
|
|
||||||
self.wfile.write(response)
|
self.wfile.write(http.response_preamble(101) + "\r\n")
|
||||||
|
headers = websockets.server_handshake_headers("malformed key")
|
||||||
|
self.wfile.write(headers.format() + "\r\n")
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
self.handshake_done = True
|
self.handshake_done = True
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user