organize exceptions, improve content-length handling

This commit is contained in:
Maximilian Hils 2015-09-17 02:14:14 +02:00
parent e1659f3fcf
commit dad9f06cb9
9 changed files with 130 additions and 92 deletions

View File

@ -16,7 +16,7 @@ class NetlibException(Exception):
super(NetlibException, self).__init__(message) super(NetlibException, self).__init__(message)
class ReadDisconnect(object): class Disconnect(object):
"""Immediate EOF""" """Immediate EOF"""
@ -24,9 +24,35 @@ class HttpException(NetlibException):
pass pass
class HttpReadDisconnect(HttpException, ReadDisconnect): class HttpReadDisconnect(HttpException, Disconnect):
pass pass
class HttpSyntaxException(HttpException): class HttpSyntaxException(HttpException):
pass pass
class TcpException(NetlibException):
pass
class TcpDisconnect(TcpException, Disconnect):
pass
class TcpReadIncomplete(TcpException):
pass
class TcpTimeout(TcpException):
pass
class TlsException(NetlibException):
pass
class InvalidCertificateException(TlsException):
pass

View File

@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False):
if not preserve_transfer_encoding: if not preserve_transfer_encoding:
headers.pop(b"Transfer-Encoding", None) headers.pop(b"Transfer-Encoding", None)
# If body is defined (i.e. not None or CONTENT_MISSING), we always # If body is defined (i.e. not None or CONTENT_MISSING),
# add a content-length header. # we now need to set a content-length header.
if response.body or response.body == b"": if response.body or response.body == b"":
headers[b"Content-Length"] = str(len(response.body)).encode("ascii") headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
return bytes(headers) return bytes(headers)

View File

@ -4,15 +4,14 @@ import sys
import re import re
from ... import utils from ... import utils
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect
from .. import Request, Response, Headers from .. import Request, Response, Headers
from netlib.tcp import NetLibDisconnect
def read_request(rfile, body_size_limit=None): def read_request(rfile, body_size_limit=None):
request = read_request_head(rfile) request = read_request_head(rfile)
expected_body_size = expected_http_body_size(request) expected_body_size = expected_http_body_size(request)
request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request.timestamp_end = time.time() request.timestamp_end = time.time()
return request return request
@ -51,7 +50,7 @@ def read_request_head(rfile):
def read_response(rfile, request, body_size_limit=None): def read_response(rfile, request, body_size_limit=None):
response = read_response_head(rfile) response = read_response_head(rfile)
expected_body_size = expected_http_body_size(request, response) expected_body_size = expected_http_body_size(request, response)
response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
response.timestamp_end = time.time() response.timestamp_end = time.time()
return response return response
@ -215,7 +214,7 @@ def _get_first_line(rfile):
if line == b"\r\n" or line == b"\n": if line == b"\r\n" or line == b"\n":
# Possible leftover from previous message # Possible leftover from previous message
line = rfile.readline() line = rfile.readline()
except NetLibDisconnect: except TcpDisconnect:
raise HttpReadDisconnect() raise HttpReadDisconnect()
if not line: if not line:
raise HttpReadDisconnect() raise HttpReadDisconnect()

View File

@ -231,7 +231,7 @@ class Request(object):
self.path = path self.path = path
self.httpversion = httpversion self.httpversion = httpversion
self.headers = headers self.headers = headers
self.body = body self._body = body
self.timestamp_start = timestamp_start self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end
self.form_out = form_out or form_in self.form_out = form_out or form_in
@ -452,6 +452,16 @@ class Request(object):
raise ValueError("Invalid URL: %s" % url) raise ValueError("Invalid URL: %s" % url)
self.scheme, self.host, self.port, self.path = parts self.scheme, self.host, self.port, self.path = parts
@property
def body(self):
return self._body
@body.setter
def body(self, body):
self._body = body
if isinstance(body, bytes):
self.headers["Content-Length"] = str(len(body)).encode()
@property @property
def content(self): # pragma: no cover def content(self): # pragma: no cover
# TODO: remove deprecated getter # TODO: remove deprecated getter
@ -488,7 +498,7 @@ class Response(object):
self.status_code = status_code self.status_code = status_code
self.msg = msg self.msg = msg
self.headers = headers self.headers = headers
self.body = body self._body = body
self.timestamp_start = timestamp_start self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end
@ -551,6 +561,16 @@ class Response(object):
) )
self.headers.set_all("Set-Cookie", values) self.headers.set_all("Set-Cookie", values)
@property
def body(self):
return self._body
@body.setter
def body(self, body):
self._body = body
if isinstance(body, bytes):
self.headers["Content-Length"] = str(len(body)).encode()
@property @property
def content(self): # pragma: no cover def content(self): # pragma: no cover
# TODO: remove deprecated getter # TODO: remove deprecated getter

