mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2025-01-30 14:58:38 +00:00
Refactor to prepare for SNI fixes.
This commit is contained in:
parent
d0639e8925
commit
705559d65e
@ -50,36 +50,13 @@ class ProxyConfig:
|
||||
self.certstore = certutils.CertStore(certdir)
|
||||
|
||||
|
||||
class RequestReplayThread(threading.Thread):
|
||||
def __init__(self, config, flow, masterq):
|
||||
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
r = self.flow.request
|
||||
server = ServerConnection(self.config, r.host, r.port)
|
||||
server.connect(r.scheme)
|
||||
server.send(r)
|
||||
httpversion, code, msg, headers, content = http.read_response(
|
||||
server.rfile, r.method, self.config.body_size_limit
|
||||
)
|
||||
response = flow.Response(
|
||||
self.flow.request, httpversion, code, msg, headers, content, server.cert
|
||||
)
|
||||
self.channel.ask(response)
|
||||
except (ProxyError, http.HttpError, tcp.NetLibError), v:
|
||||
err = flow.Error(self.flow.request, str(v))
|
||||
self.channel.ask(err)
|
||||
|
||||
|
||||
class ServerConnection(tcp.TCPClient):
|
||||
def __init__(self, config, host, port):
|
||||
tcp.TCPClient.__init__(self, host, port)
|
||||
self.config = config
|
||||
self.requestcount = 0
|
||||
|
||||
def connect(self, scheme):
|
||||
def connect(self, scheme, sni):
|
||||
tcp.TCPClient.connect(self)
|
||||
if scheme == "https":
|
||||
clientcert = None
|
||||
@ -88,7 +65,7 @@ class ServerConnection(tcp.TCPClient):
|
||||
if os.path.exists(path):
|
||||
clientcert = path
|
||||
try:
|
||||
self.convert_to_ssl(clientcert=clientcert, sni=self.host)
|
||||
self.convert_to_ssl(cert=clientcert, sni=sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(400, str(v))
|
||||
|
||||
@ -109,12 +86,35 @@ class ServerConnection(tcp.TCPClient):
|
||||
pass
|
||||
|
||||
|
||||
class RequestReplayThread(threading.Thread):
|
||||
def __init__(self, config, flow, masterq):
|
||||
self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
r = self.flow.request
|
||||
server = ServerConnection(self.config, r.host, r.port)
|
||||
server.connect(r.scheme, r.host)
|
||||
server.send(r)
|
||||
httpversion, code, msg, headers, content = http.read_response(
|
||||
server.rfile, r.method, self.config.body_size_limit
|
||||
)
|
||||
response = flow.Response(
|
||||
self.flow.request, httpversion, code, msg, headers, content, server.cert
|
||||
)
|
||||
self.channel.ask(response)
|
||||
except (ProxyError, http.HttpError, tcp.NetLibError), v:
|
||||
err = flow.Error(self.flow.request, str(v))
|
||||
self.channel.ask(err)
|
||||
|
||||
|
||||
class ServerConnectionPool:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.conn = None
|
||||
|
||||
def get_connection(self, scheme, host, port):
|
||||
def get_connection(self, scheme, host, port, sni):
|
||||
sc = self.conn
|
||||
if self.conn and (host, port) != (sc.host, sc.port):
|
||||
sc.terminate()
|
||||
@ -122,7 +122,7 @@ class ServerConnectionPool:
|
||||
if not self.conn:
|
||||
try:
|
||||
self.conn = ServerConnection(self.config, host, port)
|
||||
self.conn.connect(scheme)
|
||||
self.conn.connect(scheme, sni)
|
||||
except tcp.NetLibError, v:
|
||||
raise ProxyError(502, v)
|
||||
return self.conn
|
||||
@ -190,18 +190,18 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
# the case, we want to reconnect without sending an error
|
||||
# to the client.
|
||||
while 1:
|
||||
sc = self.server_conn_pool.get_connection(scheme, host, port, host)
|
||||
sc.send(request)
|
||||
sc.rfile.reset_timestamps()
|
||||
try:
|
||||
sc = self.server_conn_pool.get_connection(scheme, host, port)
|
||||
sc.send(request)
|
||||
sc.rfile.reset_timestamps()
|
||||
httpversion, code, msg, headers, content = http.read_response(
|
||||
sc.rfile,
|
||||
request.method,
|
||||
self.config.body_size_limit
|
||||
)
|
||||
except http.HttpErrorConnClosed, v:
|
||||
self.server_conn_pool.del_connection(scheme, host, port)
|
||||
if sc.requestcount > 1:
|
||||
self.server_conn_pool.del_connection(scheme, host, port)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
@ -324,25 +324,6 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
self.rfile.first_byte_timestamp, utils.timestamp()
|
||||
)
|
||||
|
||||
def read_request_reverse(self, client_conn):
|
||||
line = self.get_line(self.rfile)
|
||||
if line == "":
|
||||
return None
|
||||
scheme, host, port = self.config.reverse_proxy
|
||||
r = http.parse_init_http(line)
|
||||
if not r:
|
||||
raise ProxyError(400, "Bad HTTP request line: %s"%repr(line))
|
||||
method, path, httpversion = r
|
||||
headers = self.read_headers(authenticate=False)
|
||||
content = http.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
return flow.Request(
|
||||
client_conn, httpversion, host, port, "http", method, path, headers, content,
|
||||
self.rfile.first_byte_timestamp, utils.timestamp()
|
||||
)
|
||||
|
||||
|
||||
def read_request_proxy(self, client_conn):
|
||||
line = self.get_line(self.rfile)
|
||||
if line == "":
|
||||
@ -398,6 +379,24 @@ class ProxyHandler(tcp.BaseHandler):
|
||||
self.rfile.first_byte_timestamp, utils.timestamp()
|
||||
)
|
||||
|
||||
def read_request_reverse(self, client_conn):
|
||||
line = self.get_line(self.rfile)
|
||||
if line == "":
|
||||
return None
|
||||
scheme, host, port = self.config.reverse_proxy
|
||||
r = http.parse_init_http(line)
|
||||
if not r:
|
||||
raise ProxyError(400, "Bad HTTP request line: %s"%repr(line))
|
||||
method, path, httpversion = r
|
||||
headers = self.read_headers(authenticate=False)
|
||||
content = http.read_http_body_request(
|
||||
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
|
||||
)
|
||||
return flow.Request(
|
||||
client_conn, httpversion, host, port, "http", method, path, headers, content,
|
||||
self.rfile.first_byte_timestamp, utils.timestamp()
|
||||
)
|
||||
|
||||
def read_request(self, client_conn):
|
||||
self.rfile.reset_timestamps()
|
||||
if self.config.transparent_proxy:
|
||||
|
@ -40,7 +40,7 @@ class TestServerConnection:
|
||||
|
||||
def test_simple(self):
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
|
||||
sc.connect("http")
|
||||
sc.connect("http", "host.com")
|
||||
r = tutils.treq()
|
||||
r.path = "/p/200:da"
|
||||
sc.send(r)
|
||||
@ -54,7 +54,7 @@ class TestServerConnection:
|
||||
|
||||
def test_terminate_error(self):
|
||||
sc = proxy.ServerConnection(proxy.ProxyConfig(), self.d.IFACE, self.d.port)
|
||||
sc.connect("http")
|
||||
sc.connect("http", "host.com")
|
||||
sc.connection = mock.Mock()
|
||||
sc.connection.close = mock.Mock(side_effect=IOError)
|
||||
sc.terminate()
|
||||
@ -75,14 +75,14 @@ class TestServerConnectionPool:
|
||||
@mock.patch("libmproxy.proxy.ServerConnection", _dummysc)
|
||||
def test_pooling(self):
|
||||
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
|
||||
c = p.get_connection("http", "localhost", 80)
|
||||
c2 = p.get_connection("http", "localhost", 80)
|
||||
c = p.get_connection("http", "localhost", 80, "localhost")
|
||||
c2 = p.get_connection("http", "localhost", 80, "localhost")
|
||||
assert c is c2
|
||||
c3 = p.get_connection("http", "foo", 80)
|
||||
c3 = p.get_connection("http", "foo", 80, "localhost")
|
||||
assert not c is c3
|
||||
|
||||
@mock.patch("libmproxy.proxy.ServerConnection", _errsc)
|
||||
def test_connection_error(self):
|
||||
p = proxy.ServerConnectionPool(proxy.ProxyConfig())
|
||||
tutils.raises("502", p.get_connection, "http", "localhost", 80)
|
||||
tutils.raises("502", p.get_connection, "http", "localhost", 80, "localhost")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user