add on-the-wire representation methods

This commit is contained in:
Thomas Kriechbaumer 2015-07-29 11:27:43 +02:00
parent 827fe824d9
commit c7fcc2cca5
5 changed files with 324 additions and 157 deletions

View File

@ -7,6 +7,7 @@ import urlparse
import time
from netlib import odict, utils, tcp, http
from netlib.http import semantics
from .. import status_codes
from ..exceptions import *
@ -15,7 +16,7 @@ class TCPHandler(object):
self.rfile = rfile
self.wfile = wfile
class HTTP1Protocol(object):
class HTTP1Protocol(semantics.ProtocolMixin):
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
@ -195,6 +196,32 @@ class HTTP1Protocol(object):
)
def assemble_request(self, request):
assert isinstance(request, semantics.Request)
if request.body == semantics.CONTENT_MISSING:
raise http.HttpError(
502,
"Cannot assemble flow with CONTENT_MISSING"
)
first_line = self._assemble_request_first_line(request)
headers = self._assemble_request_headers(request)
return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
def assemble_response(self, response):
assert isinstance(response, semantics.Response)
if response.body == semantics.CONTENT_MISSING:
raise http.HttpError(
502,
"Cannot assemble flow with CONTENT_MISSING"
)
first_line = self._assemble_response_first_line(response)
headers = self._assemble_response_headers(response)
return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
def read_headers(self):
"""
Read a set of headers.
@ -363,7 +390,6 @@ class HTTP1Protocol(object):
return line
def _read_chunked(self, limit, is_request):
"""
Read a chunked HTTP body.
@ -526,3 +552,74 @@ class HTTP1Protocol(object):
except ValueError:
return None
return (proto, code, msg)
@classmethod
def _assemble_request_first_line(self, request):
if request.form_in == "relative":
request_line = '%s %s HTTP/%s.%s' % (
request.method,
request.path,
request.httpversion[0],
request.httpversion[1],
)
elif request.form_in == "authority":
request_line = '%s %s:%s HTTP/%s.%s' % (
request.method,
request.host,
request.port,
request.httpversion[0],
request.httpversion[1],
)
elif request.form_in == "absolute":
request_line = '%s %s://%s:%s%s HTTP/%s.%s' % (
request.method,
request.scheme,
request.host,
request.port,
request.path,
request.httpversion[0],
request.httpversion[1],
)
else:
raise http.HttpError(400, "Invalid request form")
return request_line
def _assemble_request_headers(self, request):
headers = request.headers.copy()
for k in request._headers_to_strip_off:
del headers[k]
if 'host' not in headers and request.scheme and request.host and request.port:
headers["Host"] = [utils.hostport(request.scheme,
request.host,
request.port)]
# If content is defined (i.e. not None or CONTENT_MISSING), we always
# add a content-length header.
if request.body or request.body == "":
headers["Content-Length"] = [str(len(request.body))]
return headers.format()
def _assemble_response_first_line(self, response):
return 'HTTP/%s.%s %s %s' % (
response.httpversion[0],
response.httpversion[1],
response.status_code,
response.msg,
)
def _assemble_response_headers(self, response, preserve_transfer_encoding=False):
headers = response.headers.copy()
for k in response._headers_to_strip_off:
del headers[k]
if not preserve_transfer_encoding:
del headers['Transfer-Encoding']
# If body is defined (i.e. not None or CONTENT_MISSING), we always
# add a content-length header.
if response.body or response.body == "":
headers["Content-Length"] = [str(len(response.body))]
return headers.format()

View File

@ -4,6 +4,7 @@ import time
from hpack.hpack import Encoder, Decoder
from netlib import http, utils, odict
from netlib.http import semantics
from . import frame
@ -13,7 +14,7 @@ class TCPHandler(object):
self.wfile = wfile
class HTTP2Protocol(object):
class HTTP2Protocol(semantics.ProtocolMixin):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
@ -59,26 +60,104 @@ class HTTP2Protocol(object):
self.current_stream_id = None
self.connection_preface_performed = False
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
break
stream_id, headers, body = self._receive_transmission(include_body)
def _read_settings_ack(self, hide=False): # pragma no cover
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
assert frm.flags & frame.Frame.FLAG_ACK
assert len(frm.settings) == 0
break
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
port = '' # TODO: parse port number?
request = http.Request(
"",
headers.get_first(':method', ['']),
headers.get_first(':scheme', ['']),
headers.get_first(':host', ['']),
port,
headers.get_first(':path', ['']),
(2, 0),
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(include_body)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = http.Response(
(2, 0),
headers[':status'][0],
"",
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def assemble_request(self, request):
assert isinstance(request, semantics.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = [
(b':method', bytes(request.method)),
(b':path', bytes(request.path)),
(b':scheme', b'https'),
(b':authority', authority),
] + request.headers.items()
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, semantics.Response)
headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items()
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None)),
self._create_body(response.body, stream_id),
))
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
@ -100,18 +179,6 @@ class HTTP2Protocol(object):
self.send_frame(frame.SettingsFrame(state=self), hide=True)
self._receive_settings(hide=True)
def next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def send_frame(self, frm, hide=False):
raw_bytes = frm.to_bytes()
self.tcp_handler.wfile.write(raw_bytes)
@ -128,6 +195,39 @@ class HTTP2Protocol(object):
return frm
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != self.ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
break
def _read_settings_ack(self, hide=False): # pragma no cover
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
assert frm.flags & frame.Frame.FLAG_ACK
assert len(frm.settings) == 0
break
def _next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
@ -181,89 +281,6 @@ class HTTP2Protocol(object):
return [frm.to_bytes()]
def create_request(self, method, path, headers=None, body=None):
if headers is None:
headers = []
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = [
(b':method', bytes(method)),
(b':path', bytes(path)),
(b':scheme', b'https'),
(b':authority', authority),
] + headers
stream_id = self.next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(body is None)),
self._create_body(body, stream_id)))
def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(include_body)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = http.Response(
(2, 0),
headers[':status'][0],
"",
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(include_body)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
port = '' # TODO: parse port number?
request = http.Request(
"",
headers.get_first(':method', ['']),
headers.get_first(':scheme', ['']),
headers.get_first(':host', ['']),
port,
headers.get_first(':path', ['']),
(2, 0),
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def _receive_transmission(self, include_body=True):
body_expected = True
@ -295,19 +312,3 @@ class HTTP2Protocol(object):
headers.add(header, value)
return stream_id, headers, body
def create_response(self, code, stream_id=None, headers=None, body=None):
if headers is None:
headers = []
if isinstance(headers, odict.ODict):
headers = headers.items()
headers = [(b':status', bytes(str(code)))] + headers
if not stream_id:
stream_id = self.next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(body is None)),
self._create_body(body, stream_id),
))

