mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
commit
4ee1ad88fc
@ -12,7 +12,10 @@ from pyasn1.codec.der.decoder import decode
|
|||||||
from pyasn1.error import PyAsn1Error
|
from pyasn1.error import PyAsn1Error
|
||||||
import OpenSSL
|
import OpenSSL
|
||||||
|
|
||||||
|
from .utils import Serializable
|
||||||
|
|
||||||
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
|
# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
|
||||||
|
|
||||||
DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
|
DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
|
||||||
# Generated with "openssl dhparam". It's too slow to generate this on startup.
|
# Generated with "openssl dhparam". It's too slow to generate this on startup.
|
||||||
DEFAULT_DHPARAM = b"""
|
DEFAULT_DHPARAM = b"""
|
||||||
@ -361,7 +364,7 @@ class _GeneralNames(univ.SequenceOf):
|
|||||||
constraint.ValueSizeConstraint(1, 1024)
|
constraint.ValueSizeConstraint(1, 1024)
|
||||||
|
|
||||||
|
|
||||||
class SSLCert(object):
|
class SSLCert(Serializable):
|
||||||
|
|
||||||
def __init__(self, cert):
|
def __init__(self, cert):
|
||||||
"""
|
"""
|
||||||
@ -375,15 +378,25 @@ class SSLCert(object):
|
|||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
@classmethod
|
def get_state(self):
|
||||||
def from_pem(klass, txt):
|
return self.to_pem()
|
||||||
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
|
|
||||||
return klass(x509)
|
def set_state(self, state):
|
||||||
|
self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_der(klass, der):
|
def from_state(cls, state):
|
||||||
|
cls.from_pem(state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pem(cls, txt):
|
||||||
|
x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
|
||||||
|
return cls(x509)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_der(cls, der):
|
||||||
pem = ssl.DER_cert_to_PEM_cert(der)
|
pem = ssl.DER_cert_to_PEM_cert(der)
|
||||||
return klass.from_pem(pem)
|
return cls.from_pem(pem)
|
||||||
|
|
||||||
def to_pem(self):
|
def to_pem(self):
|
||||||
return OpenSSL.crypto.dump_certificate(
|
return OpenSSL.crypto.dump_certificate(
|
||||||
|
@ -14,7 +14,7 @@ except ImportError: # pragma: nocover
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from netlib.utils import always_byte_args, always_bytes
|
from netlib.utils import always_byte_args, always_bytes, Serializable
|
||||||
|
|
||||||
if six.PY2: # pragma: nocover
|
if six.PY2: # pragma: nocover
|
||||||
_native = lambda x: x
|
_native = lambda x: x
|
||||||
@ -27,7 +27,7 @@ else:
|
|||||||
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
|
_always_byte_args = always_byte_args("utf-8", "surrogateescape")
|
||||||
|
|
||||||
|
|
||||||
class Headers(MutableMapping):
|
class Headers(MutableMapping, Serializable):
|
||||||
"""
|
"""
|
||||||
Header class which allows both convenient access to individual headers as well as
|
Header class which allows both convenient access to individual headers as well as
|
||||||
direct access to the underlying raw data. Provides a full dictionary interface.
|
direct access to the underlying raw data. Provides a full dictionary interface.
|
||||||
@ -193,11 +193,10 @@ class Headers(MutableMapping):
|
|||||||
def copy(self):
|
def copy(self):
|
||||||
return Headers(copy.copy(self.fields))
|
return Headers(copy.copy(self.fields))
|
||||||
|
|
||||||
# Implement the StateObject protocol from mitmproxy
|
|
||||||
def get_state(self):
|
def get_state(self):
|
||||||
return tuple(tuple(field) for field in self.fields)
|
return tuple(tuple(field) for field in self.fields)
|
||||||
|
|
||||||
def load_state(self, state):
|
def set_state(self, state):
|
||||||
self.fields = [list(field) for field in state]
|
self.fields = [list(field) for field in state]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -4,9 +4,9 @@ import warnings
|
|||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from .headers import Headers
|
||||||
from .. import encoding, utils
|
from .. import encoding, utils
|
||||||
|
|
||||||
|
|
||||||
CONTENT_MISSING = 0
|
CONTENT_MISSING = 0
|
||||||
|
|
||||||
if six.PY2: # pragma: nocover
|
if six.PY2: # pragma: nocover
|
||||||
@ -18,7 +18,7 @@ else:
|
|||||||
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
|
_always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape")
|
||||||
|
|
||||||
|
|
||||||
class MessageData(object):
|
class MessageData(utils.Serializable):
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if isinstance(other, MessageData):
|
if isinstance(other, MessageData):
|
||||||
return self.__dict__ == other.__dict__
|
return self.__dict__ == other.__dict__
|
||||||
@ -27,8 +27,24 @@ class MessageData(object):
|
|||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
for k, v in state.items():
|
||||||
|
if k == "headers":
|
||||||
|
v = Headers.from_state(v)
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
class Message(object):
|
def get_state(self):
|
||||||
|
state = vars(self).copy()
|
||||||
|
state["headers"] = state["headers"].get_state()
|
||||||
|
return state
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state):
|
||||||
|
state["headers"] = Headers.from_state(state["headers"])
|
||||||
|
return cls(**state)
|
||||||
|
|
||||||
|
|
||||||
|
class Message(utils.Serializable):
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data = data
|
self.data = data
|
||||||
|
|
||||||
@ -40,6 +56,16 @@ class Message(object):
|
|||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
return not self.__eq__(other)
|
return not self.__eq__(other)
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return self.data.get_state()
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.data.set_state(state)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state):
|
||||||
|
return cls(**state)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def headers(self):
|
def headers(self):
|
||||||
"""
|
"""
|
||||||
|
@ -16,9 +16,8 @@ from .message import Message, _native, _always_bytes, MessageData
|
|||||||
class RequestData(MessageData):
|
class RequestData(MessageData):
|
||||||
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
|
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None,
|
||||||
timestamp_start=None, timestamp_end=None):
|
timestamp_start=None, timestamp_end=None):
|
||||||
if not headers:
|
if not isinstance(headers, Headers):
|
||||||
headers = Headers()
|
headers = Headers(headers)
|
||||||
assert isinstance(headers, Headers)
|
|
||||||
|
|
||||||
self.first_line_format = first_line_format
|
self.first_line_format = first_line_format
|
||||||
self.method = method
|
self.method = method
|
||||||
|
@ -12,9 +12,8 @@ from ..odict import ODict
|
|||||||
class ResponseData(MessageData):
|
class ResponseData(MessageData):
|
||||||
def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
|
def __init__(self, http_version, status_code, reason=None, headers=None, content=None,
|
||||||
timestamp_start=None, timestamp_end=None):
|
timestamp_start=None, timestamp_end=None):
|
||||||
if not headers:
|
if not isinstance(headers, Headers):
|
||||||
headers = Headers()
|
headers = Headers(headers)
|
||||||
assert isinstance(headers, Headers)
|
|
||||||
|
|
||||||
self.http_version = http_version
|
self.http_version = http_version
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
|
@ -3,6 +3,8 @@ import re
|
|||||||
import copy
|
import copy
|
||||||
import six
|
import six
|
||||||
|
|
||||||
|
from .utils import Serializable
|
||||||
|
|
||||||
|
|
||||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -13,7 +15,7 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
|
|||||||
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ODict(object):
|
class ODict(Serializable):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
A dictionary-like object for managing ordered (key, value) data. Think
|
A dictionary-like object for managing ordered (key, value) data. Think
|
||||||
@ -172,12 +174,12 @@ class ODict(object):
|
|||||||
def get_state(self):
|
def get_state(self):
|
||||||
return [tuple(i) for i in self.lst]
|
return [tuple(i) for i in self.lst]
|
||||||
|
|
||||||
def load_state(self, state):
|
def set_state(self, state):
|
||||||
self.lst = [list(i) for i in state]
|
self.lst = [list(i) for i in state]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state(klass, state):
|
def from_state(cls, state):
|
||||||
return klass([list(i) for i in state])
|
return cls([list(i) for i in state])
|
||||||
|
|
||||||
|
|
||||||
class ODictCaseless(ODict):
|
class ODictCaseless(ODict):
|
||||||
|
@ -16,7 +16,7 @@ import six
|
|||||||
import OpenSSL
|
import OpenSSL
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
from . import certutils, version_check
|
from . import certutils, version_check, utils
|
||||||
|
|
||||||
# This is a rather hackish way to make sure that
|
# This is a rather hackish way to make sure that
|
||||||
# the latest version of pyOpenSSL is actually installed.
|
# the latest version of pyOpenSSL is actually installed.
|
||||||
@ -298,7 +298,7 @@ class Reader(_FileLike):
|
|||||||
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
|
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
|
||||||
|
|
||||||
|
|
||||||
class Address(object):
|
class Address(utils.Serializable):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and
|
This class wraps an IPv4/IPv6 tuple to provide named attributes and
|
||||||
@ -309,6 +309,20 @@ class Address(object):
|
|||||||
self.address = tuple(address)
|
self.address = tuple(address)
|
||||||
self.use_ipv6 = use_ipv6
|
self.use_ipv6 = use_ipv6
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
return {
|
||||||
|
"address": self.address,
|
||||||
|
"use_ipv6": self.use_ipv6
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_state(self, state):
|
||||||
|
self.address = state["address"]
|
||||||
|
self.use_ipv6 = state["use_ipv6"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_state(cls, state):
|
||||||
|
return Address(**state)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def wrap(cls, t):
|
def wrap(cls, t):
|
||||||
if isinstance(t, cls):
|
if isinstance(t, cls):
|
||||||
|
@ -1,14 +1,45 @@
|
|||||||
from __future__ import absolute_import, print_function, division
|
from __future__ import absolute_import, print_function, division
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
import string
|
|
||||||
import codecs
|
import codecs
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from six.moves import urllib
|
from six.moves import urllib
|
||||||
import hyperframe
|
import hyperframe
|
||||||
|
|
||||||
|
|
||||||
|
@six.add_metaclass(ABCMeta)
|
||||||
|
class Serializable(object):
|
||||||
|
"""
|
||||||
|
Abstract Base Class that defines an API to save an object's state and restore it later on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def from_state(cls, state):
|
||||||
|
"""
|
||||||
|
Create a new object from the given state.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state(self):
|
||||||
|
"""
|
||||||
|
Retrieve object state.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_state(self, state):
|
||||||
|
"""
|
||||||
|
Set object state to the given state.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def always_bytes(unicode_or_bytes, *encode_args):
|
def always_bytes(unicode_or_bytes, *encode_args):
|
||||||
if isinstance(unicode_or_bytes, six.text_type):
|
if isinstance(unicode_or_bytes, six.text_type):
|
||||||
return unicode_or_bytes.encode(*encode_args)
|
return unicode_or_bytes.encode(*encode_args)
|
||||||
|
@ -148,5 +148,5 @@ class TestHeaders(object):
|
|||||||
|
|
||||||
headers2 = Headers()
|
headers2 = Headers()
|
||||||
assert headers != headers2
|
assert headers != headers2
|
||||||
headers2.load_state(headers.get_state())
|
headers2.set_state(headers.get_state())
|
||||||
assert headers == headers2
|
assert headers == headers2
|
||||||
|
@ -12,7 +12,7 @@ from .test_message import _test_decoded_attr, _test_passthrough_attr
|
|||||||
|
|
||||||
class TestRequestData(object):
|
class TestRequestData(object):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
with raises(AssertionError):
|
with raises(ValueError if six.PY2 else TypeError):
|
||||||
treq(headers="foobar")
|
treq(headers="foobar")
|
||||||
|
|
||||||
assert isinstance(treq(headers=None).headers, Headers)
|
assert isinstance(treq(headers=None).headers, Headers)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from __future__ import absolute_import, print_function, division
|
from __future__ import absolute_import, print_function, division
|
||||||
|
|
||||||
|
import six
|
||||||
|
|
||||||
from netlib.http import Headers
|
from netlib.http import Headers
|
||||||
from netlib.odict import ODict, ODictCaseless
|
from netlib.odict import ODict, ODictCaseless
|
||||||
from netlib.tutils import raises, tresp
|
from netlib.tutils import raises, tresp
|
||||||
@ -8,7 +10,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr
|
|||||||
|
|
||||||
class TestResponseData(object):
|
class TestResponseData(object):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
with raises(AssertionError):
|
with raises(ValueError if six.PY2 else TypeError):
|
||||||
tresp(headers="foobar")
|
tresp(headers="foobar")
|
||||||
|
|
||||||
assert isinstance(tresp(headers=None).headers, Headers)
|
assert isinstance(tresp(headers=None).headers, Headers)
|
||||||
|
@ -24,7 +24,7 @@ class TestODict(object):
|
|||||||
nd = odict.ODict.from_state(state)
|
nd = odict.ODict.from_state(state)
|
||||||
assert nd == od
|
assert nd == od
|
||||||
b = odict.ODict()
|
b = odict.ODict()
|
||||||
b.load_state(state)
|
b.set_state(state)
|
||||||
assert b == od
|
assert b == od
|
||||||
|
|
||||||
def test_in_any(self):
|
def test_in_any(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user