mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-27 02:24:18 +00:00
add tcp.Address to unify ipv4/ipv6 address handling
This commit is contained in:
parent
71c1017575
commit
763cb90b66
@ -237,7 +237,7 @@ class SSLCert:
|
||||
|
||||
|
||||
def get_remote_cert(host, port, sni):
|
||||
c = tcp.TCPClient(host, port)
|
||||
c = tcp.TCPClient((host, port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(sni=sni)
|
||||
return c.cert
|
||||
|
@ -173,6 +173,35 @@ class Reader(_FileLike):
|
||||
return result
|
||||
|
||||
|
||||
class Address(tuple):
|
||||
"""
|
||||
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
|
||||
"""
|
||||
def __new__(cls, address, use_ipv6=False):
|
||||
a = super(Address, cls).__new__(cls, tuple(address))
|
||||
a.family = socket.AF_INET6 if use_ipv6 else socket.AF_INET
|
||||
return a
|
||||
|
||||
@classmethod
|
||||
def wrap(cls, t):
|
||||
if isinstance(t, cls):
|
||||
return t
|
||||
else:
|
||||
return cls(t)
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self[0]
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self[1]
|
||||
|
||||
@property
|
||||
def is_ipv6(self):
|
||||
return self.family == socket.AF_INET6
|
||||
|
||||
|
||||
class SocketCloseMixin:
|
||||
def finish(self):
|
||||
self.finished = True
|
||||
@ -209,10 +238,9 @@ class SocketCloseMixin:
|
||||
class TCPClient(SocketCloseMixin):
|
||||
rbufsize = -1
|
||||
wbufsize = -1
|
||||
def __init__(self, host, port, source_address=None, use_ipv6=False):
|
||||
self.host, self.port = host, port
|
||||
def __init__(self, address, source_address=None):
|
||||
self.address = Address.wrap(address)
|
||||
self.source_address = source_address
|
||||
self.use_ipv6 = use_ipv6
|
||||
self.connection, self.rfile, self.wfile = None, None, None
|
||||
self.cert = None
|
||||
self.ssl_established = False
|
||||
@ -245,14 +273,14 @@ class TCPClient(SocketCloseMixin):
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
|
||||
connection = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||
if self.source_address:
|
||||
connection.bind(self.source_address)
|
||||
connection.connect((self.host, self.port))
|
||||
connection.connect(self.address)
|
||||
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
||||
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
||||
except (socket.error, IOError), err:
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
||||
raise NetLibError('Error connecting to "%s": %s' % (self.address[0], err))
|
||||
self.connection = connection
|
||||
|
||||
def settimeout(self, n):
|
||||
@ -269,8 +297,9 @@ class BaseHandler(SocketCloseMixin):
|
||||
"""
|
||||
rbufsize = -1
|
||||
wbufsize = -1
|
||||
def __init__(self, connection):
|
||||
def __init__(self, connection, address):
|
||||
self.connection = connection
|
||||
self.address = Address.wrap(address)
|
||||
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
|
||||
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
|
||||
|
||||
@ -339,19 +368,18 @@ class BaseHandler(SocketCloseMixin):
|
||||
|
||||
class TCPServer:
|
||||
request_queue_size = 20
|
||||
def __init__(self, server_address, use_ipv6=False):
|
||||
self.server_address = server_address
|
||||
self.use_ipv6 = use_ipv6
|
||||
def __init__(self, address):
|
||||
self.address = Address.wrap(address)
|
||||
self.__is_shut_down = threading.Event()
|
||||
self.__shutdown_request = False
|
||||
self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.socket.bind(self.server_address)
|
||||
self.server_address = self.socket.getsockname()
|
||||
self.port = self.server_address[1]
|
||||
self.socket.bind(self.address)
|
||||
self.address = Address.wrap(self.socket.getsockname())
|
||||
self.socket.listen(self.request_queue_size)
|
||||
|
||||
def connection_thread(self, connection, client_address):
|
||||
client_address = Address(client_address)
|
||||
try:
|
||||
self.handle_client_connection(connection, client_address)
|
||||
except:
|
||||
|
@ -17,19 +17,18 @@ class ServerTestBase:
|
||||
ssl = None
|
||||
handler = None
|
||||
addr = ("localhost", 0)
|
||||
use_ipv6 = False
|
||||
|
||||
@classmethod
|
||||
def setupAll(cls):
|
||||
cls.q = Queue.Queue()
|
||||
s = cls.makeserver()
|
||||
cls.port = s.port
|
||||
cls.port = s.address.port
|
||||
cls.server = ServerThread(s)
|
||||
cls.server.start()
|
||||
|
||||
@classmethod
|
||||
def makeserver(cls):
|
||||
return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
|
||||
return TServer(cls.ssl, cls.q, cls.handler, cls.addr)
|
||||
|
||||
@classmethod
|
||||
def teardownAll(cls):
|
||||
@ -41,17 +40,17 @@ class ServerTestBase:
|
||||
|
||||
|
||||
class TServer(tcp.TCPServer):
|
||||
def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
|
||||
def __init__(self, ssl, q, handler_klass, addr):
|
||||
"""
|
||||
ssl: A {cert, key, v3_only} dict.
|
||||
"""
|
||||
tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
|
||||
tcp.TCPServer.__init__(self, addr)
|
||||
self.ssl, self.q = ssl, q
|
||||
self.handler_klass = handler_klass
|
||||
self.last_handler = None
|
||||
|
||||
def handle_client_connection(self, request, client_address):
|
||||
h = self.handler_klass(request)
|
||||
h = self.handler_klass(request, client_address)
|
||||
self.last_handler = h
|
||||
if self.ssl:
|
||||
cert = certutils.SSLCert.from_pem(
|
||||
|
@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
|
||||
handler = NoContentLengthHTTPHandler
|
||||
|
||||
def test_no_content_length(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
|
||||
assert content == "bar\r\n\r\n"
|
||||
|
@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase):
|
||||
for i in range(20):
|
||||
random_port = random.randrange(1024, 65535)
|
||||
try:
|
||||
c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port))
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port))
|
||||
c.connect()
|
||||
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
||||
return
|
||||
@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase):
|
||||
|
||||
class TestServerIPv6(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
use_ipv6 = True
|
||||
addr = tcp.Address(("localhost", 0), use_ipv6=True)
|
||||
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient("::1", self.port, use_ipv6=True)
|
||||
c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True))
|
||||
c.connect()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase):
|
||||
handler = FinishFailHandler
|
||||
def test_disconnect_in_finish(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write("foo\n")
|
||||
c.wfile.flush()
|
||||
@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase):
|
||||
v3_only = False
|
||||
)
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL)
|
||||
testval = "echo!\n"
|
||||
@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase):
|
||||
v3_only = True
|
||||
)
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
|
||||
|
||||
@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase):
|
||||
v3_only = False
|
||||
)
|
||||
def test_clientcert(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem"))
|
||||
assert c.rfile.readline().strip() == "1"
|
||||
|
||||
def test_clientcert_err(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises(
|
||||
tcp.NetLibError,
|
||||
@ -212,7 +212,7 @@ class TestSNI(test.ServerTestBase):
|
||||
v3_only = False
|
||||
)
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert c.rfile.readline() == "foo.com"
|
||||
@ -228,7 +228,7 @@ class TestClientCipherList(test.ServerTestBase):
|
||||
cipher_list = 'RC4-SHA'
|
||||
)
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl(sni="foo.com")
|
||||
assert c.rfile.readline() == "['RC4-SHA']"
|
||||
@ -243,7 +243,7 @@ class TestSSLDisconnect(test.ServerTestBase):
|
||||
v3_only = False
|
||||
)
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
# Excercise SSL.ZeroReturnError
|
||||
@ -255,7 +255,7 @@ class TestSSLDisconnect(test.ServerTestBase):
|
||||
|
||||
class TestDisconnect(test.ServerTestBase):
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.rfile.read(10)
|
||||
c.wfile.write("foo")
|
||||
@ -266,7 +266,7 @@ class TestDisconnect(test.ServerTestBase):
|
||||
class TestServerTimeOut(test.ServerTestBase):
|
||||
handler = TimeoutHandler
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
time.sleep(0.3)
|
||||
assert self.last_handler.timeout
|
||||
@ -275,7 +275,7 @@ class TestServerTimeOut(test.ServerTestBase):
|
||||
class TestTimeOut(test.ServerTestBase):
|
||||
handler = HangHandler
|
||||
def test_timeout(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.settimeout(0.1)
|
||||
assert c.gettimeout() == 0.1
|
||||
@ -291,7 +291,7 @@ class TestSSLTimeOut(test.ServerTestBase):
|
||||
v3_only = False
|
||||
)
|
||||
def test_timeout_client(self):
|
||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
c.settimeout(0.1)
|
||||
@ -300,7 +300,7 @@ class TestSSLTimeOut(test.ServerTestBase):
|
||||
|
||||
class TestTCPClient:
|
||||
def test_conerr(self):
|
||||
c = tcp.TCPClient("127.0.0.1", 0)
|
||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||
tutils.raises(tcp.NetLibError, c.connect)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user