import contextlib import datetime import pytz from typing import Optional, List from pagermaid.config import Config from pagermaid.dependence import scheduler, sqlite from pagermaid.enums import Message from pagermaid.listener import listener from pagermaid.services import bot from pagermaid.utils import alias_command class SendTask: task_id: Optional[int] cid: int msg: str interval: bool cron: bool pause: bool time_limit: int hour: str = "0" minute: str = "0" second: str = "0" def __init__( self, task_id: Optional[int] = None, cid: int = 0, msg: str = "", interval: bool = False, cron: bool = False, pause: bool = False, time_limit: int = -1, hour: str = "0", minute: str = "0", second: str = "0", ): self.task_id = task_id self.cid = cid self.msg = msg self.interval = interval self.cron = cron self.pause = pause self.time_limit = time_limit self.hour = hour self.minute = minute self.second = second def reduce_time(self): if self.time_limit > 0: self.time_limit -= 1 self.save_to_file() def export(self): return { "task_id": self.task_id, "cid": self.cid, "msg": self.msg, "interval": self.interval, "cron": self.cron, "pause": self.pause, "time_limit": self.time_limit, "hour": self.hour, "minute": self.minute, "second": self.second, } def get_job(self): return scheduler.get_job(f"sendat|{self.cid}|{self.task_id}") def remove_job(self): if self.get_job(): scheduler.remove_job(f"sendat|{self.cid}|{self.task_id}") def export_str(self, show_all: bool = False): text = ( f"{self.task_id} - " f"{'循环任务' if self.interval else '单次任务'} - " ) if job := self.get_job(): time: datetime.datetime = job.next_run_time text += f"{time.strftime('%Y-%m-%d %H:%M:%S')} - " else: text += "未运行 - " if show_all: text += f"{self.cid} - " text += f"{self.msg}" return text @staticmethod def check_time(time: str, min_value: int = None, max_value: int = None) -> str: if max_value and int(time) > max_value: raise ValueError(f"Time value {time} is too large") if min_value and int(time) < min_value: raise ValueError(f"Time value {time} is too small") if int(time) < 0: raise ValueError(f"Time value {time} is too small") return time def save_to_file(self): data = sqlite.get("sendat_tasks", []) for i in data: if i["task_id"] == self.task_id: data.remove(i) break data.append(self.export()) sqlite["sendat_tasks"] = data @staticmethod def parse_date(date: str): datetime.datetime.strptime(date, "%H:%M:%S") def parse_task(self, text: str): self.msg = "|".join(text.split("|")[1:]).strip() if not self.msg: raise ValueError("No message provided") text = text.split("|")[0].strip() if "every" in text: self.interval = True text = text.replace("every", "").strip() data = text.split(" ") if len(data) % 2: raise ValueError("Invalid task format") format_right = False no_date = True for i in range(1, len(data)): if data[i] == "seconds": format_right = True self.second = self.check_time(data[i - 1], 0, 60) elif data[i] == "minutes": format_right = True self.minute = self.check_time(data[i - 1], 0, 60) elif data[i] == "hours": format_right = True self.hour = self.check_time(data[i - 1], 0, 168) elif data[i] == "times": self.interval = True self.time_limit = int(self.check_time(data[i - 1], min_value=1)) elif data[i] == "date": format_right = True no_date = False self.cron = True date = datetime.datetime.strptime(data[i - 1], "%H:%M:%S") self.hour = str(date.hour) self.minute = str(date.minute) self.second = str(date.second) if not format_right: raise ValueError("Invalid task format") if no_date: self.interval = True self.time_limit = self.time_limit if self.time_limit > 0 else -1 class SendTasks: tasks: List[SendTask] def __init__(self): self.tasks = [] def add(self, task: SendTask): for i in self.tasks: if i.task_id == task.task_id: return self.tasks.append(task) def remove(self, task_id: int): for task in self.tasks: if task.task_id == task_id: task.remove_job() self.tasks.remove(task) return True return False def get(self, task_id: int) -> Optional[SendTask]: return next((task for task in self.tasks if task.task_id == task_id), None) def get_all(self) -> List[SendTask]: return self.tasks def get_all_ids(self) -> List[int]: return [task.task_id for task in self.tasks] def print_all_tasks(self, show_all: bool = False, cid: int = 0) -> str: return "\n".join( task.export_str(show_all) for task in self.tasks if task.cid == cid or show_all ) def save_to_file(self): data = [task.export() for task in self.tasks] sqlite["sendat_tasks"] = data def load_from_file(self): data = sqlite.get("sendat_tasks", []) for i in data: self.add(SendTask(**i)) def pause_task(self, task_id): if task := self.get(task_id): task.pause = True task.remove_job() self.save_to_file() return True return False @staticmethod async def send_message(task: SendTask, tasks): with contextlib.suppress(Exception): await bot.send_message(task.cid, task.msg) task.reduce_time() if task.time_limit == 0: task.remove_job() tasks.remove(task.task_id) if not task.interval: task.remove_job() def register_interval_task(self, task: SendTask): scheduler.add_job( self.send_message, "interval", id=f"sendat|{task.cid}|{task.task_id}", name=f"sendat|{task.cid}|{task.task_id}", hours=int(task.hour), minutes=int(task.minute), seconds=int(task.second), args=[task, self], ) def register_cron_task(self, task: SendTask): scheduler.add_job( self.send_message, "cron", id=f"sendat|{task.cid}|{task.task_id}", name=f"sendat|{task.cid}|{task.task_id}", hour=int(task.hour), minute=int(task.minute), second=int(task.second), args=[task, self], ) def register_date_task(self, task: SendTask): date_now = datetime.datetime.now(pytz.timezone(Config.TIME_ZONE)) date_will = date_now.replace( hour=int(task.hour), minute=int(task.minute), second=int(task.second) ) if date_will < date_now: date_will += datetime.timedelta(days=1) scheduler.add_job( self.send_message, "date", id=f"sendat|{task.cid}|{task.task_id}", name=f"sendat|{task.cid}|{task.task_id}", run_date=date_will, args=[task, self], ) def register_single_task(self, task: SendTask): if task.pause or task.time_limit == 0: return if task.interval: if task.cron: self.register_cron_task(task) else: self.register_interval_task(task) else: self.register_date_task(task) def resume_task(self, task_id: int): if task := self.get(task_id): task.pause = False self.register_single_task(task) self.save_to_file() return True return False def register_all_tasks(self): for task in self.tasks: self.register_single_task(task) def get_next_task_id(self): return max(task.task_id for task in self.tasks) + 1 if self.tasks else 1 send_tasks = SendTasks() send_tasks.load_from_file() send_tasks.register_all_tasks() send_help_msg = f""" 定时发送消息。 ,{alias_command("sendat")} 时间 | 消息内容 i.e. ,{alias_command("sendat")} 16:00:00 date | 投票截止! ,{alias_command("sendat")} every 23:59:59 date | 又是无所事事的一天呢。 ,{alias_command("sendat")} every 1 minutes | 又过去了一分钟。 ,{alias_command("sendat")} 3 times 1 minutes | 此消息将出现三次,间隔均为一分钟。 ,{alias_command("sendat")} rm 2 - 删除某个任务 ,{alias_command("sendat")} pause 1 - 暂停某个任务 ,{alias_command("sendat")} resume 1 - 恢复某个任务 ,{alias_command("sendat")} list - 获取任务列表 """ async def from_msg_get_task_id(message: Message): uid = -1 try: uid = int(message.parameter[1]) except ValueError: await message.edit("请输入正确的参数") message.continue_propagation() ids = send_tasks.get_all_ids() if uid not in ids: await message.edit("该任务不存在") message.continue_propagation() return uid @listener( command="sendat", parameters="时间 | 消息内容", need_admin=True, description=f"定时发送消息\n请使用 ,{alias_command('sendat')} h 查看可用命令", ) async def send_at(message: Message): if message.arguments == "h" or len(message.parameter) == 0: return await message.edit(send_help_msg) if len(message.parameter) == 1: if message.parameter[0] != "list": return await message.edit("请输入正确的参数") if send_tasks.get_all_ids(): return await message.edit( f"已注册的任务:\n\n{send_tasks.print_all_tasks(show_all=False, cid=message.chat.id)}" ) else: return await message.edit("没有已注册的任务。") if len(message.parameter) == 2: if message.parameter[0] == "rm": if uid := await from_msg_get_task_id(message): send_tasks.remove(uid) send_tasks.save_to_file() send_tasks.load_from_file() return await message.edit(f"已删除任务 {uid}") elif message.parameter[0] == "pause": if uid := await from_msg_get_task_id(message): send_tasks.pause_task(uid) return await message.edit(f"已暂停任务 {uid}") elif message.parameter[0] == "resume": if uid := await from_msg_get_task_id(message): send_tasks.resume_task(uid) return await message.edit(f"已恢复任务 {uid}") elif message.parameter[0] == "list": if send_tasks.get_all_ids(): return await message.edit( f"已注册的任务:\n\n{send_tasks.print_all_tasks(show_all=True)}" ) else: return await message.edit("没有已注册的任务。") # add task task = SendTask(send_tasks.get_next_task_id()) task.cid = message.chat.id try: task.parse_task(message.arguments) except Exception as e: return await message.edit(f"参数错误:{e}") send_tasks.add(task) send_tasks.register_single_task(task) send_tasks.save_to_file() send_tasks.load_from_file() await message.edit(f"已添加任务 {task.task_id}")