From 08fbe6f1118455bc44d05db30b83bdf81feda2a0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 31 May 2016 17:16:31 +1200 Subject: [PATCH] Start cleaning up netlib.utils - Remove http2 functions, move to http2.frame - Remove Serializable, move to netlib.basetypes --- mitmproxy/models/tcp.py | 4 +- mitmproxy/protocol/http2.py | 5 +- mitmproxy/stateobject.py | 4 +- netlib/basetypes.py | 33 ++++++++++++ netlib/certutils.py | 4 +- netlib/http/http2/connections.py | 59 +++++++++++----------- netlib/http/http2/frame.py | 21 ++++++++ netlib/http/message.py | 17 +++---- netlib/http/request.py | 2 +- netlib/multidict.py | 5 +- netlib/odict.py | 8 +-- netlib/tcp.py | 4 +- netlib/utils.py | 52 ------------------- test/mitmproxy/test_protocol_http2.py | 14 ++--- test/netlib/http/http2/test_connections.py | 26 +++++----- test/netlib/test_basetypes.py | 27 ++++++++++ test/netlib/test_utils.py | 27 ---------- 17 files changed, 157 insertions(+), 155 deletions(-) create mode 100644 netlib/basetypes.py create mode 100644 netlib/http/http2/frame.py create mode 100644 test/netlib/test_basetypes.py diff --git a/mitmproxy/models/tcp.py b/mitmproxy/models/tcp.py index b87a74acf..c7cfb9f8a 100644 --- a/mitmproxy/models/tcp.py +++ b/mitmproxy/models/tcp.py @@ -1,11 +1,11 @@ import time from typing import List -from netlib.utils import Serializable +import netlib.basetypes from .flow import Flow -class TCPMessage(Serializable): +class TCPMessage(netlib.basetypes.Serializable): def __init__(self, from_client, content, timestamp=None): self.content = content diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index b41016768..24460ec91 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -14,7 +14,8 @@ from hyperframe.frame import PriorityFrame from netlib.tcp import ssl_read_select from netlib.exceptions import HttpException from netlib.http import Headers -from netlib.utils import http2_read_raw_frame, parse_url +from netlib.utils import parse_url +from netlib.http.http2 import frame from .base import Layer from .http import _HttpTransmissionLayer, HttpLayer @@ -233,7 +234,7 @@ class Http2Layer(Layer): with source_conn.h2.lock: try: - raw_frame = b''.join(http2_read_raw_frame(source_conn.rfile)) + raw_frame = b''.join(frame.http2_read_raw_frame(source_conn.rfile)) except: # read frame failed: connection closed self._kill_all_streams() diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index 765c35d6c..eb57fa00a 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -3,7 +3,7 @@ from __future__ import absolute_import import six from typing import List, Any -from netlib.utils import Serializable +import netlib.basetypes def _is_list(cls): @@ -13,7 +13,7 @@ def _is_list(cls): return issubclass(cls, List) or is_list_bugfix -class StateObject(Serializable): +class StateObject(netlib.basetypes.Serializable): """ An object with serializable state. diff --git a/netlib/basetypes.py b/netlib/basetypes.py new file mode 100644 index 000000000..d03246ff1 --- /dev/null +++ b/netlib/basetypes.py @@ -0,0 +1,33 @@ +import six +import abc + +@six.add_metaclass(abc.ABCMeta) +class Serializable(object): + """ + Abstract Base Class that defines an API to save an object's state and restore it later on. + """ + + @classmethod + @abc.abstractmethod + def from_state(cls, state): + """ + Create a new object from the given state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_state(self): + """ + Retrieve object state. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_state(self, state): + """ + Set object state to the given state. + """ + raise NotImplementedError() + + def copy(self): + return self.from_state(self.get_state()) diff --git a/netlib/certutils.py b/netlib/certutils.py index 34e01ed37..4a19d170a 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -12,7 +12,7 @@ from pyasn1.codec.der.decoder import decode from pyasn1.error import PyAsn1Error import OpenSSL -from .utils import Serializable +from . import basetypes # Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815 @@ -364,7 +364,7 @@ class _GeneralNames(univ.SequenceOf): constraint.ValueSizeConstraint(1, 1024) -class SSLCert(Serializable): +class SSLCert(basetypes.Serializable): def __init__(self, cert): """ diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 6b91f2ff6..03f1804b2 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -2,11 +2,12 @@ from __future__ import (absolute_import, print_function, division) import itertools import time +import hyperframe.frame + from hpack.hpack import Encoder, Decoder from ... import utils from .. import Headers, Response, Request - -from hyperframe import frame +from . import frame class TCPHandler(object): @@ -38,12 +39,12 @@ class HTTP2Protocol(object): CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n' HTTP2_DEFAULT_SETTINGS = { - frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, - frame.SettingsFrame.ENABLE_PUSH: 1, - frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, - frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, - frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, + hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096, + hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1, + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None, + hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14, + hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None, } def __init__( @@ -253,9 +254,9 @@ class HTTP2Protocol(object): magic = self.tcp_handler.rfile.safe_read(magic_length) assert magic == self.CLIENT_CONNECTION_PREFACE - frm = frame.SettingsFrame(settings={ - frame.SettingsFrame.ENABLE_PUSH: 0, - frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, + frm = hyperframe.frame.SettingsFrame(settings={ + hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0, + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1, }) self.send_frame(frm, hide=True) self._receive_settings(hide=True) @@ -266,7 +267,7 @@ class HTTP2Protocol(object): self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - self.send_frame(frame.SettingsFrame(), hide=True) + self.send_frame(hyperframe.frame.SettingsFrame(), hide=True) self._receive_settings(hide=True) # server announces own settings self._receive_settings(hide=True) # server acks my settings @@ -279,18 +280,18 @@ class HTTP2Protocol(object): def read_frame(self, hide=False): while True: - frm = utils.http2_read_frame(self.tcp_handler.rfile) + frm = frame.http2_read_frame(self.tcp_handler.rfile) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) - if isinstance(frm, frame.PingFrame): - raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() + if isinstance(frm, hyperframe.frame.PingFrame): + raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize() self.tcp_handler.wfile.write(raw_bytes) self.tcp_handler.wfile.flush() continue - if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags: + if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags: self._apply_settings(frm.settings, hide) - if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0: + if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0: self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length) return frm @@ -302,7 +303,7 @@ class HTTP2Protocol(object): return True def _handle_unexpected_frame(self, frm): - if isinstance(frm, frame.SettingsFrame): + if isinstance(frm, hyperframe.frame.SettingsFrame): return if self.unhandled_frame_cb: self.unhandled_frame_cb(frm) @@ -310,7 +311,7 @@ class HTTP2Protocol(object): def _receive_settings(self, hide=False): while True: frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): + if isinstance(frm, hyperframe.frame.SettingsFrame): break else: self._handle_unexpected_frame(frm) @@ -334,26 +335,26 @@ class HTTP2Protocol(object): old_value = '-' self.http2_settings[setting] = value - frm = frame.SettingsFrame(flags=['ACK']) + frm = hyperframe.frame.SettingsFrame(flags=['ACK']) self.send_frame(frm, hide) def _update_flow_control_window(self, stream_id, increment): - frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment) + frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment) self.send_frame(frm) - frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) + frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment) self.send_frame(frm) def _create_headers(self, headers, stream_id, end_stream=True): def frame_cls(chunks): for i in chunks: if i == 0: - yield frame.HeadersFrame, i + yield hyperframe.frame.HeadersFrame, i else: - yield frame.ContinuationFrame, i + yield hyperframe.frame.ContinuationFrame, i header_block_fragment = self.encoder.encode(headers.fields) - chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] + chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) frms = [frm_cls( flags=[], @@ -374,9 +375,9 @@ class HTTP2Protocol(object): if body is None or len(body) == 0: return b'' - chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] + chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] chunks = range(0, len(body), chunk_size) - frms = [frame.DataFrame( + frms = [hyperframe.frame.DataFrame( flags=[], stream_id=stream_id, data=body[i:i + chunk_size]) for i in chunks] @@ -400,7 +401,7 @@ class HTTP2Protocol(object): while True: frm = self.read_frame() if ( - (isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and + (isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and (stream_id is None or frm.stream_id == stream_id) ): stream_id = frm.stream_id @@ -414,7 +415,7 @@ class HTTP2Protocol(object): while body_expected: frm = self.read_frame() - if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id: + if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id: body += frm.data if 'END_STREAM' in frm.flags: break diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 000000000..d45be6461 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,21 @@ +import codecs + +import hyperframe + + +def http2_read_raw_frame(rfile): + header = rfile.safe_read(9) + length = int(codecs.encode(header[:3], 'hex_codec'), 16) + + if length == 4740180: + raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) + + body = rfile.safe_read(length) + return [header, body] + + +def http2_read_frame(rfile): + header, body = http2_read_raw_frame(rfile) + frame, length = hyperframe.frame.Frame.parse_frame_header(header) + frame.parse_body(memoryview(body)) + return frame diff --git a/netlib/http/message.py b/netlib/http/message.py index 13d401a74..d9654f26e 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,9 +4,8 @@ import warnings import six -from .headers import Headers -from .. import encoding, utils -from ..utils import always_bytes +from .. import encoding, utils, basetypes +from . import headers if six.PY2: # pragma: no cover def _native(x): @@ -20,10 +19,10 @@ else: return x.decode("utf-8", "surrogateescape") def _always_bytes(x): - return always_bytes(x, "utf-8", "surrogateescape") + return utils.always_bytes(x, "utf-8", "surrogateescape") -class MessageData(utils.Serializable): +class MessageData(basetypes.Serializable): def __eq__(self, other): if isinstance(other, MessageData): return self.__dict__ == other.__dict__ @@ -38,7 +37,7 @@ class MessageData(utils.Serializable): def set_state(self, state): for k, v in state.items(): if k == "headers": - v = Headers.from_state(v) + v = headers.Headers.from_state(v) setattr(self, k, v) def get_state(self): @@ -48,11 +47,11 @@ class MessageData(utils.Serializable): @classmethod def from_state(cls, state): - state["headers"] = Headers.from_state(state["headers"]) + state["headers"] = headers.Headers.from_state(state["headers"]) return cls(**state) -class Message(utils.Serializable): +class Message(basetypes.Serializable): def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -72,7 +71,7 @@ class Message(utils.Serializable): @classmethod def from_state(cls, state): - state["headers"] = Headers.from_state(state["headers"]) + state["headers"] = headers.Headers.from_state(state["headers"]) return cls(**state) @property diff --git a/netlib/http/request.py b/netlib/http/request.py index fa8d54aa5..80a9ae653 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -6,7 +6,7 @@ import six from six.moves import urllib from netlib import utils -from netlib.http import cookies +from . import cookies from .. import encoding from ..multidict import MultiDictView from .headers import Headers diff --git a/netlib/multidict.py b/netlib/multidict.py index f8876cbd5..6139d60ad 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -9,12 +9,11 @@ except ImportError: # pragma: no cover from collections import MutableMapping # Workaround for Python < 3.3 import six - -from .utils import Serializable +from . import basetypes @six.add_metaclass(ABCMeta) -class _MultiDict(MutableMapping, Serializable): +class _MultiDict(MutableMapping, basetypes.Serializable): def __repr__(self): fields = ( repr(field) diff --git a/netlib/odict.py b/netlib/odict.py index 8a638dabc..87887a294 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -3,10 +3,10 @@ import copy import six -from .utils import Serializable, safe_subn +from . import basetypes, utils -class ODict(Serializable): +class ODict(basetypes.Serializable): """ A dictionary-like object for managing ordered (key, value) data. Think @@ -139,9 +139,9 @@ class ODict(Serializable): """ new, count = [], 0 for k, v in self.lst: - k, c = safe_subn(pattern, repl, k, *args, **kwargs) + k, c = utils.safe_subn(pattern, repl, k, *args, **kwargs) count += c - v, c = safe_subn(pattern, repl, v, *args, **kwargs) + v, c = utils.safe_subn(pattern, repl, v, *args, **kwargs) count += c new.append([k, v]) self.lst = new diff --git a/netlib/tcp.py b/netlib/tcp.py index c7231dbb6..5662c9737 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,7 +16,7 @@ import six import OpenSSL from OpenSSL import SSL -from . import certutils, version_check, utils +from . import certutils, version_check, basetypes # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. @@ -302,7 +302,7 @@ class Reader(_FileLike): raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(utils.Serializable): +class Address(basetypes.Serializable): """ This class wraps an IPv4/IPv6 tuple to provide named attributes and diff --git a/netlib/utils.py b/netlib/utils.py index 174f616de..770ad6a6a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -3,46 +3,12 @@ import os.path import re import codecs import unicodedata -from abc import ABCMeta, abstractmethod import importlib import inspect import six from six.moves import urllib -import hyperframe - - -@six.add_metaclass(ABCMeta) -class Serializable(object): - """ - Abstract Base Class that defines an API to save an object's state and restore it later on. - """ - - @classmethod - @abstractmethod - def from_state(cls, state): - """ - Create a new object from the given state. - """ - raise NotImplementedError() - - @abstractmethod - def get_state(self): - """ - Retrieve object state. - """ - raise NotImplementedError() - - @abstractmethod - def set_state(self, state): - """ - Set object state to the given state. - """ - raise NotImplementedError() - - def copy(self): - return self.from_state(self.get_state()) def always_bytes(unicode_or_bytes, *encode_args): @@ -395,24 +361,6 @@ def multipartdecode(headers, content): return [] -def http2_read_raw_frame(rfile): - header = rfile.safe_read(9) - length = int(codecs.encode(header[:3], 'hex_codec'), 16) - - if length == 4740180: - raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20)) - - body = rfile.safe_read(length) - return [header, body] - - -def http2_read_frame(rfile): - header, body = http2_read_raw_frame(rfile) - frame, length = hyperframe.frame.Frame.parse_frame_header(header) - frame.parse_body(memoryview(body)) - return frame - - def safe_subn(pattern, repl, target, *args, **kwargs): """ There are Unicode conversion problems with re.subn. We try to smooth diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index 4a7620147..5ab42caeb 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -13,7 +13,7 @@ from mitmproxy.cmdline import APP_HOST, APP_PORT import netlib from ..netlib import tservers as netlib_tservers -from netlib.utils import http2_read_raw_frame +from netlib.http.http2 import frame from . import tservers @@ -48,7 +48,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): done = False while not done: try: - raw = b''.join(http2_read_raw_frame(self.rfile)) + raw = b''.join(frame.http2_read_raw_frame(self.rfile)) events = h2_conn.receive_data(raw) except: break @@ -200,7 +200,7 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): done = False while not done: try: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile))) except: break client.wfile.write(h2_conn.data_to_send()) @@ -270,7 +270,7 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase): done = False while not done: try: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile))) except: break client.wfile.write(h2_conn.data_to_send()) @@ -362,7 +362,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): responses = 0 while not done: try: - raw = b''.join(http2_read_raw_frame(client.rfile)) + raw = b''.join(frame.http2_read_raw_frame(client.rfile)) events = h2_conn.receive_data(raw) except: break @@ -412,7 +412,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): responses = 0 while not done: try: - events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile))) + events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile))) except: break client.wfile.write(h2_conn.data_to_send()) @@ -479,7 +479,7 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase): done = False while not done: try: - raw = b''.join(http2_read_raw_frame(client.rfile)) + raw = b''.join(frame.http2_read_raw_frame(client.rfile)) h2_conn.receive_data(raw) except: break diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 69667d1cb..be68a28cb 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -1,12 +1,12 @@ import mock import codecs -from hyperframe import frame - -from netlib import tcp, http, utils +import hyperframe +from netlib import tcp, http from netlib.tutils import raises from netlib.exceptions import TcpDisconnect from netlib.http.http2.connections import HTTP2Protocol, TCPHandler +from netlib.http.http2 import frame from ... import tservers @@ -111,11 +111,11 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): self.wfile.flush() # check empty settings frame - raw = utils.http2_read_raw_frame(self.rfile) + raw = frame.http2_read_raw_frame(self.rfile) assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec') # check settings acknowledgement - raw = utils.http2_read_raw_frame(self.rfile) + raw = frame.http2_read_raw_frame(self.rfile) assert raw == codecs.decode('000000040100000000', 'hex_codec') # send settings acknowledgement @@ -214,19 +214,19 @@ class TestApplySettings(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol._apply_settings({ - frame.SettingsFrame.ENABLE_PUSH: 'foo', - frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', - frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', + hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo', + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar', + hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef', }) assert c.rfile.safe_read(2) == b"OK" assert protocol.http2_settings[ - frame.SettingsFrame.ENABLE_PUSH] == 'foo' + hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo' assert protocol.http2_settings[ - frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' + hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar' assert protocol.http2_settings[ - frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' + hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef' class TestCreateHeaders(object): @@ -258,7 +258,7 @@ class TestCreateHeaders(object): (b'server', b'version')]) protocol = HTTP2Protocol(self.c) - protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 8 + protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8 bytes = protocol._create_headers(headers, 1, end_stream=True) assert len(bytes) == 3 assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec') @@ -281,7 +281,7 @@ class TestCreateBody(object): def test_create_body_multiple_frames(self): protocol = HTTP2Protocol(self.c) - protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 5 + protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5 bytes = protocol._create_body(b'foobarmehm42', 1) assert len(bytes) == 3 assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec') diff --git a/test/netlib/test_basetypes.py b/test/netlib/test_basetypes.py new file mode 100644 index 000000000..2a7eea818 --- /dev/null +++ b/test/netlib/test_basetypes.py @@ -0,0 +1,27 @@ +from netlib import basetypes + +class SerializableDummy(basetypes.Serializable): + def __init__(self, i): + self.i = i + + def get_state(self): + return self.i + + def set_state(self, i): + self.i = i + + def from_state(self, state): + return type(self)(state) + + +class TestSerializable: + + def test_copy(self): + a = SerializableDummy(42) + assert a.i == 42 + b = a.copy() + assert b.i == 42 + + a.set_state(1) + assert a.i == 1 + assert b.i == 42 diff --git a/test/netlib/test_utils.py b/test/netlib/test_utils.py index e4c81a482..cd629d777 100644 --- a/test/netlib/test_utils.py +++ b/test/netlib/test_utils.py @@ -144,33 +144,6 @@ def test_parse_content_type(): assert v == ('text', 'html', {'charset': 'UTF-8'}) -class SerializableDummy(utils.Serializable): - def __init__(self, i): - self.i = i - - def get_state(self): - return self.i - - def set_state(self, i): - self.i = i - - def from_state(self, state): - return type(self)(state) - - -class TestSerializable: - - def test_copy(self): - a = SerializableDummy(42) - assert a.i == 42 - b = a.copy() - assert b.i == 42 - - a.set_state(1) - assert a.i == 1 - assert b.i == 42 - - def test_safe_subn(): assert utils.safe_subn("foo", u"bar", "\xc2foo")