move custom HTTP/2 stack from netlib to pathod

This commit is contained in:
Thomas Kriechbaumer 2016-06-17 14:15:48 +02:00
parent fcf5dc8728
commit eb3ed87100
19 changed files with 533 additions and 519 deletions

View File

@ -1,8 +1,6 @@
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from .connections import HTTP2Protocol
from netlib.http.http2 import framereader from netlib.http.http2 import framereader
__all__ = [ __all__ = [
"HTTP2Protocol",
"framereader", "framereader",
] ]

View File

@ -1,432 +0,0 @@
from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from netlib import utils, strutils
from netlib.http import url
import netlib.http.headers
import netlib.http.response
import netlib.http.request
from netlib.http.http2 import framereader
class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP2Protocol(object):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
PROTOCOL_ERROR=0x1,
INTERNAL_ERROR=0x2,
FLOW_CONTROL_ERROR=0x3,
SETTINGS_TIMEOUT=0x4,
STREAM_CLOSED=0x5,
FRAME_SIZE_ERROR=0x6,
REFUSED_STREAM=0x7,
CANCEL=0x8,
COMPRESSION_ERROR=0x9,
CONNECT_ERROR=0xa,
ENHANCE_YOUR_CALM=0xb,
INADEQUATE_SECURITY=0xc,
HTTP_1_1_REQUIRED=0xd
)
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
self,
tcp_handler=None,
rfile=None,
wfile=None,
is_server=False,
dump_frames=False,
encoder=None,
decoder=None,
unhandled_frame_cb=None,
):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
self.dump_frames = dump_frames
self.encoder = encoder or Encoder()
self.decoder = decoder or Decoder()
self.unhandled_frame_cb = unhandled_frame_cb
self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(
self,
__rfile,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
authority = headers.get(':authority', b'')
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None
port = None
if path == '*' or path.startswith("/"):
first_line_format = "relative"
elif method == 'CONNECT':
first_line_format = "authority"
if ":" in authority:
host, port = authority.split(":", 1)
else:
host = authority
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
if host is None:
host = 'localhost'
if port is None:
port = 80 if scheme == 'http' else 443
port = int(port)
request = netlib.http.request.Request(
first_line_format,
method.encode('ascii'),
scheme.encode('ascii'),
host.encode('ascii'),
port,
path.encode('ascii'),
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def read_response(
self,
__rfile,
request_method=b'',
body_size_limit=None,
include_body=True,
stream_id=None,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
stream_id=stream_id,
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = netlib.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def assemble(self, message):
if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message)
elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = request.headers.copy()
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
headers.insert(0, b':path', request.path.encode('ascii'))
headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy()
if ':status' not in headers:
headers.insert(0, b':status', strutils.always_bytes(response.status_code))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
self._create_body(response.body, stream_id),
))
def perform_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
if self.is_server:
self.perform_server_connection_preface(force)
else:
self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
def send_frame(self, frm, hide=False):
raw_bytes = frm.serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
while True:
frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != b'h2':
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
def _next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
if not old_value:
old_value = '-'
self.http2_settings[setting] = value
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield hyperframe.frame.HeadersFrame, i
else:
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
stream_id=stream_id,
data=header_block_fragment[i:i + chunk_size]) for frm_cls, i in frame_cls(chunks)]
frms[-1].flags.add('END_HEADERS')
if end_stream:
frms[0].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _create_body(self, body, stream_id):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
frms[-1].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _receive_transmission(self, stream_id=None, include_body=True):
if not include_body:
raise NotImplementedError()
body_expected = True
header_blocks = b''
body = b''
while True:
frm = self.read_frame()
if (
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
header_blocks += frm.data
if 'END_STREAM' in frm.flags:
body_expected = False
if 'END_HEADERS' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
while body_expected:
frm = self.read_frame()
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body

View File

@ -11,18 +11,18 @@ import time
import OpenSSL.crypto import OpenSSL.crypto
import six import six
import logging
from netlib.tutils import treq
from netlib import strutils
from netlib import tcp, certutils, websockets, socks from netlib import tcp, certutils, websockets, socks
from netlib import exceptions from netlib import exceptions
from netlib.http import http1 from netlib.http import http1
from netlib.http import http2
from netlib import basethread from netlib import basethread
from pathod import log, language from . import log, language
from .protocols import http2
import logging
from netlib.tutils import treq
from netlib import strutils
logging.getLogger("hpack").setLevel(logging.WARNING) logging.getLogger("hpack").setLevel(logging.WARNING)
@ -227,7 +227,7 @@ class Pathoc(tcp.TCPClient):
"Pathoc might not be working as expected without ALPN.", "Pathoc might not be working as expected without ALPN.",
timestamp=False timestamp=False
) )
self.protocol = http2.HTTP2Protocol(self, dump_frames=self.http2_framedump) self.protocol = http2.HTTP2StateProtocol(self, dump_frames=self.http2_framedump)
else: else:
self.protocol = http1 self.protocol = http1

