mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
organize exceptions, improve content-length handling
This commit is contained in:
parent
e1659f3fcf
commit
dad9f06cb9
@ -16,7 +16,7 @@ class NetlibException(Exception):
|
|||||||
super(NetlibException, self).__init__(message)
|
super(NetlibException, self).__init__(message)
|
||||||
|
|
||||||
|
|
||||||
class ReadDisconnect(object):
|
class Disconnect(object):
|
||||||
"""Immediate EOF"""
|
"""Immediate EOF"""
|
||||||
|
|
||||||
|
|
||||||
@ -24,9 +24,35 @@ class HttpException(NetlibException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HttpReadDisconnect(HttpException, ReadDisconnect):
|
class HttpReadDisconnect(HttpException, Disconnect):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class HttpSyntaxException(HttpException):
|
class HttpSyntaxException(HttpException):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TcpException(NetlibException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TcpDisconnect(TcpException, Disconnect):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TcpReadIncomplete(TcpException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TcpTimeout(TcpException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TlsException(NetlibException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCertificateException(TlsException):
|
||||||
|
pass
|
||||||
|
@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False):
|
|||||||
if not preserve_transfer_encoding:
|
if not preserve_transfer_encoding:
|
||||||
headers.pop(b"Transfer-Encoding", None)
|
headers.pop(b"Transfer-Encoding", None)
|
||||||
|
|
||||||
# If body is defined (i.e. not None or CONTENT_MISSING), we always
|
# If body is defined (i.e. not None or CONTENT_MISSING),
|
||||||
# add a content-length header.
|
# we now need to set a content-length header.
|
||||||
if response.body or response.body == b"":
|
if response.body or response.body == b"":
|
||||||
headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
|
headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
|
||||||
|
|
||||||
return bytes(headers)
|
return bytes(headers)
|
||||||
|
@ -4,15 +4,14 @@ import sys
|
|||||||
import re
|
import re
|
||||||
|
|
||||||
from ... import utils
|
from ... import utils
|
||||||
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException
|
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect
|
||||||
from .. import Request, Response, Headers
|
from .. import Request, Response, Headers
|
||||||
from netlib.tcp import NetLibDisconnect
|
|
||||||
|
|
||||||
|
|
||||||
def read_request(rfile, body_size_limit=None):
|
def read_request(rfile, body_size_limit=None):
|
||||||
request = read_request_head(rfile)
|
request = read_request_head(rfile)
|
||||||
expected_body_size = expected_http_body_size(request)
|
expected_body_size = expected_http_body_size(request)
|
||||||
request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
|
request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
|
||||||
request.timestamp_end = time.time()
|
request.timestamp_end = time.time()
|
||||||
return request
|
return request
|
||||||
|
|
||||||
@ -51,7 +50,7 @@ def read_request_head(rfile):
|
|||||||
def read_response(rfile, request, body_size_limit=None):
|
def read_response(rfile, request, body_size_limit=None):
|
||||||
response = read_response_head(rfile)
|
response = read_response_head(rfile)
|
||||||
expected_body_size = expected_http_body_size(request, response)
|
expected_body_size = expected_http_body_size(request, response)
|
||||||
response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
|
response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
|
||||||
response.timestamp_end = time.time()
|
response.timestamp_end = time.time()
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -215,7 +214,7 @@ def _get_first_line(rfile):
|
|||||||
if line == b"\r\n" or line == b"\n":
|
if line == b"\r\n" or line == b"\n":
|
||||||
# Possible leftover from previous message
|
# Possible leftover from previous message
|
||||||
line = rfile.readline()
|
line = rfile.readline()
|
||||||
except NetLibDisconnect:
|
except TcpDisconnect:
|
||||||
raise HttpReadDisconnect()
|
raise HttpReadDisconnect()
|
||||||
if not line:
|
if not line:
|
||||||
raise HttpReadDisconnect()
|
raise HttpReadDisconnect()
|
||||||
|
@ -231,7 +231,7 @@ class Request(object):
|
|||||||
self.path = path
|
self.path = path
|
||||||
self.httpversion = httpversion
|
self.httpversion = httpversion
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.body = body
|
self._body = body
|
||||||
self.timestamp_start = timestamp_start
|
self.timestamp_start = timestamp_start
|
||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
self.form_out = form_out or form_in
|
self.form_out = form_out or form_in
|
||||||
@ -452,6 +452,16 @@ class Request(object):
|
|||||||
raise ValueError("Invalid URL: %s" % url)
|
raise ValueError("Invalid URL: %s" % url)
|
||||||
self.scheme, self.host, self.port, self.path = parts
|
self.scheme, self.host, self.port, self.path = parts
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self):
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
@body.setter
|
||||||
|
def body(self, body):
|
||||||
|
self._body = body
|
||||||
|
if isinstance(body, bytes):
|
||||||
|
self.headers["Content-Length"] = str(len(body)).encode()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self): # pragma: no cover
|
def content(self): # pragma: no cover
|
||||||
# TODO: remove deprecated getter
|
# TODO: remove deprecated getter
|
||||||
@ -488,7 +498,7 @@ class Response(object):
|
|||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.msg = msg
|
self.msg = msg
|
||||||
self.headers = headers
|
self.headers = headers
|
||||||
self.body = body
|
self._body = body
|
||||||
self.timestamp_start = timestamp_start
|
self.timestamp_start = timestamp_start
|
||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
|
|
||||||
@ -551,6 +561,16 @@ class Response(object):
|
|||||||
)
|
)
|
||||||
self.headers.set_all("Set-Cookie", values)
|
self.headers.set_all("Set-Cookie", values)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self):
|
||||||
|
return self._body
|
||||||
|
|
||||||
|
@body.setter
|
||||||
|
def body(self, body):
|
||||||
|
self._body = body
|
||||||
|
if isinstance(body, bytes):
|
||||||
|
self.headers["Content-Length"] = str(len(body)).encode()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def content(self): # pragma: no cover
|
def content(self): # pragma: no cover
|
||||||
# TODO: remove deprecated getter
|
# TODO: remove deprecated getter
|
||||||
|
108
netlib/tcp.py
108
netlib/tcp.py
@ -16,6 +16,9 @@ from . import certutils, version_check
|
|||||||
|
|
||||||
# This is a rather hackish way to make sure that
|
# This is a rather hackish way to make sure that
|
||||||
# the latest version of pyOpenSSL is actually installed.
|
# the latest version of pyOpenSSL is actually installed.
|
||||||
|
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
|
||||||
|
TcpTimeout, TcpDisconnect, TcpException
|
||||||
|
|
||||||
version_check.check_pyopenssl_version()
|
version_check.check_pyopenssl_version()
|
||||||
|
|
||||||
|
|
||||||
@ -24,11 +27,17 @@ EINTR = 4
|
|||||||
# To enable all SSL methods use: SSLv23
|
# To enable all SSL methods use: SSLv23
|
||||||
# then add options to disable certain methods
|
# then add options to disable certain methods
|
||||||
# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
|
# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
|
||||||
|
SSL_BASIC_OPTIONS = (
|
||||||
|
SSL.OP_CIPHER_SERVER_PREFERENCE
|
||||||
|
)
|
||||||
|
if hasattr(SSL, "OP_NO_COMPRESSION"):
|
||||||
|
SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION
|
||||||
|
|
||||||
SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
|
SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
|
||||||
SSL_DEFAULT_OPTIONS = (
|
SSL_DEFAULT_OPTIONS = (
|
||||||
SSL.OP_NO_SSLv2 |
|
SSL.OP_NO_SSLv2 |
|
||||||
SSL.OP_NO_SSLv3 |
|
SSL.OP_NO_SSLv3 |
|
||||||
SSL.OP_CIPHER_SERVER_PREFERENCE
|
SSL_BASIC_OPTIONS
|
||||||
)
|
)
|
||||||
if hasattr(SSL, "OP_NO_COMPRESSION"):
|
if hasattr(SSL, "OP_NO_COMPRESSION"):
|
||||||
SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
|
SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
|
||||||
@ -39,42 +48,17 @@ Don't ask...
|
|||||||
https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
|
https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
|
||||||
"""
|
"""
|
||||||
sslversion_choices = {
|
sslversion_choices = {
|
||||||
"all": (SSL.SSLv23_METHOD, 0),
|
"all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS),
|
||||||
# SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
|
# SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
|
||||||
# TLSv1_METHOD would be TLS 1.0 only
|
# TLSv1_METHOD would be TLS 1.0 only
|
||||||
"secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)),
|
"secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)),
|
||||||
"SSLv2": (SSL.SSLv2_METHOD, 0),
|
"SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS),
|
||||||
"SSLv3": (SSL.SSLv3_METHOD, 0),
|
"SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS),
|
||||||
"TLSv1": (SSL.TLSv1_METHOD, 0),
|
"TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS),
|
||||||
"TLSv1_1": (SSL.TLSv1_1_METHOD, 0),
|
"TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS),
|
||||||
"TLSv1_2": (SSL.TLSv1_2_METHOD, 0),
|
"TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class NetLibError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NetLibDisconnect(NetLibError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NetLibIncomplete(NetLibError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NetLibTimeout(NetLibError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NetLibSSLError(NetLibError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class NetLibInvalidCertificateError(NetLibSSLError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SSLKeyLogger(object):
|
class SSLKeyLogger(object):
|
||||||
|
|
||||||
def __init__(self, filename):
|
def __init__(self, filename):
|
||||||
@ -168,17 +152,17 @@ class Writer(_FileLike):
|
|||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""
|
"""
|
||||||
May raise NetLibDisconnect
|
May raise TcpDisconnect
|
||||||
"""
|
"""
|
||||||
if hasattr(self.o, "flush"):
|
if hasattr(self.o, "flush"):
|
||||||
try:
|
try:
|
||||||
self.o.flush()
|
self.o.flush()
|
||||||
except (socket.error, IOError) as v:
|
except (socket.error, IOError) as v:
|
||||||
raise NetLibDisconnect(str(v))
|
raise TcpDisconnect(str(v))
|
||||||
|
|
||||||
def write(self, v):
|
def write(self, v):
|
||||||
"""
|
"""
|
||||||
May raise NetLibDisconnect
|
May raise TcpDisconnect
|
||||||
"""
|
"""
|
||||||
if v:
|
if v:
|
||||||
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
||||||
@ -191,7 +175,7 @@ class Writer(_FileLike):
|
|||||||
self.add_log(v[:r])
|
self.add_log(v[:r])
|
||||||
return r
|
return r
|
||||||
except (SSL.Error, socket.error) as e:
|
except (SSL.Error, socket.error) as e:
|
||||||
raise NetLibDisconnect(str(e))
|
raise TcpDisconnect(str(e))
|
||||||
|
|
||||||
|
|
||||||
class Reader(_FileLike):
|
class Reader(_FileLike):
|
||||||
@ -210,23 +194,29 @@ class Reader(_FileLike):
|
|||||||
try:
|
try:
|
||||||
data = self.o.read(rlen)
|
data = self.o.read(rlen)
|
||||||
except SSL.ZeroReturnError:
|
except SSL.ZeroReturnError:
|
||||||
|
# TLS connection was shut down cleanly
|
||||||
break
|
break
|
||||||
except SSL.WantReadError:
|
except (SSL.WantWriteError, SSL.WantReadError):
|
||||||
|
# From the OpenSSL docs:
|
||||||
|
# If the underlying BIO is non-blocking, SSL_read() will also return when the
|
||||||
|
# underlying BIO could not satisfy the needs of SSL_read() to continue the
|
||||||
|
# operation. In this case a call to SSL_get_error with the return value of
|
||||||
|
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
|
||||||
if (time.time() - start) < self.o.gettimeout():
|
if (time.time() - start) < self.o.gettimeout():
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
raise NetLibTimeout
|
raise TcpTimeout()
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
raise NetLibTimeout
|
raise TcpTimeout()
|
||||||
except socket.error:
|
except socket.error as e:
|
||||||
raise NetLibDisconnect
|
raise TcpDisconnect(str(e))
|
||||||
except SSL.SysCallError as e:
|
except SSL.SysCallError as e:
|
||||||
if e.args == (-1, 'Unexpected EOF'):
|
if e.args == (-1, 'Unexpected EOF'):
|
||||||
break
|
break
|
||||||
raise NetLibSSLError(e.message)
|
raise TlsException(e.message)
|
||||||
except SSL.Error as e:
|
except SSL.Error as e:
|
||||||
raise NetLibSSLError(e.message)
|
raise TlsException(e.message)
|
||||||
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
||||||
if not data:
|
if not data:
|
||||||
break
|
break
|
||||||
@ -260,9 +250,9 @@ class Reader(_FileLike):
|
|||||||
result = self.read(length)
|
result = self.read(length)
|
||||||
if length != -1 and len(result) != length:
|
if length != -1 and len(result) != length:
|
||||||
if not result:
|
if not result:
|
||||||
raise NetLibDisconnect()
|
raise TcpDisconnect()
|
||||||
else:
|
else:
|
||||||
raise NetLibIncomplete(
|
raise TcpReadIncomplete(
|
||||||
"Expected %s bytes, got %s" % (length, len(result))
|
"Expected %s bytes, got %s" % (length, len(result))
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
@ -275,15 +265,15 @@ class Reader(_FileLike):
|
|||||||
Up to the next N bytes if peeking is successful.
|
Up to the next N bytes if peeking is successful.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NetLibError if there was an error with the socket
|
TcpException if there was an error with the socket
|
||||||
NetLibSSLError if there was an error with pyOpenSSL.
|
TlsException if there was an error with pyOpenSSL.
|
||||||
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
|
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
|
||||||
"""
|
"""
|
||||||
if isinstance(self.o, socket._fileobject):
|
if isinstance(self.o, socket._fileobject):
|
||||||
try:
|
try:
|
||||||
return self.o._sock.recv(length, socket.MSG_PEEK)
|
return self.o._sock.recv(length, socket.MSG_PEEK)
|
||||||
except socket.error as e:
|
except socket.error as e:
|
||||||
raise NetLibError(repr(e))
|
raise TcpException(repr(e))
|
||||||
elif isinstance(self.o, SSL.Connection):
|
elif isinstance(self.o, SSL.Connection):
|
||||||
try:
|
try:
|
||||||
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
|
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
|
||||||
@ -296,7 +286,7 @@ class Reader(_FileLike):
|
|||||||
self.o._raise_ssl_error(self.o._ssl, result)
|
self.o._raise_ssl_error(self.o._ssl, result)
|
||||||
return SSL._ffi.buffer(buf, result)[:]
|
return SSL._ffi.buffer(buf, result)[:]
|
||||||
except SSL.Error as e:
|
except SSL.Error as e:
|
||||||
six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2])
|
six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2])
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
|
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
|
||||||
|
|
||||||
@ -461,7 +451,7 @@ class _Connection(object):
|
|||||||
try:
|
try:
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
self.wfile.close()
|
self.wfile.close()
|
||||||
except NetLibDisconnect:
|
except TcpDisconnect:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.rfile.close()
|
self.rfile.close()
|
||||||
@ -525,7 +515,7 @@ class _Connection(object):
|
|||||||
# TODO: maybe change this to with newer pyOpenSSL APIs
|
# TODO: maybe change this to with newer pyOpenSSL APIs
|
||||||
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
|
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
raise NetLibError("SSL cipher specification error: %s" % str(v))
|
raise TlsException("SSL cipher specification error: %s" % str(v))
|
||||||
|
|
||||||
# SSLKEYLOGFILE
|
# SSLKEYLOGFILE
|
||||||
if log_ssl_key:
|
if log_ssl_key:
|
||||||
@ -546,7 +536,7 @@ class _Connection(object):
|
|||||||
elif alpn_select_callback is not None and alpn_select is None:
|
elif alpn_select_callback is not None and alpn_select is None:
|
||||||
context.set_alpn_select_callback(alpn_select_callback)
|
context.set_alpn_select_callback(alpn_select_callback)
|
||||||
elif alpn_select_callback is not None and alpn_select is not None:
|
elif alpn_select_callback is not None and alpn_select is not None:
|
||||||
raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
|
raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@ -594,7 +584,7 @@ class TCPClient(_Connection):
|
|||||||
context.use_privatekey_file(cert)
|
context.use_privatekey_file(cert)
|
||||||
context.use_certificate_file(cert)
|
context.use_certificate_file(cert)
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
raise NetLibError("SSL client certificate error: %s" % str(v))
|
raise TlsException("SSL client certificate error: %s" % str(v))
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
|
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
|
||||||
@ -618,15 +608,15 @@ class TCPClient(_Connection):
|
|||||||
self.connection.do_handshake()
|
self.connection.do_handshake()
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
if self.ssl_verification_error:
|
if self.ssl_verification_error:
|
||||||
raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v))
|
raise InvalidCertificateException("SSL handshake error: %s" % repr(v))
|
||||||
else:
|
else:
|
||||||
raise NetLibError("SSL handshake error: %s" % repr(v))
|
raise TlsException("SSL handshake error: %s" % repr(v))
|
||||||
|
|
||||||
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
|
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
|
||||||
# certificate validation failure
|
# certificate validation failure
|
||||||
verification_mode = sslctx_kwargs.get('verify_options', None)
|
verification_mode = sslctx_kwargs.get('verify_options', None)
|
||||||
if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER:
|
if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER:
|
||||||
raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed")
|
raise InvalidCertificateException("SSL handshake error: certificate verify failed")
|
||||||
|
|
||||||
self.ssl_established = True
|
self.ssl_established = True
|
||||||
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
|
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
|
||||||
@ -644,7 +634,7 @@ class TCPClient(_Connection):
|
|||||||
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
||||||
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
||||||
except (socket.error, IOError) as err:
|
except (socket.error, IOError) as err:
|
||||||
raise NetLibError(
|
raise TcpException(
|
||||||
'Error connecting to "%s": %s' %
|
'Error connecting to "%s": %s' %
|
||||||
(self.address.host, err))
|
(self.address.host, err))
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
@ -750,7 +740,7 @@ class BaseHandler(_Connection):
|
|||||||
try:
|
try:
|
||||||
self.connection.do_handshake()
|
self.connection.do_handshake()
|
||||||
except SSL.Error as v:
|
except SSL.Error as v:
|
||||||
raise NetLibError("SSL handshake error: %s" % repr(v))
|
raise TlsException("SSL handshake error: %s" % repr(v))
|
||||||
self.ssl_established = True
|
self.ssl_established = True
|
||||||
self.rfile.set_descriptor(self.connection)
|
self.rfile.set_descriptor(self.connection)
|
||||||
self.wfile.set_descriptor(self.connection)
|
self.wfile.set_descriptor(self.connection)
|
||||||
|
@ -2,6 +2,7 @@ import OpenSSL
|
|||||||
import mock
|
import mock
|
||||||
|
|
||||||
from netlib import tcp, http, tutils
|
from netlib import tcp, http, tutils
|
||||||
|
from netlib.exceptions import TcpDisconnect
|
||||||
from netlib.http import Headers
|
from netlib.http import Headers
|
||||||
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
|
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
|
||||||
from netlib.http.http2.frame import *
|
from netlib.http.http2.frame import *
|
||||||
@ -127,7 +128,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
|
|||||||
protocol.perform_server_connection_preface()
|
protocol.perform_server_connection_preface()
|
||||||
assert protocol.connection_preface_performed
|
assert protocol.connection_preface_performed
|
||||||
|
|
||||||
tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True)
|
tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True)
|
||||||
|
|
||||||
|
|
||||||
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
|
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
|
||||||
|
@ -12,6 +12,8 @@ import OpenSSL
|
|||||||
|
|
||||||
from netlib import tcp, certutils, tutils
|
from netlib import tcp, certutils, tutils
|
||||||
from . import tservers
|
from . import tservers
|
||||||
|
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
|
||||||
|
TcpTimeout, TcpDisconnect, TcpException
|
||||||
|
|
||||||
|
|
||||||
class EchoHandler(tcp.BaseHandler):
|
class EchoHandler(tcp.BaseHandler):
|
||||||
@ -93,7 +95,7 @@ class TestServerBind(tservers.ServerTestBase):
|
|||||||
c.connect()
|
c.connect()
|
||||||
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
||||||
return
|
return
|
||||||
except tcp.NetLibError: # port probably already in use
|
except TcpException: # port probably already in use
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -140,7 +142,7 @@ class TestFinishFail(tservers.ServerTestBase):
|
|||||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.wfile.write("foo\n")
|
c.wfile.write("foo\n")
|
||||||
c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect)
|
c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
|
||||||
c.finish()
|
c.finish()
|
||||||
|
|
||||||
|
|
||||||
@ -180,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
|
|||||||
def test_failure(self):
|
def test_failure(self):
|
||||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com")
|
tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
|
||||||
|
|
||||||
|
|
||||||
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
||||||
@ -224,7 +226,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
|
|||||||
c.connect()
|
c.connect()
|
||||||
|
|
||||||
tutils.raises(
|
tutils.raises(
|
||||||
tcp.NetLibInvalidCertificateError,
|
InvalidCertificateException,
|
||||||
c.convert_to_ssl,
|
c.convert_to_ssl,
|
||||||
verify_options=SSL.VERIFY_PEER,
|
verify_options=SSL.VERIFY_PEER,
|
||||||
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem"))
|
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem"))
|
||||||
@ -327,7 +329,7 @@ class TestSSLClientCert(tservers.ServerTestBase):
|
|||||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
tutils.raises(
|
tutils.raises(
|
||||||
tcp.NetLibError,
|
TlsException,
|
||||||
c.convert_to_ssl,
|
c.convert_to_ssl,
|
||||||
cert=tutils.test_data.path("data/clientcert/make")
|
cert=tutils.test_data.path("data/clientcert/make")
|
||||||
)
|
)
|
||||||
@ -432,7 +434,7 @@ class TestSSLDisconnect(tservers.ServerTestBase):
|
|||||||
# Excercise SSL.ZeroReturnError
|
# Excercise SSL.ZeroReturnError
|
||||||
c.rfile.read(10)
|
c.rfile.read(10)
|
||||||
c.close()
|
c.close()
|
||||||
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
|
tutils.raises(TcpDisconnect, c.wfile.write, "foo")
|
||||||
tutils.raises(Queue.Empty, self.q.get_nowait)
|
tutils.raises(Queue.Empty, self.q.get_nowait)
|
||||||
|
|
||||||
|
|
||||||
@ -447,7 +449,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
|
|||||||
# Exercise SSL.SysCallError
|
# Exercise SSL.SysCallError
|
||||||
c.rfile.read(10)
|
c.rfile.read(10)
|
||||||
c.close()
|
c.close()
|
||||||
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
|
tutils.raises(TcpDisconnect, c.wfile.write, "foo")
|
||||||
|
|
||||||
|
|
||||||
class TestDisconnect(tservers.ServerTestBase):
|
class TestDisconnect(tservers.ServerTestBase):
|
||||||
@ -470,7 +472,7 @@ class TestServerTimeOut(tservers.ServerTestBase):
|
|||||||
self.settimeout(0.01)
|
self.settimeout(0.01)
|
||||||
try:
|
try:
|
||||||
self.rfile.read(10)
|
self.rfile.read(10)
|
||||||
except tcp.NetLibTimeout:
|
except TcpTimeout:
|
||||||
self.timeout = True
|
self.timeout = True
|
||||||
|
|
||||||
def test_timeout(self):
|
def test_timeout(self):
|
||||||
@ -488,7 +490,7 @@ class TestTimeOut(tservers.ServerTestBase):
|
|||||||
c.connect()
|
c.connect()
|
||||||
c.settimeout(0.1)
|
c.settimeout(0.1)
|
||||||
assert c.gettimeout() == 0.1
|
assert c.gettimeout() == 0.1
|
||||||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
tutils.raises(TcpTimeout, c.rfile.read, 10)
|
||||||
|
|
||||||
|
|
||||||
class TestALPNClient(tservers.ServerTestBase):
|
class TestALPNClient(tservers.ServerTestBase):
|
||||||
@ -540,7 +542,7 @@ class TestSSLTimeOut(tservers.ServerTestBase):
|
|||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl()
|
c.convert_to_ssl()
|
||||||
c.settimeout(0.1)
|
c.settimeout(0.1)
|
||||||
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
|
tutils.raises(TcpTimeout, c.rfile.read, 10)
|
||||||
|
|
||||||
|
|
||||||
class TestDHParams(tservers.ServerTestBase):
|
class TestDHParams(tservers.ServerTestBase):
|
||||||
@ -570,7 +572,7 @@ class TestTCPClient:
|
|||||||
|
|
||||||
def test_conerr(self):
|
def test_conerr(self):
|
||||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||||
tutils.raises(tcp.NetLibError, c.connect)
|
tutils.raises(TcpException, c.connect)
|
||||||
|
|
||||||
|
|
||||||
class TestFileLike:
|
class TestFileLike:
|
||||||
@ -639,7 +641,7 @@ class TestFileLike:
|
|||||||
o = mock.MagicMock()
|
o = mock.MagicMock()
|
||||||
o.flush = mock.MagicMock(side_effect=socket.error)
|
o.flush = mock.MagicMock(side_effect=socket.error)
|
||||||
s.o = o
|
s.o = o
|
||||||
tutils.raises(tcp.NetLibDisconnect, s.flush)
|
tutils.raises(TcpDisconnect, s.flush)
|
||||||
|
|
||||||
def test_reader_read_error(self):
|
def test_reader_read_error(self):
|
||||||
s = cStringIO.StringIO("foobar\nfoobar")
|
s = cStringIO.StringIO("foobar\nfoobar")
|
||||||
@ -647,7 +649,7 @@ class TestFileLike:
|
|||||||
o = mock.MagicMock()
|
o = mock.MagicMock()
|
||||||
o.read = mock.MagicMock(side_effect=socket.error)
|
o.read = mock.MagicMock(side_effect=socket.error)
|
||||||
s.o = o
|
s.o = o
|
||||||
tutils.raises(tcp.NetLibDisconnect, s.read, 10)
|
tutils.raises(TcpDisconnect, s.read, 10)
|
||||||
|
|
||||||
def test_reset_timestamps(self):
|
def test_reset_timestamps(self):
|
||||||
s = cStringIO.StringIO("foobar\nfoobar")
|
s = cStringIO.StringIO("foobar\nfoobar")
|
||||||
@ -678,24 +680,24 @@ class TestFileLike:
|
|||||||
s = mock.MagicMock()
|
s = mock.MagicMock()
|
||||||
s.read = mock.MagicMock(side_effect=SSL.Error())
|
s.read = mock.MagicMock(side_effect=SSL.Error())
|
||||||
s = tcp.Reader(s)
|
s = tcp.Reader(s)
|
||||||
tutils.raises(tcp.NetLibSSLError, s.read, 1)
|
tutils.raises(TlsException, s.read, 1)
|
||||||
|
|
||||||
def test_read_syscall_ssl_error(self):
|
def test_read_syscall_ssl_error(self):
|
||||||
s = mock.MagicMock()
|
s = mock.MagicMock()
|
||||||
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
|
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
|
||||||
s = tcp.Reader(s)
|
s = tcp.Reader(s)
|
||||||
tutils.raises(tcp.NetLibSSLError, s.read, 1)
|
tutils.raises(TlsException, s.read, 1)
|
||||||
|
|
||||||
def test_reader_readline_disconnect(self):
|
def test_reader_readline_disconnect(self):
|
||||||
o = mock.MagicMock()
|
o = mock.MagicMock()
|
||||||
o.read = mock.MagicMock(side_effect=socket.error)
|
o.read = mock.MagicMock(side_effect=socket.error)
|
||||||
s = tcp.Reader(o)
|
s = tcp.Reader(o)
|
||||||
tutils.raises(tcp.NetLibDisconnect, s.readline, 10)
|
tutils.raises(TcpDisconnect, s.readline, 10)
|
||||||
|
|
||||||
def test_reader_incomplete_error(self):
|
def test_reader_incomplete_error(self):
|
||||||
s = cStringIO.StringIO("foobar")
|
s = cStringIO.StringIO("foobar")
|
||||||
s = tcp.Reader(s)
|
s = tcp.Reader(s)
|
||||||
tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10)
|
tutils.raises(TcpReadIncomplete, s.safe_read, 10)
|
||||||
|
|
||||||
|
|
||||||
class TestAddress:
|
class TestAddress:
|
||||||
|
@ -14,7 +14,7 @@ def test_hexdump():
|
|||||||
assert utils.hexdump("one\0" * 10)
|
assert utils.hexdump("one\0" * 10)
|
||||||
|
|
||||||
|
|
||||||
def test_cleanBin():
|
def test_clean_bin():
|
||||||
assert utils.clean_bin(b"one") == b"one"
|
assert utils.clean_bin(b"one") == b"one"
|
||||||
assert utils.clean_bin(b"\00ne") == b".ne"
|
assert utils.clean_bin(b"\00ne") == b".ne"
|
||||||
assert utils.clean_bin(b"\nne") == b"\nne"
|
assert utils.clean_bin(b"\nne") == b"\nne"
|
||||||
|
@ -176,7 +176,7 @@ class TestBadHandshake(tservers.ServerTestBase):
|
|||||||
"""
|
"""
|
||||||
handler = BadHandshakeHandler
|
handler = BadHandshakeHandler
|
||||||
|
|
||||||
@raises(tcp.NetLibDisconnect)
|
@raises(TcpDisconnect)
|
||||||
def test(self):
|
def test(self):
|
||||||
client = WebSocketsClient(("127.0.0.1", self.port))
|
client = WebSocketsClient(("127.0.0.1", self.port))
|
||||||
client.connect()
|
client.connect()
|
||||||
|
Loading…
Reference in New Issue
Block a user