mirror of
https://github.com/Grasscutters/mitmproxy.git
synced 2024-11-23 08:11:00 +00:00
unify ipv4/ipv6 address handling
This commit is contained in:
parent
94e530ec4f
commit
17f09aa0af
@ -5,9 +5,9 @@ python:
|
|||||||
install:
|
install:
|
||||||
- "pip install coveralls --use-mirrors"
|
- "pip install coveralls --use-mirrors"
|
||||||
- "pip install nose-cov --use-mirrors"
|
- "pip install nose-cov --use-mirrors"
|
||||||
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git"
|
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git@tcp_proxy"
|
||||||
- "pip install -r requirements.txt --use-mirrors"
|
- "pip install -r requirements.txt --use-mirrors"
|
||||||
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git"
|
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git@tcp_proxy"
|
||||||
# command to run tests, e.g. python setup.py test
|
# command to run tests, e.g. python setup.py test
|
||||||
script:
|
script:
|
||||||
- "nosetests --with-cov --cov-report term-missing"
|
- "nosetests --with-cov --cov-report term-missing"
|
||||||
|
@ -115,7 +115,7 @@ class HTTPResponse(HTTPMessage):
|
|||||||
|
|
||||||
class HTTPRequest(HTTPMessage):
|
class HTTPRequest(HTTPMessage):
|
||||||
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
|
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
|
||||||
timestamp_start, timestamp_end, form_out=None, ip=None):
|
timestamp_start, timestamp_end, form_out=None):
|
||||||
self.form_in = form_in
|
self.form_in = form_in
|
||||||
self.method = method
|
self.method = method
|
||||||
self.scheme = scheme
|
self.scheme = scheme
|
||||||
@ -129,7 +129,6 @@ class HTTPRequest(HTTPMessage):
|
|||||||
self.timestamp_end = timestamp_end
|
self.timestamp_end = timestamp_end
|
||||||
|
|
||||||
self.form_out = form_out or self.form_in
|
self.form_out = form_out or self.form_in
|
||||||
self.ip = ip # resolved ip address
|
|
||||||
assert isinstance(headers, ODictCaseless)
|
assert isinstance(headers, ODictCaseless)
|
||||||
|
|
||||||
#FIXME: Compatibility Fix
|
#FIXME: Compatibility Fix
|
||||||
@ -352,7 +351,7 @@ class HTTPHandler(ProtocolHandler):
|
|||||||
if request.form_in == "authority":
|
if request.form_in == "authority":
|
||||||
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
|
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
|
||||||
if directly_addressed_at_mitmproxy:
|
if directly_addressed_at_mitmproxy:
|
||||||
self.c.establish_server_connection(request.host, request.port)
|
self.c.establish_server_connection((request.host, request.port))
|
||||||
self.c.client_conn.wfile.write(
|
self.c.client_conn.wfile.write(
|
||||||
'HTTP/1.1 200 Connection established\r\n' +
|
'HTTP/1.1 200 Connection established\r\n' +
|
||||||
('Proxy-agent: %s\r\n' % self.c.server_version) +
|
('Proxy-agent: %s\r\n' % self.c.server_version) +
|
||||||
@ -369,7 +368,7 @@ class HTTPHandler(ProtocolHandler):
|
|||||||
request.form_out = "origin"
|
request.form_out = "origin"
|
||||||
if ((not self.c.server_conn) or
|
if ((not self.c.server_conn) or
|
||||||
(self.c.server_conn.address != (request.host, request.port))):
|
(self.c.server_conn.address != (request.host, request.port))):
|
||||||
self.c.establish_server_connection(request.host, request.port)
|
self.c.establish_server_connection((request.host, request.port))
|
||||||
else:
|
else:
|
||||||
raise http.HttpError(400, "Invalid Request")
|
raise http.HttpError(400, "Invalid Request")
|
||||||
|
|
||||||
|
@ -40,18 +40,13 @@ class ProxyConfig:
|
|||||||
|
|
||||||
|
|
||||||
class ClientConnection(tcp.BaseHandler):
|
class ClientConnection(tcp.BaseHandler):
|
||||||
def __init__(self, client_connection, host, port):
|
def __init__(self, client_connection, address):
|
||||||
tcp.BaseHandler.__init__(self, client_connection)
|
tcp.BaseHandler.__init__(self, client_connection, address)
|
||||||
self.host, self.port = host, port
|
|
||||||
|
|
||||||
self.timestamp_start = utils.timestamp()
|
self.timestamp_start = utils.timestamp()
|
||||||
self.timestamp_end = None
|
self.timestamp_end = None
|
||||||
self.timestamp_ssl_setup = None
|
self.timestamp_ssl_setup = None
|
||||||
|
|
||||||
@property
|
|
||||||
def address(self):
|
|
||||||
return self.host, self.port
|
|
||||||
|
|
||||||
def convert_to_ssl(self, *args, **kwargs):
|
def convert_to_ssl(self, *args, **kwargs):
|
||||||
tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
|
tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
|
||||||
self.timestamp_ssl_setup = utils.timestamp()
|
self.timestamp_ssl_setup = utils.timestamp()
|
||||||
@ -62,21 +57,19 @@ class ClientConnection(tcp.BaseHandler):
|
|||||||
|
|
||||||
|
|
||||||
class ServerConnection(tcp.TCPClient):
|
class ServerConnection(tcp.TCPClient):
|
||||||
def __init__(self, host, port):
|
def __init__(self, address):
|
||||||
tcp.TCPClient.__init__(self, host, port)
|
tcp.TCPClient.__init__(self, address)
|
||||||
|
|
||||||
|
self.peername = None
|
||||||
self.timestamp_start = None
|
self.timestamp_start = None
|
||||||
self.timestamp_end = None
|
self.timestamp_end = None
|
||||||
self.timestamp_tcp_setup = None
|
self.timestamp_tcp_setup = None
|
||||||
self.timestamp_ssl_setup = None
|
self.timestamp_ssl_setup = None
|
||||||
|
|
||||||
@property
|
|
||||||
def address(self):
|
|
||||||
return self.host, self.port
|
|
||||||
|
|
||||||
def connect(self):
|
def connect(self):
|
||||||
self.timestamp_start = utils.timestamp()
|
self.timestamp_start = utils.timestamp()
|
||||||
tcp.TCPClient.connect(self)
|
tcp.TCPClient.connect(self)
|
||||||
|
self.peername = self.connection.getpeername()
|
||||||
self.timestamp_tcp_setup = utils.timestamp()
|
self.timestamp_tcp_setup = utils.timestamp()
|
||||||
|
|
||||||
def establish_ssl(self, clientcerts, sni):
|
def establish_ssl(self, clientcerts, sni):
|
||||||
@ -125,7 +118,7 @@ class RequestReplayThread(threading.Thread):
|
|||||||
class ConnectionHandler:
|
class ConnectionHandler:
|
||||||
def __init__(self, config, client_connection, client_address, server, channel, server_version):
|
def __init__(self, config, client_connection, client_address, server, channel, server_version):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client_conn = ClientConnection(client_connection, *client_address)
|
self.client_conn = ClientConnection(client_connection, client_address)
|
||||||
self.server_conn = None
|
self.server_conn = None
|
||||||
self.channel, self.server_version = channel, server_version
|
self.channel, self.server_version = channel, server_version
|
||||||
|
|
||||||
@ -142,7 +135,7 @@ class ConnectionHandler:
|
|||||||
def del_server_connection(self):
|
def del_server_connection(self):
|
||||||
if self.server_conn and self.server_conn.connection:
|
if self.server_conn and self.server_conn.connection:
|
||||||
self.server_conn.finish()
|
self.server_conn.finish()
|
||||||
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)])
|
self.log("serverdisconnect", ["%s:%s" % self.server_conn.address])
|
||||||
self.channel.tell("serverdisconnect", self)
|
self.channel.tell("serverdisconnect", self)
|
||||||
self.server_conn = None
|
self.server_conn = None
|
||||||
self.sni = None
|
self.sni = None
|
||||||
@ -169,7 +162,7 @@ class ConnectionHandler:
|
|||||||
self.determine_conntype()
|
self.determine_conntype()
|
||||||
|
|
||||||
if server_address:
|
if server_address:
|
||||||
self.establish_server_connection(*server_address)
|
self.establish_server_connection(server_address)
|
||||||
self._handle_ssl()
|
self._handle_ssl()
|
||||||
|
|
||||||
while not self.close:
|
while not self.close:
|
||||||
@ -191,7 +184,7 @@ class ConnectionHandler:
|
|||||||
Check if we can already identify SSL connections.
|
Check if we can already identify SSL connections.
|
||||||
"""
|
"""
|
||||||
if self.config.transparent_proxy:
|
if self.config.transparent_proxy:
|
||||||
client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"])
|
client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"])
|
||||||
elif self.config.reverse_proxy:
|
elif self.config.reverse_proxy:
|
||||||
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
|
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
|
||||||
# TODO: Make protocol generic (as with transparent proxies)
|
# TODO: Make protocol generic (as with transparent proxies)
|
||||||
@ -205,18 +198,18 @@ class ConnectionHandler:
|
|||||||
#TODO: Add ruleset to select correct protocol depending on mode/target port etc.
|
#TODO: Add ruleset to select correct protocol depending on mode/target port etc.
|
||||||
self.conntype = "http"
|
self.conntype = "http"
|
||||||
|
|
||||||
def establish_server_connection(self, host, port):
|
def establish_server_connection(self, address):
|
||||||
"""
|
"""
|
||||||
Establishes a new server connection to the given server
|
Establishes a new server connection to the given server
|
||||||
If there is already an existing server connection, it will be killed.
|
If there is already an existing server connection, it will be killed.
|
||||||
"""
|
"""
|
||||||
self.del_server_connection()
|
self.del_server_connection()
|
||||||
self.server_conn = ServerConnection(host, port)
|
self.server_conn = ServerConnection(address)
|
||||||
try:
|
try:
|
||||||
self.server_conn.connect()
|
self.server_conn.connect()
|
||||||
except tcp.NetLibError, v:
|
except tcp.NetLibError, v:
|
||||||
raise ProxyError(502, v)
|
raise ProxyError(502, v)
|
||||||
self.log("serverconnect", ["%s:%s" % (host, port)])
|
self.log("serverconnect", ["%s:%s" % address])
|
||||||
self.channel.tell("serverconnect", self)
|
self.channel.tell("serverconnect", self)
|
||||||
|
|
||||||
def establish_ssl(self, client=False, server=False):
|
def establish_ssl(self, client=False, server=False):
|
||||||
@ -227,7 +220,7 @@ class ConnectionHandler:
|
|||||||
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
|
||||||
"""
|
"""
|
||||||
# TODO: Implement SSL pass-through handling and change conntype
|
# TODO: Implement SSL pass-through handling and change conntype
|
||||||
if self.server_conn.host == "ycombinator.com":
|
if self.server_conn.address.host == "ycombinator.com":
|
||||||
self.conntype = "tcp"
|
self.conntype = "tcp"
|
||||||
|
|
||||||
if server:
|
if server:
|
||||||
@ -244,14 +237,14 @@ class ConnectionHandler:
|
|||||||
def server_reconnect(self, no_ssl=False):
|
def server_reconnect(self, no_ssl=False):
|
||||||
self.log("server reconnect")
|
self.log("server reconnect")
|
||||||
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
had_ssl, sni = self.server_conn.ssl_established, self.sni
|
||||||
self.establish_server_connection(*self.server_conn.address)
|
self.establish_server_connection(self.server_conn.address)
|
||||||
if had_ssl and not no_ssl:
|
if had_ssl and not no_ssl:
|
||||||
self.sni = sni
|
self.sni = sni
|
||||||
self.establish_ssl(server=True)
|
self.establish_ssl(server=True)
|
||||||
|
|
||||||
def log(self, msg, subs=()):
|
def log(self, msg, subs=()):
|
||||||
msg = [
|
msg = [
|
||||||
"%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg)
|
"%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg)
|
||||||
]
|
]
|
||||||
for i in subs:
|
for i in subs:
|
||||||
msg.append(" -> " + i)
|
msg.append(" -> " + i)
|
||||||
@ -263,7 +256,7 @@ class ConnectionHandler:
|
|||||||
with open(self.config.certfile, "rb") as f:
|
with open(self.config.certfile, "rb") as f:
|
||||||
return certutils.SSLCert.from_pem(f.read())
|
return certutils.SSLCert.from_pem(f.read())
|
||||||
else:
|
else:
|
||||||
host = self.server_conn.host
|
host = self.server_conn.address.host
|
||||||
sans = []
|
sans = []
|
||||||
if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
|
if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
|
||||||
upstream_cert = self.server_conn.cert
|
upstream_cert = self.server_conn.cert
|
||||||
@ -307,14 +300,14 @@ class ProxyServer(tcp.TCPServer):
|
|||||||
allow_reuse_address = True
|
allow_reuse_address = True
|
||||||
bound = True
|
bound = True
|
||||||
|
|
||||||
def __init__(self, config, port, address='', server_version=version.NAMEVERSION):
|
def __init__(self, config, port, host='', server_version=version.NAMEVERSION):
|
||||||
"""
|
"""
|
||||||
Raises ProxyServerError if there's a startup problem.
|
Raises ProxyServerError if there's a startup problem.
|
||||||
"""
|
"""
|
||||||
self.config, self.port, self.address = config, port, address
|
self.config = config
|
||||||
self.server_version = server_version
|
self.server_version = server_version
|
||||||
try:
|
try:
|
||||||
tcp.TCPServer.__init__(self, (address, port))
|
tcp.TCPServer.__init__(self, (host, port))
|
||||||
except socket.error, v:
|
except socket.error, v:
|
||||||
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
|
||||||
self.channel = None
|
self.channel = None
|
||||||
|
@ -46,7 +46,7 @@ class CommonMixin:
|
|||||||
assert l.response.code == 304
|
assert l.response.code == 304
|
||||||
|
|
||||||
def test_invalid_http(self):
|
def test_invalid_http(self):
|
||||||
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
|
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
|
||||||
t.connect()
|
t.connect()
|
||||||
t.wfile.write("invalid\r\n\r\n")
|
t.wfile.write("invalid\r\n\r\n")
|
||||||
t.wfile.flush()
|
t.wfile.flush()
|
||||||
@ -70,7 +70,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin):
|
|||||||
assert "ValueError" in ret.content
|
assert "ValueError" in ret.content
|
||||||
|
|
||||||
def test_invalid_connect(self):
|
def test_invalid_connect(self):
|
||||||
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
|
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
|
||||||
t.connect()
|
t.connect()
|
||||||
t.wfile.write("CONNECT invalid\n\n")
|
t.wfile.write("CONNECT invalid\n\n")
|
||||||
t.wfile.flush()
|
t.wfile.flush()
|
||||||
|
Loading…
Reference in New Issue
Block a user