mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-12-04 04:37:15 +00:00
Merge branch 'tcp_proxy'
This commit is contained in:
commit
3d52d16e8d
@ -237,7 +237,7 @@ class SSLCert:
|
|||||||
|
|
||||||
|
|
||||||
def get_remote_cert(host, port, sni):
|
def get_remote_cert(host, port, sni):
|
||||||
c = tcp.TCPClient(host, port)
|
c = tcp.TCPClient((host, port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl(sni=sni)
|
c.convert_to_ssl(sni=sni)
|
||||||
return c.cert
|
return c.cert
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import re, copy
|
import re, copy
|
||||||
|
|
||||||
|
|
||||||
def safe_subn(pattern, repl, target, *args, **kwargs):
|
def safe_subn(pattern, repl, target, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
There are Unicode conversion problems with re.subn. We try to smooth
|
There are Unicode conversion problems with re.subn. We try to smooth
|
||||||
@ -98,6 +99,9 @@ class ODict:
|
|||||||
def _get_state(self):
|
def _get_state(self):
|
||||||
return [tuple(i) for i in self.lst]
|
return [tuple(i) for i in self.lst]
|
||||||
|
|
||||||
|
def _load_state(self, state):
|
||||||
|
self.list = [list(i) for i in state]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_state(klass, state):
|
def _from_state(klass, state):
|
||||||
return klass([list(i) for i in state])
|
return klass([list(i) for i in state])
|
||||||
|
232
netlib/tcp.py
232
netlib/tcp.py
@ -173,60 +173,57 @@ class Reader(_FileLike):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class TCPClient:
|
class Address(object):
|
||||||
rbufsize = -1
|
"""
|
||||||
wbufsize = -1
|
This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information.
|
||||||
def __init__(self, host, port, source_address=None, use_ipv6=False):
|
"""
|
||||||
self.host, self.port = host, port
|
def __init__(self, address, use_ipv6=False):
|
||||||
self.source_address = source_address
|
self.address = tuple(address)
|
||||||
self.use_ipv6 = use_ipv6
|
self.use_ipv6 = use_ipv6
|
||||||
self.connection, self.rfile, self.wfile = None, None, None
|
|
||||||
self.cert = None
|
|
||||||
self.ssl_established = False
|
|
||||||
|
|
||||||
def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None):
|
@classmethod
|
||||||
"""
|
def wrap(cls, t):
|
||||||
cert: Path to a file containing both client cert and private key.
|
if isinstance(t, cls):
|
||||||
"""
|
return t
|
||||||
context = SSL.Context(method)
|
else:
|
||||||
if options is not None:
|
return cls(t)
|
||||||
context.set_options(options)
|
|
||||||
if cert:
|
def __call__(self):
|
||||||
|
return self.address
|
||||||
|
|
||||||
|
@property
|
||||||
|
def host(self):
|
||||||
|
return self.address[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def port(self):
|
||||||
|
return self.address[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_ipv6(self):
|
||||||
|
return self.family == socket.AF_INET6
|
||||||
|
|
||||||
|
@use_ipv6.setter
|
||||||
|
def use_ipv6(self, b):
|
||||||
|
self.family = socket.AF_INET6 if b else socket.AF_INET
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
other = Address.wrap(other)
|
||||||
|
return (self.address, self.family) == (other.address, other.family)
|
||||||
|
|
||||||
|
|
||||||
|
class SocketCloseMixin(object):
|
||||||
|
def finish(self):
|
||||||
|
self.finished = True
|
||||||
try:
|
try:
|
||||||
context.use_privatekey_file(cert)
|
if not getattr(self.wfile, "closed", False):
|
||||||
context.use_certificate_file(cert)
|
self.wfile.flush()
|
||||||
except SSL.Error, v:
|
self.close()
|
||||||
raise NetLibError("SSL client certificate error: %s"%str(v))
|
self.wfile.close()
|
||||||
self.connection = SSL.Connection(context, self.connection)
|
self.rfile.close()
|
||||||
self.ssl_established = True
|
except (socket.error, NetLibDisconnect):
|
||||||
if sni:
|
# Remote has disconnected
|
||||||
self.connection.set_tlsext_host_name(sni)
|
pass
|
||||||
self.connection.set_connect_state()
|
|
||||||
try:
|
|
||||||
self.connection.do_handshake()
|
|
||||||
except SSL.Error, v:
|
|
||||||
raise NetLibError("SSL handshake error: %s"%str(v))
|
|
||||||
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
|
|
||||||
self.rfile.set_descriptor(self.connection)
|
|
||||||
self.wfile.set_descriptor(self.connection)
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
try:
|
|
||||||
connection = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
|
|
||||||
if self.source_address:
|
|
||||||
connection.bind(self.source_address)
|
|
||||||
connection.connect((self.host, self.port))
|
|
||||||
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
|
||||||
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
|
||||||
except (socket.error, IOError), err:
|
|
||||||
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
|
|
||||||
self.connection = connection
|
|
||||||
|
|
||||||
def settimeout(self, n):
|
|
||||||
self.connection.settimeout(n)
|
|
||||||
|
|
||||||
def gettimeout(self):
|
|
||||||
return self.connection.gettimeout()
|
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""
|
"""
|
||||||
@ -248,23 +245,80 @@ class TCPClient:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BaseHandler:
|
class TCPClient(SocketCloseMixin):
|
||||||
|
rbufsize = -1
|
||||||
|
wbufsize = -1
|
||||||
|
def __init__(self, address, source_address=None):
|
||||||
|
self.address = Address.wrap(address)
|
||||||
|
self.source_address = Address.wrap(source_address) if source_address else None
|
||||||
|
self.connection, self.rfile, self.wfile = None, None, None
|
||||||
|
self.cert = None
|
||||||
|
self.ssl_established = False
|
||||||
|
self.sni = None
|
||||||
|
|
||||||
|
def convert_to_ssl(self, cert=None, sni=None, method=TLSv1_METHOD, options=None):
|
||||||
|
"""
|
||||||
|
cert: Path to a file containing both client cert and private key.
|
||||||
|
"""
|
||||||
|
context = SSL.Context(method)
|
||||||
|
if options is not None:
|
||||||
|
context.set_options(options)
|
||||||
|
if cert:
|
||||||
|
try:
|
||||||
|
context.use_privatekey_file(cert)
|
||||||
|
context.use_certificate_file(cert)
|
||||||
|
except SSL.Error, v:
|
||||||
|
raise NetLibError("SSL client certificate error: %s"%str(v))
|
||||||
|
self.connection = SSL.Connection(context, self.connection)
|
||||||
|
self.ssl_established = True
|
||||||
|
if sni:
|
||||||
|
self.sni = sni
|
||||||
|
self.connection.set_tlsext_host_name(sni)
|
||||||
|
self.connection.set_connect_state()
|
||||||
|
try:
|
||||||
|
self.connection.do_handshake()
|
||||||
|
except SSL.Error, v:
|
||||||
|
raise NetLibError("SSL handshake error: %s"%str(v))
|
||||||
|
self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
|
||||||
|
self.rfile.set_descriptor(self.connection)
|
||||||
|
self.wfile.set_descriptor(self.connection)
|
||||||
|
|
||||||
|
def connect(self):
|
||||||
|
try:
|
||||||
|
connection = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||||
|
if self.source_address:
|
||||||
|
connection.bind(self.source_address())
|
||||||
|
connection.connect(self.address())
|
||||||
|
self.rfile = Reader(connection.makefile('rb', self.rbufsize))
|
||||||
|
self.wfile = Writer(connection.makefile('wb', self.wbufsize))
|
||||||
|
except (socket.error, IOError), err:
|
||||||
|
raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def settimeout(self, n):
|
||||||
|
self.connection.settimeout(n)
|
||||||
|
|
||||||
|
def gettimeout(self):
|
||||||
|
return self.connection.gettimeout()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseHandler(SocketCloseMixin):
|
||||||
"""
|
"""
|
||||||
The instantiator is expected to call the handle() and finish() methods.
|
The instantiator is expected to call the handle() and finish() methods.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
rbufsize = -1
|
rbufsize = -1
|
||||||
wbufsize = -1
|
wbufsize = -1
|
||||||
def __init__(self, connection, client_address, server):
|
|
||||||
|
def __init__(self, connection, address, server):
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
self.address = Address.wrap(address)
|
||||||
|
self.server = server
|
||||||
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
|
self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
|
||||||
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
|
self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
|
||||||
|
|
||||||
self.client_address = client_address
|
|
||||||
self.server = server
|
|
||||||
self.finished = False
|
self.finished = False
|
||||||
self.ssl_established = False
|
self.ssl_established = False
|
||||||
|
|
||||||
self.clientcert = None
|
self.clientcert = None
|
||||||
|
|
||||||
def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None):
|
def convert_to_ssl(self, cert, key, method=SSLv23_METHOD, options=None, handle_sni=None, request_client_cert=False, cipher_list=None):
|
||||||
@ -318,66 +372,34 @@ class BaseHandler:
|
|||||||
self.rfile.set_descriptor(self.connection)
|
self.rfile.set_descriptor(self.connection)
|
||||||
self.wfile.set_descriptor(self.connection)
|
self.wfile.set_descriptor(self.connection)
|
||||||
|
|
||||||
def finish(self):
|
|
||||||
self.finished = True
|
|
||||||
try:
|
|
||||||
if not getattr(self.wfile, "closed", False):
|
|
||||||
self.wfile.flush()
|
|
||||||
self.close()
|
|
||||||
self.wfile.close()
|
|
||||||
self.rfile.close()
|
|
||||||
except (socket.error, NetLibDisconnect):
|
|
||||||
# Remote has disconnected
|
|
||||||
pass
|
|
||||||
|
|
||||||
def handle(self): # pragma: no cover
|
def handle(self): # pragma: no cover
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def settimeout(self, n):
|
def settimeout(self, n):
|
||||||
self.connection.settimeout(n)
|
self.connection.settimeout(n)
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""
|
|
||||||
Does a hard close of the socket, i.e. a shutdown, followed by a close.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if self.ssl_established:
|
|
||||||
self.connection.shutdown()
|
|
||||||
self.connection.sock_shutdown(socket.SHUT_WR)
|
|
||||||
else:
|
|
||||||
self.connection.shutdown(socket.SHUT_WR)
|
|
||||||
# Section 4.2.2.13 of RFC 1122 tells us that a close() with any
|
|
||||||
# pending readable data could lead to an immediate RST being sent.
|
|
||||||
# http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
|
|
||||||
while self.connection.recv(4096):
|
|
||||||
pass
|
|
||||||
except (socket.error, SSL.Error):
|
|
||||||
# Socket probably already closed
|
|
||||||
pass
|
|
||||||
self.connection.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TCPServer:
|
class TCPServer:
|
||||||
request_queue_size = 20
|
request_queue_size = 20
|
||||||
def __init__(self, server_address, use_ipv6=False):
|
def __init__(self, address):
|
||||||
self.server_address = server_address
|
self.address = Address.wrap(address)
|
||||||
self.use_ipv6 = use_ipv6
|
|
||||||
self.__is_shut_down = threading.Event()
|
self.__is_shut_down = threading.Event()
|
||||||
self.__shutdown_request = False
|
self.__shutdown_request = False
|
||||||
self.socket = socket.socket(socket.AF_INET6 if self.use_ipv6 else socket.AF_INET, socket.SOCK_STREAM)
|
self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
|
||||||
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
self.socket.bind(self.server_address)
|
self.socket.bind(self.address())
|
||||||
self.server_address = self.socket.getsockname()
|
self.address = Address.wrap(self.socket.getsockname())
|
||||||
self.port = self.server_address[1]
|
|
||||||
self.socket.listen(self.request_queue_size)
|
self.socket.listen(self.request_queue_size)
|
||||||
|
|
||||||
def request_thread(self, request, client_address):
|
def connection_thread(self, connection, client_address):
|
||||||
|
client_address = Address(client_address)
|
||||||
try:
|
try:
|
||||||
self.handle_connection(request, client_address)
|
self.handle_client_connection(connection, client_address)
|
||||||
request.close()
|
|
||||||
except:
|
except:
|
||||||
self.handle_error(request, client_address)
|
self.handle_error(connection, client_address)
|
||||||
request.close()
|
finally:
|
||||||
|
connection.close()
|
||||||
|
|
||||||
def serve_forever(self, poll_interval=0.1):
|
def serve_forever(self, poll_interval=0.1):
|
||||||
self.__is_shut_down.clear()
|
self.__is_shut_down.clear()
|
||||||
@ -391,10 +413,10 @@ class TCPServer:
|
|||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
if self.socket in r:
|
if self.socket in r:
|
||||||
request, client_address = self.socket.accept()
|
connection, client_address = self.socket.accept()
|
||||||
t = threading.Thread(
|
t = threading.Thread(
|
||||||
target = self.request_thread,
|
target = self.connection_thread,
|
||||||
args = (request, client_address)
|
args = (connection, client_address)
|
||||||
)
|
)
|
||||||
t.setDaemon(1)
|
t.setDaemon(1)
|
||||||
t.start()
|
t.start()
|
||||||
@ -410,18 +432,18 @@ class TCPServer:
|
|||||||
|
|
||||||
def handle_error(self, request, client_address, fp=sys.stderr):
|
def handle_error(self, request, client_address, fp=sys.stderr):
|
||||||
"""
|
"""
|
||||||
Called when handle_connection raises an exception.
|
Called when handle_client_connection raises an exception.
|
||||||
"""
|
"""
|
||||||
# If a thread has persisted after interpreter exit, the module might be
|
# If a thread has persisted after interpreter exit, the module might be
|
||||||
# none.
|
# none.
|
||||||
if traceback:
|
if traceback:
|
||||||
exc = traceback.format_exc()
|
exc = traceback.format_exc()
|
||||||
print >> fp, '-'*40
|
print >> fp, '-'*40
|
||||||
print >> fp, "Error in processing of request from %s:%s"%client_address
|
print >> fp, "Error in processing of request from %s:%s" % (client_address.host, client_address.port)
|
||||||
print >> fp, exc
|
print >> fp, exc
|
||||||
print >> fp, '-'*40
|
print >> fp, '-'*40
|
||||||
|
|
||||||
def handle_connection(self, request, client_address): # pragma: no cover
|
def handle_client_connection(self, conn, client_address): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Called after client connection.
|
Called after client connection.
|
||||||
"""
|
"""
|
||||||
|
@ -17,19 +17,18 @@ class ServerTestBase:
|
|||||||
ssl = None
|
ssl = None
|
||||||
handler = None
|
handler = None
|
||||||
addr = ("localhost", 0)
|
addr = ("localhost", 0)
|
||||||
use_ipv6 = False
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setupAll(cls):
|
def setupAll(cls):
|
||||||
cls.q = Queue.Queue()
|
cls.q = Queue.Queue()
|
||||||
s = cls.makeserver()
|
s = cls.makeserver()
|
||||||
cls.port = s.port
|
cls.port = s.address.port
|
||||||
cls.server = ServerThread(s)
|
cls.server = ServerThread(s)
|
||||||
cls.server.start()
|
cls.server.start()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def makeserver(cls):
|
def makeserver(cls):
|
||||||
return TServer(cls.ssl, cls.q, cls.handler, cls.addr, cls.use_ipv6)
|
return TServer(cls.ssl, cls.q, cls.handler, cls.addr)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def teardownAll(cls):
|
def teardownAll(cls):
|
||||||
@ -41,16 +40,16 @@ class ServerTestBase:
|
|||||||
|
|
||||||
|
|
||||||
class TServer(tcp.TCPServer):
|
class TServer(tcp.TCPServer):
|
||||||
def __init__(self, ssl, q, handler_klass, addr, use_ipv6):
|
def __init__(self, ssl, q, handler_klass, addr):
|
||||||
"""
|
"""
|
||||||
ssl: A {cert, key, v3_only} dict.
|
ssl: A {cert, key, v3_only} dict.
|
||||||
"""
|
"""
|
||||||
tcp.TCPServer.__init__(self, addr, use_ipv6=use_ipv6)
|
tcp.TCPServer.__init__(self, addr)
|
||||||
self.ssl, self.q = ssl, q
|
self.ssl, self.q = ssl, q
|
||||||
self.handler_klass = handler_klass
|
self.handler_klass = handler_klass
|
||||||
self.last_handler = None
|
self.last_handler = None
|
||||||
|
|
||||||
def handle_connection(self, request, client_address):
|
def handle_client_connection(self, request, client_address):
|
||||||
h = self.handler_klass(request, client_address, self)
|
h = self.handler_klass(request, client_address, self)
|
||||||
self.last_handler = h
|
self.last_handler = h
|
||||||
if self.ssl:
|
if self.ssl:
|
||||||
|
@ -1,17 +1,22 @@
|
|||||||
import cStringIO, urllib, time, traceback
|
import cStringIO, urllib, time, traceback
|
||||||
import odict
|
import odict, tcp
|
||||||
|
|
||||||
|
|
||||||
class ClientConn:
|
class ClientConn:
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
self.address = address
|
self.address = tcp.Address.wrap(address)
|
||||||
|
|
||||||
|
|
||||||
|
class Flow:
|
||||||
|
def __init__(self, client_conn):
|
||||||
|
self.client_conn = client_conn
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
def __init__(self, client_conn, scheme, method, path, headers, content):
|
def __init__(self, client_conn, scheme, method, path, headers, content):
|
||||||
self.scheme, self.method, self.path = scheme, method, path
|
self.scheme, self.method, self.path = scheme, method, path
|
||||||
self.headers, self.content = headers, content
|
self.headers, self.content = headers, content
|
||||||
self.client_conn = client_conn
|
self.flow = Flow(client_conn)
|
||||||
|
|
||||||
|
|
||||||
def date_time_string():
|
def date_time_string():
|
||||||
@ -60,8 +65,8 @@ class WSGIAdaptor:
|
|||||||
'SERVER_PROTOCOL': "HTTP/1.1",
|
'SERVER_PROTOCOL': "HTTP/1.1",
|
||||||
}
|
}
|
||||||
environ.update(extra)
|
environ.update(extra)
|
||||||
if request.client_conn.address:
|
if request.flow.client_conn.address:
|
||||||
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address
|
environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.flow.client_conn.address()
|
||||||
|
|
||||||
for key, value in request.headers.items():
|
for key, value in request.headers.items():
|
||||||
key = 'HTTP_' + key.upper().replace('-', '_')
|
key = 'HTTP_' + key.upper().replace('-', '_')
|
||||||
|
@ -223,7 +223,7 @@ class TestReadResponseNoContentLength(test.ServerTestBase):
|
|||||||
handler = NoContentLengthHTTPHandler
|
handler = NoContentLengthHTTPHandler
|
||||||
|
|
||||||
def test_no_content_length(self):
|
def test_no_content_length(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
|
httpversion, code, msg, headers, content = http.read_response(c.rfile, "GET", None)
|
||||||
assert content == "bar\r\n\r\n"
|
assert content == "bar\r\n\r\n"
|
||||||
|
@ -73,7 +73,7 @@ class TestServer(test.ServerTestBase):
|
|||||||
handler = EchoHandler
|
handler = EchoHandler
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
testval = "echo!\n"
|
testval = "echo!\n"
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.wfile.write(testval)
|
c.wfile.write(testval)
|
||||||
c.wfile.flush()
|
c.wfile.flush()
|
||||||
@ -88,7 +88,7 @@ class TestServerBind(test.ServerTestBase):
|
|||||||
for i in range(20):
|
for i in range(20):
|
||||||
random_port = random.randrange(1024, 65535)
|
random_port = random.randrange(1024, 65535)
|
||||||
try:
|
try:
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port, source_address=("127.0.0.1", random_port))
|
c = tcp.TCPClient(("127.0.0.1", self.port), source_address=("127.0.0.1", random_port))
|
||||||
c.connect()
|
c.connect()
|
||||||
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
assert c.rfile.readline() == str(("127.0.0.1", random_port))
|
||||||
return
|
return
|
||||||
@ -98,11 +98,11 @@ class TestServerBind(test.ServerTestBase):
|
|||||||
|
|
||||||
class TestServerIPv6(test.ServerTestBase):
|
class TestServerIPv6(test.ServerTestBase):
|
||||||
handler = EchoHandler
|
handler = EchoHandler
|
||||||
use_ipv6 = True
|
addr = tcp.Address(("localhost", 0), use_ipv6=True)
|
||||||
|
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
testval = "echo!\n"
|
testval = "echo!\n"
|
||||||
c = tcp.TCPClient("::1", self.port, use_ipv6=True)
|
c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.wfile.write(testval)
|
c.wfile.write(testval)
|
||||||
c.wfile.flush()
|
c.wfile.flush()
|
||||||
@ -127,7 +127,7 @@ class TestFinishFail(test.ServerTestBase):
|
|||||||
handler = FinishFailHandler
|
handler = FinishFailHandler
|
||||||
def test_disconnect_in_finish(self):
|
def test_disconnect_in_finish(self):
|
||||||
testval = "echo!\n"
|
testval = "echo!\n"
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.wfile.write("foo\n")
|
c.wfile.write("foo\n")
|
||||||
c.wfile.flush()
|
c.wfile.flush()
|
||||||
@ -137,7 +137,7 @@ class TestDisconnect(test.ServerTestBase):
|
|||||||
handler = EchoHandler
|
handler = EchoHandler
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
testval = "echo!\n"
|
testval = "echo!\n"
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.wfile.write(testval)
|
c.wfile.write(testval)
|
||||||
c.wfile.flush()
|
c.wfile.flush()
|
||||||
@ -153,7 +153,7 @@ class TestServerSSL(test.ServerTestBase):
|
|||||||
v3_only = False
|
v3_only = False
|
||||||
)
|
)
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL)
|
c.convert_to_ssl(sni="foo.com", options=tcp.OP_ALL)
|
||||||
testval = "echo!\n"
|
testval = "echo!\n"
|
||||||
@ -174,7 +174,7 @@ class TestSSLv3Only(test.ServerTestBase):
|
|||||||
v3_only = True
|
v3_only = True
|
||||||
)
|
)
|
||||||
def test_failure(self):
|
def test_failure(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
|
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
|
||||||
|
|
||||||
@ -188,13 +188,13 @@ class TestSSLClientCert(test.ServerTestBase):
|
|||||||
v3_only = False
|
v3_only = False
|
||||||
)
|
)
|
||||||
def test_clientcert(self):
|
def test_clientcert(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem"))
|
c.convert_to_ssl(cert=tutils.test_data.path("data/clientcert/client.pem"))
|
||||||
assert c.rfile.readline().strip() == "1"
|
assert c.rfile.readline().strip() == "1"
|
||||||
|
|
||||||
def test_clientcert_err(self):
|
def test_clientcert_err(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
tutils.raises(
|
tutils.raises(
|
||||||
tcp.NetLibError,
|
tcp.NetLibError,
|
||||||
@ -212,9 +212,10 @@ class TestSNI(test.ServerTestBase):
|
|||||||
v3_only = False
|
v3_only = False
|
||||||
)
|
)
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl(sni="foo.com")
|
c.convert_to_ssl(sni="foo.com")
|
||||||
|
assert c.sni == "foo.com"
|
||||||
assert c.rfile.readline() == "foo.com"
|
assert c.rfile.readline() == "foo.com"
|
||||||
|
|
||||||
|
|
||||||
@ -228,7 +229,7 @@ class TestClientCipherList(test.ServerTestBase):
|
|||||||
cipher_list = 'RC4-SHA'
|
cipher_list = 'RC4-SHA'
|
||||||
)
|
)
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl(sni="foo.com")
|
c.convert_to_ssl(sni="foo.com")
|
||||||
assert c.rfile.readline() == "['RC4-SHA']"
|
assert c.rfile.readline() == "['RC4-SHA']"
|
||||||
@ -243,7 +244,7 @@ class TestSSLDisconnect(test.ServerTestBase):
|
|||||||
v3_only = False
|
v3_only = False
|
||||||
)
|
)
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl()
|
c.convert_to_ssl()
|
||||||
# Excercise SSL.ZeroReturnError
|
# Excercise SSL.ZeroReturnError
|
||||||
@ -255,7 +256,7 @@ class TestSSLDisconnect(test.ServerTestBase):
|
|||||||
|
|
||||||
class TestDisconnect(test.ServerTestBase):
|
class TestDisconnect(test.ServerTestBase):
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.rfile.read(10)
|
c.rfile.read(10)
|
||||||
c.wfile.write("foo")
|
c.wfile.write("foo")
|
||||||
@ -266,7 +267,7 @@ class TestDisconnect(test.ServerTestBase):
|
|||||||
class TestServerTimeOut(test.ServerTestBase):
|
class TestServerTimeOut(test.ServerTestBase):
|
||||||
handler = TimeoutHandler
|
handler = TimeoutHandler
|
||||||
def test_timeout(self):
|
def test_timeout(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
assert self.last_handler.timeout
|
assert self.last_handler.timeout
|
||||||
@ -275,7 +276,7 @@ class TestServerTimeOut(test.ServerTestBase):
|
|||||||
class TestTimeOut(test.ServerTestBase):
|
class TestTimeOut(test.ServerTestBase):
|
||||||
handler = HangHandler
|
handler = HangHandler
|
||||||
def test_timeout(self):
|
def test_timeout(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.settimeout(0.1)
|
c.settimeout(0.1)
|
||||||
assert c.gettimeout() == 0.1
|
assert c.gettimeout() == 0.1
|
||||||
@ -291,7 +292,7 @@ class TestSSLTimeOut(test.ServerTestBase):
|
|||||||
v3_only = False
|
v3_only = False
|
||||||
)
|
)
|
||||||
def test_timeout_client(self):
|
def test_timeout_client(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", self.port)
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
c.convert_to_ssl()
|
c.convert_to_ssl()
|
||||||
c.settimeout(0.1)
|
c.settimeout(0.1)
|
||||||
@ -300,7 +301,7 @@ class TestSSLTimeOut(test.ServerTestBase):
|
|||||||
|
|
||||||
class TestTCPClient:
|
class TestTCPClient:
|
||||||
def test_conerr(self):
|
def test_conerr(self):
|
||||||
c = tcp.TCPClient("127.0.0.1", 0)
|
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||||
tutils.raises(tcp.NetLibError, c.connect)
|
tutils.raises(tcp.NetLibError, c.connect)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user