Start cleaning up netlib.utils

- Remove http2 functions, move to http2.frame
- Remove Serializable, move to netlib.basetypes
This commit is contained in:
Aldo Cortesi 2016-05-31 17:16:31 +12:00
parent 2f526393d2
commit 08fbe6f111
17 changed files with 157 additions and 155 deletions

View File

@ -1,11 +1,11 @@
import time
from typing import List
from netlib.utils import Serializable
import netlib.basetypes
from .flow import Flow
class TCPMessage(Serializable):
class TCPMessage(netlib.basetypes.Serializable):
def __init__(self, from_client, content, timestamp=None):
self.content = content

View File

@ -14,7 +14,8 @@ from hyperframe.frame import PriorityFrame
from netlib.tcp import ssl_read_select
from netlib.exceptions import HttpException
from netlib.http import Headers
from netlib.utils import http2_read_raw_frame, parse_url
from netlib.utils import parse_url
from netlib.http.http2 import frame
from .base import Layer
from .http import _HttpTransmissionLayer, HttpLayer
@ -233,7 +234,7 @@ class Http2Layer(Layer):
with source_conn.h2.lock:
try:
raw_frame = b''.join(http2_read_raw_frame(source_conn.rfile))
raw_frame = b''.join(frame.http2_read_raw_frame(source_conn.rfile))
except:
# read frame failed: connection closed
self._kill_all_streams()

View File

@ -3,7 +3,7 @@ from __future__ import absolute_import
import six
from typing import List, Any
from netlib.utils import Serializable
import netlib.basetypes
def _is_list(cls):
@ -13,7 +13,7 @@ def _is_list(cls):
return issubclass(cls, List) or is_list_bugfix
class StateObject(Serializable):
class StateObject(netlib.basetypes.Serializable):
"""
An object with serializable state.

33
netlib/basetypes.py Normal file
View File

@ -0,0 +1,33 @@
import six
import abc
@six.add_metaclass(abc.ABCMeta)
class Serializable(object):
"""
Abstract Base Class that defines an API to save an object's state and restore it later on.
"""
@classmethod
@abc.abstractmethod
def from_state(cls, state):
"""
Create a new object from the given state.
"""
raise NotImplementedError()
@abc.abstractmethod
def get_state(self):
"""
Retrieve object state.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_state(self, state):
"""
Set object state to the given state.
"""
raise NotImplementedError()
def copy(self):
return self.from_state(self.get_state())

View File

