unify ipv4/ipv6 address handling

This commit is contained in:
Maximilian Hils 2014-01-28 17:28:20 +01:00
parent 94e530ec4f
commit 17f09aa0af
4 changed files with 28 additions and 36 deletions

View File

@ -5,9 +5,9 @@ python:
install: install:
- "pip install coveralls --use-mirrors" - "pip install coveralls --use-mirrors"
- "pip install nose-cov --use-mirrors" - "pip install nose-cov --use-mirrors"
- "pip install --upgrade git+https://github.com/mitmproxy/netlib.git" - "pip install --upgrade git+https://github.com/mitmproxy/netlib.git@tcp_proxy"
- "pip install -r requirements.txt --use-mirrors" - "pip install -r requirements.txt --use-mirrors"
- "pip install --upgrade git+https://github.com/mitmproxy/pathod.git" - "pip install --upgrade git+https://github.com/mitmproxy/pathod.git@tcp_proxy"
# command to run tests, e.g. python setup.py test # command to run tests, e.g. python setup.py test
script: script:
- "nosetests --with-cov --cov-report term-missing" - "nosetests --with-cov --cov-report term-missing"

View File

@ -115,7 +115,7 @@ class HTTPResponse(HTTPMessage):
class HTTPRequest(HTTPMessage): class HTTPRequest(HTTPMessage):
def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content, def __init__(self, form_in, method, scheme, host, port, path, httpversion, headers, content,
timestamp_start, timestamp_end, form_out=None, ip=None): timestamp_start, timestamp_end, form_out=None):
self.form_in = form_in self.form_in = form_in
self.method = method self.method = method
self.scheme = scheme self.scheme = scheme
@ -129,7 +129,6 @@ class HTTPRequest(HTTPMessage):
self.timestamp_end = timestamp_end self.timestamp_end = timestamp_end
self.form_out = form_out or self.form_in self.form_out = form_out or self.form_in
self.ip = ip # resolved ip address
assert isinstance(headers, ODictCaseless) assert isinstance(headers, ODictCaseless)
#FIXME: Compatibility Fix #FIXME: Compatibility Fix
@ -352,7 +351,7 @@ class HTTPHandler(ProtocolHandler):
if request.form_in == "authority": if request.form_in == "authority":
directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy directly_addressed_at_mitmproxy = (self.c.mode == "regular") and not self.c.config.forward_proxy
if directly_addressed_at_mitmproxy: if directly_addressed_at_mitmproxy:
self.c.establish_server_connection(request.host, request.port) self.c.establish_server_connection((request.host, request.port))
self.c.client_conn.wfile.write( self.c.client_conn.wfile.write(
'HTTP/1.1 200 Connection established\r\n' + 'HTTP/1.1 200 Connection established\r\n' +
('Proxy-agent: %s\r\n' % self.c.server_version) + ('Proxy-agent: %s\r\n' % self.c.server_version) +
@ -369,7 +368,7 @@ class HTTPHandler(ProtocolHandler):
request.form_out = "origin" request.form_out = "origin"
if ((not self.c.server_conn) or if ((not self.c.server_conn) or
(self.c.server_conn.address != (request.host, request.port))): (self.c.server_conn.address != (request.host, request.port))):
self.c.establish_server_connection(request.host, request.port) self.c.establish_server_connection((request.host, request.port))
else: else:
raise http.HttpError(400, "Invalid Request") raise http.HttpError(400, "Invalid Request")

View File

@ -40,18 +40,13 @@ class ProxyConfig:
class ClientConnection(tcp.BaseHandler): class ClientConnection(tcp.BaseHandler):
def __init__(self, client_connection, host, port): def __init__(self, client_connection, address):
tcp.BaseHandler.__init__(self, client_connection) tcp.BaseHandler.__init__(self, client_connection, address)
self.host, self.port = host, port
self.timestamp_start = utils.timestamp() self.timestamp_start = utils.timestamp()
self.timestamp_end = None self.timestamp_end = None
self.timestamp_ssl_setup = None self.timestamp_ssl_setup = None
@property
def address(self):
return self.host, self.port
def convert_to_ssl(self, *args, **kwargs): def convert_to_ssl(self, *args, **kwargs):
tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs)
self.timestamp_ssl_setup = utils.timestamp() self.timestamp_ssl_setup = utils.timestamp()
@ -62,21 +57,19 @@ class ClientConnection(tcp.BaseHandler):
class ServerConnection(tcp.TCPClient): class ServerConnection(tcp.TCPClient):
def __init__(self, host, port): def __init__(self, address):
tcp.TCPClient.__init__(self, host, port) tcp.TCPClient.__init__(self, address)
self.peername = None
self.timestamp_start = None self.timestamp_start = None
self.timestamp_end = None self.timestamp_end = None
self.timestamp_tcp_setup = None self.timestamp_tcp_setup = None
self.timestamp_ssl_setup = None self.timestamp_ssl_setup = None
@property
def address(self):
return self.host, self.port
def connect(self): def connect(self):
self.timestamp_start = utils.timestamp() self.timestamp_start = utils.timestamp()
tcp.TCPClient.connect(self) tcp.TCPClient.connect(self)
self.peername = self.connection.getpeername()
self.timestamp_tcp_setup = utils.timestamp() self.timestamp_tcp_setup = utils.timestamp()
def establish_ssl(self, clientcerts, sni): def establish_ssl(self, clientcerts, sni):
@ -125,7 +118,7 @@ class RequestReplayThread(threading.Thread):
class ConnectionHandler: class ConnectionHandler:
def __init__(self, config, client_connection, client_address, server, channel, server_version): def __init__(self, config, client_connection, client_address, server, channel, server_version):
self.config = config self.config = config
self.client_conn = ClientConnection(client_connection, *client_address) self.client_conn = ClientConnection(client_connection, client_address)
self.server_conn = None self.server_conn = None
self.channel, self.server_version = channel, server_version self.channel, self.server_version = channel, server_version
@ -140,9 +133,9 @@ class ConnectionHandler:
self.mode = "transparent" self.mode = "transparent"
def del_server_connection(self): def del_server_connection(self):
if self.server_conn and self.server_conn.connection: if self.server_conn and self.server_conn.connection:
self.server_conn.finish() self.server_conn.finish()
self.log("serverdisconnect", ["%s:%s" % (self.server_conn.host, self.server_conn.port)]) self.log("serverdisconnect", ["%s:%s" % self.server_conn.address])
self.channel.tell("serverdisconnect", self) self.channel.tell("serverdisconnect", self)
self.server_conn = None self.server_conn = None
self.sni = None self.sni = None
@ -169,7 +162,7 @@ class ConnectionHandler:
self.determine_conntype() self.determine_conntype()
if server_address: if server_address:
self.establish_server_connection(*server_address) self.establish_server_connection(server_address)
self._handle_ssl() self._handle_ssl()
while not self.close: while not self.close:
@ -191,7 +184,7 @@ class ConnectionHandler:
Check if we can already identify SSL connections. Check if we can already identify SSL connections.
""" """
if self.config.transparent_proxy: if self.config.transparent_proxy:
client_ssl = server_ssl = (self.server_conn.port in self.config.transparent_proxy["sslports"]) client_ssl = server_ssl = (self.server_conn.address.port in self.config.transparent_proxy["sslports"])
elif self.config.reverse_proxy: elif self.config.reverse_proxy:
client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https") client_ssl = server_ssl = (self.config.reverse_proxy[0] == "https")
# TODO: Make protocol generic (as with transparent proxies) # TODO: Make protocol generic (as with transparent proxies)
@ -205,18 +198,18 @@ class ConnectionHandler:
#TODO: Add ruleset to select correct protocol depending on mode/target port etc. #TODO: Add ruleset to select correct protocol depending on mode/target port etc.
self.conntype = "http" self.conntype = "http"
def establish_server_connection(self, host, port): def establish_server_connection(self, address):
""" """
Establishes a new server connection to the given server Establishes a new server connection to the given server
If there is already an existing server connection, it will be killed. If there is already an existing server connection, it will be killed.
""" """
self.del_server_connection() self.del_server_connection()
self.server_conn = ServerConnection(host, port) self.server_conn = ServerConnection(address)
try: try:
self.server_conn.connect() self.server_conn.connect()
except tcp.NetLibError, v: except tcp.NetLibError, v:
raise ProxyError(502, v) raise ProxyError(502, v)
self.log("serverconnect", ["%s:%s" % (host, port)]) self.log("serverconnect", ["%s:%s" % address])
self.channel.tell("serverconnect", self) self.channel.tell("serverconnect", self)
def establish_ssl(self, client=False, server=False): def establish_ssl(self, client=False, server=False):
@ -227,7 +220,7 @@ class ConnectionHandler:
A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening A protocol handler must raise a ConnTypeChanged exception if it detects that this is happening
""" """
# TODO: Implement SSL pass-through handling and change conntype # TODO: Implement SSL pass-through handling and change conntype
if self.server_conn.host == "ycombinator.com": if self.server_conn.address.host == "ycombinator.com":
self.conntype = "tcp" self.conntype = "tcp"
if server: if server:
@ -244,14 +237,14 @@ class ConnectionHandler:
def server_reconnect(self, no_ssl=False): def server_reconnect(self, no_ssl=False):
self.log("server reconnect") self.log("server reconnect")
had_ssl, sni = self.server_conn.ssl_established, self.sni had_ssl, sni = self.server_conn.ssl_established, self.sni
self.establish_server_connection(*self.server_conn.address) self.establish_server_connection(self.server_conn.address)
if had_ssl and not no_ssl: if had_ssl and not no_ssl:
self.sni = sni self.sni = sni
self.establish_ssl(server=True) self.establish_ssl(server=True)
def log(self, msg, subs=()): def log(self, msg, subs=()):
msg = [ msg = [
"%s:%s: %s" % (self.client_conn.host, self.client_conn.port, msg) "%s:%s: %s" % (self.client_conn.address.host, self.client_conn.address.port, msg)
] ]
for i in subs: for i in subs:
msg.append(" -> " + i) msg.append(" -> " + i)
@ -263,7 +256,7 @@ class ConnectionHandler:
with open(self.config.certfile, "rb") as f: with open(self.config.certfile, "rb") as f:
return certutils.SSLCert.from_pem(f.read()) return certutils.SSLCert.from_pem(f.read())
else: else:
host = self.server_conn.host host = self.server_conn.address.host
sans = [] sans = []
if not self.config.no_upstream_cert or not self.server_conn.ssl_established: if not self.config.no_upstream_cert or not self.server_conn.ssl_established:
upstream_cert = self.server_conn.cert upstream_cert = self.server_conn.cert
@ -307,14 +300,14 @@ class ProxyServer(tcp.TCPServer):
allow_reuse_address = True allow_reuse_address = True
bound = True bound = True
def __init__(self, config, port, address='', server_version=version.NAMEVERSION): def __init__(self, config, port, host='', server_version=version.NAMEVERSION):
""" """
Raises ProxyServerError if there's a startup problem. Raises ProxyServerError if there's a startup problem.
""" """
self.config, self.port, self.address = config, port, address self.config = config
self.server_version = server_version self.server_version = server_version
try: try:
tcp.TCPServer.__init__(self, (address, port)) tcp.TCPServer.__init__(self, (host, port))
except socket.error, v: except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror) raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.channel = None self.channel = None

View File

@ -46,7 +46,7 @@ class CommonMixin:
assert l.response.code == 304 assert l.response.code == 304
def test_invalid_http(self): def test_invalid_http(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port) t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
t.connect() t.connect()
t.wfile.write("invalid\r\n\r\n") t.wfile.write("invalid\r\n\r\n")
t.wfile.flush() t.wfile.flush()
@ -70,7 +70,7 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin):
assert "ValueError" in ret.content assert "ValueError" in ret.content
def test_invalid_connect(self): def test_invalid_connect(self):
t = tcp.TCPClient("127.0.0.1", self.proxy.port) t = tcp.TCPClient(("127.0.0.1", self.proxy.address.port))
t.connect() t.connect()
t.wfile.write("CONNECT invalid\n\n") t.wfile.write("CONNECT invalid\n\n")
t.wfile.flush() t.wfile.flush()