From bda49dd178fee1361f3585bd7efad67883298e5a Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 1 Feb 2016 19:38:14 +0100 Subject: [PATCH] fix #113, make Reader.peek() work on Python 3 --- netlib/tcp.py | 30 +++++++++++++++++++++++++----- test/test_tcp.py | 2 +- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/netlib/tcp.py b/netlib/tcp.py index 8902b9dca..57a9b737f 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -25,6 +25,10 @@ from netlib.exceptions import InvalidCertificateException, TcpReadIncomplete, Tl version_check.check_pyopenssl_version() +if six.PY2: + socket_fileobject = socket._fileobject +else: + socket_fileobject = socket.SocketIO EINTR = 4 @@ -270,7 +274,7 @@ class Reader(_FileLike): TlsException if there was an error with pyOpenSSL. NotImplementedError if the underlying file object is not a (pyOpenSSL) socket """ - if isinstance(self.o, socket._fileobject): + if isinstance(self.o, socket_fileobject): try: return self.o._sock.recv(length, socket.MSG_PEEK) except socket.error as e: @@ -423,8 +427,17 @@ class _Connection(object): def __init__(self, connection): if connection: self.connection = connection - self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + # Ideally, we would use the Buffered IO in Python 3 by default. + # Unfortunately, the implementation of .peek() is broken for n>1 bytes, + # as it may just return what's left in the buffer and not all the bytes we want. + # As a workaround, we just use unbuffered sockets directly. + # https://mail.python.org/pipermail/python-dev/2009-June/089986.html + if six.PY2: + self.rfile = Reader(self.connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(self.connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(self.connection, "rb")) + self.wfile = Writer(socket.SocketIO(self.connection, "wb")) else: self.connection = None self.rfile = None @@ -663,8 +676,15 @@ class TCPClient(_Connection): connection.connect(self.address()) if not self.source_address: self.source_address = Address(connection.getsockname()) - self.rfile = Reader(connection.makefile('rb', self.rbufsize)) - self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + + # See _Connection.__init__ why we do this dance. + if six.PY2: + self.rfile = Reader(connection.makefile('rb', self.rbufsize)) + self.wfile = Writer(connection.makefile('wb', self.wbufsize)) + else: + self.rfile = Reader(socket.SocketIO(connection, "rb")) + self.wfile = Writer(socket.SocketIO(connection, "wb")) + except (socket.error, IOError) as err: raise TcpException( 'Error connecting to "%s": %s' % diff --git a/test/test_tcp.py b/test/test_tcp.py index a68bf1e6e..20a295ddd 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -723,7 +723,7 @@ class TestPeek(tservers.ServerTestBase): c.wfile.write(testval) c.wfile.flush() - assert c.rfile.peek(4) == "peek"[:4] + assert c.rfile.peek(4) == b"peek"[:4] assert c.rfile.peek(6) == testval