PagerMaid-Pyro/pagermaid/common/plugin.py

319 lines
11 KiB
Python
Raw Normal View History

2023-01-31 16:24:56 +00:00
import contextlib
import json
import os
from pathlib import Path
from typing import Optional, List, Tuple, Dict
2023-01-31 16:24:56 +00:00
from pydantic import BaseModel, ValidationError
2023-01-31 16:24:56 +00:00
from pagermaid import Config, logs
from pagermaid.enums import Message
2023-01-31 16:24:56 +00:00
from pagermaid.common.cache import cache
from pagermaid.modules import plugin_list as active_plugins
2023-01-31 16:24:56 +00:00
from pagermaid.utils import client
from pagermaid.services import sqlite
2023-01-31 16:24:56 +00:00
2023-03-12 03:56:01 +00:00
plugins_path = Path("plugins")
2023-01-31 16:24:56 +00:00
class LocalPlugin(BaseModel):
name: str
status: bool
installed: bool = False
version: Optional[float]
@property
def normal_path(self) -> Path:
return plugins_path / f"{self.name}.py"
@property
def disabled_path(self) -> Path:
return plugins_path / f"{self.name}.py.disabled"
@property
def load_status(self) -> bool:
"""插件加载状态"""
return self.name in active_plugins
2023-01-31 16:24:56 +00:00
def remove(self):
with contextlib.suppress(FileNotFoundError):
os.remove(self.normal_path)
with contextlib.suppress(FileNotFoundError):
os.remove(self.disabled_path)
def enable(self) -> bool:
try:
os.rename(self.disabled_path, self.normal_path)
return True
except Exception:
return False
def disable(self) -> bool:
try:
os.rename(self.normal_path, self.disabled_path)
return True
except Exception:
return False
class RemotePlugin(LocalPlugin):
section: str
maintainer: str
size: str
supported: bool
des: str = ""
des_short: str = ""
2023-01-31 16:24:56 +00:00
...
async def install(self) -> bool:
2023-03-12 03:56:01 +00:00
html = await client.get(f"{Config.GIT_SOURCE}{self.name}/main.py")
2023-01-31 16:24:56 +00:00
if html.status_code == 200:
self.remove()
with open(plugins_path / f"{self.name}.py", mode="wb") as f:
2023-03-12 03:56:01 +00:00
f.write(html.text.encode("utf-8"))
2023-01-31 16:24:56 +00:00
return True
return False
class PluginRemote(BaseModel):
url: str
status: bool
@property
def text(self) -> str:
return f"{'' if self.status else ''} {self.url}"
class PluginRemoteManager:
2023-01-31 16:24:56 +00:00
def __init__(self):
self.key = "plugins_remotes"
def get_remotes(self) -> List[PluginRemote]:
return [PluginRemote(**i) for i in sqlite.get(self.key, [])]
def set_remotes(self, remotes: List[PluginRemote]):
sqlite[self.key] = [i.dict() for i in remotes]
def add_remote(self, remote_url: str) -> bool:
remotes = self.get_remotes()
if not next(filter(lambda x: x.url == remote_url, remotes), None):
remotes.append(PluginRemote(url=remote_url, status=True))
self.set_remotes(remotes)
return True
return False
def remove_remote(self, remote_url: str) -> bool:
remotes = self.get_remotes()
if next(filter(lambda x: x.url == remote_url, remotes), None):
remotes = [i for i in remotes if i.url != remote_url]
self.set_remotes(remotes)
return True
return False
def disable_remote(self, remote_url: str) -> bool:
remotes = self.get_remotes()
if remote := next(filter(lambda x: x.url == remote_url, remotes), None):
remote.status = False
self.set_remotes(remotes)
return True
return False
def enable_remote(self, remote_url: str) -> bool:
remotes = self.get_remotes()
if remote := next(filter(lambda x: x.url == remote_url, remotes), None):
remote.status = True
self.set_remotes(remotes)
return True
return False
class PluginManager:
def __init__(self, remote_manager: PluginRemoteManager):
self.remote_manager = remote_manager
2023-01-31 16:24:56 +00:00
self.version_map = {}
self.remote_version_map = {}
self.plugins: List[LocalPlugin] = []
self.remote_plugins: List[RemotePlugin] = []
def load_local_version_map(self):
if not os.path.exists(plugins_path / "version.json"):
return
2023-03-12 03:56:01 +00:00
with open(plugins_path / "version.json", "r", encoding="utf-8") as f:
2023-01-31 16:24:56 +00:00
self.version_map = json.load(f)
def save_local_version_map(self):
2023-03-12 03:56:01 +00:00
with open(plugins_path / "version.json", "w", encoding="utf-8") as f:
2023-01-31 16:24:56 +00:00
json.dump(self.version_map, f, indent=4)
def get_local_version(self, name: str) -> Optional[float]:
2023-02-01 06:39:41 +00:00
data = self.version_map.get(name)
return float(data) if data else None
2023-01-31 16:24:56 +00:00
def set_local_version(self, name: str, version: float) -> None:
self.version_map[name] = version
self.save_local_version_map()
def get_plugin_install_status(self, name: str) -> bool:
return name in self.version_map
@staticmethod
def get_plugin_load_status(name: str) -> bool:
return bool(os.path.exists(plugins_path / f"{name}.py"))
def remove_plugin(self, name: str) -> bool:
if plugin := self.get_local_plugin(name):
plugin.remove()
if name in self.version_map:
self.version_map.pop(name)
self.save_local_version_map()
return True
return False
def enable_plugin(self, name: str) -> bool:
return plugin.enable() if (plugin := self.get_local_plugin(name)) else False
def disable_plugin(self, name: str) -> bool:
return plugin.disable() if (plugin := self.get_local_plugin(name)) else False
def load_local_plugins(self) -> List[LocalPlugin]:
self.load_local_version_map()
self.plugins = []
2023-03-12 03:56:01 +00:00
for plugin in os.listdir("plugins"):
if plugin.endswith(".py") or plugin.endswith(".py.disabled"):
plugin = (
plugin[:-12] if plugin.endswith(".py.disabled") else plugin[:-3]
)
2023-01-31 16:24:56 +00:00
self.plugins.append(
LocalPlugin(
name=plugin,
installed=self.get_plugin_install_status(plugin),
status=self.get_plugin_load_status(plugin),
2023-03-12 03:56:01 +00:00
version=self.get_local_version(plugin),
2023-01-31 16:24:56 +00:00
)
)
return self.plugins
def get_local_plugin(self, name: str) -> LocalPlugin:
return next(filter(lambda x: x.name == name, self.plugins), None)
@staticmethod
async def fetch_remote_url(url: str) -> List[Dict]:
try:
data = await client.get(f"{url}list.json")
data.raise_for_status()
return data.json()["list"]
except Exception as e:
logs.error(f"获取远程插件列表失败: {e}")
raise e
async def load_remote_plugins_no_cache(self) -> List[RemotePlugin]:
remote_urls = [i.url for i in self.remote_manager.get_remotes()]
remote_urls.insert(0, Config.GIT_SOURCE)
plugins = []
plugins_name = []
for remote in remote_urls:
try:
plugin_list = await self.fetch_remote_url(remote)
except Exception as e:
logs.error(f"获取远程插件列表失败: {e}")
self.remote_manager.disable_remote(remote)
continue
self.remote_manager.enable_remote(remote)
for plugin in plugin_list:
try:
plugin_model = RemotePlugin(**plugin, status=False)
if plugin_model.name in plugins_name:
continue
plugins.append(plugin_model)
plugins_name.append(plugin_model.name)
except ValidationError:
logs.warning(f"远程插件 {plugin} 信息不完整")
continue
2023-01-31 16:24:56 +00:00
self.remote_plugins = plugins
self.remote_version_map = {plugin.name: plugin.version for plugin in plugins}
2023-01-31 16:24:56 +00:00
return plugins
@cache()
async def load_remote_plugins_cache(self) -> List[RemotePlugin]:
return await self.load_remote_plugins_no_cache()
async def load_remote_plugins(
self, enable_cache: bool = True
) -> List[RemotePlugin]:
plugin_list = (
await self.load_remote_plugins_cache()
if enable_cache
else await self.load_remote_plugins_no_cache()
)
2023-02-01 13:05:39 +00:00
for i in plugin_list:
i.status = self.get_plugin_load_status(i.name)
return plugin_list
2023-01-31 16:24:56 +00:00
def get_remote_plugin(self, name: str) -> RemotePlugin:
return next(filter(lambda x: x.name == name, self.remote_plugins), None)
def plugin_need_update(self, name: str) -> bool:
if local_version := self.get_local_version(name):
if local_version == 0.0:
return False
if remote_version := self.remote_version_map.get(name):
return local_version < remote_version
return False
async def install_remote_plugin(self, name: str) -> bool:
if plugin := self.get_remote_plugin(name):
if await plugin.install():
self.set_local_version(name, plugin.version)
return True
return False
async def update_remote_plugin(self, name: str) -> bool:
if self.plugin_need_update(name):
return await self.install_remote_plugin(name)
return False
async def update_all_remote_plugin(self) -> List[RemotePlugin]:
updated_plugins = []
for i in self.remote_plugins:
if await self.update_remote_plugin(i.name):
updated_plugins.append(i)
return updated_plugins
@staticmethod
async def download_from_message(message: Message) -> str:
"""Download a plugin from a message"""
reply = message.reply_to_message
file_path = None
if (
reply
and reply.document
and reply.document.file_name
and reply.document.file_name.endswith(".py")
):
file_path = await message.reply_to_message.download()
elif (
message.document
and message.document.file_name
and message.document.file_name.endswith(".py")
):
file_path = await message.download()
return file_path
def get_plugins_status(
self,
) -> Tuple[List[str], List[LocalPlugin], List[LocalPlugin]]:
"""Get plugins status"""
all_local_plugins = self.plugins
disabled_plugins = []
inactive_plugins = []
for plugin in all_local_plugins:
if not plugin.status:
inactive_plugins.append(plugin)
elif not plugin.load_status:
disabled_plugins.append(plugin)
return active_plugins, disabled_plugins, inactive_plugins
plugin_remote_manager = PluginRemoteManager()
plugin_manager = PluginManager(plugin_remote_manager)