mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
Significantly simplify server connection handling, and test.
This commit is contained in:
parent
705559d65e
commit
0257815141
@ -51,21 +51,22 @@ class ProxyConfig:
|
||||
|
||||
|
||||
class ServerConnection(tcp.TCPClient):
|
||||
def __init__(self, config, host, port):
|
||||
def __init__(self, config, scheme, host, port, sni):
|
||||
tcp.TCPClient.__init__(self, host, port)
|
||||
self.config = config
|
||||
self.scheme, self.sni = scheme, sni
|
||||
self.requestcount = 0
|
||||
|
||||
def connect(self, scheme, sni):
|
||||
def connect(self):
|
||||
tcp.TCPClient.connect(self)
|
||||
if scheme == "https":
|
||||
if self.scheme == "https":
|
||||
clientcert = None
|
||||
if self.config.clientcerts:
|
||||
path = os.path.join(self.config.clientcerts, self.host.encode("idna")) + ".pem"
|
||||
if os.path.exists(path):
|
||||
clientcert = path
|
||||
try:
|
||||
self.convert_to_ssl(cert=clientcert, sni=sni)
|
||||
self.convert_to_ssl(cert=clientcert, sni=self.sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(400, str(v))
|
||||
|
||||
@ -94,8 +95,8 @@ class RequestReplayThread(threading.Thread):
|
||||
def run(self):
|
||||
try:
|
||||
r = self.flow.request
|
||||
server = ServerConnection(self.config, r.host, r.port)
|
||||
server.connect(r.scheme, r.host)
|
||||
server = ServerConnection(self.config, r.scheme, r.host, r.port, r.host)
|
||||
server.connect()
|
||||
server.send(r)
|
||||
httpversion, code, msg, headers, content = http.read_response(
|
||||
server.rfile, r.method, self.config.body_size_limit
|
||||
@ -109,37 +110,40 @@ class RequestReplayThread(threading.Thread):
|
||||
self.channel.ask(err)
|
||||
|
||||
|
||||
class ServerConnectionPool:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
|
||||
def get_connection(self, scheme, host, port, sni):
|
||||
sc = self.conn
|
||||
if self.conn and (host, port) != (sc.host, sc.port):
|
||||
sc.terminate()
|
||||
self.conn = None
|
||||
if not self.conn:
|
||||
try:
|
||||
self.conn = ServerConnection(self.config, host, port)
|
||||
self.conn.connect(scheme, sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
return self.conn
|
||||
|
||||
def del_connection(self, scheme, host, port):
|
||||
self.conn = None
|
||||
|
||||
|
||||
class ProxyHandler(tcp.BaseHandler):
|
||||
def __init__(self, config, connection, client_address, server, channel, server_version):
|
||||
self.channel, self.server_version = channel, server_version
|
||||
self.config = config
|
||||
self.server_conn_pool = ServerConnectionPool(config)
|
||||
self.proxy_connect_state = None
|
||||
self.sni = None
|
||||
self.server_conn = None
|
||||
tcp.BaseHandler.__init__(self, connection, client_address, server)
|
||||
|
||||
def get_server_connection(self, cc, scheme, host, port, sni):
|
||||
sc = self.server_conn
|
||||
if sc and (scheme, host, port, sni) != (sc.scheme, sc.host, sc.port, sc.sni):
|
||||
sc.terminate()
|
||||
self.server_conn = None
|
||||
self.log(
|
||||
cc,
|
||||
"switching connection", [
|
||||
"%s://%s:%s (sni=%s) -> %s://%s:%s (sni=%s)"%(
|
||||
scheme, host, port, sni,
|
||||
sc.scheme, sc.host, sc.port, sc.sni
|
||||
)
|
||||
]
|
||||
)
|
||||
if not self.server_conn:
|
||||
try:
|
||||
self.server_conn = ServerConnection(self.config, scheme, host, port, sni)
|
||||
self.server_conn.connect()
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
return self.server_conn
|
||||
|
||||
def del_server_connection(self):
|
||||
self.server_conn = None
|
||||
|
||||
def handle(self):
|
||||
cc = flow.ClientConnect(self.client_address)
|
||||
self.log(cc, "connect")
|
||||
@ -190,7 +194,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
# the case, we want to reconnect without sending an error
|
||||
# to the client.
|
||||
while 1:
|
||||
sc = self.server_conn_pool.get_connection(scheme, host, port, host)
|
||||
sc = self.get_server_connection(cc, scheme, host, port, host)
|
||||
sc.send(request)
|
||||
sc.rfile.reset_timestamps()
|
||||
try:
|
||||
@ -200,7 +204,7 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
self.config.body_size_limit
|
||||
)
|
||||
except http.HttpErrorConnClosed, v:
|
||||
self.server_conn_pool.del_connection(scheme, host, port)
|
||||
self.del_server_connection()
|
||||
if sc.requestcount > 1:
|
||||
continue
|
||||
else:
|
||||
|
@ -39,8 +39,8 @@ class TestServerConnection:
|
||||
self.d.shutdown()
|
||||
|
||||
def test_simple(self):
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
|
||||
sc.connect("http", "host.com")
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com")
|
||||
sc.connect()
|
||||
r = tutils.treq()
|
||||
r.path = "/p/200:da"
|
||||
sc.send(r)
|
||||
@ -53,36 +53,9 @@ class TestServerConnection:
|
||||
sc.terminate()
|
||||
|
||||
def test_terminate_error(self):
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
|
||||
sc.connect("http", "host.com")
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), "http", self.d.IFACE, self.d.port, "host.com")
|
||||
sc.connect()
|
||||
sc.connection = mock.Mock()
|
||||
sc.connection.close = mock.Mock(side_effect=IOError)
|
||||
sc.terminate()
|
||||
|
||||
|
||||
|
||||
def _dummysc(config, host, port):
|
||||
return mock.MagicMock(config=config, host=host, port=port)
|
||||
|
||||
|
||||
def _errsc(config, host, port):
|
||||
m = mock.MagicMock(config=config, host=host, port=port)
|
||||
m.connect = mock.MagicMock(side_effect=tcp.NetLibError())
|
||||
return m
|
||||
|
||||
|
||||
class TestServerConnectionPool:
|
||||
@mock.patch("libmproxy.proxy.ServerConnection", _dummysc)
|
||||
def test_pooling(self):
|
||||
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
|
||||
c = p.get_connection("http", "localhost", 80, "localhost")
|
||||
c2 = p.get_connection("http", "localhost", 80, "localhost")
|
||||
assert c is c2
|
||||
c3 = p.get_connection("http", "foo", 80, "localhost")
|
||||
assert not c is c3
|
||||
|
||||
@mock.patch("libmproxy.proxy.ServerConnection", _errsc)
|
||||
def test_connection_error(self):
|
||||
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
|
||||
tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost")
|
||||
|
||||
|
@ -85,7 +85,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
||||
def test_connection_close(self):
|
||||
# Add a body, so we have a content-length header, which combined with
|
||||
# HTTP1.1 means the connection is kept alive.
|
||||
response = '%s/p/200:b@1'%self.urlbase
|
||||
response = '%s/p/200:b@1'%self.server.urlbase
|
||||
|
||||
# Lets sanity check that the connection does indeed stay open by
|
||||
# issuing two requests over the same connection
|
||||
@ -99,7 +99,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
||||
tutils.raises("disconnect", p.request, "get:'%s'"%response)
|
||||
|
||||
def test_reconnect(self):
|
||||
req = "get:'%s/p/200:b@1:da'"%self.urlbase
|
||||
req = "get:'%s/p/200:b@1:da'"%self.server.urlbase
|
||||
p = self.pathoc()
|
||||
assert p.request(req)
|
||||
# Server has disconnected. Mitmproxy should detect this, and reconnect.
|
||||
@ -107,7 +107,7 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
||||
assert p.request(req)
|
||||
|
||||
# However, if the server disconnects on our first try, it's an error.
|
||||
req = "get:'%s/p/200:b@1:d0'"%self.urlbase
|
||||
req = "get:'%s/p/200:b@1:d0'"%self.server.urlbase
|
||||
p = self.pathoc()
|
||||
tutils.raises("server disconnect", p.request, req)
|
||||
|
||||
@ -118,13 +118,29 @@ class TestHTTP(tservers.HTTPProxTest, SanityMixin):
|
||||
m.side_effect = IOError("error!")
|
||||
tutils.raises("empty reply", self.pathod, "304")
|
||||
|
||||
def test_get_connection_switching(self):
|
||||
def switched(l):
|
||||
for i in l:
|
||||
if "switching" in i:
|
||||
return True
|
||||
req = "get:'%s/p/200:b@1'"
|
||||
p = self.pathoc()
|
||||
assert p.request(req%self.server.urlbase)
|
||||
assert p.request(req%self.server2.urlbase)
|
||||
assert switched(self.proxy.log)
|
||||
|
||||
def test_get_connection_err(self):
|
||||
p = self.pathoc()
|
||||
ret = p.request("get:'http://localhost:0'")
|
||||
assert ret[1] == 502
|
||||
|
||||
|
||||
class TestHTTPS(tservers.HTTPProxTest, SanityMixin):
|
||||
ssl = True
|
||||
clientcerts = True
|
||||
def test_clientcert(self):
|
||||
f = self.pathod("304")
|
||||
assert self.last_log()["request"]["clientcert"]["keyinfo"]
|
||||
assert self.server.last_log()["request"]["clientcert"]["keyinfo"]
|
||||
|
||||
|
||||
class TestReverse(tservers.ReverseProxTest, SanityMixin):
|
||||
@ -211,7 +227,7 @@ class TestKillRequest(tservers.HTTPProxTest):
|
||||
p = self.pathoc()
|
||||
tutils.raises("empty reply", self.pathod, "200")
|
||||
# Nothing should have hit the server
|
||||
assert not self.last_log()
|
||||
assert not self.server.last_log()
|
||||
|
||||
|
||||
class MasterKillResponse(tservers.TestMaster):
|
||||
@ -225,5 +241,5 @@ class TestKillResponse(tservers.HTTPProxTest):
|
||||
p = self.pathoc()
|
||||
tutils.raises("empty reply", self.pathod, "200")
|
||||
# The server should have seen a request
|
||||
assert self.last_log()
|
||||
assert self.server.last_log()
|
||||
|
||||
|
@ -28,6 +28,7 @@ class TestMaster(flow.FlowMaster):
|
||||
state = flow.State()
|
||||
flow.FlowMaster.__init__(self, s, state)
|
||||
self.testq = testq
|
||||
self.log = []
|
||||
|
||||
def handle_request(self, m):
|
||||
flow.FlowMaster.handle_request(self, m)
|
||||
@ -37,6 +38,10 @@ class TestMaster(flow.FlowMaster):
|
||||
flow.FlowMaster.handle_response(self, m)
|
||||
m.reply()
|
||||
|
||||
def handle_log(self, l):
|
||||
self.log.append(l.msg)
|
||||
l.reply()
|
||||
|
||||
|
||||
class ProxyThread(threading.Thread):
|
||||
def __init__(self, tmaster):
|
||||
@ -48,6 +53,10 @@ class ProxyThread(threading.Thread):
|
||||
def port(self):
|
||||
return self.tmaster.server.port
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return self.tmaster.log
|
||||
|
||||
def run(self):
|
||||
self.tmaster.run()
|
||||
|
||||
@ -61,6 +70,7 @@ class ProxTestBase:
|
||||
def setupAll(cls):
|
||||
cls.tqueue = Queue.Queue()
|
||||
cls.server = libpathod.test.Daemon(ssl=cls.ssl)
|
||||
cls.server2 = libpathod.test.Daemon(ssl=cls.ssl)
|
||||
pconf = cls.get_proxy_config()
|
||||
config = proxy.ProxyConfig(
|
||||
cacert = tutils.test_data.path("data/serverkey.pem"),
|
||||
@ -78,6 +88,7 @@ class ProxTestBase:
|
||||
def teardownAll(cls):
|
||||
cls.proxy.shutdown()
|
||||
cls.server.shutdown()
|
||||
cls.server2.shutdown()
|
||||
|
||||
def setUp(self):
|
||||
self.master.state.clear()
|
||||
@ -95,16 +106,6 @@ class ProxTestBase:
|
||||
(self.scheme, ("127.0.0.1", self.proxy.port))
|
||||
)
|
||||
|
||||
@property
|
||||
def urlbase(self):
|
||||
"""
|
||||
The URL base for the server instance.
|
||||
"""
|
||||
return self.server.urlbase
|
||||
|
||||
def last_log(self):
|
||||
return self.server.last_log()
|
||||
|
||||
|
||||
class HTTPProxTest(ProxTestBase):
|
||||
ssl = None
|
||||
@ -129,7 +130,7 @@ class HTTPProxTest(ProxTestBase):
|
||||
Constructs a pathod request, with the appropriate base and proxy.
|
||||
"""
|
||||
return hurl.get(
|
||||
self.urlbase + "/p/" + spec,
|
||||
self.server.urlbase + "/p/" + spec,
|
||||
proxy=self.proxies,
|
||||
validate_cert=False,
|
||||
#debug=hurl.utils.stdout_debug
|
||||
|
Loading…
Reference in New Issue
Block a user