View File

@ -7,6 +7,32 @@ import urlparse
from .. import utils, odict
CONTENT_MISSING = 0
class ProtocolMixin(object):
def read_request(self):
raise NotImplemented
def read_response(self):
raise NotImplemented
def assemble(self, message):
if isinstance(message, Request):
return self.assemble_request(message)
elif isinstance(message, Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
raise NotImplemented
def assemble_response(self, response):
raise NotImplemented
class Request(object):
def __init__(
@ -18,12 +44,14 @@ class Request(object):
port,
path,
httpversion,
headers,
body,
headers=None,
body=None,
timestamp_start=None,
timestamp_end=None,
):
assert isinstance(headers, odict.ODictCaseless) or not headers
if not headers:
headers = odict.ODictCaseless()
assert isinstance(headers, odict.ODictCaseless)
self.form_in = form_in
self.method = method
@ -37,6 +65,7 @@ class Request(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@ -80,14 +109,16 @@ class Response(object):
self,
httpversion,
status_code,
msg,
headers,
body,
msg=None,
headers=None,
body=None,
sslinfo=None,
timestamp_start=None,
timestamp_end=None,
):
assert isinstance(headers, odict.ODictCaseless) or not headers
if not headers:
headers = odict.ODictCaseless()
assert isinstance(headers, odict.ODictCaseless)
self.httpversion = httpversion
self.status_code = status_code
@ -98,6 +129,7 @@ class Response(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]

View File

@ -129,3 +129,13 @@ class Data(object):
if not os.path.exists(fullpath):
raise ValueError("dataPath: %s does not exist." % fullpath)
return fullpath
def hostport(scheme, host, port):
"""
Returns the host component, with a port specifcation if needed.
"""
if (port, scheme) in [(80, "http"), (443, "https")]:
return host
else:
return "%s:%s" % (host, port)

View File

@ -1,6 +1,6 @@
import OpenSSL
from netlib import tcp, odict
from netlib import tcp, odict, http
from netlib.http import http2
from netlib.http.http2.frame import *
from ... import tutils, tservers
@ -117,11 +117,11 @@ class TestClientStreamIds():
def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None
assert self.protocol.next_stream_id() == 1
assert self.protocol._next_stream_id() == 1
assert self.protocol.current_stream_id == 1
assert self.protocol.next_stream_id() == 3
assert self.protocol._next_stream_id() == 3
assert self.protocol.current_stream_id == 3
assert self.protocol.next_stream_id() == 5
assert self.protocol._next_stream_id() == 5
assert self.protocol.current_stream_id == 5
@ -131,11 +131,11 @@ class TestServerStreamIds():
def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None
assert self.protocol.next_stream_id() == 2
assert self.protocol._next_stream_id() == 2
assert self.protocol.current_stream_id == 2
assert self.protocol.next_stream_id() == 4
assert self.protocol._next_stream_id() == 4
assert self.protocol.current_stream_id == 4
assert self.protocol.next_stream_id() == 6
assert self.protocol._next_stream_id() == 6
assert self.protocol.current_stream_id == 6
@ -215,17 +215,36 @@ class TestCreateBody():
# TODO: add test for too large frames
class TestCreateRequest():
class TestAssembleRequest():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_request_simple(self):
bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/')
def test_assemble_request_simple(self):
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'',
'',
'',
'/',
(2, 0),
None,
None,
))
assert len(bytes) == 1
assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
def test_create_request_with_body(self):
bytes = http2.HTTP2Protocol(self.c).create_request(
'GET', '/', [(b'foo', b'bar')], 'foobar')
def test_assemble_request_with_body(self):
bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
'',
'GET',
'',
'',
'',
'/',
(2, 0),
odict.ODictCaseless([('foo', 'bar')]),
'foobar',
))
assert len(bytes) == 2
assert bytes[0] ==\
'0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
@ -315,16 +334,24 @@ class TestCreateResponse():
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_response_simple(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200)
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
))
assert len(bytes) == 1
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
def test_create_response_with_body(self):
bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(
200, 1, [(b'foo', b'bar')], 'foobar')
bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
'',
odict.ODictCaseless([('foo', 'bar')]),
'foobar'
))
assert len(bytes) == 2
assert bytes[0] ==\
'00000901040000000188408294e7838c767f'.decode('hex')
'00000901040000000288408294e7838c767f'.decode('hex')
assert bytes[1] ==\
'000006000100000001666f6f626172'.decode('hex')
'000006000100000002666f6f626172'.decode('hex')