mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
Certificate flags
This commit is contained in:
parent
2a12aa3c47
commit
f5cc63d653
1
.gitignore
vendored
1
.gitignore
vendored
@ -8,3 +8,4 @@ MANIFEST
|
|||||||
*.swo
|
*.swo
|
||||||
.coverage
|
.coverage
|
||||||
.idea
|
.idea
|
||||||
|
__pycache__
|
||||||
|
36
netlib/certffi.py
Normal file
36
netlib/certffi.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import cffi
|
||||||
|
import OpenSSL
|
||||||
|
xffi = cffi.FFI()
|
||||||
|
xffi.cdef ("""
|
||||||
|
struct rsa_meth_st {
|
||||||
|
int flags;
|
||||||
|
...;
|
||||||
|
};
|
||||||
|
struct rsa_st {
|
||||||
|
int pad;
|
||||||
|
long version;
|
||||||
|
struct rsa_meth_st *meth;
|
||||||
|
...;
|
||||||
|
};
|
||||||
|
""")
|
||||||
|
xffi.verify(
|
||||||
|
"""#include <openssl/rsa.h>""",
|
||||||
|
extra_compile_args=['-w']
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle(privkey):
|
||||||
|
new = xffi.new("struct rsa_st*")
|
||||||
|
newbuf = xffi.buffer(new)
|
||||||
|
rsa = OpenSSL.SSL._lib.EVP_PKEY_get1_RSA(privkey._pkey)
|
||||||
|
oldbuf = OpenSSL.SSL._ffi.buffer(rsa)
|
||||||
|
newbuf[:] = oldbuf[:]
|
||||||
|
return new
|
||||||
|
|
||||||
|
def set_flags(privkey, val):
|
||||||
|
hdl = handle(privkey)
|
||||||
|
hdl.meth.flags = val
|
||||||
|
return privkey
|
||||||
|
|
||||||
|
def get_flags(privkey):
|
||||||
|
hdl = handle(privkey)
|
||||||
|
return hdl.meth.flags
|
@ -111,6 +111,7 @@ class DNTree:
|
|||||||
return current.value
|
return current.value
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CertStore:
|
class CertStore:
|
||||||
"""
|
"""
|
||||||
Implements an in-memory certificate store.
|
Implements an in-memory certificate store.
|
||||||
@ -222,6 +223,11 @@ class CertStore:
|
|||||||
c = (c, None)
|
c = (c, None)
|
||||||
return (c[0], c[1] or self.privkey)
|
return (c[0], c[1] or self.privkey)
|
||||||
|
|
||||||
|
def gen_pkey(self, cert):
|
||||||
|
import certffi
|
||||||
|
certffi.set_flags(self.privkey, 1)
|
||||||
|
return self.privkey
|
||||||
|
|
||||||
|
|
||||||
class _GeneralName(univ.Choice):
|
class _GeneralName(univ.Choice):
|
||||||
# We are only interested in dNSNames. We use a default handler to ignore
|
# We are only interested in dNSNames. We use a default handler to ignore
|
||||||
@ -326,6 +332,7 @@ class SSLCert:
|
|||||||
return altnames
|
return altnames
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_remote_cert(host, port, sni):
|
def get_remote_cert(host, port, sni):
|
||||||
c = tcp.TCPClient((host, port))
|
c = tcp.TCPClient((host, port))
|
||||||
c.connect()
|
c.connect()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from netlib import certutils
|
from netlib import certutils, certffi
|
||||||
import OpenSSL
|
import OpenSSL
|
||||||
import tutils
|
import tutils
|
||||||
|
|
||||||
@ -83,6 +83,16 @@ class TestCertStore:
|
|||||||
ret = ca1.get_cert("foo.com", [])
|
ret = ca1.get_cert("foo.com", [])
|
||||||
assert ret[0].serial == dc[0].serial
|
assert ret[0].serial == dc[0].serial
|
||||||
|
|
||||||
|
def test_gen_pkey(self):
|
||||||
|
try:
|
||||||
|
with tutils.tmpdir() as d:
|
||||||
|
ca1 = certutils.CertStore.from_store(os.path.join(d, "ca1"), "test")
|
||||||
|
ca2 = certutils.CertStore.from_store(os.path.join(d, "ca2"), "test")
|
||||||
|
cert = ca1.get_cert("foo.com", [])
|
||||||
|
assert certffi.get_flags(ca2.gen_pkey(cert[0])) == 1
|
||||||
|
finally:
|
||||||
|
certffi.set_flags(ca2.privkey, 0)
|
||||||
|
|
||||||
|
|
||||||
class TestDummyCert:
|
class TestDummyCert:
|
||||||
def test_with_ca(self):
|
def test_with_ca(self):
|
||||||
@ -125,3 +135,5 @@ class TestSSLCert:
|
|||||||
d = file(tutils.test_data.path("data/dercert"),"rb").read()
|
d = file(tutils.test_data.path("data/dercert"),"rb").read()
|
||||||
s = certutils.SSLCert.from_der(d)
|
s = certutils.SSLCert.from_der(d)
|
||||||
assert s.cn
|
assert s.cn
|
||||||
|
|
||||||
|
|
||||||
|
127
test/test_tcp.py
127
test/test_tcp.py
@ -4,16 +4,6 @@ import mock
|
|||||||
import tutils
|
import tutils
|
||||||
from OpenSSL import SSL
|
from OpenSSL import SSL
|
||||||
|
|
||||||
class SNIHandler(tcp.BaseHandler):
|
|
||||||
sni = None
|
|
||||||
def handle_sni(self, connection):
|
|
||||||
self.sni = connection.get_servername()
|
|
||||||
|
|
||||||
def handle(self):
|
|
||||||
self.wfile.write(self.sni)
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class EchoHandler(tcp.BaseHandler):
|
class EchoHandler(tcp.BaseHandler):
|
||||||
sni = None
|
sni = None
|
||||||
def handle_sni(self, connection):
|
def handle_sni(self, connection):
|
||||||
@ -25,58 +15,19 @@ class EchoHandler(tcp.BaseHandler):
|
|||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
class ClientPeernameHandler(tcp.BaseHandler):
|
|
||||||
def handle(self):
|
|
||||||
self.wfile.write(str(self.connection.getpeername()))
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class CertHandler(tcp.BaseHandler):
|
|
||||||
sni = None
|
|
||||||
def handle_sni(self, connection):
|
|
||||||
self.sni = connection.get_servername()
|
|
||||||
|
|
||||||
def handle(self):
|
|
||||||
self.wfile.write("%s\n"%self.clientcert.serial)
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class ClientCipherListHandler(tcp.BaseHandler):
|
class ClientCipherListHandler(tcp.BaseHandler):
|
||||||
sni = None
|
sni = None
|
||||||
|
|
||||||
def handle(self):
|
def handle(self):
|
||||||
self.wfile.write("%s"%self.connection.get_cipher_list())
|
self.wfile.write("%s"%self.connection.get_cipher_list())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
class CurrentCipherHandler(tcp.BaseHandler):
|
|
||||||
sni = None
|
|
||||||
def handle(self):
|
|
||||||
self.wfile.write("%s"%str(self.get_current_cipher()))
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
|
|
||||||
class DisconnectHandler(tcp.BaseHandler):
|
|
||||||
def handle(self):
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class HangHandler(tcp.BaseHandler):
|
class HangHandler(tcp.BaseHandler):
|
||||||
def handle(self):
|
def handle(self):
|
||||||
while 1:
|
while 1:
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
class TimeoutHandler(tcp.BaseHandler):
|
|
||||||
def handle(self):
|
|
||||||
self.timeout = False
|
|
||||||
self.settimeout(0.01)
|
|
||||||
try:
|
|
||||||
self.rfile.read(10)
|
|
||||||
except tcp.NetLibTimeout:
|
|
||||||
self.timeout = True
|
|
||||||
|
|
||||||
|
|
||||||
class TestServer(test.ServerTestBase):
|
class TestServer(test.ServerTestBase):
|
||||||
handler = EchoHandler
|
handler = EchoHandler
|
||||||
def test_echo(self):
|
def test_echo(self):
|
||||||
@ -89,7 +40,10 @@ class TestServer(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestServerBind(test.ServerTestBase):
|
class TestServerBind(test.ServerTestBase):
|
||||||
handler = ClientPeernameHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
self.wfile.write(str(self.connection.getpeername()))
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
def test_bind(self):
|
def test_bind(self):
|
||||||
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
|
""" Test to bind to a given random port. Try again if the random port turned out to be blocked. """
|
||||||
@ -198,7 +152,14 @@ class TestSSLv3Only(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestSSLClientCert(test.ServerTestBase):
|
class TestSSLClientCert(test.ServerTestBase):
|
||||||
handler = CertHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
sni = None
|
||||||
|
def handle_sni(self, connection):
|
||||||
|
self.sni = connection.get_servername()
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
self.wfile.write("%s\n"%self.clientcert.serial)
|
||||||
|
self.wfile.flush()
|
||||||
ssl = dict(
|
ssl = dict(
|
||||||
cert = tutils.test_data.path("data/server.crt"),
|
cert = tutils.test_data.path("data/server.crt"),
|
||||||
key = tutils.test_data.path("data/server.key"),
|
key = tutils.test_data.path("data/server.key"),
|
||||||
@ -222,7 +183,15 @@ class TestSSLClientCert(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestSNI(test.ServerTestBase):
|
class TestSNI(test.ServerTestBase):
|
||||||
handler = SNIHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
sni = None
|
||||||
|
def handle_sni(self, connection):
|
||||||
|
self.sni = connection.get_servername()
|
||||||
|
|
||||||
|
def handle(self):
|
||||||
|
self.wfile.write(self.sni)
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
ssl = dict(
|
ssl = dict(
|
||||||
cert = tutils.test_data.path("data/server.crt"),
|
cert = tutils.test_data.path("data/server.crt"),
|
||||||
key = tutils.test_data.path("data/server.key"),
|
key = tutils.test_data.path("data/server.key"),
|
||||||
@ -254,7 +223,11 @@ class TestServerCipherList(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestServerCurrentCipher(test.ServerTestBase):
|
class TestServerCurrentCipher(test.ServerTestBase):
|
||||||
handler = CurrentCipherHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
sni = None
|
||||||
|
def handle(self):
|
||||||
|
self.wfile.write("%s"%str(self.get_current_cipher()))
|
||||||
|
self.wfile.flush()
|
||||||
ssl = dict(
|
ssl = dict(
|
||||||
cert = tutils.test_data.path("data/server.crt"),
|
cert = tutils.test_data.path("data/server.crt"),
|
||||||
key = tutils.test_data.path("data/server.key"),
|
key = tutils.test_data.path("data/server.key"),
|
||||||
@ -300,7 +273,9 @@ class TestClientCipherListError(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestSSLDisconnect(test.ServerTestBase):
|
class TestSSLDisconnect(test.ServerTestBase):
|
||||||
handler = DisconnectHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
self.close()
|
||||||
ssl = dict(
|
ssl = dict(
|
||||||
cert = tutils.test_data.path("data/server.crt"),
|
cert = tutils.test_data.path("data/server.crt"),
|
||||||
key = tutils.test_data.path("data/server.key"),
|
key = tutils.test_data.path("data/server.key"),
|
||||||
@ -329,7 +304,15 @@ class TestDisconnect(test.ServerTestBase):
|
|||||||
|
|
||||||
|
|
||||||
class TestServerTimeOut(test.ServerTestBase):
|
class TestServerTimeOut(test.ServerTestBase):
|
||||||
handler = TimeoutHandler
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
self.timeout = False
|
||||||
|
self.settimeout(0.01)
|
||||||
|
try:
|
||||||
|
self.rfile.read(10)
|
||||||
|
except tcp.NetLibTimeout:
|
||||||
|
self.timeout = True
|
||||||
|
|
||||||
def test_timeout(self):
|
def test_timeout(self):
|
||||||
c = tcp.TCPClient(("127.0.0.1", self.port))
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
c.connect()
|
c.connect()
|
||||||
@ -383,6 +366,40 @@ class TestDHParams(test.ServerTestBase):
|
|||||||
assert ret[0] == "DHE-RSA-AES256-SHA"
|
assert ret[0] == "DHE-RSA-AES256-SHA"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrivkeyGen(test.ServerTestBase):
|
||||||
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
with tutils.tmpdir() as d:
|
||||||
|
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||||
|
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||||
|
cert, _ = ca1.get_cert("foo.com", [])
|
||||||
|
key = ca2.gen_pkey(cert)
|
||||||
|
self.convert_to_ssl(cert, key)
|
||||||
|
|
||||||
|
def test_privkey(self):
|
||||||
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
|
c.connect()
|
||||||
|
tutils.raises("bad record mac", c.convert_to_ssl)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPrivkeyGenNoFlags(test.ServerTestBase):
|
||||||
|
class handler(tcp.BaseHandler):
|
||||||
|
def handle(self):
|
||||||
|
with tutils.tmpdir() as d:
|
||||||
|
ca1 = certutils.CertStore.from_store(d, "test2")
|
||||||
|
ca2 = certutils.CertStore.from_store(d, "test3")
|
||||||
|
cert, _ = ca1.get_cert("foo.com", [])
|
||||||
|
certffi.set_flags(ca2.privkey, 0)
|
||||||
|
self.convert_to_ssl(cert, ca2.privkey)
|
||||||
|
|
||||||
|
def test_privkey(self):
|
||||||
|
c = tcp.TCPClient(("127.0.0.1", self.port))
|
||||||
|
c.connect()
|
||||||
|
tutils.raises("unexpected eof", c.convert_to_ssl)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestTCPClient:
|
class TestTCPClient:
|
||||||
def test_conerr(self):
|
def test_conerr(self):
|
||||||
c = tcp.TCPClient(("127.0.0.1", 0))
|
c = tcp.TCPClient(("127.0.0.1", 0))
|
||||||
|
Loading…
Reference in New Issue
Block a user