diff --git a/netlib/exceptions.py b/netlib/exceptions.py index e13af4734..e30235af2 100644 --- a/netlib/exceptions.py +++ b/netlib/exceptions.py @@ -16,7 +16,7 @@ class NetlibException(Exception): super(NetlibException, self).__init__(message) -class ReadDisconnect(object): +class Disconnect(object): """Immediate EOF""" @@ -24,9 +24,35 @@ class HttpException(NetlibException): pass -class HttpReadDisconnect(HttpException, ReadDisconnect): +class HttpReadDisconnect(HttpException, Disconnect): pass class HttpSyntaxException(HttpException): pass + + +class TcpException(NetlibException): + pass + + +class TcpDisconnect(TcpException, Disconnect): + pass + + + + +class TcpReadIncomplete(TcpException): + pass + + +class TcpTimeout(TcpException): + pass + + +class TlsException(NetlibException): + pass + + +class InvalidCertificateException(TlsException): + pass diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index ace25d796..33b9ef25d 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -95,9 +95,9 @@ def _assemble_response_headers(response, preserve_transfer_encoding=False): if not preserve_transfer_encoding: headers.pop(b"Transfer-Encoding", None) - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == b"": - headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + # If body is defined (i.e. not None or CONTENT_MISSING), + # we now need to set a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 62025d15a..7f2b7bab2 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -4,15 +4,14 @@ import sys import re from ... import utils -from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException, TcpDisconnect from .. import Request, Response, Headers -from netlib.tcp import NetLibDisconnect def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -51,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response @@ -215,7 +214,7 @@ def _get_first_line(rfile): if line == b"\r\n" or line == b"\n": # Possible leftover from previous message line = rfile.readline() - except NetLibDisconnect: + except TcpDisconnect: raise HttpReadDisconnect() if not line: raise HttpReadDisconnect() diff --git a/netlib/http/models.py b/netlib/http/models.py index 2d09535c8..b4446ecb1 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -231,7 +231,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end self.form_out = form_out or form_in @@ -452,6 +452,16 @@ class Request(object): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter @@ -488,7 +498,7 @@ class Response(object): self.status_code = status_code self.msg = msg self.headers = headers - self.body = body + self._body = body self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end @@ -551,6 +561,16 @@ class Response(object): ) self.headers.set_all("Set-Cookie", values) + @property + def body(self): + return self._body + + @body.setter + def body(self, body): + self._body = body + if isinstance(body, bytes): + self.headers["Content-Length"] = str(len(body)).encode() + @property def content(self): # pragma: no cover # TODO: remove deprecated getter diff --git a/netlib/tcp.py b/netlib/tcp.py index 1eb417b44..707e11e0f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -16,6 +16,9 @@ from . import certutils, version_check # This is a rather hackish way to make sure that # the latest version of pyOpenSSL is actually installed. +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException + version_check.check_pyopenssl_version() @@ -24,11 +27,17 @@ EINTR = 4 # To enable all SSL methods use: SSLv23 # then add options to disable certain methods # https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 +SSL_BASIC_OPTIONS = ( + SSL.OP_CIPHER_SERVER_PREFERENCE +) +if hasattr(SSL, "OP_NO_COMPRESSION"): + SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION + SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD SSL_DEFAULT_OPTIONS = ( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | - SSL.OP_CIPHER_SERVER_PREFERENCE + SSL_BASIC_OPTIONS ) if hasattr(SSL, "OP_NO_COMPRESSION"): SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION @@ -39,42 +48,17 @@ Don't ask... https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 """ sslversion_choices = { - "all": (SSL.SSLv23_METHOD, 0), + "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ # TLSv1_METHOD would be TLS 1.0 only - "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)), - "SSLv2": (SSL.SSLv2_METHOD, 0), - "SSLv3": (SSL.SSLv3_METHOD, 0), - "TLSv1": (SSL.TLSv1_METHOD, 0), - "TLSv1_1": (SSL.TLSv1_1_METHOD, 0), - "TLSv1_2": (SSL.TLSv1_2_METHOD, 0), + "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), + "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), + "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), + "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), + "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), } - -class NetLibError(Exception): - pass - - -class NetLibDisconnect(NetLibError): - pass - - -class NetLibIncomplete(NetLibError): - pass - - -class NetLibTimeout(NetLibError): - pass - - -class NetLibSSLError(NetLibError): - pass - - -class NetLibInvalidCertificateError(NetLibSSLError): - pass - - class SSLKeyLogger(object): def __init__(self, filename): @@ -168,17 +152,17 @@ class Writer(_FileLike): def flush(self): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if hasattr(self.o, "flush"): try: self.o.flush() except (socket.error, IOError) as v: - raise NetLibDisconnect(str(v)) + raise TcpDisconnect(str(v)) def write(self, v): """ - May raise NetLibDisconnect + May raise TcpDisconnect """ if v: self.first_byte_timestamp = self.first_byte_timestamp or time.time() @@ -191,7 +175,7 @@ class Writer(_FileLike): self.add_log(v[:r]) return r except (SSL.Error, socket.error) as e: - raise NetLibDisconnect(str(e)) + raise TcpDisconnect(str(e)) class Reader(_FileLike): @@ -210,23 +194,29 @@ class Reader(_FileLike): try: data = self.o.read(rlen) except SSL.ZeroReturnError: + # TLS connection was shut down cleanly break - except SSL.WantReadError: + except (SSL.WantWriteError, SSL.WantReadError): + # From the OpenSSL docs: + # If the underlying BIO is non-blocking, SSL_read() will also return when the + # underlying BIO could not satisfy the needs of SSL_read() to continue the + # operation. In this case a call to SSL_get_error with the return value of + # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. if (time.time() - start) < self.o.gettimeout(): time.sleep(0.1) continue else: - raise NetLibTimeout + raise TcpTimeout() except socket.timeout: - raise NetLibTimeout - except socket.error: - raise NetLibDisconnect + raise TcpTimeout() + except socket.error as e: + raise TcpDisconnect(str(e)) except SSL.SysCallError as e: if e.args == (-1, 'Unexpected EOF'): break - raise NetLibSSLError(e.message) + raise TlsException(e.message) except SSL.Error as e: - raise NetLibSSLError(e.message) + raise TlsException(e.message) self.first_byte_timestamp = self.first_byte_timestamp or time.time() if not data: break @@ -260,9 +250,9 @@ class Reader(_FileLike): result = self.read(length) if length != -1 and len(result) != length: if not result: - raise NetLibDisconnect() + raise TcpDisconnect() else: - raise NetLibIncomplete( + raise TcpReadIncomplete( "Expected %s bytes, got %s" % (length, len(result)) ) return result @@ -275,15 +265,15 @@ class Reader(_FileLike): Up to the next N bytes if peeking is successful. Raises: - NetLibError if there was an error with the socket - NetLibSSLError if there was an error with pyOpenSSL. + TcpException if there was an error with the socket + TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ if isinstance(self.o, socket._fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: - raise NetLibError(repr(e)) + raise TcpException(repr(e)) elif isinstance(self.o, SSL.Connection): try: if tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) > (0, 15): @@ -296,7 +286,7 @@ class Reader(_FileLike): self.o._raise_ssl_error(self.o._ssl, result) return SSL._ffi.buffer(buf, result)[:] except SSL.Error as e: - six.reraise(NetLibSSLError, NetLibSSLError(str(e)), sys.exc_info()[2]) + six.reraise(TlsException, TlsException(str(e)), sys.exc_info()[2]) else: raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") @@ -461,7 +451,7 @@ class _Connection(object): try: self.wfile.flush() self.wfile.close() - except NetLibDisconnect: + except TcpDisconnect: pass self.rfile.close() @@ -525,7 +515,7 @@ class _Connection(object): # TODO: maybe change this to with newer pyOpenSSL APIs context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: - raise NetLibError("SSL cipher specification error: %s" % str(v)) + raise TlsException("SSL cipher specification error: %s" % str(v)) # SSLKEYLOGFILE if log_ssl_key: @@ -546,7 +536,7 @@ class _Connection(object): elif alpn_select_callback is not None and alpn_select is None: context.set_alpn_select_callback(alpn_select_callback) elif alpn_select_callback is not None and alpn_select is not None: - raise NetLibError("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") + raise TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") return context @@ -594,7 +584,7 @@ class TCPClient(_Connection): context.use_privatekey_file(cert) context.use_certificate_file(cert) except SSL.Error as v: - raise NetLibError("SSL client certificate error: %s" % str(v)) + raise TlsException("SSL client certificate error: %s" % str(v)) return context def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): @@ -618,15 +608,15 @@ class TCPClient(_Connection): self.connection.do_handshake() except SSL.Error as v: if self.ssl_verification_error: - raise NetLibInvalidCertificateError("SSL handshake error: %s" % repr(v)) + raise InvalidCertificateException("SSL handshake error: %s" % repr(v)) else: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on # certificate validation failure verification_mode = sslctx_kwargs.get('verify_options', None) if self.ssl_verification_error is not None and verification_mode == SSL.VERIFY_PEER: - raise NetLibInvalidCertificateError("SSL handshake error: certificate verify failed") + raise InvalidCertificateException("SSL handshake error: certificate verify failed") self.ssl_established = True self.cert = certutils.SSLCert(self.connection.get_peer_certificate()) @@ -644,7 +634,7 @@ class TCPClient(_Connection): self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize)) except (socket.error, IOError) as err: - raise NetLibError( + raise TcpException( 'Error connecting to "%s": %s' % (self.address.host, err)) self.connection = connection @@ -750,7 +740,7 @@ class BaseHandler(_Connection): try: self.connection.do_handshake() except SSL.Error as v: - raise NetLibError("SSL handshake error: %s" % repr(v)) + raise TlsException("SSL handshake error: %s" % repr(v)) self.ssl_established = True self.rfile.set_descriptor(self.connection) self.wfile.set_descriptor(self.connection) diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index a369eb492..598b5cd7a 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -2,6 +2,7 @@ import OpenSSL import mock from netlib import tcp, http, tutils +from netlib.exceptions import TcpDisconnect from netlib.http import Headers from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * @@ -127,7 +128,7 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase): protocol.perform_server_connection_preface() assert protocol.connection_preface_performed - tutils.raises(tcp.NetLibDisconnect, protocol.perform_server_connection_preface, force=True) + tutils.raises(TcpDisconnect, protocol.perform_server_connection_preface, force=True) class TestPerformClientConnectionPreface(tservers.ServerTestBase): diff --git a/test/test_tcp.py b/test/test_tcp.py index 2a5deb2bd..615900ce9 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -12,6 +12,8 @@ import OpenSSL from netlib import tcp, certutils, tutils from . import tservers +from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, TlsException, \ + TcpTimeout, TcpDisconnect, TcpException class EchoHandler(tcp.BaseHandler): @@ -93,7 +95,7 @@ class TestServerBind(tservers.ServerTestBase): c.connect() assert c.rfile.readline() == str(("127.0.0.1", random_port)) return - except tcp.NetLibError: # port probably already in use + except TcpException: # port probably already in use pass @@ -140,7 +142,7 @@ class TestFinishFail(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() c.wfile.write("foo\n") - c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) + c.wfile.flush = mock.Mock(side_effect=TcpDisconnect) c.finish() @@ -180,7 +182,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -224,7 +226,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): c.connect() tutils.raises( - tcp.NetLibInvalidCertificateError, + InvalidCertificateException, c.convert_to_ssl, verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted.pem")) @@ -327,7 +329,7 @@ class TestSSLClientCert(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() tutils.raises( - tcp.NetLibError, + TlsException, c.convert_to_ssl, cert=tutils.test_data.path("data/clientcert/make") ) @@ -432,7 +434,7 @@ class TestSSLDisconnect(tservers.ServerTestBase): # Excercise SSL.ZeroReturnError c.rfile.read(10) c.close() - tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, "foo") tutils.raises(Queue.Empty, self.q.get_nowait) @@ -447,7 +449,7 @@ class TestSSLHardDisconnect(tservers.ServerTestBase): # Exercise SSL.SysCallError c.rfile.read(10) c.close() - tutils.raises(tcp.NetLibDisconnect, c.wfile.write, "foo") + tutils.raises(TcpDisconnect, c.wfile.write, "foo") class TestDisconnect(tservers.ServerTestBase): @@ -470,7 +472,7 @@ class TestServerTimeOut(tservers.ServerTestBase): self.settimeout(0.01) try: self.rfile.read(10) - except tcp.NetLibTimeout: + except TcpTimeout: self.timeout = True def test_timeout(self): @@ -488,7 +490,7 @@ class TestTimeOut(tservers.ServerTestBase): c.connect() c.settimeout(0.1) assert c.gettimeout() == 0.1 - tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + tutils.raises(TcpTimeout, c.rfile.read, 10) class TestALPNClient(tservers.ServerTestBase): @@ -540,7 +542,7 @@ class TestSSLTimeOut(tservers.ServerTestBase): c.connect() c.convert_to_ssl() c.settimeout(0.1) - tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + tutils.raises(TcpTimeout, c.rfile.read, 10) class TestDHParams(tservers.ServerTestBase): @@ -570,7 +572,7 @@ class TestTCPClient: def test_conerr(self): c = tcp.TCPClient(("127.0.0.1", 0)) - tutils.raises(tcp.NetLibError, c.connect) + tutils.raises(TcpException, c.connect) class TestFileLike: @@ -639,7 +641,7 @@ class TestFileLike: o = mock.MagicMock() o.flush = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(tcp.NetLibDisconnect, s.flush) + tutils.raises(TcpDisconnect, s.flush) def test_reader_read_error(self): s = cStringIO.StringIO("foobar\nfoobar") @@ -647,7 +649,7 @@ class TestFileLike: o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s.o = o - tutils.raises(tcp.NetLibDisconnect, s.read, 10) + tutils.raises(TcpDisconnect, s.read, 10) def test_reset_timestamps(self): s = cStringIO.StringIO("foobar\nfoobar") @@ -678,24 +680,24 @@ class TestFileLike: s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.Error()) s = tcp.Reader(s) - tutils.raises(tcp.NetLibSSLError, s.read, 1) + tutils.raises(TlsException, s.read, 1) def test_read_syscall_ssl_error(self): s = mock.MagicMock() s.read = mock.MagicMock(side_effect=SSL.SysCallError()) s = tcp.Reader(s) - tutils.raises(tcp.NetLibSSLError, s.read, 1) + tutils.raises(TlsException, s.read, 1) def test_reader_readline_disconnect(self): o = mock.MagicMock() o.read = mock.MagicMock(side_effect=socket.error) s = tcp.Reader(o) - tutils.raises(tcp.NetLibDisconnect, s.readline, 10) + tutils.raises(TcpDisconnect, s.readline, 10) def test_reader_incomplete_error(self): s = cStringIO.StringIO("foobar") s = tcp.Reader(s) - tutils.raises(tcp.NetLibIncomplete, s.safe_read, 10) + tutils.raises(TcpReadIncomplete, s.safe_read, 10) class TestAddress: diff --git a/test/test_utils.py b/test/test_utils.py index 8b2ddae4b..eb7aa31a8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -14,7 +14,7 @@ def test_hexdump(): assert utils.hexdump("one\0" * 10) -def test_cleanBin(): +def test_clean_bin(): assert utils.clean_bin(b"one") == b"one" assert utils.clean_bin(b"\00ne") == b".ne" assert utils.clean_bin(b"\nne") == b"\nne" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 3fdeb6839..3af5dc9c2 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -176,7 +176,7 @@ class TestBadHandshake(tservers.ServerTestBase): """ handler = BadHandshakeHandler - @raises(tcp.NetLibDisconnect) + @raises(TcpDisconnect) def test(self): client = WebSocketsClient(("127.0.0.1", self.port)) client.connect()