organize exceptions, improve content-length handling

This commit is contained in:
Maximilian Hils 2015-09-17 02:14:14 +02:00
parent e1659f3fcf
commit dad9f06cb9
9 changed files with 130 additions and 92 deletions

View File

@ -16,7 +16,7 @@ class NetlibException(Exception):
super(NetlibException, self).__init__(message)
class ReadDisconnect(object):
class Disconnect(object):
"""Immediate EOF"""
@ -24,9 +24,35 @@ class HttpException(NetlibException):
pass
class HttpReadDisconnect(HttpException, ReadDisconnect):
class HttpReadDisconnect(HttpException, Disconnect):
pass
class HttpSyntaxException(HttpException):
pass
class TcpException(NetlibException):
pass
class TcpDisconnect(TcpException, Disconnect):
pass
class TcpReadIncomplete(TcpException):
pass
class TcpTimeout(TcpException):
pass
class TlsException(NetlibException):
pass
class InvalidCertificateException(TlsException):
pass

View File

@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False):
if not preserve_transfer_encoding:
headers.pop(b"Transfer-Encoding", None)
# If body is defined (i.e. not None or CONTENT_MISSING), we always
# add a content-length header.
if response.body or response.body == b"":
headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
# If body is defined (i.e. not None or CONTENT_MISSING),
# we now need to set a content-length header.
if response.body or response.body == b"":
headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
return bytes(headers)

View File

@ -4,15 +4,14 @@ import sys
import re
from ... import utils
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException
from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect
from .. import Request, Response, Headers
from netlib.tcp import NetLibDisconnect
def read_request(rfile, body_size_limit=None):
request = read_request_head(rfile)
expected_body_size = expected_http_body_size(request)
request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
request.timestamp_end = time.time()
return request
@ -51,7 +50,7 @@ def read_request_head(rfile):
def read_response(rfile, request, body_size_limit=None):
response = read_response_head(rfile)
expected_body_size = expected_http_body_size(request, response)
response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
response.timestamp_end = time.time()
return response
@ -215,7 +214,7 @@ def _get_first_line(rfile):
if line == b"\r\n" or line == b"\n":
# Possible leftover from previous message
line = rfile.readline()
except NetLibDisconnect:
except TcpDisconnect:
raise HttpReadDisconnect()
if not line:
raise HttpReadDisconnect()

View File

@ -231,7 +231,7 @@ class Request(object):
self.path = path
self.httpversion = httpversion
self.headers = headers
self.body = body
self._body = body
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
self.form_out = form_out or form_in
@ -452,6 +452,16 @@ class Request(object):
raise ValueError("Invalid URL: %s" % url)
self.scheme, self.host, self.port, self.path = parts
@property
def body(self):
return self._body
@body.setter
def body(self, body):
self._body = body
if isinstance(body, bytes):
self.headers["Content-Length"] = str(len(body)).encode()
@property
def content(self): # pragma: no cover
# TODO: remove deprecated getter
@ -488,7 +498,7 @@ class Response(object):
self.status_code = status_code
self.msg = msg
self.headers = headers
self.body = body
self._body = body
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
@ -551,6 +561,16 @@ class Response(object):
)
self.headers.set_all("Set-Cookie", values)
@property
def body(self):
return self._body
@body.setter
def body(self, body):
self._body = body
if isinstance(body, bytes):
self.headers["Content-Length"] = str(len(body)).encode()
@property
def content(self): # pragma: no cover
# TODO: remove deprecated getter

View File

