mirror of
https://github.com/LmeSzinc/StarRailCopilot.git
synced 2024-11-16 14:31:16 +00:00
231 lines
7.2 KiB
Python
231 lines
7.2 KiB
Python
"""
|
|
Copy from pywebio.platform.remote_access
|
|
|
|
* Implementation of remote access
|
|
Use https://github.com/wang0618/localshare service by running a ssh subprocess in PyWebIO application.
|
|
|
|
The stdout of ssh process is the connection info.
|
|
|
|
|
|
"""
|
|
|
|
import json
|
|
import shlex
|
|
import threading
|
|
import time
|
|
from subprocess import PIPE, Popen
|
|
from typing import TYPE_CHECKING
|
|
|
|
from module.logger import logger
|
|
from module.config.utils import random_id
|
|
from module.webui.setting import State
|
|
|
|
if TYPE_CHECKING:
|
|
from module.webui.utils import TaskHandler
|
|
|
|
_ssh_process: Popen = None
|
|
_ssh_thread: threading.Thread = None
|
|
_ssh_notfound: bool = False
|
|
address: str = None
|
|
|
|
|
|
def am_i_the_only_thread() -> bool:
|
|
"""Whether the current thread is the only non-Daemon threads in the process"""
|
|
alive_none_daemonic_thread_cnt = sum(
|
|
1
|
|
for t in threading.enumerate()
|
|
if t.is_alive() and not t.isDaemon() or t is threading.current_thread()
|
|
)
|
|
return alive_none_daemonic_thread_cnt == 1
|
|
|
|
|
|
def remote_access_service(
|
|
local_host="127.0.0.1",
|
|
local_port=22367,
|
|
server="app.pywebio.online",
|
|
server_port=1022,
|
|
remote_port="/",
|
|
setup_timeout=60,
|
|
):
|
|
"""
|
|
Wait at most one minute to get the ssh output, if it gets a normal out, the connection is successfully established.
|
|
Otherwise report error and kill ssh process.
|
|
|
|
:param local_port: ssh local listen port
|
|
:param server: ssh server domain
|
|
:param server_port: ssh server port
|
|
:param setup_timeout: If the service can't setup successfully in `setup_timeout` seconds, then exit.
|
|
"""
|
|
global _ssh_process, _ssh_notfound
|
|
|
|
bin = State.deploy_config.SSHExecutable
|
|
cmd = f"{bin} -oStrictHostKeyChecking=no -R {remote_port}:{local_host}:{local_port} -p {server_port} {server} -- --output json"
|
|
args = shlex.split(cmd)
|
|
logger.debug(f"remote access service command: {cmd}")
|
|
|
|
if _ssh_process is not None and _ssh_process.poll() is None:
|
|
logger.warning(f"Kill previous ssh process [{_ssh_process.pid}]")
|
|
_ssh_process.kill()
|
|
try:
|
|
_ssh_process = Popen(args, stdout=PIPE, stderr=PIPE)
|
|
except FileNotFoundError as e:
|
|
logger.critical(
|
|
f"Cannot find SSH executable {bin}, please install OpenSSH or specify SSHExecutable in deploy.yaml"
|
|
)
|
|
_ssh_notfound = True
|
|
return
|
|
logger.info(f"remote access process pid: {_ssh_process.pid}")
|
|
success = False
|
|
|
|
def timeout_killer(wait_sec):
|
|
time.sleep(wait_sec)
|
|
if not success and _ssh_process.poll() is None:
|
|
logger.info("Connection timeout, kill ssh process")
|
|
_ssh_process.kill()
|
|
|
|
threading.Thread(
|
|
target=timeout_killer, kwargs=dict(wait_sec=setup_timeout), daemon=True
|
|
).start()
|
|
|
|
stdout = _ssh_process.stdout.readline().decode("utf8")
|
|
logger.debug(f"ssh server stdout: {stdout}")
|
|
connection_info = {}
|
|
try:
|
|
connection_info = json.loads(stdout)
|
|
success = True
|
|
except json.decoder.JSONDecodeError:
|
|
if not success and _ssh_process.poll() is None:
|
|
_ssh_process.kill()
|
|
|
|
if success:
|
|
if connection_info.get("status", "fail") != "success":
|
|
logger.info(
|
|
f"Failed to establish remote access, this is the error message from service provider: {connection_info.get('message', '')}"
|
|
)
|
|
new_username = connection_info.get("change_username", None)
|
|
if new_username:
|
|
logger.info(f"Server requested to change username, change it to: {new_username}")
|
|
State.deploy_config.SSHUser = new_username
|
|
else:
|
|
global address
|
|
address = connection_info["address"]
|
|
logger.debug(f"Remote access url: {address}")
|
|
|
|
# wait ssh or main thread exit
|
|
while not am_i_the_only_thread() and _ssh_process.poll() is None:
|
|
# while _ssh_process.poll() is None:
|
|
time.sleep(1)
|
|
|
|
if _ssh_process.poll() is None: # main thread exit, kill ssh process
|
|
logger.info("App process exit, killing ssh process")
|
|
_ssh_process.kill()
|
|
else: # ssh process exit by itself or by timeout killer
|
|
stderr = _ssh_process.stderr.read().decode("utf8")
|
|
if stderr:
|
|
logger.error(f"PyWebIO application remote access service error: {stderr}")
|
|
else:
|
|
logger.info("PyWebIO application remote access service exit.")
|
|
address = None
|
|
|
|
|
|
def start_remote_access_service_(**kwargs):
|
|
logger.info("Start remote access service")
|
|
try:
|
|
remote_access_service(**kwargs)
|
|
except KeyboardInterrupt: # ignore KeyboardInterrupt
|
|
pass
|
|
except Exception as e:
|
|
logger.exception(e)
|
|
finally:
|
|
if _ssh_process:
|
|
logger.info("Exception occurred, killing ssh process")
|
|
_ssh_process.kill()
|
|
logger.info("Exit remote access service thread")
|
|
|
|
|
|
class ParseError(Exception):
|
|
pass
|
|
|
|
|
|
def start_remote_access_service(**kwagrs):
|
|
global _ssh_thread
|
|
|
|
try:
|
|
server, server_port = State.deploy_config.SSHServer.split(":")
|
|
except (ValueError, AttributeError):
|
|
raise ParseError(
|
|
f"Failed to parse SSH server [{State.deploy_config.SSHServer}]"
|
|
)
|
|
if State.deploy_config.WebuiHost == "0.0.0.0":
|
|
local_host = "127.0.0.1"
|
|
elif State.deploy_config.WebuiHost == "::":
|
|
local_host = "[::1]"
|
|
else:
|
|
local_host = State.deploy_config.WebuiHost
|
|
|
|
if State.deploy_config.SSHUser is None:
|
|
logger.info("SSHUser is not set, generate a random one")
|
|
State.deploy_config.SSHUser = random_id(24)
|
|
|
|
server = f"{State.deploy_config.SSHUser}@{server}"
|
|
kwagrs.setdefault("server", server)
|
|
kwagrs.setdefault("server_port", server_port)
|
|
kwagrs.setdefault("local_host", local_host)
|
|
kwagrs.setdefault("local_port", State.deploy_config.WebuiPort)
|
|
|
|
_ssh_thread = threading.Thread(
|
|
target=start_remote_access_service_,
|
|
kwargs=kwagrs,
|
|
daemon=False,
|
|
)
|
|
_ssh_thread.start()
|
|
return _ssh_thread
|
|
|
|
|
|
class RemoteAccess:
|
|
@staticmethod
|
|
def keep_ssh_alive():
|
|
task_handler: TaskHandler
|
|
task_handler = yield
|
|
while True:
|
|
if _ssh_thread is not None and _ssh_thread.is_alive():
|
|
yield
|
|
continue
|
|
logger.info("Remote access service is not running, starting now")
|
|
try:
|
|
start_remote_access_service()
|
|
except ParseError as e:
|
|
logger.exception(e)
|
|
task_handler.remove_current_task()
|
|
yield
|
|
|
|
@staticmethod
|
|
def kill_ssh_process():
|
|
if RemoteAccess.is_alive():
|
|
_ssh_process.kill()
|
|
|
|
@staticmethod
|
|
def is_alive():
|
|
return (
|
|
_ssh_thread is not None
|
|
and _ssh_thread.is_alive()
|
|
and _ssh_process is not None
|
|
and _ssh_process.poll() is None
|
|
)
|
|
|
|
@staticmethod
|
|
def get_state():
|
|
if RemoteAccess.is_alive():
|
|
if address is not None:
|
|
return 1
|
|
else:
|
|
return 2
|
|
elif _ssh_notfound:
|
|
return 3
|
|
else:
|
|
return 0
|
|
|
|
@staticmethod
|
|
def get_entry_point():
|
|
return address if RemoteAccess.is_alive() else None
|