merge smurfix/ipv6, add ipv6 support for TCPServer, add ipv6 test

This commit is contained in:
Maximilian Hils 2013-12-13 15:04:38 +01:00
commit f2e8efdf15
4 changed files with 27 additions and 10 deletions

1
README
View File

@ -1,3 +1,4 @@
[![Build Status](https://travis-ci.org/mitmproxy/netlib.png)](https://travis-ci.org/mitmproxy/netlib) [![Coverage Status](https://coveralls.io/repos/mitmproxy/netlib/badge.png)](https://coveralls.io/r/mitmproxy/netlib)
Netlib is a collection of network utility classes, used by the pathod and Netlib is a collection of network utility classes, used by the pathod and
mitmproxy projects. It differs from other projects in some fundamental mitmproxy projects. It differs from other projects in some fundamental

View File

@ -176,12 +176,13 @@ class Reader(_FileLike):
class TCPClient: class TCPClient:
rbufsize = -1 rbufsize = -1
wbufsize = -1 wbufsize = -1
def __init__(self, host, port, source_address=None): def __init__(self, host, port, source_address=None, use_ipv6=False):
self.host, self.port = host, port self.host, self.port = host, port
self.source_address = source_address
self.use_ipv6 = use_ipv6
self.connection, self.rfile, self.wfile = None, None, None self.connection, self.rfile, self.wfile = None, None, None
self.cert = None self.cert = None
self.ssl_established = False self.ssl_established = False
self.source_address = source_address
def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None): def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None):
""" """
@ -211,11 +212,10 @@ class TCPClient:
def connect(self): def connect(self):
try: try:
addr = socket.gethostbyname(self.host) connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if self.source_address: if self.source_address:
connection.bind(self.source_address) connection.bind(self.source_address)
connection.connect((addr, self.port)) connection.connect((self.host, self.port))
self.rfile = Reader(connection.makefile('rb', self.rbufsize)) self.rfile = Reader(connection.makefile('rb', self.rbufsize))
self.wfile = Writer(connection.makefile('wb', self.wbufsize)) self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except (socket.error, IOError), err: except (socket.error, IOError), err:
@ -359,11 +359,12 @@ class BaseHandler:
class TCPServer: class TCPServer:
request_queue_size = 20 request_queue_size = 20
def __init__(self, server_address): def __init__(self, server_address, use_ipv6=False):
self.server_address = server_address self.server_address = server_address
self.use_ipv6 = use_ipv6
self.__is_shut_down = threading.Event() self.__is_shut_down = threading.Event()
self.__shutdown_request = False self.__shutdown_request = False
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(self.server_address) self.socket.bind(self.server_address)
self.server_address = self.socket.getsockname() self.server_address = self.socket.getsockname()

View File

@ -16,6 +16,9 @@ class ServerThread(threading.Thread):
class ServerTestBase: class ServerTestBase:
ssl = None ssl = None
handler = None handler = None
addr = ("localhost", 0)
use_ipv6 = False
@classmethod @classmethod
def setupAll(cls): def setupAll(cls):
cls.q = Queue.Queue() cls.q = Queue.Queue()
@ -26,7 +29,7 @@ class ServerTestBase:
@classmethod @classmethod
def makeserver(cls): def makeserver(cls):
return TServer(cls.ssl, cls.q, cls.handler) return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
@classmethod @classmethod
def teardownAll(cls): def teardownAll(cls):
@ -38,11 +41,11 @@ class ServerTestBase:
class TServer(tcp.TCPServer): class TServer(tcp.TCPServer):
def __init__(self, ssl, q, handler_klass, addr=("127.0.0.1", 0)): def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
""" """
ssl: A {cert, key, v3_only} dict. ssl: A {cert, key, v3_only} dict.
""" """
tcp.TCPServer.__init__(self, addr) tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
self.ssl, self.q = ssl, q self.ssl, self.q = ssl, q
self.handler_klass = handler_klass self.handler_klass = handler_klass
self.last_handler = None self.last_handler = None

View File

@ -74,6 +74,18 @@ class TestServer(test.ServerTestBase):
assert c.rfile.readline() == testval assert c.rfile.readline() == testval
class TestServerIPv6(test.ServerTestBase):
handler = EchoHandler
use_ipv6 = True
def test_echo(self):
testval = "echo!\n"
c = tcp.TCPClient("::1", self.port, use_ipv6=True)
c.connect()
c.wfile.write(testval)
c.wfile.flush()
assert c.rfile.readline() == testval
class FinishFailHandler(tcp.BaseHandler): class FinishFailHandler(tcp.BaseHandler):
def handle(self): def handle(self):