@ -16,6 +16,9 @@ from . import certutils, version_check
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
TcpTimeout, TcpDisconnect, TcpException
version_check.check_pyopenssl_version()
@ -24,11 +27,17 @@ EINTR = 4
# To enable all SSL methods use: SSLv23
# then add options to disable certain methods
# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
SSL_BASIC_OPTIONS = (
SSL.OP_CIPHER_SERVER_PREFERENCE
)
if hasattr(SSL, "OP_NO_COMPRESSION"):
SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION
SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
SSL_DEFAULT_OPTIONS = (
SSL.OP_NO_SSLv2 |
SSL.OP_NO_SSLv3 |
SSL.OP_CIPHER_SERVER_PREFERENCE
SSL_BASIC_OPTIONS
)
if hasattr(SSL, "OP_NO_COMPRESSION"):
SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
@ -39,42 +48,17 @@ Don't ask...
https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
"""
sslversion_choices = {
"all": (SSL.SSLv23_METHOD, 0),
"all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS),
# SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
# TLSv1_METHOD would be TLS 1.0 only
"secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)),
"SSLv2": (SSL.SSLv2_METHOD, 0),
"SSLv3": (SSL.SSLv3_METHOD, 0),
"TLSv1": (SSL.TLSv1_METHOD, 0),
"TLSv1_1": (SSL.TLSv1_1_METHOD, 0),
"TLSv1_2": (SSL.TLSv1_2_METHOD, 0),
"secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)),
"SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS),
"SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS),
"TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS),
"TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS),
"TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
}
class NetLibError(Exception):
pass
class NetLibDisconnect(NetLibError):
pass
class NetLibIncomplete(NetLibError):
pass
class NetLibTimeout(NetLibError):
pass
class NetLibSSLError(NetLibError):
pass
class NetLibInvalidCertificateError(NetLibSSLError):
pass
class SSLKeyLogger(object):
def __init__(self, filename):
@ -168,17 +152,17 @@ class Writer(_FileLike):
def flush(self):
"""
May raise NetLibDisconnect
May raise TcpDisconnect
"""
if hasattr(self.o, "flush"):
try:
self.o.flush()
except (socket.error, IOError) as v:
raise NetLibDisconnect(str(v))
raise TcpDisconnect(str(v))
def write(self, v):
"""
May raise NetLibDisconnect
May raise TcpDisconnect
"""
if v:
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
@ -191,7 +175,7 @@ class Writer(_FileLike):
self.add_log(v[:r])
return r
except (SSL.Error, socket.error) as e:
raise NetLibDisconnect(str(e))
raise TcpDisconnect(str(e))
class Reader(_FileLike):
@ -210,23 +194,29 @@ class Reader(_FileLike):
try:
data = self.o.read(rlen)
except SSL.ZeroReturnError:
# TLS connection was shut down cleanly
break
except SSL.WantReadError:
except (SSL.WantWriteError, SSL.WantReadError):
# From the OpenSSL docs:
# If the underlying BIO is non-blocking, SSL_read() will also return when the
# underlying BIO could not satisfy the needs of SSL_read() to continue the
# operation. In this case a call to SSL_get_error with the return value of
# SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
if (time.time() - start) < self.o.gettimeout():
time.sleep(0.1)
continue
else:
raise NetLibTimeout
raise TcpTimeout()
except socket.timeout:
raise NetLibTimeout
except socket.error:
raise NetLibDisconnect
raise TcpTimeout()
except socket.error as e:
raise TcpDisconnect(str(e))
except SSL.SysCallError as e:
if e.args == (-1, 'Unexpected EOF'):
break
raise NetLibSSLError(e.message)
raise TlsException(e.message)
except SSL.Error as e:
raise NetLibSSLError(e.message)
raise TlsException(e.message)
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
if not data:
break
@ -260,9 +250,9 @@ class Reader(_FileLike):
result = self.read(length)
if length != -1 and len(result) != length:
if not result:
raise NetLibDisconnect()
raise TcpDisconnect()
else:
raise NetLibIncomplete(
raise TcpReadIncomplete(
"Expected %s bytes, got %s" % (length, len(result))
)
return result
@ -275,15 +265,15 @@ class Reader(_FileLike):
Up to the next N bytes if peeking is successful.
Raises:
NetLibError if there was an error with the socket
NetLibSSLError if there was an error with pyOpenSSL.
TcpException if there was an error with the socket
TlsException if there was an error with pyOpenSSL.
NotImplementedError if the underlying file object is not a (pyOpenSSL) socket
"""
if isinstance(self.o, socket._fileobject):
try:
return self.o._sock.recv(length, socket.MSG_PEEK)
except socket.error as e:
raise NetLibError(repr(e))
raise TcpException(repr(e))
elif isinstance(self.o, SSL.Connection):
try:
if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15):
@ -296,7 +286,7 @@ class Reader(_FileLike):
self.o._raise_ssl_error(self.o._ssl, result)
return SSL._ffi.buffer(buf, result)[:]
except SSL.Error as e:
six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2])
six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2])
else:
raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
@ -461,7 +451,7 @@ class _Connection(object):
try:
self.wfile.flush()
self.wfile.close()
except NetLibDisconnect:
except TcpDisconnect:
pass
self.rfile.close()
@ -525,7 +515,7 @@ class _Connection(object):
# TODO: maybe change this to with newer pyOpenSSL APIs
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
raise NetLibError("SSL cipher specification error: %s" % str(v))
raise TlsException("SSL cipher specification error: %s" % str(v))
# SSLKEYLOGFILE
if log_ssl_key:
@ -546,7 +536,7 @@ class _Connection(object):
elif alpn_select_callback is not None and alpn_select is None:
context.set_alpn_select_callback(alpn_select_callback)
elif alpn_select_callback is not None and alpn_select is not None:
raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
return context
@ -594,7 +584,7 @@ class TCPClient(_Connection):
context.use_privatekey_file(cert)
context.use_certificate_file(cert)
except SSL.Error as v:
raise NetLibError("SSL client certificate error: %s" % str(v))
raise TlsException("SSL client certificate error: %s" % str(v))
return context
def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
@ -618,15 +608,15 @@ class TCPClient(_Connection):
self.connection.do_handshake()
except SSL.Error as v:
if self.ssl_verification_error:
raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v))
raise InvalidCertificateException("SSL handshake error: %s" % repr(v))
else:
raise NetLibError("SSL handshake error: %s" % repr(v))
raise TlsException("SSL handshake error: %s" % repr(v))
# Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
# certificate validation failure
verification_mode = sslctx_kwargs.get('verify_options', None)
if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER:
raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed")
raise InvalidCertificateException("SSL handshake error: certificate verify failed")
self.ssl_established = True
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
@ -644,7 +634,7 @@ class TCPClient(_Connection):
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError) as err:
raise NetLibError(
raise TcpException(
'Error connecting to "%s": %s' %
(self.address.host, err))
self.connection = connection
@ -750,7 +740,7 @@ class BaseHandler(_Connection):
try:
self.connection.do_handshake()
except SSL.Error as v:
raise NetLibError("SSL handshake error: %s" % repr(v))
raise TlsException("SSL handshake error: %s" % repr(v))
self.ssl_established = True
self.rfile.set_descriptor(self.connection)
self.wfile.set_descriptor(self.connection)

View File

@ -2,6 +2,7 @@ import OpenSSL
import mock
from netlib import tcp, http, tutils
from netlib.exceptions import TcpDisconnect
from netlib.http import Headers
from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2.frame import *
@ -127,7 +128,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
protocol.perform_server_connection_preface()
assert protocol.connection_preface_performed
tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True)
tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):

