Don't allow server.via change for live connections (#4841)

* don't allow `server.via` change for live connections

* return early if no tls context was set
This commit is contained in:
Maximilian Hils 2021-10-05 21:19:51 +02:00 committed by GitHub
commit aa2f935dbb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 7 deletions

View File

@ -291,12 +291,12 @@ class Server(Connection):
return f"Server({human.format_address(self.address)}, state={self.state.name.lower()}{tls_state}{local_port})" return f"Server({human.format_address(self.address)}, state={self.state.name.lower()}{tls_state}{local_port})"
def __setattr__(self, name, value): def __setattr__(self, name, value):
if name == "address": if name in ("address", "via"):
connection_open = self.__dict__.get("state", ConnectionState.CLOSED) is ConnectionState.OPEN connection_open = self.__dict__.get("state", ConnectionState.CLOSED) is ConnectionState.OPEN
# assigning the current value is okay, that may be an artifact of calling .set_state(). # assigning the current value is okay, that may be an artifact of calling .set_state().
address_changed = self.__dict__.get("address") != value attr_changed = self.__dict__.get(name) != value
if connection_open and address_changed: if connection_open and attr_changed:
raise RuntimeError("Cannot change server address on open connection.") raise RuntimeError(f"Cannot change server.{name} on open connection.")
return super().__setattr__(name, value) return super().__setattr__(name, value)
def get_state(self): def get_state(self):

View File

@ -177,6 +177,7 @@ class _TLSLayer(tunnel.TunnelLayer):
if not tls_start.ssl_conn: if not tls_start.ssl_conn:
yield commands.Log("No TLS context was provided, failing connection.", "error") yield commands.Log("No TLS context was provided, failing connection.", "error")
yield commands.CloseConnection(self.conn) yield commands.CloseConnection(self.conn)
return
assert tls_start.ssl_conn assert tls_start.ssl_conn
self.tls = tls_start.ssl_conn self.tls = tls_start.ssl_conn

View File

@ -706,10 +706,15 @@ def test_upstream_proxy(tctx, redirect, scheme):
assert playbook assert playbook
if redirect == "change-proxy": if redirect == "change-destination":
assert server2().address == ("other-proxy", 1234) assert flow().server_conn.address[0] == "other-server"
else: else:
assert server2().address == ("proxy", 8080) assert flow().server_conn.address[0] == "example.com"
if redirect == "change-proxy":
assert server2().address == flow().server_conn.via.address == ("other-proxy", 1234)
else:
assert server2().address == flow().server_conn.via.address == ("proxy", 8080)
assert ( assert (
playbook playbook