mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-30 11:19:23 +00:00
Merge pull request #84 from Kriechi/http2-wip
[WIP] Protocol Refactoring for HTTP/2
This commit is contained in:
commit
199f2a44fe
@ -4,8 +4,10 @@ import collections
|
||||
import string
|
||||
import sys
|
||||
import urlparse
|
||||
import time
|
||||
|
||||
from netlib import odict, utils, tcp, http
|
||||
from netlib.http import semantics
|
||||
from .. import status_codes
|
||||
from ..exceptions import *
|
||||
|
||||
@ -14,13 +16,10 @@ class TCPHandler(object):
|
||||
self.rfile = rfile
|
||||
self.wfile = wfile
|
||||
|
||||
class HTTP1Protocol(object):
|
||||
class HTTP1Protocol(semantics.ProtocolMixin):
|
||||
|
||||
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
|
||||
if tcp_handler:
|
||||
self.tcp_handler = tcp_handler
|
||||
else:
|
||||
self.tcp_handler = TCPHandler(rfile, wfile)
|
||||
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
|
||||
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
@ -39,6 +38,10 @@ class HTTP1Protocol(object):
|
||||
Raises:
|
||||
HttpError: If the input is invalid.
|
||||
"""
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
httpversion, host, port, scheme, method, path, headers, body = (
|
||||
None, None, None, None, None, None, None, None)
|
||||
|
||||
@ -106,6 +109,12 @@ class HTTP1Protocol(object):
|
||||
True
|
||||
)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
timestamp_end = time.time()
|
||||
|
||||
return http.Request(
|
||||
form_in,
|
||||
method,
|
||||
@ -115,7 +124,9 @@ class HTTP1Protocol(object):
|
||||
path,
|
||||
httpversion,
|
||||
headers,
|
||||
body
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
@ -124,12 +135,15 @@ class HTTP1Protocol(object):
|
||||
Returns an http.Response
|
||||
|
||||
By default, both response header and body are read.
|
||||
If include_body=False is specified, content may be one of the
|
||||
If include_body=False is specified, body may be one of the
|
||||
following:
|
||||
- None, if the response is technically allowed to have a response body
|
||||
- "", if the response must not have a response body (e.g. it's a
|
||||
response to a HEAD request)
|
||||
"""
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
line = self.tcp_handler.rfile.readline()
|
||||
# Possible leftover from previous message
|
||||
@ -149,7 +163,7 @@ class HTTP1Protocol(object):
|
||||
raise HttpError(502, "Invalid headers.")
|
||||
|
||||
if include_body:
|
||||
content = self.read_http_body(
|
||||
body = self.read_http_body(
|
||||
headers,
|
||||
body_size_limit,
|
||||
request_method,
|
||||
@ -157,10 +171,55 @@ class HTTP1Protocol(object):
|
||||
False
|
||||
)
|
||||
else:
|
||||
# if include_body==False then a None content means the body should be
|
||||
# if include_body==False then a None body means the body should be
|
||||
# read separately
|
||||
content = None
|
||||
return http.Response(httpversion, code, msg, headers, content)
|
||||
body = None
|
||||
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
if include_body:
|
||||
timestamp_end = time.time()
|
||||
else:
|
||||
timestamp_end = None
|
||||
|
||||
return http.Response(
|
||||
httpversion,
|
||||
code,
|
||||
msg,
|
||||
headers,
|
||||
body,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
|
||||
|
||||
def assemble_request(self, request):
|
||||
assert isinstance(request, semantics.Request)
|
||||
|
||||
if request.body == semantics.CONTENT_MISSING:
|
||||
raise http.HttpError(
|
||||
502,
|
||||
"Cannot assemble flow with CONTENT_MISSING"
|
||||
)
|
||||
first_line = self._assemble_request_first_line(request)
|
||||
headers = self._assemble_request_headers(request)
|
||||
return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
|
||||
|
||||
|
||||
def assemble_response(self, response):
|
||||
assert isinstance(response, semantics.Response)
|
||||
|
||||
if response.body == semantics.CONTENT_MISSING:
|
||||
raise http.HttpError(
|
||||
502,
|
||||
"Cannot assemble flow with CONTENT_MISSING"
|
||||
)
|
||||
first_line = self._assemble_response_first_line(response)
|
||||
headers = self._assemble_response_headers(response)
|
||||
return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
|
||||
|
||||
|
||||
def read_headers(self):
|
||||
@ -331,7 +390,6 @@ class HTTP1Protocol(object):
|
||||
return line
|
||||
|
||||
|
||||
|
||||
def _read_chunked(self, limit, is_request):
|
||||
"""
|
||||
Read a chunked HTTP body.
|
||||
@ -494,3 +552,74 @@ class HTTP1Protocol(object):
|
||||
except ValueError:
|
||||
return None
|
||||
return (proto, code, msg)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _assemble_request_first_line(self, request):
|
||||
if request.form_in == "relative":
|
||||
request_line = '%s %s HTTP/%s.%s' % (
|
||||
request.method,
|
||||
request.path,
|
||||
request.httpversion[0],
|
||||
request.httpversion[1],
|
||||
)
|
||||
elif request.form_in == "authority":
|
||||
request_line = '%s %s:%s HTTP/%s.%s' % (
|
||||
request.method,
|
||||
request.host,
|
||||
request.port,
|
||||
request.httpversion[0],
|
||||
request.httpversion[1],
|
||||
)
|
||||
elif request.form_in == "absolute":
|
||||
request_line = '%s %s://%s:%s%s HTTP/%s.%s' % (
|
||||
request.method,
|
||||
request.scheme,
|
||||
request.host,
|
||||
request.port,
|
||||
request.path,
|
||||
request.httpversion[0],
|
||||
request.httpversion[1],
|
||||
)
|
||||
else:
|
||||
raise http.HttpError(400, "Invalid request form")
|
||||
return request_line
|
||||
|
||||
def _assemble_request_headers(self, request):
|
||||
headers = request.headers.copy()
|
||||
for k in request._headers_to_strip_off:
|
||||
del headers[k]
|
||||
if 'host' not in headers and request.scheme and request.host and request.port:
|
||||
headers["Host"] = [utils.hostport(request.scheme,
|
||||
request.host,
|
||||
request.port)]
|
||||
|
||||
# If content is defined (i.e. not None or CONTENT_MISSING), we always
|
||||
# add a content-length header.
|
||||
if request.body or request.body == "":
|
||||
headers["Content-Length"] = [str(len(request.body))]
|
||||
|
||||
return headers.format()
|
||||
|
||||
|
||||
def _assemble_response_first_line(self, response):
|
||||
return 'HTTP/%s.%s %s %s' % (
|
||||
response.httpversion[0],
|
||||
response.httpversion[1],
|
||||
response.status_code,
|
||||
response.msg,
|
||||
)
|
||||
|
||||
def _assemble_response_headers(self, response, preserve_transfer_encoding=False):
|
||||
headers = response.headers.copy()
|
||||
for k in response._headers_to_strip_off:
|
||||
del headers[k]
|
||||
if not preserve_transfer_encoding:
|
||||
del headers['Transfer-Encoding']
|
||||
|
||||
# If body is defined (i.e. not None or CONTENT_MISSING), we always
|
||||
# add a content-length header.
|
||||
if response.body or response.body == "":
|
||||
headers["Content-Length"] = [str(len(response.body))]
|
||||
|
||||
return headers.format()
|
||||
|
@ -1,12 +1,20 @@
|
||||
from __future__ import (absolute_import, print_function, division)
|
||||
import itertools
|
||||
import time
|
||||
|
||||
from hpack.hpack import Encoder, Decoder
|
||||
from netlib import http, utils, odict
|
||||
from netlib.http import semantics
|
||||
from . import frame
|
||||
|
||||
|
||||
class HTTP2Protocol(object):
|
||||
class TCPHandler(object):
|
||||
def __init__(self, rfile, wfile=None):
|
||||
self.rfile = rfile
|
||||
self.wfile = wfile
|
||||
|
||||
|
||||
class HTTP2Protocol(semantics.ProtocolMixin):
|
||||
|
||||
ERROR_CODES = utils.BiDi(
|
||||
NO_ERROR=0x0,
|
||||
@ -31,16 +39,182 @@ class HTTP2Protocol(object):
|
||||
|
||||
ALPN_PROTO_H2 = 'h2'
|
||||
|
||||
def __init__(self, tcp_handler, is_server=False, dump_frames=False):
|
||||
self.tcp_handler = tcp_handler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tcp_handler=None,
|
||||
rfile=None,
|
||||
wfile=None,
|
||||
is_server=False,
|
||||
dump_frames=False,
|
||||
encoder=None,
|
||||
decoder=None,
|
||||
):
|
||||
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
|
||||
self.is_server = is_server
|
||||
self.dump_frames = dump_frames
|
||||
self.encoder = encoder or Encoder()
|
||||
self.decoder = decoder or Decoder()
|
||||
|
||||
self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy()
|
||||
self.current_stream_id = None
|
||||
self.encoder = Encoder()
|
||||
self.decoder = Decoder()
|
||||
self.connection_preface_performed = False
|
||||
self.dump_frames = dump_frames
|
||||
|
||||
def read_request(self, include_body=True, body_size_limit=None, allow_empty=False):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
stream_id, headers, body = self._receive_transmission(include_body)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
timestamp_end = time.time()
|
||||
|
||||
request = http.Request(
|
||||
"relative", # TODO: use the correct value
|
||||
headers.get_first(':method', 'GET'),
|
||||
headers.get_first(':scheme', 'https'),
|
||||
headers.get_first(':host', 'localhost'),
|
||||
443, # TODO: parse port number from host?
|
||||
headers.get_first(':path', '/'),
|
||||
(2, 0),
|
||||
headers,
|
||||
body,
|
||||
timestamp_start,
|
||||
timestamp_end,
|
||||
)
|
||||
request.stream_id = stream_id
|
||||
|
||||
return request
|
||||
|
||||
def read_response(self, request_method='', body_size_limit=None, include_body=True):
|
||||
self.perform_connection_preface()
|
||||
|
||||
timestamp_start = time.time()
|
||||
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
|
||||
self.tcp_handler.rfile.reset_timestamps()
|
||||
|
||||
stream_id, headers, body = self._receive_transmission(include_body)
|
||||
|
||||
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
|
||||
# more accurate timestamp_start
|
||||
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
|
||||
|
||||
if include_body:
|
||||
timestamp_end = time.time()
|
||||
else:
|
||||
timestamp_end = None
|
||||
|
||||
response = http.Response(
|
||||
(2, 0),
|
||||
int(headers.get_first(':status')),
|
||||
"",
|
||||
headers,
|
||||
body,
|
||||
timestamp_start=timestamp_start,
|
||||
timestamp_end=timestamp_end,
|
||||
)
|
||||
response.stream_id = stream_id
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def assemble_request(self, request):
|
||||
assert isinstance(request, semantics.Request)
|
||||
|
||||
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
|
||||
if self.tcp_handler.address.port != 443:
|
||||
authority += ":%d" % self.tcp_handler.address.port
|
||||
|
||||
headers = request.headers.copy()
|
||||
|
||||
if not ':authority' in headers.keys():
|
||||
headers.add(':authority', bytes(authority), prepend=True)
|
||||
if not ':scheme' in headers.keys():
|
||||
headers.add(':scheme', bytes(request.scheme), prepend=True)
|
||||
if not ':path' in headers.keys():
|
||||
headers.add(':path', bytes(request.path), prepend=True)
|
||||
if not ':method' in headers.keys():
|
||||
headers.add(':method', bytes(request.method), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
|
||||
if hasattr(request, 'stream_id'):
|
||||
stream_id = request.stream_id
|
||||
else:
|
||||
stream_id = self._next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
|
||||
self._create_body(request.body, stream_id)))
|
||||
|
||||
def assemble_response(self, response):
|
||||
assert isinstance(response, semantics.Response)
|
||||
|
||||
headers = response.headers.copy()
|
||||
|
||||
if not ':status' in headers.keys():
|
||||
headers.add(':status', bytes(str(response.status_code)), prepend=True)
|
||||
|
||||
headers = headers.items()
|
||||
|
||||
if hasattr(response, 'stream_id'):
|
||||
stream_id = response.stream_id
|
||||
else:
|
||||
stream_id = self._next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
|
||||
self._create_body(response.body, stream_id),
|
||||
))
|
||||
|
||||
def perform_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
if self.is_server:
|
||||
self.perform_server_connection_preface(force)
|
||||
else:
|
||||
self.perform_client_connection_preface(force)
|
||||
|
||||
def perform_server_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
|
||||
magic = self.tcp_handler.rfile.safe_read(magic_length)
|
||||
assert magic == self.CLIENT_CONNECTION_PREFACE
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self), hide=True)
|
||||
self._receive_settings(hide=True)
|
||||
|
||||
def perform_client_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self), hide=True)
|
||||
self._receive_settings(hide=True)
|
||||
|
||||
def send_frame(self, frm, hide=False):
|
||||
raw_bytes = frm.to_bytes()
|
||||
self.tcp_handler.wfile.write(raw_bytes)
|
||||
self.tcp_handler.wfile.flush()
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable(">>"))
|
||||
|
||||
def read_frame(self, hide=False):
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable("<<"))
|
||||
if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
|
||||
self._apply_settings(frm.settings, hide)
|
||||
|
||||
return frm
|
||||
|
||||
def check_alpn(self):
|
||||
alp = self.tcp_handler.get_alpn_proto_negotiated()
|
||||
@ -63,27 +237,7 @@ class HTTP2Protocol(object):
|
||||
assert len(frm.settings) == 0
|
||||
break
|
||||
|
||||
def perform_server_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
|
||||
magic = self.tcp_handler.rfile.safe_read(magic_length)
|
||||
assert magic == self.CLIENT_CONNECTION_PREFACE
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self), hide=True)
|
||||
self._receive_settings(hide=True)
|
||||
|
||||
def perform_client_connection_preface(self, force=False):
|
||||
if force or not self.connection_preface_performed:
|
||||
self.connection_preface_performed = True
|
||||
|
||||
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
|
||||
|
||||
self.send_frame(frame.SettingsFrame(state=self), hide=True)
|
||||
self._receive_settings(hide=True)
|
||||
|
||||
def next_stream_id(self):
|
||||
def _next_stream_id(self):
|
||||
if self.current_stream_id is None:
|
||||
if self.is_server:
|
||||
# servers must use even stream ids
|
||||
@ -95,22 +249,6 @@ class HTTP2Protocol(object):
|
||||
self.current_stream_id += 2
|
||||
return self.current_stream_id
|
||||
|
||||
def send_frame(self, frm, hide=False):
|
||||
raw_bytes = frm.to_bytes()
|
||||
self.tcp_handler.wfile.write(raw_bytes)
|
||||
self.tcp_handler.wfile.flush()
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable(">>"))
|
||||
|
||||
def read_frame(self, hide=False):
|
||||
frm = frame.Frame.from_file(self.tcp_handler.rfile, self)
|
||||
if not hide and self.dump_frames: # pragma no cover
|
||||
print(frm.human_readable("<<"))
|
||||
if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK:
|
||||
self._apply_settings(frm.settings, hide)
|
||||
|
||||
return frm
|
||||
|
||||
def _apply_settings(self, settings, hide=False):
|
||||
for setting, value in settings.items():
|
||||
old_value = self.http2_settings[setting]
|
||||
@ -164,51 +302,7 @@ class HTTP2Protocol(object):
|
||||
|
||||
return [frm.to_bytes()]
|
||||
|
||||
|
||||
def create_request(self, method, path, headers=None, body=None):
|
||||
if headers is None:
|
||||
headers = []
|
||||
|
||||
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
|
||||
if self.tcp_handler.address.port != 443:
|
||||
authority += ":%d" % self.tcp_handler.address.port
|
||||
|
||||
headers = [
|
||||
(b':method', bytes(method)),
|
||||
(b':path', bytes(path)),
|
||||
(b':scheme', b'https'),
|
||||
(b':authority', authority),
|
||||
] + headers
|
||||
|
||||
stream_id = self.next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(body is None)),
|
||||
self._create_body(body, stream_id)))
|
||||
|
||||
def read_response(self, *args):
|
||||
stream_id, headers, body = self._receive_transmission()
|
||||
|
||||
status = headers[':status'][0]
|
||||
response = http.Response("HTTP/2", status, "", headers, body)
|
||||
response.stream_id = stream_id
|
||||
return response
|
||||
|
||||
def read_request(self):
|
||||
stream_id, headers, body = self._receive_transmission()
|
||||
|
||||
form_in = ""
|
||||
method = headers.get(':method', [''])[0]
|
||||
scheme = headers.get(':scheme', [''])[0]
|
||||
host = headers.get(':host', [''])[0]
|
||||
port = '' # TODO: parse port number?
|
||||
path = headers.get(':path', [''])[0]
|
||||
|
||||
request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body)
|
||||
request.stream_id = stream_id
|
||||
return request
|
||||
|
||||
def _receive_transmission(self):
|
||||
def _receive_transmission(self, include_body=True):
|
||||
body_expected = True
|
||||
|
||||
stream_id = 0
|
||||
@ -239,19 +333,3 @@ class HTTP2Protocol(object):
|
||||
headers.add(header, value)
|
||||
|
||||
return stream_id, headers, body
|
||||
|
||||
def create_response(self, code, stream_id=None, headers=None, body=None):
|
||||
if headers is None:
|
||||
headers = []
|
||||
if isinstance(headers, odict.ODict):
|
||||
headers = headers.items()
|
||||
|
||||
headers = [(b':status', bytes(str(code)))] + headers
|
||||
|
||||
if not stream_id:
|
||||
stream_id = self.next_stream_id()
|
||||
|
||||
return list(itertools.chain(
|
||||
self._create_headers(headers, stream_id, end_stream=(body is None)),
|
||||
self._create_body(body, stream_id),
|
||||
))
|
||||
|
@ -7,6 +7,32 @@ import urlparse
|
||||
|
||||
from .. import utils, odict
|
||||
|
||||
CONTENT_MISSING = 0
|
||||
|
||||
|
||||
class ProtocolMixin(object):
|
||||
|
||||
def read_request(self):
|
||||
raise NotImplemented
|
||||
|
||||
def read_response(self):
|
||||
raise NotImplemented
|
||||
|
||||
def assemble(self, message):
|
||||
if isinstance(message, Request):
|
||||
return self.assemble_request(message)
|
||||
elif isinstance(message, Response):
|
||||
return self.assemble_response(message)
|
||||
else:
|
||||
raise ValueError("HTTP message not supported.")
|
||||
|
||||
def assemble_request(self, request):
|
||||
raise NotImplemented
|
||||
|
||||
def assemble_response(self, response):
|
||||
raise NotImplemented
|
||||
|
||||
|
||||
class Request(object):
|
||||
|
||||
def __init__(
|
||||
@ -18,9 +44,15 @@ class Request(object):
|
||||
port,
|
||||
path,
|
||||
httpversion,
|
||||
headers,
|
||||
body,
|
||||
headers=None,
|
||||
body=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
):
|
||||
if not headers:
|
||||
headers = odict.ODictCaseless()
|
||||
assert isinstance(headers, odict.ODictCaseless)
|
||||
|
||||
self.form_in = form_in
|
||||
self.method = method
|
||||
self.scheme = scheme
|
||||
@ -30,17 +62,31 @@ class Request(object):
|
||||
self.httpversion = httpversion
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "Request(%s - %s, %s)" % (self.method, self.host, self.path)
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
|
||||
class EmptyRequest(Request):
|
||||
def __init__(self):
|
||||
@ -63,28 +109,59 @@ class Response(object):
|
||||
self,
|
||||
httpversion,
|
||||
status_code,
|
||||
msg,
|
||||
headers,
|
||||
body,
|
||||
msg=None,
|
||||
headers=None,
|
||||
body=None,
|
||||
sslinfo=None,
|
||||
timestamp_start=None,
|
||||
timestamp_end=None,
|
||||
):
|
||||
if not headers:
|
||||
headers = odict.ODictCaseless()
|
||||
assert isinstance(headers, odict.ODictCaseless)
|
||||
|
||||
self.httpversion = httpversion
|
||||
self.status_code = status_code
|
||||
self.msg = msg
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.sslinfo = sslinfo
|
||||
self.timestamp_start = timestamp_start
|
||||
self.timestamp_end = timestamp_end
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
try:
|
||||
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
|
||||
return self_d == other_d
|
||||
except:
|
||||
return False
|
||||
|
||||
def __repr__(self):
|
||||
return "Response(%s - %s)" % (self.status_code, self.msg)
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.body
|
||||
|
||||
@content.setter
|
||||
def content(self, content):
|
||||
# TODO: remove deprecated setter
|
||||
self.body = content
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
# TODO: remove deprecated getter
|
||||
return self.status_code
|
||||
|
||||
@code.setter
|
||||
def code(self, code):
|
||||
# TODO: remove deprecated setter
|
||||
self.status_code = code
|
||||
|
||||
|
||||
|
||||
def is_valid_port(port):
|
||||
if not 0 <= port <= 65535:
|
||||
|
@ -96,7 +96,10 @@ class ODict(object):
|
||||
return True
|
||||
return False
|
||||
|
||||
def add(self, key, value):
|
||||
def add(self, key, value, prepend=False):
|
||||
if prepend:
|
||||
self.lst.insert(0, [key, value])
|
||||
else:
|
||||
self.lst.append([key, value])
|
||||
|
||||
def get(self, k, d=None):
|
||||
|
@ -129,3 +129,13 @@ class Data(object):
|
||||
if not os.path.exists(fullpath):
|
||||
raise ValueError("dataPath: %s does not exist." % fullpath)
|
||||
return fullpath
|
||||
|
||||
|
||||
def hostport(scheme, host, port):
|
||||
"""
|
||||
Returns the host component, with a port specifcation if needed.
|
||||
"""
|
||||
if (port, scheme) in [(80, "http"), (443, "https")]:
|
||||
return host
|
||||
else:
|
||||
return "%s:%s" % (host, port)
|
||||
|
@ -297,10 +297,10 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase):
|
||||
|
||||
|
||||
def test_read_response():
|
||||
def tst(data, method, limit, include_body=True):
|
||||
def tst(data, method, body_size_limit, include_body=True):
|
||||
data = textwrap.dedent(data)
|
||||
return mock_protocol(data).read_response(
|
||||
method, limit, include_body=include_body
|
||||
method, body_size_limit, include_body=include_body
|
||||
)
|
||||
|
||||
tutils.raises("server disconnect", tst, "", "GET", None)
|
||||
|
@ -1,6 +1,6 @@
|
||||
import OpenSSL
|
||||
|
||||
from netlib import tcp, odict
|
||||
from netlib import tcp, odict, http
|
||||
from netlib.http import http2
|
||||
from netlib.http.http2.frame import *
|
||||
from ... import tutils, tservers
|
||||
@ -117,11 +117,11 @@ class TestClientStreamIds():
|
||||
|
||||
def test_client_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
assert self.protocol.next_stream_id() == 1
|
||||
assert self.protocol._next_stream_id() == 1
|
||||
assert self.protocol.current_stream_id == 1
|
||||
assert self.protocol.next_stream_id() == 3
|
||||
assert self.protocol._next_stream_id() == 3
|
||||
assert self.protocol.current_stream_id == 3
|
||||
assert self.protocol.next_stream_id() == 5
|
||||
assert self.protocol._next_stream_id() == 5
|
||||
assert self.protocol.current_stream_id == 5
|
||||
|
||||
|
||||
@ -131,11 +131,11 @@ class TestServerStreamIds():
|
||||
|
||||
def test_server_stream_ids(self):
|
||||
assert self.protocol.current_stream_id is None
|
||||
assert self.protocol.next_stream_id() == 2
|
||||
assert self.protocol._next_stream_id() == 2
|
||||
assert self.protocol.current_stream_id == 2
|
||||
assert self.protocol.next_stream_id() == 4
|
||||
assert self.protocol._next_stream_id() == 4
|
||||
assert self.protocol.current_stream_id == 4
|
||||
assert self.protocol.next_stream_id() == 6
|
||||
assert self.protocol._next_stream_id() == 6
|
||||
assert self.protocol.current_stream_id == 6
|
||||
|
||||
|
||||
@ -215,17 +215,36 @@ class TestCreateBody():
|
||||
# TODO: add test for too large frames
|
||||
|
||||
|
||||
class TestCreateRequest():
|
||||
class TestAssembleRequest():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_create_request_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/')
|
||||
def test_assemble_request_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
None,
|
||||
None,
|
||||
))
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
|
||||
|
||||
def test_create_request_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).create_request(
|
||||
'GET', '/', [(b'foo', b'bar')], 'foobar')
|
||||
def test_assemble_request_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
|
||||
'',
|
||||
'GET',
|
||||
'https',
|
||||
'',
|
||||
'',
|
||||
'/',
|
||||
(2, 0),
|
||||
odict.ODictCaseless([('foo', 'bar')]),
|
||||
'foobar',
|
||||
))
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
|
||||
@ -250,11 +269,12 @@ class TestReadResponse(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.httpversion == "HTTP/2"
|
||||
assert resp.status_code == "200"
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b'foobar'
|
||||
@ -275,12 +295,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_response()
|
||||
|
||||
assert resp.stream_id
|
||||
assert resp.httpversion == "HTTP/2"
|
||||
assert resp.status_code == "200"
|
||||
assert resp.httpversion == (2, 0)
|
||||
assert resp.status_code == 200
|
||||
assert resp.msg == ""
|
||||
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
|
||||
assert resp.body == b''
|
||||
@ -303,6 +324,7 @@ class TestReadRequest(tservers.ServerTestBase):
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
protocol = http2.HTTP2Protocol(c, is_server=True)
|
||||
protocol.connection_preface_performed = True
|
||||
|
||||
resp = protocol.read_request()
|
||||
|
||||
@ -315,16 +337,24 @@ class TestCreateResponse():
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
|
||||
def test_create_response_simple(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
(2, 0),
|
||||
200,
|
||||
))
|
||||
assert len(bytes) == 1
|
||||
assert bytes[0] ==\
|
||||
'00000101050000000288'.decode('hex')
|
||||
|
||||
def test_create_response_with_body(self):
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
|
||||
200, 1, [(b'foo', b'bar')], 'foobar')
|
||||
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
|
||||
(2, 0),
|
||||
200,
|
||||
'',
|
||||
odict.ODictCaseless([('foo', 'bar')]),
|
||||
'foobar'
|
||||
))
|
||||
assert len(bytes) == 2
|
||||
assert bytes[0] ==\
|
||||
'00000901040000000188408294e7838c767f'.decode('hex')
|
||||
'00000901040000000288408294e7838c767f'.decode('hex')
|
||||
assert bytes[1] ==\
|
||||
'000006000100000001666f6f626172'.decode('hex')
|
||||
'000006000100000002666f6f626172'.decode('hex')
|
||||
|
Loading…
Reference in New Issue
Block a user