move custom HTTP/2 stack from netlib to pathod

This commit is contained in:
Thomas Kriechbaumer 2016-06-17 14:15:48 +02:00
parent fcf5dc8728
commit eb3ed87100
19 changed files with 533 additions and 519 deletions

View File

@ -1,8 +1,6 @@
from __future__ import absolute_import, print_function, division
from .connections import HTTP2Protocol
from netlib.http.http2 import framereader
__all__ = [
"HTTP2Protocol",
"framereader",
]

View File

@ -1,432 +0,0 @@
from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from netlib import utils, strutils
from netlib.http import url
import netlib.http.headers
import netlib.http.response
import netlib.http.request
from netlib.http.http2 import framereader
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP2Protocol(object):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
PROTOCOL_ERROR=0x1,
INTERNAL_ERROR=0x2,
FLOW_CONTROL_ERROR=0x3,
SETTINGS_TIMEOUT=0x4,
STREAM_CLOSED=0x5,
FRAME_SIZE_ERROR=0x6,
REFUSED_STREAM=0x7,
CANCEL=0x8,
COMPRESSION_ERROR=0x9,
CONNECT_ERROR=0xa,
ENHANCE_YOUR_CALM=0xb,
INADEQUATE_SECURITY=0xc,
HTTP_1_1_REQUIRED=0xd
)
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
self,
tcp_handler=None,
rfile=None,
wfile=None,
is_server=False,
dump_frames=False,
encoder=None,
decoder=None,
unhandled_frame_cb=None,
):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
self.dump_frames = dump_frames
self.encoder = encoder or Encoder()
self.decoder = decoder or Decoder()
self.unhandled_frame_cb = unhandled_frame_cb
self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(
self,
__rfile,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
authority = headers.get(':authority', b'')
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None
port = None
if path == '*' or path.startswith("/"):
first_line_format = "relative"
elif method == 'CONNECT':
first_line_format = "authority"
if ":" in authority:
host, port = authority.split(":", 1)
else:
host = authority
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
if host is None:
host = 'localhost'
if port is None:
port = 80 if scheme == 'http' else 443
port = int(port)
request = netlib.http.request.Request(
first_line_format,
method.encode('ascii'),
scheme.encode('ascii'),
host.encode('ascii'),
port,
path.encode('ascii'),
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def read_response(
self,
__rfile,
request_method=b'',
body_size_limit=None,
include_body=True,
stream_id=None,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
stream_id=stream_id,
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = netlib.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def assemble(self, message):
if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message)
elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = request.headers.copy()
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
headers.insert(0, b':path', request.path.encode('ascii'))
headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy()
if ':status' not in headers:
headers.insert(0, b':status', strutils.always_bytes(response.status_code))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
self._create_body(response.body, stream_id),
))
def perform_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
if self.is_server:
self.perform_server_connection_preface(force)
else:
self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
def send_frame(self, frm, hide=False):
raw_bytes = frm.serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
while True:
frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != b'h2':
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
def _next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
if not old_value:
old_value = '-'
self.http2_settings[setting] = value
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield hyperframe.frame.HeadersFrame, i
else:
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
stream_id=stream_id,
data=header_block_fragment[i:i + chunk_size]) for frm_cls, i in frame_cls(chunks)]
frms[-1].flags.add('END_HEADERS')
if end_stream:
frms[0].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _create_body(self, body, stream_id):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
frms[-1].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _receive_transmission(self, stream_id=None, include_body=True):
if not include_body:
raise NotImplementedError()
body_expected = True
header_blocks = b''
body = b''
while True:
frm = self.read_frame()
if (
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
header_blocks += frm.data
if 'END_STREAM' in frm.flags:
body_expected = False
if 'END_HEADERS' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
while body_expected:
frm = self.read_frame()
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body

View File

@ -11,18 +11,18 @@ import time
import OpenSSL.crypto
import six
import logging
from netlib.tutils import treq
from netlib import strutils
from netlib import tcp, certutils, websockets, socks
from netlib import exceptions
from netlib.http import http1
from netlib.http import http2
from netlib import basethread
from pathod import log, language
from . import log, language
from .protocols import http2
import logging
from netlib.tutils import treq
from netlib import strutils
logging.getLogger("hpack").setLevel(logging.WARNING)
@ -227,7 +227,7 @@ class Pathoc(tcp.TCPClient):
"Pathoc might not be working as expected without ALPN.",
timestamp=False
)
self.protocol = http2.HTTP2Protocol(self, dump_frames=self.http2_framedump)
self.protocol = http2.HTTP2StateProtocol(self, dump_frames=self.http2_framedump)
else:
self.protocol = http1