View File

@ -16,6 +16,9 @@ from . import certutils, version_check
# This is a rather hackish way to make sure that # This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed. # the latest version of pyOpenSSL is actually installed.
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
TcpTimeout, TcpDisconnect, TcpException
version_check.check_pyopenssl_version() version_check.check_pyopenssl_version()
@ -24,11 +27,17 @@ EINTR = 4
# To enable all SSL methods use: SSLv23 # To enable all SSL methods use: SSLv23
# then add options to disable certain methods # then add options to disable certain methods
# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
SSL_BASIC_OPTIONS = (
SSL.OP_CIPHER_SERVER_PREFERENCE
)
if hasattr(SSL, "OP_NO_COMPRESSION"):
SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION
SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
SSL_DEFAULT_OPTIONS = ( SSL_DEFAULT_OPTIONS = (
SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv2 |
SSL.OP_NO_SSLv3 | SSL.OP_NO_SSLv3 |
SSL.OP_CIPHER_SERVER_PREFERENCE SSL_BASIC_OPTIONS
) )
if hasattr(SSL, "OP_NO_COMPRESSION"): if hasattr(SSL, "OP_NO_COMPRESSION"):
SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
@ -39,42 +48,17 @@ Don't ask...
https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
""" """
sslversion_choices = { sslversion_choices = {
"all": (SSL.SSLv23_METHOD, 0), "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS),
# SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
# TLSv1_METHOD would be TLS 1.0 only # TLSv1_METHOD would be TLS 1.0 only
"secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)),
"SSLv2": (SSL.SSLv2_METHOD, 0), "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS),
"SSLv3": (SSL.SSLv3_METHOD, 0), "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS),
"TLSv1": (SSL.TLSv1_METHOD, 0), "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS),
"TLSv1_1": (SSL.TLSv1_1_METHOD, 0), "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS),
"TLSv1_2": (SSL.TLSv1_2_METHOD, 0), "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
} }
class NetLibError(Exception):
pass
class NetLibDisconnect(NetLibError):
pass
class NetLibIncomplete(NetLibError):
pass
class NetLibTimeout(NetLibError):
pass
class NetLibSSLError(NetLibError):
pass
class NetLibInvalidCertificateError(NetLibSSLError):
pass
class SSLKeyLogger(object): class SSLKeyLogger(object):
def __init__(self, filename): def __init__(self, filename):
@ -168,17 +152,17 @@ class Writer(_FileLike):
def flush(self): def flush(self):
""" """
May raise NetLibDisconnect May raise TcpDisconnect
""" """
if hasattr(self.o, "flush"): if hasattr(self.o, "flush"):
try: try:
self.o.flush() self.o.flush()
except (socket.error, IOError) as v: except (socket.error, IOError) as v:
raise NetLibDisconnect(str(v)) raise TcpDisconnect(str(v))
def write(self, v): def write(self, v):
""" """
May raise NetLibDisconnect May raise TcpDisconnect
""" """
if v: if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time() self.first_byte_timestamp = self.first_byte_timestamp or time.time()
@ -191,7 +175,7 @@ class Writer(_FileLike):
self.add_log(v[:r]) self.add_log(v[:r])
return r return r
except (SSL.Error, socket.error) as e: except (SSL.Error, socket.error) as e:
raise NetLibDisconnect(str(e)) raise TcpDisconnect(str(e))
class Reader(_FileLike): class Reader(_FileLike):
@ -210,23 +194,29 @@ class Reader(_FileLike):
try: try:
data = self.o.read(rlen) data = self.o.read(rlen)
except SSL.ZeroReturnError: except SSL.ZeroReturnError:
# TLS connection was shut down cleanly
break break
except SSL.WantReadError: except (SSL.WantWriteError, SSL.WantReadError):
# From the OpenSSL docs:
# If the underlying BIO is non-blocking, SSL_read() will also return when the
# underlying BIO could not satisfy the needs of SSL_read() to continue the
# operation. In this case a call to SSL_get_error with the return value of
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
if (time.time() - start) < self.o.gettimeout(): if (time.time() - start) < self.o.gettimeout():
time.sleep(0.1) time.sleep(0.1)
continue continue
else: else:
raise NetLibTimeout raise TcpTimeout()
except socket.timeout: except socket.timeout:
raise NetLibTimeout raise TcpTimeout()
except socket.error: except socket.error as e:
raise NetLibDisconnect raise TcpDisconnect(str(e))
except SSL.SysCallError as e: except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'): if e.args == (-1, 'Unexpected EOF'):
break break
raise NetLibSSLError(e.message) raise TlsException(e.message)
except SSL.Error as e: except SSL.Error as e:
raise NetLibSSLError(e.message) raise TlsException(e.message)
self.first_byte_timestamp = self.first_byte_timestamp or time.time() self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data: if not data:
break break
@ -260,9 +250,9 @@ class Reader(_FileLike):
result = self.read(length) result = self.read(length)
if length != -1 and len(result) != length: if length != -1 and len(result) != length:
if not result: if not result:
raise NetLibDisconnect() raise TcpDisconnect()
else: else:
raise NetLibIncomplete( raise TcpReadIncomplete(
"Expected %s bytes, got %s" % (length, len(result)) "Expected %s bytes, got %s" % (length, len(result))
) )
return result return result
@ -275,15 +265,15 @@ class Reader(_FileLike):
Up to the next N bytes if peeking is successful. Up to the next N bytes if peeking is successful.
Raises: Raises:
NetLibError if there was an error with the socket TcpException if there was an error with the socket
NetLibSSLError if there was an error with pyOpenSSL. TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
""" """
if isinstance(self.o, socket._fileobject): if isinstance(self.o, socket._fileobject):
try: try:
return self.o._sock.recv(length, socket.MSG_PEEK) return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e: except socket.error as e:
raise NetLibError(repr(e)) raise TcpException(repr(e))
elif isinstance(self.o, SSL.Connection): elif isinstance(self.o, SSL.Connection):
try: try:
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
@ -296,7 +286,7 @@ class Reader(_FileLike):
self.o._raise_ssl_error(self.o._ssl, result) self.o._raise_ssl_error(self.o._ssl, result)
return SSL._ffi.buffer(buf, result)[:] return SSL._ffi.buffer(buf, result)[:]
except SSL.Error as e: except SSL.Error as e:
six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2])
else: else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
@ -461,7 +451,7 @@ class _Connection(object):
try: try:
self.wfile.flush() self.wfile.flush()
self.wfile.close() self.wfile.close()
except NetLibDisconnect: except TcpDisconnect:
pass pass
self.rfile.close() self.rfile.close()
@ -525,7 +515,7 @@ class _Connection(object):
# TODO: maybe change this to with newer pyOpenSSL APIs # TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v: except SSL.Error as v:
raise NetLibError("SSL cipher specification error: %s" % str(v)) raise TlsException("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE # SSLKEYLOGFILE
if log_ssl_key: if log_ssl_key:
@ -546,7 +536,7 @@ class _Connection(object):
elif alpn_select_callback is not None and alpn_select is None: elif alpn_select_callback is not None and alpn_select is None:
context.set_alpn_select_callback(alpn_select_callback) context.set_alpn_select_callback(alpn_select_callback)
elif alpn_select_callback is not None and alpn_select is not None: elif alpn_select_callback is not None and alpn_select is not None:
raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
return context return context
@ -594,7 +584,7 @@ class TCPClient(_Connection):
context.use_privatekey_file(cert) context.use_privatekey_file(cert)
context.use_certificate_file(cert) context.use_certificate_file(cert)
except SSL.Error as v: except SSL.Error as v:
raise NetLibError("SSL client certificate error: %s" % str(v)) raise TlsException("SSL client certificate error: %s" % str(v))
return context return context
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
@ -618,15 +608,15 @@ class TCPClient(_Connection):
self.connection.do_handshake() self.connection.do_handshake()
except SSL.Error as v: except SSL.Error as v:
if self.ssl_verification_error: if self.ssl_verification_error:
raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) raise InvalidCertificateException("SSL handshake error: %s" % repr(v))
else: else:
raise NetLibError("SSL handshake error: %s" % repr(v)) raise TlsException("SSL handshake error: %s" % repr(v))
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
# certificate validation failure # certificate validation failure
verification_mode = sslctx_kwargs.get('verify_options', None) verification_mode = sslctx_kwargs.get('verify_options', None)
if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER:
raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") raise InvalidCertificateException("SSL handshake error: certificate verify failed")
self.ssl_established = True self.ssl_established = True
self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
@ -644,7 +634,7 @@ class TCPClient(_Connection):
self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError) as err: except (socket.error, IOError) as err:
raise NetLibError( raise TcpException(
'Error connecting to "%s": %s' % 'Error connecting to "%s": %s' %
(self.address.host, err)) (self.address.host, err))
self.connection = connection self.connection = connection
@ -750,7 +740,7 @@ class BaseHandler(_Connection):
try: try:
self.connection.do_handshake() self.connection.do_handshake()
except SSL.Error as v: except SSL.Error as v:
raise NetLibError("SSL handshake error: %s" % repr(v)) raise TlsException("SSL handshake error: %s" % repr(v))
self.ssl_established = True self.ssl_established = True
self.rfile.set_descriptor(self.connection) self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection)

