Merge pull request #2619 from mhils/issue-2617

Fix #2617
This commit is contained in:
Maximilian Hils 2017-11-06 11:23:16 +01:00 committed by GitHub
commit 4cb96dedd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 18 deletions

View File

@ -96,15 +96,7 @@ class ServerConnectionMixin:
def __init__(self, server_address=None): def __init__(self, server_address=None):
super().__init__() super().__init__()
self.server_conn = None self.server_conn = self.__make_server_conn(server_address)
if self.config.options.spoof_source_address and self.config.options.upstream_bind_address == '':
self.server_conn = connections.ServerConnection(
server_address, (self.ctx.client_conn.address[0], 0), True)
else:
self.server_conn = connections.ServerConnection(
server_address, (self.config.options.upstream_bind_address, 0),
self.config.options.spoof_source_address
)
self.__check_self_connect() self.__check_self_connect()
@ -125,6 +117,16 @@ class ServerConnectionMixin:
"The proxy shall not connect to itself.".format(repr(address)) "The proxy shall not connect to itself.".format(repr(address))
) )
def __make_server_conn(self, server_address):
if self.config.options.spoof_source_address and self.config.options.upstream_bind_address == '':
return connections.ServerConnection(
server_address, (self.ctx.client_conn.address[0], 0), True)
else:
return connections.ServerConnection(
server_address, (self.config.options.upstream_bind_address, 0),
self.config.options.spoof_source_address
)
def set_server(self, address): def set_server(self, address):
""" """
Sets a new server address. If there is an existing connection, it will be closed. Sets a new server address. If there is an existing connection, it will be closed.
@ -146,11 +148,7 @@ class ServerConnectionMixin:
self.server_conn.close() self.server_conn.close()
self.channel.tell("serverdisconnect", self.server_conn) self.channel.tell("serverdisconnect", self.server_conn)
self.server_conn = connections.ServerConnection( self.server_conn = self.__make_server_conn(address)
address,
(self.server_conn.source_address[0], 0),
self.config.options.spoof_source_address
)
def connect(self): def connect(self):
""" """

View File

@ -165,7 +165,7 @@ class HttpLayer(base.Layer):
def __init__(self, ctx, mode): def __init__(self, ctx, mode):
super().__init__(ctx) super().__init__(ctx)
self.mode = mode self.mode = mode
self.__initial_server_conn = None self.__initial_server_address = None # type: tuple
"Contains the original destination in transparent mode, which needs to be restored" "Contains the original destination in transparent mode, which needs to be restored"
"if an inline script modified the target server for a single http request" "if an inline script modified the target server for a single http request"
# We cannot rely on server_conn.tls_established, # We cannot rely on server_conn.tls_established,
@ -177,7 +177,7 @@ class HttpLayer(base.Layer):
def __call__(self): def __call__(self):
if self.mode == HTTPMode.transparent: if self.mode == HTTPMode.transparent:
self.__initial_server_tls = self.server_tls self.__initial_server_tls = self.server_tls
self.__initial_server_conn = self.server_conn self.__initial_server_address = self.server_conn.address
while True: while True:
flow = http.HTTPFlow( flow = http.HTTPFlow(
self.client_conn, self.client_conn,
@ -313,8 +313,8 @@ class HttpLayer(base.Layer):
# Setting request.host also updates the host header, which we want # Setting request.host also updates the host header, which we want
# to preserve # to preserve
host_header = f.request.host_header host_header = f.request.host_header
f.request.host = self.__initial_server_conn.address[0] f.request.host = self.__initial_server_address[0]
f.request.port = self.__initial_server_conn.address[1] f.request.port = self.__initial_server_address[1]
f.request.host_header = host_header # set again as .host overwrites this. f.request.host_header = host_header # set again as .host overwrites this.
f.request.scheme = "https" if self.__initial_server_tls else "http" f.request.scheme = "https" if self.__initial_server_tls else "http"
self.channel.ask("request", f) self.channel.ask("request", f)