unify ipv4/ipv6 address handling

This commit is contained in:
Maximilian Hils 2014-01-28 17:28:20 +01:00
parent 94e530ec4f
commit 17f09aa0af
4 changed files with 28 additions and 36 deletions

View File

@ -5,9 +5,9 @@ python:
install:
- "pip install coveralls --use-mirrors"
- "pip install nose-cov --use-mirrors"
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git"
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git@tcp_proxy"
- "pip install -r requirements.txt --use-mirrors"
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git"
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git@tcp_proxy"
# command to run tests, e.g. python setup.py test
script:
- "nosetests --with-cov --cov-report term-missing"

View File

@ -115,7 +115,7 @@ class HTTPResponse(HTTPMessage):
class HTTPRequest(HTTPMessage):
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
timestamp_start, timestamp_end, form_out=None, ip=None):
timestamp_start, timestamp_end, form_out=None):
self.form_in = form_in
self.method = method
self.scheme = scheme
@ -129,7 +129,6 @@ class HTTPRequest(HTTPMessage):
self.timestamp_end = timestamp_end
self.form_out = form_out or self.form_in
self.ip = ip # resolved ip address
assert isinstance(headers, ODictCaseless)
#FIXME: Compatibility Fix
@ -352,7 +351,7 @@ class HTTPHandler(ProtocolHandler):
if request.form_in == "authority":
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
if directly_addressed_at_mitmproxy:
self.c.establish_server_connection(request.host, request.port)
self.c.establish_server_connection((request.host, request.port))
self.c.client_conn.wfile.write(
'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n' % self.c.server_version) +
@ -369,7 +368,7 @@ class HTTPHandler(ProtocolHandler):
request.form_out = "origin"
if ((not self.c.server_conn) or
(self.c.server_conn.address != (request.host, request.port))):
self.c.establish_server_connection(request.host, request.port)
self.c.establish_server_connection((request.host, request.port))
else:
raise http.HttpError(400, "Invalid Request")

View File

@ -40,18 +40,13 @@ class ProxyConfig:
class ClientConnection(tcp.BaseHandler):
def __init__(self, client_connection, host, port):
tcp.BaseHandler.__init__(self, client_connection)
self.host, self.port = host, port
def __init__(self, client_connection, address):
tcp.BaseHandler.__init__(self, client_connection, address)
self.timestamp_start = utils.timestamp()
self.timestamp_end = None
self.timestamp_ssl_setup = None
@property
def address(self):
return self.host, self.port
def convert_to_ssl(self, *args, **kwargs):
tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
self.timestamp_ssl_setup = utils.timestamp()
@ -62,21 +57,19 @@ class ClientConnection(tcp.BaseHandler):
class ServerConnection(tcp.TCPClient):
def __init__(self, host, port):
tcp.TCPClient.__init__(self, host, port)
def __init__(self, address):
tcp.TCPClient.__init__(self, address)
self.peername = None
self.timestamp_start = None
self.timestamp_end = None
self.timestamp_tcp_setup = None
self.timestamp_ssl_setup = None
@property
def address(self):
return self.host, self.port
def connect(self):
self.timestamp_start = utils.timestamp()
tcp.TCPClient.connect(self)
self.peername = self.connection.getpeername()
self.timestamp_tcp_setup = utils.timestamp()
def establish_ssl(self, clientcerts, sni):
@ -125,7 +118,7 @@ class RequestReplayThread(threading.Thread):
class ConnectionHandler:
def __init__(self, config, client_connection, client_address, server, channel, server_version):
self.config = config
self.client_conn = ClientConnection(client_connection, *client_address)
self.client_conn = ClientConnection(client_connection, client_address)
self.server_conn = None
self.channel, self.server_version = channel, server_version
@ -142,7 +135,7 @@ class ConnectionHandler:
def del_server_connection(self):
if self.server_conn and self.server_conn.connection:
self.server_conn.finish()
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)])
self.log("serverdisconnect", ["%s:%s" % self.server_conn.address])
self.channel.tell("serverdisconnect", self)
self.server_conn = None
self.sni = None
@ -169,7 +162,7 @@ class ConnectionHandler:
self.determine_conntype()
if server_address:
self.establish_server_connection(*server_address)
self.establish_server_connection(server_address)
self._handle_ssl()
while not self.close:
@ -191,7 +184,7 @@ class ConnectionHandler:
Check if we can already identify SSL connections.
"""
if self.config.transparent_proxy:
client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"])
client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"])
elif self.config.reverse_proxy:
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
# TODO: Make protocol generic (as with transparent proxies)
@ -205,18 +198,18 @@ class ConnectionHandler:
#TODO: Add ruleset to select correct protocol depending on mode/target port etc.
self.conntype = "http"
def establish_server_connection(self, host, port):
def establish_server_connection(self, address):
"""
Establishes a new server connection to the given server
If there is already an existing server connection, it will be killed.
"""
self.del_server_connection()
self.server_conn = ServerConnection(host, port)
self.server_conn = ServerConnection(address)
try:
self.server_conn.connect()
except tcp.NetLibError, v:
raise ProxyError(502, v)
self.log("serverconnect", ["%s:%s" % (host, port)])
self.log("serverconnect", ["%s:%s" % address])
self.channel.tell("serverconnect", self)
def establish_ssl(self, client=False, server=False):
@ -227,7 +220,7 @@ class ConnectionHandler:
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
"""
# TODO: Implement SSL pass-through handling and change conntype
if self.server_conn.host == "ycombinator.com":
if self.server_conn.address.host == "ycombinator.com":
self.conntype = "tcp"
if server:
@ -244,14 +237,14 @@ class ConnectionHandler:
def server_reconnect(self, no_ssl=False):
self.log("server reconnect")
had_ssl, sni = self.server_conn.ssl_established, self.sni
self.establish_server_connection(*self.server_conn.address)
self.establish_server_connection(self.server_conn.address)
if had_ssl and not no_ssl:
self.sni = sni
self.establish_ssl(server=True)
def log(self, msg, subs=()):
msg = [
"%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg)
"%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg)
]
for i in subs:
msg.append(" -> " + i)
@ -263,7 +256,7 @@ class ConnectionHandler:
with open(self.config.certfile, "rb") as f:
return certutils.SSLCert.from_pem(f.read())
else:
host = self.server_conn.host
host = self.server_conn.address.host
sans = []
if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
upstream_cert = self.server_conn.cert
@ -307,14 +300,14 @@ class ProxyServer(tcp.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address='', server_version=version.NAMEVERSION):
def __init__(self, config, port, host='', server_version=version.NAMEVERSION):
"""
Raises ProxyServerError if there's a startup problem.
"""
self.config, self.port, self.address = config, port, address
self.config = config
self.server_version = server_version
try:
tcp.TCPServer.__init__(self, (address, port))
tcp.TCPServer.__init__(self, (host, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.channel = None

View File

@ -46,7 +46,7 @@ class CommonMixin:
assert l.response.code == 304
def test_invalid_http(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
t.connect()
t.wfile.write("invalid\r\n\r\n")
t.wfile.flush()
@ -70,7 +70,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin):
assert "ValueError" in ret.content
def test_invalid_connect(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port)
t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
t.connect()
t.wfile.write("CONNECT invalid\n\n")
t.wfile.flush()