View File

@ -1,12 +1,445 @@
from netlib.http import http2 from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from netlib import utils, strutils
from netlib.http import url
from netlib.http.http2 import framereader
import netlib.http.headers
import netlib.http.response
import netlib.http.request
from .. import language from .. import language
class HTTP2Protocol: class TCPHandler(object):
def __init__(self, rfile, wfile=None):
self.rfile = rfile
self.wfile = wfile
class HTTP2StateProtocol(object):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
PROTOCOL_ERROR=0x1,
INTERNAL_ERROR=0x2,
FLOW_CONTROL_ERROR=0x3,
SETTINGS_TIMEOUT=0x4,
STREAM_CLOSED=0x5,
FRAME_SIZE_ERROR=0x6,
REFUSED_STREAM=0x7,
CANCEL=0x8,
COMPRESSION_ERROR=0x9,
CONNECT_ERROR=0xa,
ENHANCE_YOUR_CALM=0xb,
INADEQUATE_SECURITY=0xc,
HTTP_1_1_REQUIRED=0xd
)
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
self,
tcp_handler=None,
rfile=None,
wfile=None,
is_server=False,
dump_frames=False,
encoder=None,
decoder=None,
unhandled_frame_cb=None,
):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
self.is_server = is_server
self.dump_frames = dump_frames
self.encoder = encoder or Encoder()
self.decoder = decoder or Decoder()
self.unhandled_frame_cb = unhandled_frame_cb
self.http2_settings = self.HTTP2_DEFAULT_SETTINGS.copy()
self.current_stream_id = None
self.connection_preface_performed = False
def read_request(
self,
__rfile,
include_body=True,
body_size_limit=None,
allow_empty=False,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
timestamp_end = time.time()
authority = headers.get(':authority', b'')
method = headers.get(':method', 'GET')
scheme = headers.get(':scheme', 'https')
path = headers.get(':path', '/')
headers.clear(":method")
headers.clear(":scheme")
headers.clear(":path")
host = None
port = None
if path == '*' or path.startswith("/"):
first_line_format = "relative"
elif method == 'CONNECT':
first_line_format = "authority"
if ":" in authority:
host, port = authority.split(":", 1)
else:
host = authority
else:
first_line_format = "absolute"
# FIXME: verify if path or :host contains what we need
scheme, host, port, _ = url.parse(path)
scheme = scheme.decode('ascii')
host = host.decode('ascii')
if host is None:
host = 'localhost'
if port is None:
port = 80 if scheme == 'http' else 443
port = int(port)
request = netlib.http.request.Request(
first_line_format,
method.encode('ascii'),
scheme.encode('ascii'),
host.encode('ascii'),
port,
path.encode('ascii'),
b"HTTP/2.0",
headers,
body,
timestamp_start,
timestamp_end,
)
request.stream_id = stream_id
return request
def read_response(
self,
__rfile,
request_method=b'',
body_size_limit=None,
include_body=True,
stream_id=None,
):
if body_size_limit is not None:
raise NotImplementedError()
self.perform_connection_preface()
timestamp_start = time.time()
if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
self.tcp_handler.rfile.reset_timestamps()
stream_id, headers, body = self._receive_transmission(
stream_id=stream_id,
include_body=include_body,
)
if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
# more accurate timestamp_start
timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
if include_body:
timestamp_end = time.time()
else:
timestamp_end = None
response = netlib.http.response.Response(
b"HTTP/2.0",
int(headers.get(':status', 502)),
b'',
headers,
body,
timestamp_start=timestamp_start,
timestamp_end=timestamp_end,
)
response.stream_id = stream_id
return response
def assemble(self, message):
if isinstance(message, netlib.http.request.Request):
return self.assemble_request(message)
elif isinstance(message, netlib.http.response.Response):
return self.assemble_response(message)
else:
raise ValueError("HTTP message not supported.")
def assemble_request(self, request):
assert isinstance(request, netlib.http.request.Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
authority += ":%d" % self.tcp_handler.address.port
headers = request.headers.copy()
if ':authority' not in headers:
headers.insert(0, b':authority', authority.encode('ascii'))
headers.insert(0, b':scheme', request.scheme.encode('ascii'))
headers.insert(0, b':path', request.path.encode('ascii'))
headers.insert(0, b':method', request.method.encode('ascii'))
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)),
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
assert isinstance(response, netlib.http.response.Response)
headers = response.headers.copy()
if ':status' not in headers:
headers.insert(0, b':status', strutils.always_bytes(response.status_code))
if hasattr(response, 'stream_id'):
stream_id = response.stream_id
else:
stream_id = self._next_stream_id()
return list(itertools.chain(
self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)),
self._create_body(response.body, stream_id),
))
def perform_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
if self.is_server:
self.perform_server_connection_preface(force)
else:
self.perform_client_connection_preface(force)
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
magic_length = len(self.CLIENT_CONNECTION_PREFACE)
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
def perform_client_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
self.connection_preface_performed = True
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
def send_frame(self, frm, hide=False):
raw_bytes = frm.serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable(">>"))
def read_frame(self, hide=False):
while True:
frm = framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
if alp != b'h2':
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALPN value: %s" % alp)
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
def _next_stream_id(self):
if self.current_stream_id is None:
if self.is_server:
# servers must use even stream ids
self.current_stream_id = 2
else:
# clients must use odd stream ids
self.current_stream_id = 1
else:
self.current_stream_id += 2
return self.current_stream_id
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
if not old_value:
old_value = '-'
self.http2_settings[setting] = value
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield hyperframe.frame.HeadersFrame, i
else:
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
stream_id=stream_id,
data=header_block_fragment[i:i + chunk_size]) for frm_cls, i in frame_cls(chunks)]
frms[-1].flags.add('END_HEADERS')
if end_stream:
frms[0].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _create_body(self, body, stream_id):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
frms[-1].flags.add('END_STREAM')
if self.dump_frames: # pragma no cover
for frm in frms:
print(frm.human_readable(">>"))
return [frm.serialize() for frm in frms]
def _receive_transmission(self, stream_id=None, include_body=True):
if not include_body:
raise NotImplementedError()
body_expected = True
header_blocks = b''
body = b''
while True:
frm = self.read_frame()
if (
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
header_blocks += frm.data
if 'END_STREAM' in frm.flags:
body_expected = False
if 'END_HEADERS' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
while body_expected:
frm = self.read_frame()
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break
else:
self._handle_unexpected_frame(frm)
headers = netlib.http.headers.Headers(
(k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
)
return stream_id, headers, body
class HTTP2Protocol(object):
def __init__(self, pathod_handler): def __init__(self, pathod_handler):
self.pathod_handler = pathod_handler self.pathod_handler = pathod_handler
self.wire_protocol = http2.HTTP2Protocol( self.wire_protocol = HTTP2StateProtocol(
self.pathod_handler, is_server=True, dump_frames=self.pathod_handler.http2_framedump self.pathod_handler, is_server=True, dump_frames=self.pathod_handler.http2_framedump
) )

View File

@ -0,0 +1 @@
# foobar

1
test/pathod/__init__.py Normal file
View File

@ -0,0 +1 @@
from __future__ import (print_function, absolute_import, division)

View File

@ -1,11 +1,10 @@
from six import BytesIO from six import BytesIO
from pathod.language import actions from pathod.language import actions, parse_pathoc, parse_pathod, serve
from pathod import language
def parse_request(s): def parse_request(s):
return next(language.parse_pathoc(s)) return next(parse_pathoc(s))
def test_unique_name(): def test_unique_name():
@ -16,9 +15,9 @@ def test_unique_name():
class TestDisconnects: class TestDisconnects:
def test_parse_pathod(self): def test_parse_pathod(self):
a = next(language.parse_pathod("400:d0")).actions[0] a = next(parse_pathod("400:d0")).actions[0]
assert a.spec() == "d0" assert a.spec() == "d0"
a = next(language.parse_pathod("400:dr")).actions[0] a = next(parse_pathod("400:dr")).actions[0]
assert a.spec() == "dr" assert a.spec() == "dr"
def test_at(self): def test_at(self):
@ -42,12 +41,12 @@ class TestDisconnects:
class TestInject: class TestInject:
def test_parse_pathod(self): def test_parse_pathod(self):
a = next(language.parse_pathod("400:ir,@100")).actions[0] a = next(parse_pathod("400:ir,@100")).actions[0]
assert a.offset == "r" assert a.offset == "r"
assert a.value.datatype == "bytes" assert a.value.datatype == "bytes"
assert a.value.usize == 100 assert a.value.usize == 100
a = next(language.parse_pathod("400:ia,@100")).actions[0] a = next(parse_pathod("400:ia,@100")).actions[0]
assert a.offset == "a" assert a.offset == "a"
def test_at(self): def test_at(self):
@ -62,8 +61,8 @@ class TestInject:
def test_serve(self): def test_serve(self):
s = BytesIO() s = BytesIO()
r = next(language.parse_pathod("400:i0,'foo'")) r = next(parse_pathod("400:i0,'foo'"))
assert language.serve(r, s, {}) assert serve(r, s, {})
def test_spec(self): def test_spec(self):
e = actions.InjectAt.expr() e = actions.InjectAt.expr()
@ -96,7 +95,7 @@ class TestPauses:
assert v.offset == "a" assert v.offset == "a"
def test_request(self): def test_request(self):
r = next(language.parse_pathod('400:p10,10')) r = next(parse_pathod('400:p10,10'))
assert r.actions[0].spec() == "p10,10" assert r.actions[0].spec() == "p10,10"
def test_spec(self): def test_spec(self):

View File

@ -1,7 +1,8 @@
import os import os
from pathod import language from pathod import language
from pathod.language import base, exceptions from pathod.language import base, exceptions
import tutils
from . import tutils
def parse_request(s): def parse_request(s):

View File

@ -1,7 +1,7 @@
import os import os
from pathod.language import generators from pathod.language import generators
import tutils from . import tutils
def test_randomgenerator(): def test_randomgenerator():

View File

@ -1,7 +1,8 @@
from six import BytesIO from six import BytesIO
from pathod import language from pathod import language
from pathod.language import http, base from pathod.language import http, base
import tutils
from . import tutils
def parse_request(s): def parse_request(s):

View File

@ -1,12 +1,13 @@
from six import BytesIO from six import BytesIO
import netlib
from netlib import tcp from netlib import tcp
from netlib.http import user_agents from netlib.http import user_agents
from pathod import language from pathod import language
from pathod.language import http2 from pathod.language import http2
import tutils from pathod.protocols.http2 import HTTP2StateProtocol
from . import tutils
def parse_request(s): def parse_request(s):
@ -20,7 +21,7 @@ def parse_response(s):
def default_settings(): def default_settings():
return language.Settings( return language.Settings(
request_host="foo.com", request_host="foo.com",
protocol=netlib.http.http2.HTTP2Protocol(tcp.TCPClient(('localhost', 1234))) protocol=HTTP2StateProtocol(tcp.TCPClient(('localhost', 1234)))
) )

View File

@ -2,7 +2,8 @@
from pathod import language from pathod import language
from pathod.language import websockets from pathod.language import websockets
import netlib.websockets import netlib.websockets
import tutils
from . import tutils
def parse_request(s): def parse_request(s):

View File

@ -5,11 +5,13 @@ from mock import Mock
from netlib import http from netlib import http
from netlib import tcp from netlib import tcp
from netlib.exceptions import NetlibException from netlib.exceptions import NetlibException
from netlib.http import http1, http2 from netlib.http import http1
from netlib.tutils import raises
from pathod import pathoc, language from pathod import pathoc, language
from netlib.tutils import raises from pathod.protocols.http2 import HTTP2StateProtocol
import tutils
from . import tutils
def test_response(): def test_response():
@ -219,7 +221,7 @@ class TestDaemonHTTP2(PathocTestDaemon):
ssl=True, ssl=True,
use_http2=True, use_http2=True,
) )
assert isinstance(c.protocol, http2.HTTP2Protocol) assert isinstance(c.protocol, HTTP2StateProtocol)
c = pathoc.Pathoc( c = pathoc.Pathoc(
("127.0.0.1", self.d.port), ("127.0.0.1", self.d.port),

View File

@ -1,8 +1,10 @@
from pathod import pathoc_cmdline as cmdline
import tutils
from six.moves import cStringIO as StringIO from six.moves import cStringIO as StringIO
import mock import mock
from pathod import pathoc_cmdline as cmdline
from . import tutils
@mock.patch("argparse.ArgumentParser.error") @mock.patch("argparse.ArgumentParser.error")
def test_pathoc(perror): def test_pathoc(perror):

View File

@ -3,7 +3,8 @@ from six.moves import cStringIO as StringIO
from pathod import pathod from pathod import pathod
from netlib import tcp from netlib import tcp
from netlib.exceptions import HttpException, TlsException from netlib.exceptions import HttpException, TlsException
import tutils
from . import tutils
class TestPathod(object): class TestPathod(object):

View File

@ -1,7 +1,9 @@
from pathod import pathod_cmdline as cmdline
import tutils
import mock import mock
from pathod import pathod_cmdline as cmdline
from . import tutils
def test_parse_anchor_spec(): def test_parse_anchor_spec():
assert cmdline.parse_anchor_spec("foo=200") == ("foo", "200") assert cmdline.parse_anchor_spec("foo=200") == ("foo", "200")

View File

@ -5,21 +5,22 @@ import hyperframe
from netlib import tcp, http from netlib import tcp, http
from netlib.tutils import raises from netlib.tutils import raises
from netlib.exceptions import TcpDisconnect from netlib.exceptions import TcpDisconnect
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2 import framereader from netlib.http.http2 import framereader
from ... import tservers from ..netlib import tservers as netlib_tservers
from pathod.protocols.http2 import HTTP2StateProtocol, TCPHandler
class TestTCPHandlerWrapper: class TestTCPHandlerWrapper:
def test_wrapped(self): def test_wrapped(self):
h = TCPHandler(rfile='foo', wfile='bar') h = TCPHandler(rfile='foo', wfile='bar')
p = HTTP2Protocol(h) p = HTTP2StateProtocol(h)
assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar' assert p.tcp_handler.wfile == 'bar'
def test_direct(self): def test_direct(self):
p = HTTP2Protocol(rfile='foo', wfile='bar') p = HTTP2StateProtocol(rfile='foo', wfile='bar')
assert isinstance(p.tcp_handler, TCPHandler) assert isinstance(p.tcp_handler, TCPHandler)
assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar' assert p.tcp_handler.wfile == 'bar'
@ -36,10 +37,10 @@ class EchoHandler(tcp.BaseHandler):
class TestProtocol: class TestProtocol:
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface")
def test_perform_connection_preface(self, mock_client_method, mock_server_method): def test_perform_connection_preface(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=False) protocol = HTTP2StateProtocol(is_server=False)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
protocol.perform_connection_preface() protocol.perform_connection_preface()
@ -50,10 +51,10 @@ class TestProtocol:
assert mock_client_method.called assert mock_client_method.called
assert not mock_server_method.called assert not mock_server_method.called
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_server_connection_preface")
@mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") @mock.patch("pathod.protocols.http2.HTTP2StateProtocol.perform_client_connection_preface")
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=True) protocol = HTTP2StateProtocol(is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
protocol.perform_connection_preface() protocol.perform_connection_preface()
@ -65,7 +66,7 @@ class TestProtocol:
assert mock_server_method.called assert mock_server_method.called
class TestCheckALPNMatch(tservers.ServerTestBase): class TestCheckALPNMatch(netlib_tservers.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
alpn_select=b'h2', alpn_select=b'h2',
@ -77,11 +78,11 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2']) c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
assert protocol.check_alpn() assert protocol.check_alpn()
class TestCheckALPNMismatch(tservers.ServerTestBase): class TestCheckALPNMismatch(netlib_tservers.ServerTestBase):
handler = EchoHandler handler = EchoHandler
ssl = dict( ssl = dict(
alpn_select=None, alpn_select=None,
@ -93,12 +94,12 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl(alpn_protos=[b'h2']) c.convert_to_ssl(alpn_protos=[b'h2'])
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
with raises(NotImplementedError): with raises(NotImplementedError):
protocol.check_alpn() protocol.check_alpn()
class TestPerformServerConnectionPreface(tservers.ServerTestBase): class TestPerformServerConnectionPreface(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
@ -125,7 +126,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self): def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
assert not protocol.connection_preface_performed assert not protocol.connection_preface_performed
protocol.perform_server_connection_preface() protocol.perform_server_connection_preface()
@ -135,12 +136,12 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
protocol.perform_server_connection_preface(force=True) protocol.perform_server_connection_preface(force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase): class TestPerformClientConnectionPreface(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
# check magic # check magic
assert self.rfile.read(24) == HTTP2Protocol.CLIENT_CONNECTION_PREFACE assert self.rfile.read(24) == HTTP2StateProtocol.CLIENT_CONNECTION_PREFACE
# check empty settings frame # check empty settings frame
assert self.rfile.read(9) ==\ assert self.rfile.read(9) ==\
@ -161,7 +162,7 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self): def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
assert not protocol.connection_preface_performed assert not protocol.connection_preface_performed
protocol.perform_client_connection_preface() protocol.perform_client_connection_preface()
@ -170,7 +171,7 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
class TestClientStreamIds(object): class TestClientStreamIds(object):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
def test_client_stream_ids(self): def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None assert self.protocol.current_stream_id is None
@ -182,9 +183,9 @@ class TestClientStreamIds(object):
assert self.protocol.current_stream_id == 5 assert self.protocol.current_stream_id == 5
class TestServerStreamIds(object): class TestserverstreamIds(object):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2StateProtocol(c, is_server=True)
def test_server_stream_ids(self): def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None assert self.protocol.current_stream_id is None
@ -196,7 +197,7 @@ class TestServerStreamIds(object):
assert self.protocol.current_stream_id == 6 assert self.protocol.current_stream_id == 6
class TestApplySettings(tservers.ServerTestBase): class TestApplySettings(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
# check settings acknowledgement # check settings acknowledgement
@ -211,7 +212,7 @@ class TestApplySettings(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
protocol._apply_settings({ protocol._apply_settings({
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo', hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
@ -239,12 +240,12 @@ class TestCreateHeaders(object):
(b':scheme', b'https'), (b':scheme', b'https'),
(b'foo', b'bar')]) (b'foo', b'bar')])
bytes = HTTP2Protocol(self.c)._create_headers( bytes = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=True) headers, 1, end_stream=True)
assert b''.join(bytes) ==\ assert b''.join(bytes) ==\
codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') codecs.decode('000014010500000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
bytes = HTTP2Protocol(self.c)._create_headers( bytes = HTTP2StateProtocol(self.c)._create_headers(
headers, 1, end_stream=False) headers, 1, end_stream=False)
assert b''.join(bytes) ==\ assert b''.join(bytes) ==\
codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec') codecs.decode('000014010400000001824488355217caf3a69a3f87408294e7838c767f', 'hex_codec')
@ -257,7 +258,7 @@ class TestCreateHeaders(object):
(b'foo', b'bar'), (b'foo', b'bar'),
(b'server', b'version')]) (b'server', b'version')])
protocol = HTTP2Protocol(self.c) protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8 protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8
bytes = protocol._create_headers(headers, 1, end_stream=True) bytes = protocol._create_headers(headers, 1, end_stream=True)
assert len(bytes) == 3 assert len(bytes) == 3
@ -270,17 +271,17 @@ class TestCreateBody(object):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
def test_create_body_empty(self): def test_create_body_empty(self):
protocol = HTTP2Protocol(self.c) protocol = HTTP2StateProtocol(self.c)
bytes = protocol._create_body(b'', 1) bytes = protocol._create_body(b'', 1)
assert b''.join(bytes) == b'' assert b''.join(bytes) == b''
def test_create_body_single_frame(self): def test_create_body_single_frame(self):
protocol = HTTP2Protocol(self.c) protocol = HTTP2StateProtocol(self.c)
bytes = protocol._create_body(b'foobar', 1) bytes = protocol._create_body(b'foobar', 1)
assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec') assert b''.join(bytes) == codecs.decode('000006000100000001666f6f626172', 'hex_codec')
def test_create_body_multiple_frames(self): def test_create_body_multiple_frames(self):
protocol = HTTP2Protocol(self.c) protocol = HTTP2StateProtocol(self.c)
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5 protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5
bytes = protocol._create_body(b'foobarmehm42', 1) bytes = protocol._create_body(b'foobarmehm42', 1)
assert len(bytes) == 3 assert len(bytes) == 3
@ -289,7 +290,7 @@ class TestCreateBody(object):
assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec') assert bytes[2] == codecs.decode('0000020001000000013432', 'hex_codec')
class TestReadRequest(tservers.ServerTestBase): class TestReadRequest(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
@ -306,7 +307,7 @@ class TestReadRequest(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
@ -319,7 +320,7 @@ class TestReadRequest(tservers.ServerTestBase):
assert req.content == b'foobar' assert req.content == b'foobar'
class TestReadRequestRelative(tservers.ServerTestBase): class TestReadRequestRelative(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
self.wfile.write( self.wfile.write(
@ -332,7 +333,7 @@ class TestReadRequestRelative(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
@ -342,7 +343,7 @@ class TestReadRequestRelative(tservers.ServerTestBase):
assert req.path == "*" assert req.path == "*"
class TestReadRequestAbsolute(tservers.ServerTestBase): class TestReadRequestAbsolute(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
self.wfile.write( self.wfile.write(
@ -355,7 +356,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
@ -366,7 +367,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
assert req.port == 22 assert req.port == 22
class TestReadRequestConnect(tservers.ServerTestBase): class TestReadRequestConnect(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
self.wfile.write( self.wfile.write(
@ -381,7 +382,7 @@ class TestReadRequestConnect(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c, is_server=True) protocol = HTTP2StateProtocol(c, is_server=True)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
req = protocol.read_request(NotImplemented) req = protocol.read_request(NotImplemented)
@ -397,7 +398,7 @@ class TestReadRequestConnect(tservers.ServerTestBase):
assert req.port == 443 assert req.port == 443
class TestReadResponse(tservers.ServerTestBase): class TestReadResponse(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
self.wfile.write( self.wfile.write(
@ -413,7 +414,7 @@ class TestReadResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42) resp = protocol.read_response(NotImplemented, stream_id=42)
@ -426,7 +427,7 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.timestamp_end assert resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase): class TestReadEmptyResponse(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler): class handler(tcp.BaseHandler):
def handle(self): def handle(self):
self.wfile.write( self.wfile.write(
@ -439,7 +440,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect(): with c.connect():
c.convert_to_ssl() c.convert_to_ssl()
protocol = HTTP2Protocol(c) protocol = HTTP2StateProtocol(c)
protocol.connection_preface_performed = True protocol.connection_preface_performed = True
resp = protocol.read_response(NotImplemented, stream_id=42) resp = protocol.read_response(NotImplemented, stream_id=42)
@ -456,7 +457,7 @@ class TestAssembleRequest(object):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
def test_request_simple(self): def test_request_simple(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request( bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'', b'',
b'GET', b'GET',
b'https', b'https',
@ -483,12 +484,12 @@ class TestAssembleRequest(object):
None, None,
) )
req.stream_id = 0x42 req.stream_id = 0x42
bytes = HTTP2Protocol(self.c).assemble_request(req) bytes = HTTP2StateProtocol(self.c).assemble_request(req)
assert len(bytes) == 1 assert len(bytes) == 1
assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec') assert bytes[0] == codecs.decode('00000d0105000000428284874188089d5c0b8170dc07', 'hex_codec')
def test_request_with_body(self): def test_request_with_body(self):
bytes = HTTP2Protocol(self.c).assemble_request(http.Request( bytes = HTTP2StateProtocol(self.c).assemble_request(http.Request(
b'', b'',
b'GET', b'GET',
b'https', b'https',
@ -510,7 +511,7 @@ class TestAssembleResponse(object):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
def test_simple(self): def test_simple(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0", b"HTTP/2.0",
200, 200,
)) ))
@ -524,13 +525,13 @@ class TestAssembleResponse(object):
200, 200,
) )
resp.stream_id = 0x42 resp.stream_id = 0x42
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp) bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(resp)
assert len(bytes) == 1 assert len(bytes) == 1
assert bytes[0] ==\ assert bytes[0] ==\
codecs.decode('00000101050000004288', 'hex_codec') codecs.decode('00000101050000004288', 'hex_codec')
def test_with_body(self): def test_with_body(self):
bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response( bytes = HTTP2StateProtocol(self.c, is_server=True).assemble_response(http.Response(
b"HTTP/2.0", b"HTTP/2.0",
200, 200,
b'', b'',

View File

@ -1,7 +1,8 @@
import logging import logging
import requests import requests
from pathod import test from pathod import test
import tutils
from . import tutils
import requests.packages.urllib3 import requests.packages.urllib3

View File

@ -1,5 +1,6 @@
from pathod import utils from pathod import utils
import tutils
from . import tutils
def test_membool(): def test_membool():