View File

@ -2,6 +2,7 @@ import OpenSSL
import mock import mock
from netlib import tcp, http, tutils from netlib import tcp, http, tutils
from netlib.exceptions import TcpDisconnect
from netlib.http import Headers from netlib.http import Headers
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2.frame import * from netlib.http.http2.frame import *
@ -127,7 +128,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
protocol.perform_server_connection_preface() protocol.perform_server_connection_preface()
assert protocol.connection_preface_performed assert protocol.connection_preface_performed
tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True) tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase): class TestPerformClientConnectionPreface(tservers.ServerTestBase):

View File

@ -12,6 +12,8 @@ import OpenSSL
from netlib import tcp, certutils, tutils from netlib import tcp, certutils, tutils
from . import tservers from . import tservers
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
TcpTimeout, TcpDisconnect, TcpException
class EchoHandler(tcp.BaseHandler): class EchoHandler(tcp.BaseHandler):
@ -93,7 +95,7 @@ class TestServerBind(tservers.ServerTestBase):
c.connect() c.connect()
assert c.rfile.readline() == str(("127.0.0.1", random_port)) assert c.rfile.readline() == str(("127.0.0.1", random_port))
return return
except tcp.NetLibError: # port probably already in use except TcpException: # port probably already in use
pass pass
@ -140,7 +142,7 @@ class TestFinishFail(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() c.connect()
c.wfile.write("foo\n") c.wfile.write("foo\n")
c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
c.finish() c.finish()
@ -180,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self): def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() c.connect()
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@ -224,7 +226,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
c.connect() c.connect()
tutils.raises( tutils.raises(
tcp.NetLibInvalidCertificateError, InvalidCertificateException,
c.convert_to_ssl, c.convert_to_ssl,
verify_options=SSL.VERIFY_PEER, verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem"))
@ -327,7 +329,7 @@ class TestSSLClientCert(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port)) c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect() c.connect()
tutils.raises( tutils.raises(
tcp.NetLibError, TlsException,
c.convert_to_ssl, c.convert_to_ssl,
cert=tutils.test_data.path("data/clientcert/make") cert=tutils.test_data.path("data/clientcert/make")
) )
@ -432,7 +434,7 @@ class TestSSLDisconnect(tservers.ServerTestBase):
# Excercise SSL.ZeroReturnError # Excercise SSL.ZeroReturnError
c.rfile.read(10) c.rfile.read(10)
c.close() c.close()
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") tutils.raises(TcpDisconnect, c.wfile.write, "foo")
tutils.raises(Queue.Empty, self.q.get_nowait) tutils.raises(Queue.Empty, self.q.get_nowait)
@ -447,7 +449,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
# Exercise SSL.SysCallError # Exercise SSL.SysCallError
c.rfile.read(10) c.rfile.read(10)
c.close() c.close()
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") tutils.raises(TcpDisconnect, c.wfile.write, "foo")
class TestDisconnect(tservers.ServerTestBase): class TestDisconnect(tservers.ServerTestBase):
@ -470,7 +472,7 @@ class TestServerTimeOut(tservers.ServerTestBase):
self.settimeout(0.01) self.settimeout(0.01)
try: try:
self.rfile.read(10) self.rfile.read(10)
except tcp.NetLibTimeout: except TcpTimeout:
self.timeout = True self.timeout = True
def test_timeout(self): def test_timeout(self):
@ -488,7 +490,7 @@ class TestTimeOut(tservers.ServerTestBase):
c.connect() c.connect()
c.settimeout(0.1) c.settimeout(0.1)
assert c.gettimeout() == 0.1 assert c.gettimeout() == 0.1
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestALPNClient(tservers.ServerTestBase): class TestALPNClient(tservers.ServerTestBase):
@ -540,7 +542,7 @@ class TestSSLTimeOut(tservers.ServerTestBase):
c.connect() c.connect()
c.convert_to_ssl() c.convert_to_ssl()
c.settimeout(0.1) c.settimeout(0.1)
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestDHParams(tservers.ServerTestBase): class TestDHParams(tservers.ServerTestBase):
@ -570,7 +572,7 @@ class TestTCPClient:
def test_conerr(self): def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0)) c = tcp.TCPClient(("127.0.0.1", 0))
tutils.raises(tcp.NetLibError, c.connect) tutils.raises(TcpException, c.connect)
class TestFileLike: class TestFileLike:
@ -639,7 +641,7 @@ class TestFileLike:
o = mock.MagicMock() o = mock.MagicMock()
o.flush = mock.MagicMock(side_effect=socket.error) o.flush = mock.MagicMock(side_effect=socket.error)
s.o = o s.o = o
tutils.raises(tcp.NetLibDisconnect, s.flush) tutils.raises(TcpDisconnect, s.flush)
def test_reader_read_error(self): def test_reader_read_error(self):
s = cStringIO.StringIO("foobar\nfoobar") s = cStringIO.StringIO("foobar\nfoobar")
@ -647,7 +649,7 @@ class TestFileLike:
o = mock.MagicMock() o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error) o.read = mock.MagicMock(side_effect=socket.error)
s.o = o s.o = o
tutils.raises(tcp.NetLibDisconnect, s.read, 10) tutils.raises(TcpDisconnect, s.read, 10)
def test_reset_timestamps(self): def test_reset_timestamps(self):
s = cStringIO.StringIO("foobar\nfoobar") s = cStringIO.StringIO("foobar\nfoobar")
@ -678,24 +680,24 @@ class TestFileLike:
s = mock.MagicMock() s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.Error()) s.read = mock.MagicMock(side_effect=SSL.Error())
s = tcp.Reader(s) s = tcp.Reader(s)
tutils.raises(tcp.NetLibSSLError, s.read, 1) tutils.raises(TlsException, s.read, 1)
def test_read_syscall_ssl_error(self): def test_read_syscall_ssl_error(self):
s = mock.MagicMock() s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.SysCallError()) s.read = mock.MagicMock(side_effect=SSL.SysCallError())
s = tcp.Reader(s) s = tcp.Reader(s)
tutils.raises(tcp.NetLibSSLError, s.read, 1) tutils.raises(TlsException, s.read, 1)
def test_reader_readline_disconnect(self): def test_reader_readline_disconnect(self):
o = mock.MagicMock() o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error) o.read = mock.MagicMock(side_effect=socket.error)
s = tcp.Reader(o) s = tcp.Reader(o)
tutils.raises(tcp.NetLibDisconnect, s.readline, 10) tutils.raises(TcpDisconnect, s.readline, 10)
def test_reader_incomplete_error(self): def test_reader_incomplete_error(self):
s = cStringIO.StringIO("foobar") s = cStringIO.StringIO("foobar")
s = tcp.Reader(s) s = tcp.Reader(s)
tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10) tutils.raises(TcpReadIncomplete, s.safe_read, 10)
class TestAddress: class TestAddress:

View File

@ -14,7 +14,7 @@ def test_hexdump():
assert utils.hexdump("one\0" * 10) assert utils.hexdump("one\0" * 10)
def test_cleanBin(): def test_clean_bin():
assert utils.clean_bin(b"one") == b"one" assert utils.clean_bin(b"one") == b"one"
assert utils.clean_bin(b"\00ne") == b".ne" assert utils.clean_bin(b"\00ne") == b".ne"
assert utils.clean_bin(b"\nne") == b"\nne" assert utils.clean_bin(b"\nne") == b"\nne"

View File

@ -176,7 +176,7 @@ class TestBadHandshake(tservers.ServerTestBase):
""" """
handler = BadHandshakeHandler handler = BadHandshakeHandler
@raises(tcp.NetLibDisconnect) @raises(TcpDisconnect)
def test(self): def test(self):
client = WebSocketsClient(("127.0.0.1", self.port)) client = WebSocketsClient(("127.0.0.1", self.port))
client.connect() client.connect()