View File

@ -1,12 +1,445 @@
from netlib.http import http2
from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from netlib import utils, strutils
from netlib.http import url
from netlib.http.http2 import framereader
import netlib.http.headers
import netlib.http.response
import netlib.http.request
from .. import language
class HTTP2Protocol:
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP2StateProtocol(object):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
PROTOCOL_ERROR=0x1,
INTERNAL_ERROR=0x2,
FLOW_CONTROL_ERROR=0x3,
SETTINGS_TIMEOUT=0x4,
STREAM_CLOSED=0x5,
FRAME_SIZE_ERROR=0x6,
REFUSED_STREAM=0x7,
CANCEL=0x8,
COMPRESSION_ERROR=0x9,
CONNECT_ERROR=0xa,
ENHANCE_YOUR_CALM=0xb,
INADEQUATE_SECURITY=0xc,
HTTP_1_1_REQUIRED=0xd
)
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
self,
tcp_handler=None,
rfile=None,
wfile=None,
is_server=False,
dump_frames=False,
encoder=None,
decoder=None,
unhandled_frame_cb=None,
):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
self.dump_frames = dump_frames
self.encoder = encoder or Encoder()
self.decoder = decoder or Decoder()
self.unhandled_frame_cb = unhandled_frame_cb
self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(
self,
__rfile,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
authority = headers.get(':authority', b'')
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None
port = None
if path == '*' or path.startswith("/"):
first_line_format = "relative"
elif method == 'CONNECT':
first_line_format = "authority"
if ":" in authority:
host, port = authority.split(":", 1)
else:
host = authority
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
if host is None:
host = 'localhost'
if port is None:
port = 80 if scheme == 'http' else 443
port = int(port)
request = netlib.http.request.Request(
first_line_format,
method.encode('ascii'),
scheme.encode('ascii'),
host.encode('ascii'),
port,
path.encode('ascii'),
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def read_response(
self,
__rfile,
request_method=b'',
body_size_limit=None,
include_body=True,
stream_id=None,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
stream_id=stream_id,
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = netlib.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def assemble(self, message):
if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message)
elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = request.headers.copy()
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
headers.insert(0, b':path', request.path.encode('ascii'))
headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy()
if ':status' not in headers:
headers.insert(0, b':status', strutils.always_bytes(response.status_code))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
self._create_body(response.body, stream_id),
))
def perform_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
if self.is_server:
self.perform_server_connection_preface(force)
else:
self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
def send_frame(self, frm, hide=False):
raw_bytes = frm.serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
while True:
frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != b'h2':
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALPN value: %s" % alp)
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
def _next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
if not old_value:
old_value = '-'
self.http2_settings[setting] = value
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield hyperframe.frame.HeadersFrame, i
else:
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
stream_id=stream_id,
data=header_block_fragment[i:i + chunk_size]) for frm_cls, i in frame_cls(chunks)]
frms[-1].flags.add('END_HEADERS')
if end_stream:
frms[0].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _create_body(self, body, stream_id):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
frms[-1].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _receive_transmission(self, stream_id=None, include_body=True):
if not include_body:
raise NotImplementedError()
body_expected = True
header_blocks = b''
body = b''
while True:
frm = self.read_frame()
if (
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
header_blocks += frm.data
if 'END_STREAM' in frm.flags:
body_expected = False
if 'END_HEADERS' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
while body_expected:
frm = self.read_frame()
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body
class HTTP2Protocol(object):
def __init__(self, pathod_handler):
self.pathod_handler = pathod_handler
self.wire_protocol = http2.HTTP2Protocol(
self.wire_protocol = HTTP2StateProtocol(
self.pathod_handler, is_server=True, dump_frames=self.pathod_handler.http2_framedump
)

View File

@ -0,0 +1 @@
# foobar

1
test/pathod/__init__.py Normal file
View File

@ -0,0 +1 @@
from __future__ import (print_function, absolute_import, division)

View File

@ -1,11 +1,10 @@
from six import BytesIO
from pathod.language import actions
from pathod import language
from pathod.language import actions, parse_pathoc, parse_pathod, serve
def parse_request(s):
return next(language.parse_pathoc(s))
return next(parse_pathoc(s))
def test_unique_name():
@ -16,9 +15,9 @@ def test_unique_name():
class TestDisconnects:
def test_parse_pathod(self):
a = next(language.parse_pathod("400:d0")).actions[0]
a = next(parse_pathod("400:d0")).actions[0]
assert a.spec() == "d0"
a = next(language.parse_pathod("400:dr")).actions[0]
a = next(parse_pathod("400:dr")).actions[0]
assert a.spec() == "dr"
def test_at(self):
@ -42,12 +41,12 @@ class TestDisconnects:
class TestInject:
def test_parse_pathod(self):
a = next(language.parse_pathod("400:ir,@100")).actions[0]
a = next(parse_pathod("400:ir,@100")).actions[0]
assert a.offset == "r"
assert a.value.datatype == "bytes"
assert a.value.usize == 100
a = next(language.parse_pathod("400:ia,@100")).actions[0]
a = next(parse_pathod("400:ia,@100")).actions[0]
assert a.offset == "a"
def test_at(self):
@ -62,8 +61,8 @@ class TestInject:
def test_serve(self):
s = BytesIO()
r = next(language.parse_pathod("400:i0,'foo'"))
assert language.serve(r, s, {})
r = next(parse_pathod("400:i0,'foo'"))
assert serve(r, s, {})
def test_spec(self):
e = actions.InjectAt.expr()
@ -96,7 +95,7 @@ class TestPauses:
assert v.offset == "a"
def test_request(self):
r = next(language.parse_pathod('400:p10,10'))
r = next(parse_pathod('400:p10,10'))
assert r.actions[0].spec() == "p10,10"
def test_spec(self):

View File

@ -1,7 +1,8 @@
import os
from pathod import language
from pathod.language import base, exceptions
import tutils
from . import tutils
def parse_request(s):

View File

@ -1,7 +1,7 @@
import os
from pathod.language import generators
import tutils
from . import tutils
def test_randomgenerator():

View File

@ -1,7 +1,8 @@
from six import BytesIO
from pathod import language
from pathod.language import http, base
import tutils
from . import tutils
def parse_request(s):

View File

@ -1,12 +1,13 @@
from six import BytesIO
import netlib
from netlib import tcp
from netlib.http import user_agents
from pathod import language
from pathod.language import http2
import tutils
from pathod.protocols.http2 import HTTP2StateProtocol
from . import tutils
def parse_request(s):
@ -20,7 +21,7 @@ def parse_response(s):
def default_settings():
return language.Settings(
request_host="foo.com",
protocol=netlib.http.http2.HTTP2Protocol(tcp.TCPClient(('localhost', 1234)))
protocol=HTTP2StateProtocol(tcp.TCPClient(('localhost', 1234)))
)

View File

@ -2,7 +2,8 @@
from pathod import language
from pathod.language import websockets
import netlib.websockets
import tutils
from . import tutils
def parse_request(s):

View File

@ -5,11 +5,13 @@ from mock import Mock
from netlib import http
from netlib import tcp
from netlib.exceptions import NetlibException
from netlib.http import http1, http2
from netlib.http import http1
from netlib.tutils import raises
from pathod import pathoc, language
from netlib.tutils import raises
import tutils
from pathod.protocols.http2 import HTTP2StateProtocol
from . import tutils
def test_response():
@ -219,7 +221,7 @@ class TestDaemonHTTP2(PathocTestDaemon):
ssl=True,
use_http2=True,
)
assert isinstance(c.protocol, http2.HTTP2Protocol)
assert isinstance(c.protocol, HTTP2StateProtocol)
c = pathoc.Pathoc(
("127.0.0.1", self.d.port),

View File

@ -1,8 +1,10 @@
from pathod import pathoc_cmdline as cmdline
import tutils
from six.moves import cStringIO as StringIO
import mock
from pathod import pathoc_cmdline as cmdline
from . import tutils
@mock.patch("argparse.ArgumentParser.error")
def test_pathoc(perror):

View File

@ -3,7 +3,8 @@ from six.moves import cStringIO as StringIO
from pathod import pathod
from netlib import tcp
from netlib.exceptions import HttpException, TlsException
import tutils
from . import tutils
class TestPathod(object):

View File

@ -1,7 +1,9 @@
from pathod import pathod_cmdline as cmdline
import tutils
import mock
from pathod import pathod_cmdline as cmdline
from . import tutils
def test_parse_anchor_spec():
assert cmdline.parse_anchor_spec("foo=200") == ("foo", "200")

View File

@ -5,21 +5,22 @@ import hyperframe
from netlib import tcp, http
from netlib.tutils import raises
from netlib.exceptions import TcpDisconnect
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2 import framereader
from ... import tservers
from ..netlib import tservers as netlib_tservers
from pathod.protocols.http2 import HTTP2StateProtocol, TCPHandler
class TestTCPHandlerWrapper:
def test_wrapped(self):
h = TCPHandler(rfile='foo', wfile='bar')
p = HTTP2Protocol(h)
p = HTTP2StateProtocol(h)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
def test_direct(self):
p = HTTP2Protocol(rfile='foo', wfile='bar')
p = HTTP2StateProtocol(rfile='foo', wfile='bar')
assert isinstance(p.tcp_handler, TCPHandler)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
@ -36,10 +37,10 @@ class EchoHandler(tcp.BaseHandler):
class TestProtocol:
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
@mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface")
@mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface")
def test_perform_connection_preface(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=False)
protocol = HTTP2StateProtocol(is_server=False)
protocol.connection_preface_performed = True
protocol.perform_connection_preface()
@ -50,10 +51,10 @@ class TestProtocol:
assert mock_client_method.called
assert not mock_server_method.called
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
@mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface")
@mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface")
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=True)
protocol = HTTP2StateProtocol(is_server=True)
protocol.connection_preface_performed = True
protocol.perform_connection_preface()
@ -65,7 +66,7 @@ class TestProtocol:
assert mock_server_method.called
class TestCheckALPNMatch(tservers.ServerTestBase):
class TestCheckALPNMatch(netlib_tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
alpn_select=b'h2',
@ -77,11 +78,11 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
assert protocol.check_alpn()
class TestCheckALPNMismatch(tservers.ServerTestBase):
class TestCheckALPNMismatch(netlib_tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
alpn_select=None,
@ -93,12 +94,12 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
with raises(NotImplementedError):
protocol.check_alpn()
class TestPerformServerConnectionPreface(tservers.ServerTestBase):
class TestPerformServerConnectionPreface(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
@ -125,7 +126,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
assert not protocol.connection_preface_performed
protocol.perform_server_connection_preface()
@ -135,12 +136,12 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
protocol.perform_server_connection_preface(force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
class TestPerformClientConnectionPreface(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# check magic
assert self.rfile.read(24) == HTTP2Protocol.CLIENT_CONNECTION_PREFACE
assert self.rfile.read(24) == HTTP2StateProtocol.CLIENT_CONNECTION_PREFACE
# check empty settings frame
assert self.rfile.read(9) ==\
@ -161,7 +162,7 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
assert not protocol.connection_preface_performed
protocol.perform_client_connection_preface()
@ -170,7 +171,7 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
class TestClientStreamIds(object):
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None
@ -182,9 +183,9 @@ class TestClientStreamIds(object):
assert self.protocol.current_stream_id == 5
class TestServerStreamIds(object):
class TestserverstreamIds(object):
c = tcp.TCPClient(("127.0.0.1", 0))
protocol = HTTP2Protocol(c, is_server=True)
protocol = HTTP2StateProtocol(c, is_server=True)
def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None
@ -196,7 +197,7 @@ class TestServerStreamIds(object):
assert self.protocol.current_stream_id == 6
class TestApplySettings(tservers.ServerTestBase):
class TestApplySettings(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
# check settings acknowledgement
@ -211,7 +212,7 @@ class TestApplySettings(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
protocol._apply_settings({
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
@ -239,12 +240,12 @@ class TestCreateHeaders(object):
(b':scheme', b'https'),
(b'foo', b'bar')])
bytes = HTTP2Protocol(self.c)._create_headers(
bytes = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=True)
assert b''.join(bytes) ==\
codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
bytes = HTTP2Protocol(self.c)._create_headers(
bytes = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=False)
assert b''.join(bytes) ==\
codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
@ -257,7 +258,7 @@ class TestCreateHeaders(object):
(b'foo', b'bar'),
(b'server', b'version')])
protocol = HTTP2Protocol(self.c)
protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8
bytes = protocol._create_headers(headers, 1, end_stream=True)
assert len(bytes) == 3
@ -270,17 +271,17 @@ class TestCreateBody(object):
c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_body_empty(self):
protocol = HTTP2Protocol(self.c)
protocol = HTTP2StateProtocol(self.c)
bytes = protocol._create_body(b'', 1)
assert b''.join(bytes) == b''
def test_create_body_single_frame(self):
protocol = HTTP2Protocol(self.c)
protocol = HTTP2StateProtocol(self.c)
bytes = protocol._create_body(b'foobar', 1)
assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec')
def test_create_body_multiple_frames(self):
protocol = HTTP2Protocol(self.c)
protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5
bytes = protocol._create_body(b'foobarmehm42', 1)
assert len(bytes) == 3
@ -289,7 +290,7 @@ class TestCreateBody(object):
assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec')
class TestReadRequest(tservers.ServerTestBase):
class TestReadRequest(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
@ -306,7 +307,7 @@ class TestReadRequest(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True)
protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented)
@ -319,7 +320,7 @@ class TestReadRequest(tservers.ServerTestBase):
assert req.content == b'foobar'
class TestReadRequestRelative(tservers.ServerTestBase):
class TestReadRequestRelative(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
@ -332,7 +333,7 @@ class TestReadRequestRelative(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True)
protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented)
@ -342,7 +343,7 @@ class TestReadRequestRelative(tservers.ServerTestBase):
assert req.path == "*"
class TestReadRequestAbsolute(tservers.ServerTestBase):
class TestReadRequestAbsolute(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
@ -355,7 +356,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True)
protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented)
@ -366,7 +367,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
assert req.port == 22
class TestReadRequestConnect(tservers.ServerTestBase):
class TestReadRequestConnect(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
@ -381,7 +382,7 @@ class TestReadRequestConnect(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True)
protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented)
@ -397,7 +398,7 @@ class TestReadRequestConnect(tservers.ServerTestBase):
assert req.port == 443
class TestReadResponse(tservers.ServerTestBase):
class TestReadResponse(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
@ -413,7 +414,7 @@ class TestReadResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42)
@ -426,7 +427,7 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase):
class TestReadEmptyResponse(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):
self.wfile.write(
@ -439,7 +440,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl()
protocol = HTTP2Protocol(c)
protocol = HTTP2StateProtocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42)
@ -456,7 +457,7 @@ class TestAssembleRequest(object):
c = tcp.TCPClient(("127.0.0.1", 0))
def test_request_simple(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'',
b'GET',
b'https',
@ -483,12 +484,12 @@ class TestAssembleRequest(object):
None,
)
req.stream_id = 0x42
bytes = HTTP2Protocol(self.c).assemble_request(req)
bytes = HTTP2StateProtocol(self.c).assemble_request(req)
assert len(bytes) == 1
assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec')
def test_request_with_body(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'',
b'GET',
b'https',
@ -510,7 +511,7 @@ class TestAssembleResponse(object):
c = tcp.TCPClient(("127.0.0.1", 0))
def test_simple(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0",
200,
))
@ -524,13 +525,13 @@ class TestAssembleResponse(object):
200,
)
resp.stream_id = 0x42
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp)
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp)
assert len(bytes) == 1
assert bytes[0] ==\
codecs.decode('00000101050000004288', 'hex_codec')
def test_with_body(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0",
200,
b'',

View File

@ -1,7 +1,8 @@
import logging
import requests
from pathod import test
import tutils
from . import tutils
import requests.packages.urllib3

View File

@ -1,5 +1,6 @@
from pathod import utils
import tutils
from . import tutils
def test_membool():