View File

@ -12,6 +12,8 @@ import OpenSSL
from netlib import tcp, certutils, tutils
from . import tservers
from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \
TcpTimeout, TcpDisconnect, TcpException
class EchoHandler(tcp.BaseHandler):
@ -93,7 +95,7 @@ class TestServerBind(tservers.ServerTestBase):
c.connect()
assert c.rfile.readline() == str(("127.0.0.1", random_port))
return
except tcp.NetLibError: # port probably already in use
except TcpException: # port probably already in use
pass
@ -140,7 +142,7 @@ class TestFinishFail(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.wfile.write("foo\n")
c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect)
c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
c.finish()
@ -180,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com")
tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@ -224,7 +226,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
c.connect()
tutils.raises(
tcp.NetLibInvalidCertificateError,
InvalidCertificateException,
c.convert_to_ssl,
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem"))
@ -327,7 +329,7 @@ class TestSSLClientCert(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
tutils.raises(
tcp.NetLibError,
TlsException,
c.convert_to_ssl,
cert=tutils.test_data.path("data/clientcert/make")
)
@ -432,7 +434,7 @@ class TestSSLDisconnect(tservers.ServerTestBase):
# Excercise SSL.ZeroReturnError
c.rfile.read(10)
c.close()
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
tutils.raises(TcpDisconnect, c.wfile.write, "foo")
tutils.raises(Queue.Empty, self.q.get_nowait)
@ -447,7 +449,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
# Exercise SSL.SysCallError
c.rfile.read(10)
c.close()
tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo")
tutils.raises(TcpDisconnect, c.wfile.write, "foo")
class TestDisconnect(tservers.ServerTestBase):
@ -470,7 +472,7 @@ class TestServerTimeOut(tservers.ServerTestBase):
self.settimeout(0.01)
try:
self.rfile.read(10)
except tcp.NetLibTimeout:
except TcpTimeout:
self.timeout = True
def test_timeout(self):
@ -488,7 +490,7 @@ class TestTimeOut(tservers.ServerTestBase):
c.connect()
c.settimeout(0.1)
assert c.gettimeout() == 0.1
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestALPNClient(tservers.ServerTestBase):
@ -540,7 +542,7 @@ class TestSSLTimeOut(tservers.ServerTestBase):
c.connect()
c.convert_to_ssl()
c.settimeout(0.1)
tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10)
tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestDHParams(tservers.ServerTestBase):
@ -570,7 +572,7 @@ class TestTCPClient:
def test_conerr(self):
c = tcp.TCPClient(("127.0.0.1", 0))
tutils.raises(tcp.NetLibError, c.connect)
tutils.raises(TcpException, c.connect)
class TestFileLike:
@ -639,7 +641,7 @@ class TestFileLike:
o = mock.MagicMock()
o.flush = mock.MagicMock(side_effect=socket.error)
s.o = o
tutils.raises(tcp.NetLibDisconnect, s.flush)
tutils.raises(TcpDisconnect, s.flush)
def test_reader_read_error(self):
s = cStringIO.StringIO("foobar\nfoobar")
@ -647,7 +649,7 @@ class TestFileLike:
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s.o = o
tutils.raises(tcp.NetLibDisconnect, s.read, 10)
tutils.raises(TcpDisconnect, s.read, 10)
def test_reset_timestamps(self):
s = cStringIO.StringIO("foobar\nfoobar")
@ -678,24 +680,24 @@ class TestFileLike:
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.Error())
s = tcp.Reader(s)
tutils.raises(tcp.NetLibSSLError, s.read, 1)
tutils.raises(TlsException, s.read, 1)
def test_read_syscall_ssl_error(self):
s = mock.MagicMock()
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
s = tcp.Reader(s)
tutils.raises(tcp.NetLibSSLError, s.read, 1)
tutils.raises(TlsException, s.read, 1)
def test_reader_readline_disconnect(self):
o = mock.MagicMock()
o.read = mock.MagicMock(side_effect=socket.error)
s = tcp.Reader(o)
tutils.raises(tcp.NetLibDisconnect, s.readline, 10)
tutils.raises(TcpDisconnect, s.readline, 10)
def test_reader_incomplete_error(self):
s = cStringIO.StringIO("foobar")
s = tcp.Reader(s)
tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10)
tutils.raises(TcpReadIncomplete, s.safe_read, 10)
class TestAddress:

View File

@ -14,7 +14,7 @@ def test_hexdump():
assert utils.hexdump("one\0" * 10)
def test_cleanBin():
def test_clean_bin():
assert utils.clean_bin(b"one") == b"one"
assert utils.clean_bin(b"\00ne") == b".ne"
assert utils.clean_bin(b"\nne") == b"\nne"

View File

@ -176,7 +176,7 @@ class TestBadHandshake(tservers.ServerTestBase):
"""
handler = BadHandshakeHandler
@raises(tcp.NetLibDisconnect)
@raises(TcpDisconnect)
def test(self):
client = WebSocketsClient(("127.0.0.1", self.port))
client.connect()