mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-26 18:18:25 +00:00
100% test coverage 🎉
This commit is contained in:
parent
63fb433690
commit
da1eb94ccd
@ -39,6 +39,9 @@ class SSLKeyLogger(object):
|
||||
if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1:
|
||||
with self.lock:
|
||||
if not self.f:
|
||||
d = os.path.dirname(self.filename)
|
||||
if not os.path.isdir(d):
|
||||
os.makedirs(d)
|
||||
self.f = open(self.filename, "ab")
|
||||
self.f.write("\r\n")
|
||||
client_random = connection.client_random().encode("hex")
|
||||
@ -51,11 +54,13 @@ class SSLKeyLogger(object):
|
||||
if self.f:
|
||||
self.f.close()
|
||||
|
||||
_logfile = os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")
|
||||
if _logfile:
|
||||
log_ssl_key = SSLKeyLogger(_logfile)
|
||||
else:
|
||||
log_ssl_key = False
|
||||
@staticmethod
|
||||
def create_logfun(filename):
|
||||
if filename:
|
||||
return SSLKeyLogger(filename)
|
||||
return False
|
||||
|
||||
log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE"))
|
||||
|
||||
|
||||
class _FileLike:
|
||||
@ -161,9 +166,9 @@ class Reader(_FileLike):
|
||||
except SSL.SysCallError as e:
|
||||
if e.args == (-1, 'Unexpected EOF'):
|
||||
break
|
||||
raise NetLibDisconnect
|
||||
except SSL.Error, v:
|
||||
raise NetLibSSLError(v.message)
|
||||
raise NetLibSSLError(e.message)
|
||||
except SSL.Error as e:
|
||||
raise NetLibSSLError(e.message)
|
||||
self.first_byte_timestamp = self.first_byte_timestamp or time.time()
|
||||
if not data:
|
||||
break
|
||||
@ -179,10 +184,7 @@ class Reader(_FileLike):
|
||||
while True:
|
||||
if size is not None and bytes_read >= size:
|
||||
break
|
||||
try:
|
||||
ch = self.read(1)
|
||||
except NetLibDisconnect:
|
||||
break
|
||||
ch = self.read(1)
|
||||
bytes_read += 1
|
||||
if not ch:
|
||||
break
|
||||
|
@ -75,7 +75,8 @@ class TServer(tcp.TCPServer):
|
||||
handle_sni = getattr(h, "handle_sni", None),
|
||||
request_client_cert = self.ssl["request_client_cert"],
|
||||
cipher_list = self.ssl.get("cipher_list", None),
|
||||
dhparams = self.ssl.get("dhparams", None)
|
||||
dhparams = self.ssl.get("dhparams", None),
|
||||
chain_file = self.ssl.get("chain_file", None)
|
||||
)
|
||||
h.handle()
|
||||
h.finish()
|
||||
|
@ -80,7 +80,7 @@ class TestCertStore:
|
||||
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
||||
assert not ca1.default_ca.get_serial_number() == ca2.default_ca.get_serial_number()
|
||||
|
||||
dc = ca2.get_cert("foo.com", [])
|
||||
dc = ca2.get_cert("foo.com", ["sans.example.com"])
|
||||
dcp = os.path.join(d, "dc")
|
||||
f = open(dcp, "wb")
|
||||
f.write(dc[0].to_pem())
|
||||
@ -118,31 +118,34 @@ class TestSSLCert:
|
||||
def test_simple(self):
|
||||
with open(tutils.test_data.path("data/text_cert"), "rb") as f:
|
||||
d = f.read()
|
||||
c = certutils.SSLCert.from_pem(d)
|
||||
assert c.cn == "google.com"
|
||||
assert len(c.altnames) == 436
|
||||
c1 = certutils.SSLCert.from_pem(d)
|
||||
assert c1.cn == "google.com"
|
||||
assert len(c1.altnames) == 436
|
||||
|
||||
with open(tutils.test_data.path("data/text_cert_2"), "rb") as f:
|
||||
d = f.read()
|
||||
c = certutils.SSLCert.from_pem(d)
|
||||
assert c.cn == "www.inode.co.nz"
|
||||
assert len(c.altnames) == 2
|
||||
assert c.digest("sha1")
|
||||
assert c.notbefore
|
||||
assert c.notafter
|
||||
assert c.subject
|
||||
assert c.keyinfo == ("RSA", 2048)
|
||||
assert c.serial
|
||||
assert c.issuer
|
||||
assert c.to_pem()
|
||||
c.has_expired
|
||||
c2 = certutils.SSLCert.from_pem(d)
|
||||
assert c2.cn == "www.inode.co.nz"
|
||||
assert len(c2.altnames) == 2
|
||||
assert c2.digest("sha1")
|
||||
assert c2.notbefore
|
||||
assert c2.notafter
|
||||
assert c2.subject
|
||||
assert c2.keyinfo == ("RSA", 2048)
|
||||
assert c2.serial
|
||||
assert c2.issuer
|
||||
assert c2.to_pem()
|
||||
assert c2.has_expired is not None
|
||||
|
||||
assert not c1 == c2
|
||||
assert c1 != c2
|
||||
|
||||
def test_err_broken_sans(self):
|
||||
with open(tutils.test_data.path("data/text_cert_weird1"), "rb") as f:
|
||||
d = f.read()
|
||||
c = certutils.SSLCert.from_pem(d)
|
||||
# This breaks unless we ignore a decoding error.
|
||||
c.altnames
|
||||
assert c.altnames is not None
|
||||
|
||||
def test_der(self):
|
||||
with open(tutils.test_data.path("data/dercert"), "rb") as f:
|
||||
|
@ -325,6 +325,12 @@ def test_parse_url():
|
||||
assert po == 80
|
||||
assert pa == "/bar"
|
||||
|
||||
s, h, po, pa = http.parse_url("http://user:pass@foo/bar")
|
||||
assert s == "http"
|
||||
assert h == "foo"
|
||||
assert po == 80
|
||||
assert pa == "/bar"
|
||||
|
||||
s, h, po, pa = http.parse_url("http://foo")
|
||||
assert pa == "/"
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from cStringIO import StringIO
|
||||
import socket
|
||||
import mock
|
||||
from nose.plugins.skip import SkipTest
|
||||
from netlib import socks, tcp
|
||||
import tutils
|
||||
@ -81,4 +82,15 @@ def test_message_unknown_atyp():
|
||||
tutils.raises(socks.SocksError, socks.Message.from_file, raw)
|
||||
|
||||
m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050)))
|
||||
tutils.raises(socks.SocksError, m.to_file, StringIO())
|
||||
tutils.raises(socks.SocksError, m.to_file, StringIO())
|
||||
|
||||
def test_read():
|
||||
cs = StringIO("1234")
|
||||
assert socks._read(cs, 3) == "123"
|
||||
|
||||
cs = StringIO("123")
|
||||
tutils.raises(socks.SocksError, socks._read, cs, 4)
|
||||
|
||||
cs = mock.Mock()
|
||||
cs.read = mock.Mock(side_effect=socket.error)
|
||||
tutils.raises(socks.SocksError, socks._read, cs, 4)
|
103
test/test_tcp.py
103
test/test_tcp.py
@ -1,4 +1,5 @@
|
||||
import cStringIO, Queue, time, socket, random
|
||||
import os
|
||||
from netlib import tcp, certutils, test, certffi
|
||||
import mock
|
||||
import tutils
|
||||
@ -71,30 +72,6 @@ class TestServerIPv6(test.ServerTestBase):
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
|
||||
class FinishFailHandler(tcp.BaseHandler):
|
||||
def handle(self):
|
||||
v = self.rfile.readline()
|
||||
self.wfile.write(v)
|
||||
self.wfile.flush()
|
||||
self.wfile.close()
|
||||
self.rfile.close()
|
||||
self.close = mock.MagicMock(side_effect=socket.error)
|
||||
|
||||
|
||||
class TestFinishFail(test.ServerTestBase):
|
||||
"""
|
||||
This tests a difficult-to-trigger exception in the .finish() method of
|
||||
the handler.
|
||||
"""
|
||||
handler = FinishFailHandler
|
||||
def test_disconnect_in_finish(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write("foo\n")
|
||||
c.wfile.flush()
|
||||
c.rfile.read(4)
|
||||
|
||||
class TestDisconnect(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
def test_echo(self):
|
||||
@ -111,6 +88,20 @@ class HardDisconnectHandler(tcp.BaseHandler):
|
||||
self.connection.close()
|
||||
|
||||
|
||||
class TestFinishFail(test.ServerTestBase):
|
||||
"""
|
||||
This tests a difficult-to-trigger exception in the .finish() method of
|
||||
the handler.
|
||||
"""
|
||||
handler = EchoHandler
|
||||
def test_disconnect_in_finish(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write("foo\n")
|
||||
c.wfile.flush = mock.Mock(side_effect=tcp.NetLibDisconnect)
|
||||
c.finish()
|
||||
|
||||
|
||||
class TestServerSSL(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
@ -118,7 +109,8 @@ class TestServerSSL(test.ServerTestBase):
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = "AES256-SHA"
|
||||
cipher_list = "AES256-SHA",
|
||||
chain_file=tutils.test_data.path("data/server.crt")
|
||||
)
|
||||
def test_echo(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
@ -150,7 +142,7 @@ class TestSSLv3Only(test.ServerTestBase):
|
||||
def test_failure(self):
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com", method=tcp.TLSv1_METHOD)
|
||||
tutils.raises(tcp.NetLibError, c.convert_to_ssl, sni="foo.com")
|
||||
|
||||
|
||||
class TestSSLClientCert(test.ServerTestBase):
|
||||
@ -385,6 +377,11 @@ class TestDHParams(test.ServerTestBase):
|
||||
ret = c.get_current_cipher()
|
||||
assert ret[0] == "DHE-RSA-AES256-SHA"
|
||||
|
||||
def test_create_dhparams(self):
|
||||
with tutils.tmpdir() as d:
|
||||
filename = os.path.join(d, "dhparam.pem")
|
||||
certutils.CertStore.load_dhparam(filename)
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
class TestPrivkeyGen(test.ServerTestBase):
|
||||
@ -527,12 +524,22 @@ class TestFileLike:
|
||||
assert s.first_byte_timestamp == expected
|
||||
|
||||
def test_read_ssl_error(self):
|
||||
s = cStringIO.StringIO("foobar\nfoobar")
|
||||
s = mock.MagicMock()
|
||||
s.read = mock.MagicMock(side_effect=SSL.Error())
|
||||
s = tcp.Reader(s)
|
||||
tutils.raises(tcp.NetLibSSLError, s.read, 1)
|
||||
|
||||
def test_read_syscall_ssl_error(self):
|
||||
s = mock.MagicMock()
|
||||
s.read = mock.MagicMock(side_effect=SSL.SysCallError())
|
||||
s = tcp.Reader(s)
|
||||
tutils.raises(tcp.NetLibSSLError, s.read, 1)
|
||||
|
||||
def test_reader_readline_disconnect(self):
|
||||
o = mock.MagicMock()
|
||||
o.read = mock.MagicMock(side_effect=socket.error)
|
||||
s = tcp.Reader(o)
|
||||
tutils.raises(tcp.NetLibDisconnect, s.readline, 10)
|
||||
|
||||
class TestAddress:
|
||||
def test_simple(self):
|
||||
@ -542,3 +549,45 @@ class TestAddress:
|
||||
assert not a == b
|
||||
c = tcp.Address("localhost", True)
|
||||
assert a == c
|
||||
assert not a != c
|
||||
assert repr(a)
|
||||
|
||||
|
||||
class TestServer(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
def test_echo(self):
|
||||
testval = "echo!\n"
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.wfile.write(testval)
|
||||
c.wfile.flush()
|
||||
assert c.rfile.readline() == testval
|
||||
|
||||
class TestSSLKeyLogger(test.ServerTestBase):
|
||||
handler = EchoHandler
|
||||
ssl = dict(
|
||||
cert = tutils.test_data.path("data/server.crt"),
|
||||
key = tutils.test_data.path("data/server.key"),
|
||||
request_client_cert = False,
|
||||
v3_only = False,
|
||||
cipher_list = "AES256-SHA"
|
||||
)
|
||||
|
||||
def test_log(self):
|
||||
_logfun = tcp.log_ssl_key
|
||||
|
||||
with tutils.tmpdir() as d:
|
||||
logfile = os.path.join(d, "foo", "bar", "logfile")
|
||||
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
|
||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||
c.connect()
|
||||
c.convert_to_ssl()
|
||||
tcp.log_ssl_key.close()
|
||||
with open(logfile, "rb") as f:
|
||||
assert f.read().count("CLIENT_RANDOM") == 2
|
||||
|
||||
tcp.log_ssl_key = _logfun
|
||||
|
||||
def test_create_logfun(self):
|
||||
assert isinstance(tcp.SSLKeyLogger.create_logfun("test"), tcp.SSLKeyLogger)
|
||||
assert not tcp.SSLKeyLogger.create_logfun(False)
|
Loading…
Reference in New Issue
Block a user