mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
refactor HTTP/1 as protocol
This commit is contained in:
parent
230c16122b
commit
808b294865
@ -9,475 +9,488 @@ from netlib import odict, utils, tcp, http
|
||||
from .. import status_codes
|
||||
from ..exceptions import *
|
||||
|
||||
class HTTP1Protocol(object):
|
||||
|
||||
def get_request_line(fp):
|
||||
"""
|
||||
Get a line, possibly preceded by a blank.
|
||||
"""
|
||||
line = fp.readline()
|
||||
if line == "\r\n" or line == "\n":
|
||||
# Possible leftover from previous message
|
||||
line = fp.readline()
|
||||
return line
|
||||
# TODO: make this a regular class - just like Response
|
||||
Request = collections.namedtuple(
|
||||
"Request",
|
||||
[
|
||||
"form_in",
|
||||
"method",
|
||||
"scheme",
|
||||
"host",
|
||||
"port",
|
||||
"path",
|
||||
"httpversion",
|
||||
"headers",
|
||||
"content"
|
||||
]
|
||||
)
|
||||
|
||||
def read_headers(fp):
|
||||
"""
|
||||
Read a set of headers from a file pointer. Stop once a blank line is
|
||||
reached. Return a ODictCaseless object, or None if headers are invalid.
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
while True:
|
||||
line = fp.readline()
|
||||
if not line or line == '\r\n' or line == '\n':
|
||||
break
|
||||
if line[0] in ' \t':
|
||||
if not ret:
|
||||
return None
|
||||
# continued header
|
||||
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
|
||||
else:
|
||||
i = line.find(':')
|
||||
# We're being liberal in what we accept, here.
|
||||
if i > 0:
|
||||
name = line[:i]
|
||||
value = line[i + 1:].strip()
|
||||
ret.append([name, value])
|
||||
def __init__(self, tcp_handler):
|
||||
self.tcp_handler = tcp_handler
|
||||
|
||||
def get_request_line(self):
|
||||
"""
|
||||
Get a line, possibly preceded by a blank.
|
||||
"""
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
if line == "\r\n" or line == "\n":
|
||||
# Possible leftover from previous message
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
return line
|
||||
|
||||
def read_headers(self):
|
||||
"""
|
||||
Read a set of headers.
|
||||
Stop once a blank line is reached.
|
||||
|
||||
Return a ODictCaseless object, or None if headers are invalid.
|
||||
"""
|
||||
ret = []
|
||||
name = ''
|
||||
while True:
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
if not line or line == '\r\n' or line == '\n':
|
||||
break
|
||||
if line[0] in ' \t':
|
||||
if not ret:
|
||||
return None
|
||||
# continued header
|
||||
ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
|
||||
else:
|
||||
return None
|
||||
return odict.ODictCaseless(ret)
|
||||
i = line.find(':')
|
||||
# We're being liberal in what we accept, here.
|
||||
if i > 0:
|
||||
name = line[:i]
|
||||
value = line[i + 1:].strip()
|
||||
ret.append([name, value])
|
||||
else:
|
||||
return None
|
||||
return odict.ODictCaseless(ret)
|
||||
|
||||
|
||||
def read_chunked(fp, limit, is_request):
|
||||
"""
|
||||
Read a chunked HTTP body.
|
||||
def read_chunked(self, limit, is_request):
|
||||
"""
|
||||
Read a chunked HTTP body.
|
||||
|
||||
May raise HttpError.
|
||||
"""
|
||||
# FIXME: Should check if chunked is the final encoding in the headers
|
||||
# http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3
|
||||
# 3.3 2.
|
||||
total = 0
|
||||
code = 400 if is_request else 502
|
||||
while True:
|
||||
line = fp.readline(128)
|
||||
if line == "":
|
||||
raise HttpErrorConnClosed(code, "Connection closed prematurely")
|
||||
if line != '\r\n' and line != '\n':
|
||||
try:
|
||||
length = int(line, 16)
|
||||
except ValueError:
|
||||
raise HttpError(
|
||||
code,
|
||||
"Invalid chunked encoding length: %s" % line
|
||||
)
|
||||
total += length
|
||||
if limit is not None and total > limit:
|
||||
msg = "HTTP Body too large. Limit is %s," \
|
||||
" chunked content longer than %s" % (limit, total)
|
||||
raise HttpError(code, msg)
|
||||
chunk = fp.read(length)
|
||||
suffix = fp.readline(5)
|
||||
if suffix != '\r\n':
|
||||
raise HttpError(code, "Malformed chunked body")
|
||||
yield line, chunk, '\r\n'
|
||||
if length == 0:
|
||||
return
|
||||
May raise HttpError.
|
||||
"""
|
||||
# FIXME: Should check if chunked is the final encoding in the headers
|
||||
# http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3
|
||||
# 3.3 2.
|
||||
total = 0
|
||||
code = 400 if is_request else 502
|
||||
while True:
|
||||
line = self.tcp_handler.rfile.readline(128)
|
||||
if line == "":
|
||||
raise HttpErrorConnClosed(code, "Connection closed prematurely")
|
||||
if line != '\r\n' and line != '\n':
|
||||
try:
|
||||
length = int(line, 16)
|
||||
except ValueError:
|
||||
raise HttpError(
|
||||
code,
|
||||
"Invalid chunked encoding length: %s" % line
|
||||
)
|
||||
total += length
|
||||
if limit is not None and total > limit:
|
||||
msg = "HTTP Body too large. Limit is %s," \
|
||||
" chunked content longer than %s" % (limit, total)
|
||||
raise HttpError(code, msg)
|
||||
chunk = self.tcp_handler.rfile.read(length)
|
||||
suffix = self.tcp_handler.rfile.readline(5)
|
||||
if suffix != '\r\n':
|
||||
raise HttpError(code, "Malformed chunked body")
|
||||
yield line, chunk, '\r\n'
|
||||
if length == 0:
|
||||
return
|
||||
|
||||
|
||||
def has_chunked_encoding(headers):
|
||||
return "chunked" in [
|
||||
i.lower() for i in http.get_header_tokens(headers, "transfer-encoding")
|
||||
]
|
||||
@classmethod
|
||||
def has_chunked_encoding(self, headers):
|
||||
return "chunked" in [
|
||||
i.lower() for i in http.get_header_tokens(headers, "transfer-encoding")
|
||||
]
|
||||
|
||||
|
||||
def parse_http_protocol(s):
|
||||
"""
|
||||
Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or
|
||||
None.
|
||||
"""
|
||||
if not s.startswith("HTTP/"):
|
||||
return None
|
||||
_, version = s.split('/', 1)
|
||||
if "." not in version:
|
||||
return None
|
||||
major, minor = version.split('.', 1)
|
||||
try:
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
except ValueError:
|
||||
return None
|
||||
return major, minor
|
||||
|
||||
|
||||
def parse_init(line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
httpversion = parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
if not utils.isascii(method):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
def parse_init_connect(line):
|
||||
"""
|
||||
Returns (host, port, httpversion) if line is a valid CONNECT line.
|
||||
http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
|
||||
"""
|
||||
v = parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
|
||||
if method.upper() != 'CONNECT':
|
||||
return None
|
||||
try:
|
||||
host, port = url.split(":")
|
||||
except ValueError:
|
||||
return None
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
return None
|
||||
if not http.is_valid_port(port):
|
||||
return None
|
||||
if not http.is_valid_host(host):
|
||||
return None
|
||||
return host, port, httpversion
|
||||
|
||||
|
||||
def parse_init_proxy(line):
|
||||
v = parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
|
||||
parts = http.parse_url(url)
|
||||
if not parts:
|
||||
return None
|
||||
scheme, host, port, path = parts
|
||||
return method, scheme, host, port, path, httpversion
|
||||
|
||||
|
||||
def parse_init_http(line):
|
||||
"""
|
||||
Returns (method, url, httpversion)
|
||||
"""
|
||||
v = parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
if not utils.isascii(url):
|
||||
return None
|
||||
if not (url.startswith("/") or url == "*"):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
def connection_close(httpversion, headers):
|
||||
"""
|
||||
Checks the message to see if the client connection should be closed
|
||||
according to RFC 2616 Section 8.1 Note that a connection should be
|
||||
closed as well if the response has been read until end of the stream.
|
||||
"""
|
||||
# At first, check if we have an explicit Connection header.
|
||||
if "connection" in headers:
|
||||
toks = http.get_header_tokens(headers, "connection")
|
||||
if "close" in toks:
|
||||
return True
|
||||
elif "keep-alive" in toks:
|
||||
return False
|
||||
# If we don't have a Connection header, HTTP 1.1 connections are assumed to
|
||||
# be persistent
|
||||
if httpversion == (1, 1):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def parse_response_line(line):
|
||||
parts = line.strip().split(" ", 2)
|
||||
if len(parts) == 2: # handle missing message gracefully
|
||||
parts.append("")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
proto, code, msg = parts
|
||||
try:
|
||||
code = int(code)
|
||||
except ValueError:
|
||||
return None
|
||||
return (proto, code, msg)
|
||||
|
||||
|
||||
def read_http_body(*args, **kwargs):
|
||||
return "".join(
|
||||
content for _, content, _ in read_http_body_chunked(*args, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
def read_http_body_chunked(
|
||||
rfile,
|
||||
headers,
|
||||
limit,
|
||||
request_method,
|
||||
response_code,
|
||||
is_request,
|
||||
max_chunk_size=None
|
||||
):
|
||||
"""
|
||||
Read an HTTP message body:
|
||||
|
||||
rfile: A file descriptor to read from
|
||||
headers: An ODictCaseless object
|
||||
limit: Size limit.
|
||||
is_request: True if the body to read belongs to a request, False
|
||||
otherwise
|
||||
"""
|
||||
if max_chunk_size is None:
|
||||
max_chunk_size = limit or sys.maxsize
|
||||
|
||||
expected_size = expected_http_body_size(
|
||||
headers, is_request, request_method, response_code
|
||||
)
|
||||
|
||||
if expected_size is None:
|
||||
if has_chunked_encoding(headers):
|
||||
# Python 3: yield from
|
||||
for x in read_chunked(rfile, limit, is_request):
|
||||
yield x
|
||||
else: # pragma: nocover
|
||||
raise HttpError(
|
||||
400 if is_request else 502,
|
||||
"Content-Length unknown but no chunked encoding"
|
||||
)
|
||||
elif expected_size >= 0:
|
||||
if limit is not None and expected_size > limit:
|
||||
raise HttpError(
|
||||
400 if is_request else 509,
|
||||
"HTTP Body too large. Limit is %s, content-length was %s" % (
|
||||
limit, expected_size
|
||||
)
|
||||
)
|
||||
bytes_left = expected_size
|
||||
while bytes_left:
|
||||
chunk_size = min(bytes_left, max_chunk_size)
|
||||
yield "", rfile.read(chunk_size), ""
|
||||
bytes_left -= chunk_size
|
||||
else:
|
||||
bytes_left = limit or -1
|
||||
while bytes_left:
|
||||
chunk_size = min(bytes_left, max_chunk_size)
|
||||
content = rfile.read(chunk_size)
|
||||
if not content:
|
||||
return
|
||||
yield "", content, ""
|
||||
bytes_left -= chunk_size
|
||||
not_done = rfile.read(1)
|
||||
if not_done:
|
||||
raise HttpError(
|
||||
400 if is_request else 509,
|
||||
"HTTP Body too large. Limit is %s," % limit
|
||||
)
|
||||
|
||||
|
||||
def expected_http_body_size(headers, is_request, request_method, response_code):
|
||||
"""
|
||||
Returns the expected body length:
|
||||
- a positive integer, if the size is known in advance
|
||||
- None, if the size in unknown in advance (chunked encoding or invalid
|
||||
data)
|
||||
- -1, if all data should be read until end of stream.
|
||||
|
||||
May raise HttpError.
|
||||
"""
|
||||
# Determine response size according to
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.3
|
||||
if request_method:
|
||||
request_method = request_method.upper()
|
||||
|
||||
if (not is_request and (
|
||||
request_method == "HEAD" or
|
||||
(request_method == "CONNECT" and response_code == 200) or
|
||||
response_code in [204, 304] or
|
||||
100 <= response_code <= 199)):
|
||||
return 0
|
||||
if has_chunked_encoding(headers):
|
||||
return None
|
||||
if "content-length" in headers:
|
||||
@classmethod
|
||||
def parse_http_protocol(self, line):
|
||||
"""
|
||||
Parse an HTTP protocol declaration.
|
||||
Returns a (major, minor) tuple, or None.
|
||||
"""
|
||||
if not line.startswith("HTTP/"):
|
||||
return None
|
||||
_, version = line.split('/', 1)
|
||||
if "." not in version:
|
||||
return None
|
||||
major, minor = version.split('.', 1)
|
||||
try:
|
||||
size = int(headers["content-length"][0])
|
||||
if size < 0:
|
||||
raise ValueError()
|
||||
return size
|
||||
major = int(major)
|
||||
minor = int(minor)
|
||||
except ValueError:
|
||||
return None
|
||||
if is_request:
|
||||
return 0
|
||||
return -1
|
||||
return major, minor
|
||||
|
||||
|
||||
# TODO: make this a regular class - just like Response
|
||||
Request = collections.namedtuple(
|
||||
"Request",
|
||||
[
|
||||
"form_in",
|
||||
"method",
|
||||
"scheme",
|
||||
"host",
|
||||
"port",
|
||||
"path",
|
||||
"httpversion",
|
||||
"headers",
|
||||
"content"
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def parse_init(self, line):
|
||||
try:
|
||||
method, url, protocol = string.split(line)
|
||||
except ValueError:
|
||||
return None
|
||||
httpversion = self.parse_http_protocol(protocol)
|
||||
if not httpversion:
|
||||
return None
|
||||
if not utils.isascii(method):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
|
||||
def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
|
||||
"""
|
||||
Parse an HTTP request from a file stream
|
||||
@classmethod
|
||||
def parse_init_connect(self, line):
|
||||
"""
|
||||
Returns (host, port, httpversion) if line is a valid CONNECT line.
|
||||
http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
|
||||
"""
|
||||
v = self.parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
|
||||
Args:
|
||||
rfile (file): Input file to read from
|
||||
include_body (bool): Read response body as well
|
||||
body_size_limit (bool): Maximum body size
|
||||
wfile (file): If specified, HTTP Expect headers are handled
|
||||
automatically, by writing a HTTP 100 CONTINUE response to the stream.
|
||||
if method.upper() != 'CONNECT':
|
||||
return None
|
||||
try:
|
||||
host, port = url.split(":")
|
||||
except ValueError:
|
||||
return None
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
return None
|
||||
if not http.is_valid_port(port):
|
||||
return None
|
||||
if not http.is_valid_host(host):
|
||||
return None
|
||||
return host, port, httpversion
|
||||
|
||||
Returns:
|
||||
Request: The HTTP request
|
||||
@classmethod
|
||||
def parse_init_proxy(self, line):
|
||||
v = self.parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
|
||||
Raises:
|
||||
HttpError: If the input is invalid.
|
||||
"""
|
||||
httpversion, host, port, scheme, method, path, headers, content = (
|
||||
None, None, None, None, None, None, None, None)
|
||||
parts = http.parse_url(url)
|
||||
if not parts:
|
||||
return None
|
||||
scheme, host, port, path = parts
|
||||
return method, scheme, host, port, path, httpversion
|
||||
|
||||
request_line = get_request_line(rfile)
|
||||
if not request_line:
|
||||
raise tcp.NetLibDisconnect()
|
||||
@classmethod
|
||||
def parse_init_http(self, line):
|
||||
"""
|
||||
Returns (method, url, httpversion)
|
||||
"""
|
||||
v = self.parse_init(line)
|
||||
if not v:
|
||||
return None
|
||||
method, url, httpversion = v
|
||||
if not utils.isascii(url):
|
||||
return None
|
||||
if not (url.startswith("/") or url == "*"):
|
||||
return None
|
||||
return method, url, httpversion
|
||||
|
||||
request_line_parts = parse_init(request_line)
|
||||
if not request_line_parts:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
method, path, httpversion = request_line_parts
|
||||
|
||||
if path == '*' or path.startswith("/"):
|
||||
form_in = "relative"
|
||||
if not utils.isascii(path):
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
elif method.upper() == 'CONNECT':
|
||||
form_in = "authority"
|
||||
r = parse_init_connect(request_line)
|
||||
if not r:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
host, port, _ = r
|
||||
path = None
|
||||
else:
|
||||
form_in = "absolute"
|
||||
r = parse_init_proxy(request_line)
|
||||
if not r:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
_, scheme, host, port, path, _ = r
|
||||
@classmethod
|
||||
def connection_close(self, httpversion, headers):
|
||||
"""
|
||||
Checks the message to see if the client connection should be closed
|
||||
according to RFC 2616 Section 8.1 Note that a connection should be
|
||||
closed as well if the response has been read until end of the stream.
|
||||
"""
|
||||
# At first, check if we have an explicit Connection header.
|
||||
if "connection" in headers:
|
||||
toks = http.get_header_tokens(headers, "connection")
|
||||
if "close" in toks:
|
||||
return True
|
||||
elif "keep-alive" in toks:
|
||||
return False
|
||||
|
||||
headers = read_headers(rfile)
|
||||
if headers is None:
|
||||
raise HttpError(400, "Invalid headers")
|
||||
# If we don't have a Connection header, HTTP 1.1 connections are assumed to
|
||||
# be persistent
|
||||
return httpversion != (1, 1)
|
||||
|
||||
expect_header = headers.get_first("expect", "").lower()
|
||||
if expect_header == "100-continue" and httpversion >= (1, 1):
|
||||
wfile.write(
|
||||
'HTTP/1.1 100 Continue\r\n'
|
||||
'\r\n'
|
||||
)
|
||||
wfile.flush()
|
||||
del headers['expect']
|
||||
|
||||
if include_body:
|
||||
content = read_http_body(
|
||||
rfile, headers, body_size_limit, method, None, True
|
||||
@classmethod
|
||||
def parse_response_line(self, line):
|
||||
parts = line.strip().split(" ", 2)
|
||||
if len(parts) == 2: # handle missing message gracefully
|
||||
parts.append("")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
proto, code, msg = parts
|
||||
try:
|
||||
code = int(code)
|
||||
except ValueError:
|
||||
return None
|
||||
return (proto, code, msg)
|
||||
|
||||
|
||||
def read_http_body(self, *args, **kwargs):
|
||||
return "".join(
|
||||
content for _, content, _ in self.read_http_body_chunked(*args, **kwargs)
|
||||
)
|
||||
|
||||
return Request(
|
||||
form_in,
|
||||
method,
|
||||
scheme,
|
||||
host,
|
||||
port,
|
||||
path,
|
||||
httpversion,
|
||||
|
||||
def read_http_body_chunked(
|
||||
self,
|
||||
headers,
|
||||
content
|
||||
)
|
||||
limit,
|
||||
request_method,
|
||||
response_code,
|
||||
is_request,
|
||||
max_chunk_size=None
|
||||
):
|
||||
"""
|
||||
Read an HTTP message body:
|
||||
headers: An ODictCaseless object
|
||||
limit: Size limit.
|
||||
is_request: True if the body to read belongs to a request, False
|
||||
otherwise
|
||||
"""
|
||||
if max_chunk_size is None:
|
||||
max_chunk_size = limit or sys.maxsize
|
||||
|
||||
|
||||
def read_response(rfile, request_method, body_size_limit, include_body=True):
|
||||
"""
|
||||
Returns an http.Response
|
||||
|
||||
By default, both response header and body are read.
|
||||
If include_body=False is specified, content may be one of the
|
||||
following:
|
||||
- None, if the response is technically allowed to have a response body
|
||||
- "", if the response must not have a response body (e.g. it's a
|
||||
response to a HEAD request)
|
||||
"""
|
||||
|
||||
line = rfile.readline()
|
||||
# Possible leftover from previous message
|
||||
if line == "\r\n" or line == "\n":
|
||||
line = rfile.readline()
|
||||
if not line:
|
||||
raise HttpErrorConnClosed(502, "Server disconnect.")
|
||||
parts = parse_response_line(line)
|
||||
if not parts:
|
||||
raise HttpError(502, "Invalid server response: %s" % repr(line))
|
||||
proto, code, msg = parts
|
||||
httpversion = parse_http_protocol(proto)
|
||||
if httpversion is None:
|
||||
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
|
||||
headers = read_headers(rfile)
|
||||
if headers is None:
|
||||
raise HttpError(502, "Invalid headers.")
|
||||
|
||||
if include_body:
|
||||
content = read_http_body(
|
||||
rfile,
|
||||
headers,
|
||||
body_size_limit,
|
||||
request_method,
|
||||
code,
|
||||
False
|
||||
expected_size = self.expected_http_body_size(
|
||||
headers, is_request, request_method, response_code
|
||||
)
|
||||
else:
|
||||
# if include_body==False then a None content means the body should be
|
||||
# read separately
|
||||
content = None
|
||||
return http.Response(httpversion, code, msg, headers, content)
|
||||
|
||||
if expected_size is None:
|
||||
if self.has_chunked_encoding(headers):
|
||||
# Python 3: yield from
|
||||
for x in self.read_chunked(limit, is_request):
|
||||
yield x
|
||||
else: # pragma: nocover
|
||||
raise HttpError(
|
||||
400 if is_request else 502,
|
||||
"Content-Length unknown but no chunked encoding"
|
||||
)
|
||||
elif expected_size >= 0:
|
||||
if limit is not None and expected_size > limit:
|
||||
raise HttpError(
|
||||
400 if is_request else 509,
|
||||
"HTTP Body too large. Limit is %s, content-length was %s" % (
|
||||
limit, expected_size
|
||||
)
|
||||
)
|
||||
bytes_left = expected_size
|
||||
while bytes_left:
|
||||
chunk_size = min(bytes_left, max_chunk_size)
|
||||
yield "", self.tcp_handler.rfile.read(chunk_size), ""
|
||||
bytes_left -= chunk_size
|
||||
else:
|
||||
bytes_left = limit or -1
|
||||
while bytes_left:
|
||||
chunk_size = min(bytes_left, max_chunk_size)
|
||||
content = self.tcp_handler.rfile.read(chunk_size)
|
||||
if not content:
|
||||
return
|
||||
yield "", content, ""
|
||||
bytes_left -= chunk_size
|
||||
not_done = self.tcp_handler.rfile.read(1)
|
||||
if not_done:
|
||||
raise HttpError(
|
||||
400 if is_request else 509,
|
||||
"HTTP Body too large. Limit is %s," % limit
|
||||
)
|
||||
|
||||
|
||||
def request_preamble(method, resource, http_major="1", http_minor="1"):
|
||||
return '%s %s HTTP/%s.%s' % (
|
||||
method, resource, http_major, http_minor
|
||||
)
|
||||
@classmethod
|
||||
def expected_http_body_size(self, headers, is_request, request_method, response_code):
|
||||
"""
|
||||
Returns the expected body length:
|
||||
- a positive integer, if the size is known in advance
|
||||
- None, if the size in unknown in advance (chunked encoding or invalid
|
||||
data)
|
||||
- -1, if all data should be read until end of stream.
|
||||
|
||||
May raise HttpError.
|
||||
"""
|
||||
# Determine response size according to
|
||||
# http://tools.ietf.org/html/rfc7230#section-3.3
|
||||
if request_method:
|
||||
request_method = request_method.upper()
|
||||
|
||||
if (not is_request and (
|
||||
request_method == "HEAD" or
|
||||
(request_method == "CONNECT" and response_code == 200) or
|
||||
response_code in [204, 304] or
|
||||
100 <= response_code <= 199)):
|
||||
return 0
|
||||
if self.has_chunked_encoding(headers):
|
||||
return None
|
||||
if "content-length" in headers:
|
||||
try:
|
||||
size = int(headers["content-length"][0])
|
||||
if size < 0:
|
||||
raise ValueError()
|
||||
return size
|
||||
except ValueError:
|
||||
return None
|
||||
if is_request:
|
||||
return 0
|
||||
return -1
|
||||
|
||||
|
||||
def response_preamble(code, message=None, http_major="1", http_minor="1"):
|
||||
if message is None:
|
||||
message = status_codes.RESPONSES.get(code)
|
||||
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
|
||||
def read_request(self, include_body=True, body_size_limit=None):
|
||||
"""
|
||||
Parse an HTTP request from a file stream
|
||||
|
||||
Args:
|
||||
include_body (bool): Read response body as well
|
||||
body_size_limit (bool): Maximum body size
|
||||
wfile (file): If specified, HTTP Expect headers are handled
|
||||
automatically, by writing a HTTP 100 CONTINUE response to the stream.
|
||||
|
||||
Returns:
|
||||
Request: The HTTP request
|
||||
|
||||
Raises:
|
||||
HttpError: If the input is invalid.
|
||||
"""
|
||||
httpversion, host, port, scheme, method, path, headers, content = (
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
request_line = self.get_request_line()
|
||||
if not request_line:
|
||||
raise tcp.NetLibDisconnect()
|
||||
|
||||
request_line_parts = self.parse_init(request_line)
|
||||
if not request_line_parts:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
method, path, httpversion = request_line_parts
|
||||
|
||||
if path == '*' or path.startswith("/"):
|
||||
form_in = "relative"
|
||||
if not utils.isascii(path):
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
elif method.upper() == 'CONNECT':
|
||||
form_in = "authority"
|
||||
r = self.parse_init_connect(request_line)
|
||||
if not r:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
host, port, _ = r
|
||||
path = None
|
||||
else:
|
||||
form_in = "absolute"
|
||||
r = self.parse_init_proxy(request_line)
|
||||
if not r:
|
||||
raise HttpError(
|
||||
400,
|
||||
"Bad HTTP request line: %s" % repr(request_line)
|
||||
)
|
||||
_, scheme, host, port, path, _ = r
|
||||
|
||||
headers = self.read_headers()
|
||||
if headers is None:
|
||||
raise HttpError(400, "Invalid headers")
|
||||
|
||||
expect_header = headers.get_first("expect", "").lower()
|
||||
if expect_header == "100-continue" and httpversion >= (1, 1):
|
||||
self.tcp_handler.wfile.write(
|
||||
'HTTP/1.1 100 Continue\r\n'
|
||||
'\r\n'
|
||||
)
|
||||
self.tcp_handler.wfile.flush()
|
||||
del headers['expect']
|
||||
|
||||
if include_body:
|
||||
content = self.read_http_body(
|
||||
headers,
|
||||
body_size_limit,
|
||||
method,
|
||||
None,
|
||||
True
|
||||
)
|
||||
|
||||
return self.Request(
|
||||
form_in,
|
||||
method,
|
||||
scheme,
|
||||
host,
|
||||
port,
|
||||
path,
|
||||
httpversion,
|
||||
headers,
|
||||
content
|
||||
)
|
||||
|
||||
|
||||
def read_response(self, request_method, body_size_limit, include_body=True):
|
||||
"""
|
||||
Returns an http.Response
|
||||
|
||||
By default, both response header and body are read.
|
||||
If include_body=False is specified, content may be one of the
|
||||
following:
|
||||
- None, if the response is technically allowed to have a response body
|
||||
- "", if the response must not have a response body (e.g. it's a
|
||||
response to a HEAD request)
|
||||
"""
|
||||
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
# Possible leftover from previous message
|
||||
if line == "\r\n" or line == "\n":
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
if not line:
|
||||
raise HttpErrorConnClosed(502, "Server disconnect.")
|
||||
parts = self.parse_response_line(line)
|
||||
if not parts:
|
||||
raise HttpError(502, "Invalid server response: %s" % repr(line))
|
||||
proto, code, msg = parts
|
||||
httpversion = self.parse_http_protocol(proto)
|
||||
if httpversion is None:
|
||||
raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
|
||||
headers = self.read_headers()
|
||||
if headers is None:
|
||||
raise HttpError(502, "Invalid headers.")
|
||||
|
||||
if include_body:
|
||||
content = self.read_http_body(
|
||||
headers,
|
||||
body_size_limit,
|
||||
request_method,
|
||||
code,
|
||||
False
|
||||
)
|
||||
else:
|
||||
# if include_body==False then a None content means the body should be
|
||||
# read separately
|
||||
content = None
|
||||
return http.Response(httpversion, code, msg, headers, content)
|
||||
|
||||
|
||||
@classmethod
|
||||
def request_preamble(self, method, resource, http_major="1", http_minor="1"):
|
||||
return '%s %s HTTP/%s.%s' % (
|
||||
method, resource, http_major, http_minor
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def response_preamble(self, code, message=None, http_major="1", http_minor="1"):
|
||||
if message is None:
|
||||
message = status_codes.RESPONSES.get(code)
|
||||
return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
|
||||
|
@ -3,70 +3,79 @@ import textwrap
|
||||
import binascii
|
||||
|
||||
from netlib import http, odict, tcp
|
||||
from netlib.http.http1 import protocol
|
||||
from netlib.http.http1 import HTTP1Protocol
|
||||
from ... import tutils, tservers
|
||||
|
||||
|
||||
def mock_protocol(data='', chunked=False):
|
||||
class TCPHandlerMock(object):
|
||||
pass
|
||||
tcp_handler = TCPHandlerMock()
|
||||
tcp_handler.rfile = cStringIO.StringIO(data)
|
||||
tcp_handler.wfile = cStringIO.StringIO()
|
||||
return HTTP1Protocol(tcp_handler)
|
||||
|
||||
|
||||
|
||||
def test_has_chunked_encoding():
|
||||
h = odict.ODictCaseless()
|
||||
assert not protocol.has_chunked_encoding(h)
|
||||
assert not HTTP1Protocol.has_chunked_encoding(h)
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
assert protocol.has_chunked_encoding(h)
|
||||
assert HTTP1Protocol.has_chunked_encoding(h)
|
||||
|
||||
|
||||
def test_read_chunked():
|
||||
|
||||
h = odict.ODictCaseless()
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n")
|
||||
|
||||
data = "1\r\na\r\n0\r\n"
|
||||
tutils.raises(
|
||||
"malformed chunked body",
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", None, True
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", None, True
|
||||
)
|
||||
|
||||
s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
|
||||
assert protocol.read_http_body(s, h, None, "GET", None, True) == "a"
|
||||
data = "1\r\na\r\n0\r\n\r\n"
|
||||
assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
|
||||
|
||||
s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n")
|
||||
assert protocol.read_http_body(s, h, None, "GET", None, True) == "a"
|
||||
data = "\r\n\r\n1\r\na\r\n0\r\n\r\n"
|
||||
assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a"
|
||||
|
||||
s = cStringIO.StringIO("\r\n")
|
||||
data = "\r\n"
|
||||
tutils.raises(
|
||||
"closed prematurely",
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", None, True
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", None, True
|
||||
)
|
||||
|
||||
s = cStringIO.StringIO("1\r\nfoo")
|
||||
data = "1\r\nfoo"
|
||||
tutils.raises(
|
||||
"malformed chunked body",
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", None, True
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", None, True
|
||||
)
|
||||
|
||||
s = cStringIO.StringIO("foo\r\nfoo")
|
||||
data = "foo\r\nfoo"
|
||||
tutils.raises(
|
||||
protocol.HttpError,
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", None, True
|
||||
http.HttpError,
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", None, True
|
||||
)
|
||||
|
||||
s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")
|
||||
tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True)
|
||||
data = "5\r\naaaaa\r\n0\r\n\r\n"
|
||||
tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True)
|
||||
|
||||
|
||||
def test_connection_close():
|
||||
h = odict.ODictCaseless()
|
||||
assert protocol.connection_close((1, 0), h)
|
||||
assert not protocol.connection_close((1, 1), h)
|
||||
assert HTTP1Protocol.connection_close((1, 0), h)
|
||||
assert not HTTP1Protocol.connection_close((1, 1), h)
|
||||
|
||||
h["connection"] = ["keep-alive"]
|
||||
assert not protocol.connection_close((1, 1), h)
|
||||
assert not HTTP1Protocol.connection_close((1, 1), h)
|
||||
|
||||
h["connection"] = ["close"]
|
||||
assert protocol.connection_close((1, 1), h)
|
||||
assert HTTP1Protocol.connection_close((1, 1), h)
|
||||
|
||||
|
||||
def test_get_header_tokens():
|
||||
@ -82,119 +91,119 @@ def test_get_header_tokens():
|
||||
|
||||
def test_read_http_body_request():
|
||||
h = odict.ODictCaseless()
|
||||
r = cStringIO.StringIO("testing")
|
||||
assert protocol.read_http_body(r, h, None, "GET", None, True) == ""
|
||||
data = "testing"
|
||||
assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == ""
|
||||
|
||||
|
||||
def test_read_http_body_response():
|
||||
h = odict.ODictCaseless()
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing"
|
||||
data = "testing"
|
||||
assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing"
|
||||
|
||||
|
||||
def test_read_http_body():
|
||||
# test default case
|
||||
h = odict.ODictCaseless()
|
||||
h["content-length"] = [7]
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing"
|
||||
data = "testing"
|
||||
assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
|
||||
|
||||
# test content length: invalid header
|
||||
h["content-length"] = ["foo"]
|
||||
s = cStringIO.StringIO("testing")
|
||||
data = "testing"
|
||||
tutils.raises(
|
||||
protocol.HttpError,
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", 200, False
|
||||
http.HttpError,
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", 200, False
|
||||
)
|
||||
|
||||
# test content length: invalid header #2
|
||||
h["content-length"] = [-1]
|
||||
s = cStringIO.StringIO("testing")
|
||||
data = "testing"
|
||||
tutils.raises(
|
||||
protocol.HttpError,
|
||||
protocol.read_http_body,
|
||||
s, h, None, "GET", 200, False
|
||||
http.HttpError,
|
||||
mock_protocol(data).read_http_body,
|
||||
h, None, "GET", 200, False
|
||||
)
|
||||
|
||||
# test content length: content length > actual content
|
||||
h["content-length"] = [5]
|
||||
s = cStringIO.StringIO("testing")
|
||||
data = "testing"
|
||||
tutils.raises(
|
||||
protocol.HttpError,
|
||||
protocol.read_http_body,
|
||||
s, h, 4, "GET", 200, False
|
||||
http.HttpError,
|
||||
mock_protocol(data).read_http_body,
|
||||
h, 4, "GET", 200, False
|
||||
)
|
||||
|
||||
# test content length: content length < actual content
|
||||
s = cStringIO.StringIO("testing")
|
||||
assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5
|
||||
data = "testing"
|
||||
assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5
|
||||
|
||||
# test no content length: limit > actual content
|
||||
h = odict.ODictCaseless()
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7
|
||||
data = "testing"
|
||||
assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7
|
||||
|
||||
# test no content length: limit < actual content
|
||||
s = tcp.Reader(cStringIO.StringIO("testing"))
|
||||
data = "testing"
|
||||
tutils.raises(
|
||||
protocol.HttpError,
|
||||
protocol.read_http_body,
|
||||
s, h, 4, "GET", 200, False
|
||||
http.HttpError,
|
||||
mock_protocol(data, chunked=True).read_http_body,
|
||||
h, 4, "GET", 200, False
|
||||
)
|
||||
|
||||
# test chunked
|
||||
h = odict.ODictCaseless()
|
||||
h["transfer-encoding"] = ["chunked"]
|
||||
s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n"))
|
||||
assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa"
|
||||
data = "5\r\naaaaa\r\n0\r\n\r\n"
|
||||
assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
|
||||
|
||||
|
||||
def test_expected_http_body_size():
|
||||
# gibber in the content-length field
|
||||
h = odict.ODictCaseless()
|
||||
h["content-length"] = ["foo"]
|
||||
assert protocol.expected_http_body_size(h, False, "GET", 200) is None
|
||||
assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
|
||||
# negative number in the content-length field
|
||||
h = odict.ODictCaseless()
|
||||
h["content-length"] = ["-7"]
|
||||
assert protocol.expected_http_body_size(h, False, "GET", 200) is None
|
||||
assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None
|
||||
# explicit length
|
||||
h = odict.ODictCaseless()
|
||||
h["content-length"] = ["5"]
|
||||
assert protocol.expected_http_body_size(h, False, "GET", 200) == 5
|
||||
assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5
|
||||
# no length
|
||||
h = odict.ODictCaseless()
|
||||
assert protocol.expected_http_body_size(h, False, "GET", 200) == -1
|
||||
assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1
|
||||
# no length request
|
||||
h = odict.ODictCaseless()
|
||||
assert protocol.expected_http_body_size(h, True, "GET", None) == 0
|
||||
assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
|
||||
|
||||
|
||||
def test_parse_http_protocol():
|
||||
assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||
assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||
assert not protocol.parse_http_protocol("HTTP/a.1")
|
||||
assert not protocol.parse_http_protocol("HTTP/1.a")
|
||||
assert not protocol.parse_http_protocol("foo/0.0")
|
||||
assert not protocol.parse_http_protocol("HTTP/x")
|
||||
assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
|
||||
assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
|
||||
assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1")
|
||||
assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a")
|
||||
assert not HTTP1Protocol.parse_http_protocol("foo/0.0")
|
||||
assert not HTTP1Protocol.parse_http_protocol("HTTP/x")
|
||||
|
||||
|
||||
def test_parse_init_connect():
|
||||
assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("bogus")
|
||||
assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
|
||||
assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0")
|
||||
assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("bogus")
|
||||
assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
|
||||
assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0")
|
||||
|
||||
|
||||
def test_parse_init_proxy():
|
||||
u = "GET http://foo.com:8888/test HTTP/1.1"
|
||||
m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
|
||||
m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u)
|
||||
assert m == "GET"
|
||||
assert s == "http"
|
||||
assert h == "foo.com"
|
||||
@ -203,27 +212,27 @@ def test_parse_init_proxy():
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
u = "G\xfeET http://foo.com:8888/test HTTP/1.1"
|
||||
assert not protocol.parse_init_proxy(u)
|
||||
assert not HTTP1Protocol.parse_init_proxy(u)
|
||||
|
||||
assert not protocol.parse_init_proxy("invalid")
|
||||
assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
|
||||
assert not HTTP1Protocol.parse_init_proxy("invalid")
|
||||
assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1")
|
||||
assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
|
||||
|
||||
|
||||
def test_parse_init_http():
|
||||
u = "GET /test HTTP/1.1"
|
||||
m, u, httpversion = protocol.parse_init_http(u)
|
||||
m, u, httpversion = HTTP1Protocol.parse_init_http(u)
|
||||
assert m == "GET"
|
||||
assert u == "/test"
|
||||
assert httpversion == (1, 1)
|
||||
|
||||
u = "G\xfeET /test HTTP/1.1"
|
||||
assert not protocol.parse_init_http(u)
|
||||
assert not HTTP1Protocol.parse_init_http(u)
|
||||
|
||||
assert not protocol.parse_init_http("invalid")
|
||||
assert not protocol.parse_init_http("GET invalid HTTP/1.1")
|
||||
assert not protocol.parse_init_http("GET /test foo/1.1")
|
||||
assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1")
|
||||
assert not HTTP1Protocol.parse_init_http("invalid")
|
||||
assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1")
|
||||
assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1")
|
||||
assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1")
|
||||
|
||||
|
||||
class TestReadHeaders:
|
||||
@ -232,8 +241,7 @@ class TestReadHeaders:
|
||||
if not verbatim:
|
||||
data = textwrap.dedent(data)
|
||||
data = data.strip()
|
||||
s = cStringIO.StringIO(data)
|
||||
return protocol.read_headers(s)
|
||||
return mock_protocol(data).read_headers()
|
||||
|
||||
def test_read_simple(self):
|
||||
data = """
|
||||
@ -287,16 +295,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase):
|
||||
def test_no_content_length(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
resp = protocol.read_response(c.rfile, "GET", None)
|
||||
resp = HTTP1Protocol(c).read_response("GET", None)
|
||||
assert resp.content == "bar\r\n\r\n"
|
||||
|
||||
|
||||
def test_read_response():
|
||||
def tst(data, method, limit, include_body=True):
|
||||
data = textwrap.dedent(data)
|
||||
r = cStringIO.StringIO(data)
|
||||
return protocol.read_response(
|
||||
r, method, limit, include_body=include_body
|
||||
return mock_protocol(data).read_response(
|
||||
method, limit, include_body=include_body
|
||||
)
|
||||
|
||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||
@ -358,16 +365,16 @@ def test_read_response():
|
||||
|
||||
|
||||
def test_get_request_line():
|
||||
r = cStringIO.StringIO("\nfoo")
|
||||
assert protocol.get_request_line(r) == "foo"
|
||||
assert not protocol.get_request_line(r)
|
||||
data = "\nfoo"
|
||||
p = mock_protocol(data)
|
||||
assert p.get_request_line() == "foo"
|
||||
assert not p.get_request_line()
|
||||
|
||||
|
||||
class TestReadRequest():
|
||||
|
||||
def tst(self, data, **kwargs):
|
||||
r = cStringIO.StringIO(data)
|
||||
return protocol.read_request(r, **kwargs)
|
||||
return mock_protocol(data).read_request(**kwargs)
|
||||
|
||||
def test_invalid(self):
|
||||
tutils.raises(
|
||||
@ -421,14 +428,15 @@ class TestReadRequest():
|
||||
assert v.host == "foo.com"
|
||||
|
||||
def test_expect(self):
|
||||
w = cStringIO.StringIO()
|
||||
r = cStringIO.StringIO(
|
||||
data = "".join(
|
||||
"GET / HTTP/1.1\r\n"
|
||||
"Content-Length: 3\r\n"
|
||||
"Expect: 100-continue\r\n\r\n"
|
||||
"foobar",
|
||||
"foobar"
|
||||
)
|
||||
v = protocol.read_request(r, wfile=w)
|
||||
assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
|
||||
|
||||
p = mock_protocol(data)
|
||||
v = p.read_request()
|
||||
assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
|
||||
assert v.content == "foo"
|
||||
assert r.read(3) == "bar"
|
||||
assert p.tcp_handler.rfile.read(3) == "bar"
|
||||
|
@ -4,6 +4,7 @@ from nose.tools import raises
|
||||
|
||||
from netlib import tcp, http, websockets
|
||||
from netlib.http.exceptions import *
|
||||
from netlib.http.http1 import HTTP1Protocol
|
||||
from .. import tutils, tservers
|
||||
|
||||
|
||||
@ -32,10 +33,13 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
|
||||
frame.to_file(self.wfile)
|
||||
|
||||
def handshake(self):
|
||||
req = http.http1.read_request(self.rfile)
|
||||
http1_protocol = HTTP1Protocol(self)
|
||||
|
||||
req = http1_protocol.read_request()
|
||||
key = self.protocol.check_client_handshake(req.headers)
|
||||
|
||||
self.wfile.write(http.http1.response_preamble(101) + "\r\n")
|
||||
preamble = http1_protocol.response_preamble(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.server_handshake_headers(key)
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
@ -56,14 +60,16 @@ class WebSocketsClient(tcp.TCPClient):
|
||||
def connect(self):
|
||||
super(WebSocketsClient, self).connect()
|
||||
|
||||
preamble = http.http1.protocol.request_preamble("GET", "/")
|
||||
http1_protocol = HTTP1Protocol(self)
|
||||
|
||||
preamble = http1_protocol.request_preamble("GET", "/")
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.client_handshake_headers()
|
||||
self.client_nonce = headers.get_first("sec-websocket-key")
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
|
||||
resp = http.http1.protocol.read_response(self.rfile, "get", None)
|
||||
resp = http1_protocol.read_response("get", None)
|
||||
server_nonce = self.protocol.check_server_handshake(resp.headers)
|
||||
|
||||
if not server_nonce == self.protocol.create_server_nonce(
|
||||
@ -151,10 +157,13 @@ class TestWebSockets(tservers.ServerTestBase):
|
||||
class BadHandshakeHandler(WebSocketsEchoHandler):
|
||||
|
||||
def handshake(self):
|
||||
client_hs = http.http1.protocol.read_request(self.rfile)
|
||||
http1_protocol = HTTP1Protocol(self)
|
||||
|
||||
client_hs = http1_protocol.read_request()
|
||||
self.protocol.check_client_handshake(client_hs.headers)
|
||||
|
||||
self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n")
|
||||
preamble = http1_protocol.response_preamble(101)
|
||||
self.wfile.write(preamble + "\r\n")
|
||||
headers = self.protocol.server_handshake_headers("malformed key")
|
||||
self.wfile.write(headers.format() + "\r\n")
|
||||
self.wfile.flush()
|
||||
|
Loading…
Reference in New Issue
Block a user