@ -12,7 +12,7 @@ from pyasn1.codec.der.decoder import decode
from pyasn1.error import PyAsn1Error
import OpenSSL
from .utils import Serializable
from . import basetypes
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
@ -364,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
constraint.ValueSizeConstraint(1, 1024)
class SSLCert(Serializable):
class SSLCert(basetypes.Serializable):
def __init__(self, cert):
"""

View File

@ -2,11 +2,12 @@ from __future__ import (absolute_import, print_function, division)
import itertools
import time
import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from ... import utils
from .. import Headers, Response, Request
from hyperframe import frame
from . import frame
class TCPHandler(object):
@ -38,12 +39,12 @@ class HTTP2Protocol(object):
CLIENT_CONNECTION_PREFACE = b'PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n'
HTTP2_DEFAULT_SETTINGS = {
frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
frame.SettingsFrame.ENABLE_PUSH: 1,
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
hyperframe.frame.SettingsFrame.HEADER_TABLE_SIZE: 4096,
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 1,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: None,
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 2 ** 16 - 1,
hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE: 2 ** 14,
hyperframe.frame.SettingsFrame.MAX_HEADER_LIST_SIZE: None,
}
def __init__(
@ -253,9 +254,9 @@ class HTTP2Protocol(object):
magic = self.tcp_handler.rfile.safe_read(magic_length)
assert magic == self.CLIENT_CONNECTION_PREFACE
frm = frame.SettingsFrame(settings={
frame.SettingsFrame.ENABLE_PUSH: 0,
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
frm = hyperframe.frame.SettingsFrame(settings={
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 0,
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 1,
})
self.send_frame(frm, hide=True)
self._receive_settings(hide=True)
@ -266,7 +267,7 @@ class HTTP2Protocol(object):
self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE)
self.send_frame(frame.SettingsFrame(), hide=True)
self.send_frame(hyperframe.frame.SettingsFrame(), hide=True)
self._receive_settings(hide=True) # server announces own settings
self._receive_settings(hide=True) # server acks my settings
@ -279,18 +280,18 @@ class HTTP2Protocol(object):
def read_frame(self, hide=False):
while True:
frm = utils.http2_read_frame(self.tcp_handler.rfile)
frm = frame.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
if isinstance(frm, frame.PingFrame):
raw_bytes = frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
if isinstance(frm, hyperframe.frame.PingFrame):
raw_bytes = hyperframe.frame.PingFrame(flags=['ACK'], payload=frm.payload).serialize()
self.tcp_handler.wfile.write(raw_bytes)
self.tcp_handler.wfile.flush()
continue
if isinstance(frm, frame.SettingsFrame) and 'ACK' not in frm.flags:
if isinstance(frm, hyperframe.frame.SettingsFrame) and 'ACK' not in frm.flags:
self._apply_settings(frm.settings, hide)
if isinstance(frm, frame.DataFrame) and frm.flow_controlled_length > 0:
if isinstance(frm, hyperframe.frame.DataFrame) and frm.flow_controlled_length > 0:
self._update_flow_control_window(frm.stream_id, frm.flow_controlled_length)
return frm
@ -302,7 +303,7 @@ class HTTP2Protocol(object):
return True
def _handle_unexpected_frame(self, frm):
if isinstance(frm, frame.SettingsFrame):
if isinstance(frm, hyperframe.frame.SettingsFrame):
return
if self.unhandled_frame_cb:
self.unhandled_frame_cb(frm)
@ -310,7 +311,7 @@ class HTTP2Protocol(object):
def _receive_settings(self, hide=False):
while True:
frm = self.read_frame(hide)
if isinstance(frm, frame.SettingsFrame):
if isinstance(frm, hyperframe.frame.SettingsFrame):
break
else:
self._handle_unexpected_frame(frm)
@ -334,26 +335,26 @@ class HTTP2Protocol(object):
old_value = '-'
self.http2_settings[setting] = value
frm = frame.SettingsFrame(flags=['ACK'])
frm = hyperframe.frame.SettingsFrame(flags=['ACK'])
self.send_frame(frm, hide)
def _update_flow_control_window(self, stream_id, increment):
frm = frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=0, window_increment=increment)
self.send_frame(frm)
frm = frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
frm = hyperframe.frame.WindowUpdateFrame(stream_id=stream_id, window_increment=increment)
self.send_frame(frm)
def _create_headers(self, headers, stream_id, end_stream=True):
def frame_cls(chunks):
for i in chunks:
if i == 0:
yield frame.HeadersFrame, i
yield hyperframe.frame.HeadersFrame, i
else:
yield frame.ContinuationFrame, i
yield hyperframe.frame.ContinuationFrame, i
header_block_fragment = self.encoder.encode(headers.fields)
chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(header_block_fragment), chunk_size)
frms = [frm_cls(
flags=[],
@ -374,9 +375,9 @@ class HTTP2Protocol(object):
if body is None or len(body) == 0:
return b''
chunk_size = self.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE]
chunk_size = self.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE]
chunks = range(0, len(body), chunk_size)
frms = [frame.DataFrame(
frms = [hyperframe.frame.DataFrame(
flags=[],
stream_id=stream_id,
data=body[i:i + chunk_size]) for i in chunks]
@ -400,7 +401,7 @@ class HTTP2Protocol(object):
while True:
frm = self.read_frame()
if (
(isinstance(frm, frame.HeadersFrame) or isinstance(frm, frame.ContinuationFrame)) and
(isinstance(frm, hyperframe.frame.HeadersFrame) or isinstance(frm, hyperframe.frame.ContinuationFrame)) and
(stream_id is None or frm.stream_id == stream_id)
):
stream_id = frm.stream_id
@ -414,7 +415,7 @@ class HTTP2Protocol(object):
while body_expected:
frm = self.read_frame()
if isinstance(frm, frame.DataFrame) and frm.stream_id == stream_id:
if isinstance(frm, hyperframe.frame.DataFrame) and frm.stream_id == stream_id:
body += frm.data
if 'END_STREAM' in frm.flags:
break

View File

@ -0,0 +1,21 @@
import codecs
import hyperframe
def http2_read_raw_frame(rfile):
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
if length == 4740180:
raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
body = rfile.safe_read(length)
return [header, body]
def http2_read_frame(rfile):
header, body = http2_read_raw_frame(rfile)
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame

View File

@ -4,9 +4,8 @@ import warnings
import six
from .headers import Headers
from .. import encoding, utils
from ..utils import always_bytes
from .. import encoding, utils, basetypes
from . import headers
if six.PY2: # pragma: no cover
def _native(x):
@ -20,10 +19,10 @@ else:
return x.decode("utf-8", "surrogateescape")
def _always_bytes(x):
return always_bytes(x, "utf-8", "surrogateescape")
return utils.always_bytes(x, "utf-8", "surrogateescape")
class MessageData(utils.Serializable):
class MessageData(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, MessageData):
return self.__dict__ == other.__dict__
@ -38,7 +37,7 @@ class MessageData(utils.Serializable):
def set_state(self, state):
for k, v in state.items():
if k == "headers":
v = Headers.from_state(v)
v = headers.Headers.from_state(v)
setattr(self, k, v)
def get_state(self):
@ -48,11 +47,11 @@ class MessageData(utils.Serializable):
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
class Message(utils.Serializable):
class Message(basetypes.Serializable):
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@ -72,7 +71,7 @@ class Message(utils.Serializable):
@classmethod
def from_state(cls, state):
state["headers"] = Headers.from_state(state["headers"])
state["headers"] = headers.Headers.from_state(state["headers"])
return cls(**state)
@property

View File

@ -6,7 +6,7 @@ import six
from six.moves import urllib
from netlib import utils
from netlib.http import cookies
from . import cookies
from .. import encoding
from ..multidict import MultiDictView
from .headers import Headers

View File

@ -9,12 +9,11 @@ except ImportError: # pragma: no cover
from collections import MutableMapping # Workaround for Python < 3.3
import six
from .utils import Serializable
from . import basetypes
@six.add_metaclass(ABCMeta)
class _MultiDict(MutableMapping, Serializable):
class _MultiDict(MutableMapping, basetypes.Serializable):
def __repr__(self):
fields = (
repr(field)

View File

@ -3,10 +3,10 @@ import copy
import six
from .utils import Serializable, safe_subn
from . import basetypes, utils
class ODict(Serializable):
class ODict(basetypes.Serializable):
"""
A dictionary-like object for managing ordered (key, value) data. Think
@ -139,9 +139,9 @@ class ODict(Serializable):
"""
new, count = [], 0
for k, v in self.lst:
k, c = safe_subn(pattern, repl, k, *args, **kwargs)
k, c = utils.safe_subn(pattern, repl, k, *args, **kwargs)
count += c
v, c = safe_subn(pattern, repl, v, *args, **kwargs)
v, c = utils.safe_subn(pattern, repl, v, *args, **kwargs)
count += c
new.append([k, v])
self.lst = new

View File

@ -16,7 +16,7 @@ import six
import OpenSSL
from OpenSSL import SSL
from . import certutils, version_check, utils
from . import certutils, version_check, basetypes
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
@ -302,7 +302,7 @@ class Reader(_FileLike):
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
class Address(utils.Serializable):
class Address(basetypes.Serializable):
"""
This class wraps an IPv4/IPv6 tuple to provide named attributes and

View File

@ -3,46 +3,12 @@ import os.path
import re
import codecs
import unicodedata
from abc import ABCMeta, abstractmethod
import importlib
import inspect
import six
from six.moves import urllib
import hyperframe
@six.add_metaclass(ABCMeta)
class Serializable(object):
"""
Abstract Base Class that defines an API to save an object's state and restore it later on.
"""
@classmethod
@abstractmethod
def from_state(cls, state):
"""
Create a new object from the given state.
"""
raise NotImplementedError()
@abstractmethod
def get_state(self):
"""
Retrieve object state.
"""
raise NotImplementedError()
@abstractmethod
def set_state(self, state):
"""
Set object state to the given state.
"""
raise NotImplementedError()
def copy(self):
return self.from_state(self.get_state())
def always_bytes(unicode_or_bytes, *encode_args):
@ -395,24 +361,6 @@ def multipartdecode(headers, content):
return []
def http2_read_raw_frame(rfile):
header = rfile.safe_read(9)
length = int(codecs.encode(header[:3], 'hex_codec'), 16)
if length == 4740180:
raise ValueError("Length field looks more like HTTP/1.1: %s" % rfile.peek(20))
body = rfile.safe_read(length)
return [header, body]
def http2_read_frame(rfile):
header, body = http2_read_raw_frame(rfile)
frame, length = hyperframe.frame.Frame.parse_frame_header(header)
frame.parse_body(memoryview(body))
return frame
def safe_subn(pattern, repl, target, *args, **kwargs):
"""
There are Unicode conversion problems with re.subn. We try to smooth

View File

@ -13,7 +13,7 @@ from mitmproxy.cmdline import APP_HOST, APP_PORT
import netlib
from ..netlib import tservers as netlib_tservers
from netlib.utils import http2_read_raw_frame
from netlib.http.http2 import frame
from . import tservers
@ -48,7 +48,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
done = False
while not done:
try:
raw = b''.join(http2_read_raw_frame(self.rfile))
raw = b''.join(frame.http2_read_raw_frame(self.rfile))
events = h2_conn.receive_data(raw)
except:
break
@ -200,7 +200,7 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -270,7 +270,7 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -362,7 +362,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
responses = 0
while not done:
try:
raw = b''.join(http2_read_raw_frame(client.rfile))
raw = b''.join(frame.http2_read_raw_frame(client.rfile))
events = h2_conn.receive_data(raw)
except:
break
@ -412,7 +412,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
responses = 0
while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
events = h2_conn.receive_data(b''.join(frame.http2_read_raw_frame(client.rfile)))
except:
break
client.wfile.write(h2_conn.data_to_send())
@ -479,7 +479,7 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
try:
raw = b''.join(http2_read_raw_frame(client.rfile))
raw = b''.join(frame.http2_read_raw_frame(client.rfile))
h2_conn.receive_data(raw)
except:
break

View File

@ -1,12 +1,12 @@
import mock
import codecs
from hyperframe import frame
from netlib import tcp, http, utils
import hyperframe
from netlib import tcp, http
from netlib.tutils import raises
from netlib.exceptions import TcpDisconnect
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2 import frame
from ... import tservers
@ -111,11 +111,11 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
self.wfile.flush()
# check empty settings frame
raw = utils.http2_read_raw_frame(self.rfile)
raw = frame.http2_read_raw_frame(self.rfile)
assert raw == codecs.decode('00000c040000000000000200000000000300000001', 'hex_codec')
# check settings acknowledgement
raw = utils.http2_read_raw_frame(self.rfile)
raw = frame.http2_read_raw_frame(self.rfile)
assert raw == codecs.decode('000000040100000000', 'hex_codec')
# send settings acknowledgement
@ -214,19 +214,19 @@ class TestApplySettings(tservers.ServerTestBase):
protocol = HTTP2Protocol(c)
protocol._apply_settings({
frame.SettingsFrame.ENABLE_PUSH: 'foo',
frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
})
assert c.rfile.safe_read(2) == b"OK"
assert protocol.http2_settings[
frame.SettingsFrame.ENABLE_PUSH] == 'foo'
hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
assert protocol.http2_settings[
frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
assert protocol.http2_settings[
frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
class TestCreateHeaders(object):
@ -258,7 +258,7 @@ class TestCreateHeaders(object):
(b'server', b'version')])
protocol = HTTP2Protocol(self.c)
protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 8
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 8
bytes = protocol._create_headers(headers, 1, end_stream=True)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000008010100000001828487408294e783', 'hex_codec')
@ -281,7 +281,7 @@ class TestCreateBody(object):
def test_create_body_multiple_frames(self):
protocol = HTTP2Protocol(self.c)
protocol.http2_settings[frame.SettingsFrame.MAX_FRAME_SIZE] = 5
protocol.http2_settings[hyperframe.frame.SettingsFrame.MAX_FRAME_SIZE] = 5
bytes = protocol._create_body(b'foobarmehm42', 1)
assert len(bytes) == 3
assert bytes[0] == codecs.decode('000005000000000001666f6f6261', 'hex_codec')

View File

@ -0,0 +1,27 @@
from netlib import basetypes
class SerializableDummy(basetypes.Serializable):
def __init__(self, i):
self.i = i
def get_state(self):
return self.i
def set_state(self, i):
self.i = i
def from_state(self, state):
return type(self)(state)
class TestSerializable:
def test_copy(self):
a = SerializableDummy(42)
assert a.i == 42
b = a.copy()
assert b.i == 42
a.set_state(1)
assert a.i == 1
assert b.i == 42

View File

@ -144,33 +144,6 @@ def test_parse_content_type():
assert v == ('text', 'html', {'charset': 'UTF-8'})
class SerializableDummy(utils.Serializable):
def __init__(self, i):
self.i = i
def get_state(self):
return self.i
def set_state(self, i):
self.i = i
def from_state(self, state):
return type(self)(state)
class TestSerializable:
def test_copy(self):
a = SerializableDummy(42)
assert a.i == 42
b = a.copy()
assert b.i == 42
a.set_state(1)
assert a.i == 1
assert b.i == 42
def test_safe_subn():
assert utils.safe_subn("foo", u"bar", "\xc2foo")