🔖 Update to v1.1.2

This commit is contained in:
xtaodada 2022-06-26 20:59:36 +08:00
parent 8ebc6223a6
commit 5bb2d7a8ae
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
19 changed files with 252 additions and 53 deletions

View File

@ -556,3 +556,6 @@ sb_channel: Successfully blocked this channel in this group.
#reload #reload
reload_des: Reload the PagerMaid-Pyro instance. reload_des: Reload the PagerMaid-Pyro instance.
reload_ok: Successfully reloaded. reload_ok: Successfully reloaded.
#conversation
conversation_already_in_error: Another conversation is already in progress.
conversation_timed_out_error: Response timed out.

View File

@ -556,3 +556,6 @@ sb_channel: Successfully blocked this channel in this group.
#reload #reload
reload_des: PagerMaid-Pyroインスタンスをリロードします。 reload_des: PagerMaid-Pyroインスタンスをリロードします。
reload_ok: 再読み込みに成功しました。 reload_ok: 再読み込みに成功しました。
#conversation
conversation_already_in_error: すでに別の会話が進行中です。
conversation_timed_out_error: 応答がタイムアウトしました。

View File

@ -556,3 +556,6 @@ sb_channel: 成功在本群封禁此频道。
#reload #reload
reload_des: 重新加载内置模块和插件。 reload_des: 重新加载内置模块和插件。
reload_ok: 重新加载内置模块和插件成功。 reload_ok: 重新加载内置模块和插件成功。
#conversation
conversation_already_in_error: 另一个操作正在进行中,请稍等片刻后重试。
conversation_timed_out_error: 响应超时,请稍等片刻后重试。

View File

@ -556,3 +556,6 @@ sb_channel: 成功在本群封禁此頻道。
#reload #reload
reload_des: 重新加載內置模塊和插件。 reload_des: 重新加載內置模塊和插件。
reload_ok: 重新加載內置模塊和插件成功。 reload_ok: 重新加載內置模塊和插件成功。
#conversation
conversation_already_in_error: 另一個操作正在進行中,請稍等片刻後重試。
conversation_timed_out_error: 響應超時,請稍等片刻後重試。

View File

@ -4,13 +4,15 @@ from datetime import datetime, timezone
from logging import getLogger, StreamHandler, CRITICAL, INFO, basicConfig, DEBUG from logging import getLogger, StreamHandler, CRITICAL, INFO, basicConfig, DEBUG
from os import getcwd from os import getcwd
from pyrogram.errors import PeerIdInvalid
from pagermaid.config import Config from pagermaid.config import Config
from pagermaid.scheduler import scheduler from pagermaid.scheduler import scheduler
import pyromod.listen import pyromod.listen
from pyrogram import Client from pyrogram import Client
import sys import sys
pgm_version = "1.1.1" pgm_version = "1.1.2"
CMD_LIST = {} CMD_LIST = {}
module_dir = __path__[0] module_dir = __path__[0]
working_dir = getcwd() working_dir = getcwd()
@ -63,7 +65,10 @@ async def log(message):
) )
if not Config.LOG: if not Config.LOG:
return return
await bot.send_message( try:
Config.LOG_ID, await bot.send_message(
message Config.LOG_ID,
) message
)
except PeerIdInvalid:
Config.LOG = False

View File

@ -16,7 +16,7 @@ from pyrogram.handlers import MessageHandler, EditedMessageHandler
from pagermaid import help_messages, logs, Config, bot, read_context, all_permissions from pagermaid import help_messages, logs, Config, bot, read_context, all_permissions
from pagermaid.group_manager import Permission from pagermaid.group_manager import Permission
from pagermaid.single_utils import Message from pagermaid.single_utils import Message, AlreadyInConversationError, TimeoutConversationError
from pagermaid.utils import lang, attach_report, sudo_filter, alias_command, get_permission_name, process_exit from pagermaid.utils import lang, attach_report, sudo_filter, alias_command, get_permission_name, process_exit
from pagermaid.utils import client as httpx_client from pagermaid.utils import client as httpx_client
@ -133,6 +133,16 @@ def listener(**args):
logs.warning( logs.warning(
"Please Don't Delete Commands While it's Processing.." "Please Don't Delete Commands While it's Processing.."
) )
except AlreadyInConversationError:
logs.warning(
"Please Don't Send Commands In The Same Conversation.."
)
await message.edit(lang("conversation_already_in_error"))
except TimeoutConversationError:
logs.warning(
"Conversation Timed out while processing commands.."
)
await message.edit(lang("conversation_timed_out_error"))
except UserNotParticipant: except UserNotParticipant:
pass pass
except ContinuePropagation as e: except ContinuePropagation as e:
@ -195,6 +205,20 @@ def raw_listener(filter_s):
raise StopPropagation from e raise StopPropagation from e
except ContinuePropagation as e: except ContinuePropagation as e:
raise ContinuePropagation from e raise ContinuePropagation from e
except MessageIdInvalid:
logs.warning(
"Please Don't Delete Commands While it's Processing.."
)
except AlreadyInConversationError:
logs.warning(
"Please Don't Send Commands In The Same Conversation.."
)
await message.edit(lang("conversation_already_in_error"))
except TimeoutConversationError:
logs.warning(
"Conversation Timed out while processing commands.."
)
await message.edit(lang("conversation_timed_out_error"))
except SystemExit: except SystemExit:
await process_exit(start=False, _client=client, message=message) await process_exit(start=False, _client=client, message=message)
sys.exit(0) sys.exit(0)

View File

@ -77,8 +77,7 @@ async def plugin(message: Message):
move_plugin(file_path) move_plugin(file_path)
await message.edit(f"<b>{lang('apt_name')}</b>\n\n" await message.edit(f"<b>{lang('apt_name')}</b>\n\n"
f"{lang('apt_plugin')} " f"{lang('apt_plugin')} "
f"{path.basename(file_path)[:-3]} {lang('apt_installed')}," f"{path.basename(file_path)[:-3]} {lang('apt_installed')}")
f"{lang('apt_reboot')}")
await log(f"{lang('apt_install_success')} {path.basename(file_path)[:-3]}.") await log(f"{lang('apt_install_success')} {path.basename(file_path)[:-3]}.")
reload_all() reload_all()
elif len(message.parameter) >= 2: elif len(message.parameter) >= 2:
@ -128,8 +127,6 @@ async def plugin(message: Message):
text += lang('apt_no_update') + " %s\n" % ", ".join(no_need_list) text += lang('apt_no_update') + " %s\n" % ", ".join(no_need_list)
await log(text) await log(text)
restart = len(success_list) > 0 restart = len(success_list) > 0
if restart:
text += lang('apt_reboot')
await message.edit(text) await message.edit(text)
if restart: if restart:
reload_all() reload_all()
@ -146,8 +143,7 @@ async def plugin(message: Message):
version_json[message.parameter[1]] = "0.0" version_json[message.parameter[1]] = "0.0"
with open(f"{plugin_directory}version.json", 'w') as f: with open(f"{plugin_directory}version.json", 'w') as f:
json.dump(version_json, f) json.dump(version_json, f)
await message.edit(f"{lang('apt_remove_success')} {message.parameter[1]}, " await message.edit(f"{lang('apt_remove_success')} {message.parameter[1]}")
f"{lang('apt_reboot')} ")
await log(f"{lang('apt_remove')} {message.parameter[1]}.") await log(f"{lang('apt_remove')} {message.parameter[1]}.")
reload_all() reload_all()
elif "/" in message.parameter[1]: elif "/" in message.parameter[1]:
@ -198,7 +194,7 @@ async def plugin(message: Message):
rename(f"{plugin_directory}{message.parameter[1]}.py.disabled", rename(f"{plugin_directory}{message.parameter[1]}.py.disabled",
f"{plugin_directory}{message.parameter[1]}.py") f"{plugin_directory}{message.parameter[1]}.py")
await message.edit(f"{lang('apt_plugin')} {message.parameter[1]} " await message.edit(f"{lang('apt_plugin')} {message.parameter[1]} "
f"{lang('apt_enable')},{lang('apt_reboot')}") f"{lang('apt_enable')}")
await log(f"{lang('apt_enable')} {message.parameter[1]}.") await log(f"{lang('apt_enable')} {message.parameter[1]}.")
reload_all() reload_all()
else: else:
@ -211,7 +207,7 @@ async def plugin(message: Message):
rename(f"{plugin_directory}{message.parameter[1]}.py", rename(f"{plugin_directory}{message.parameter[1]}.py",
f"{plugin_directory}{message.parameter[1]}.py.disabled") f"{plugin_directory}{message.parameter[1]}.py.disabled")
await message.edit(f"{lang('apt_plugin')} {message.parameter[1]} " await message.edit(f"{lang('apt_plugin')} {message.parameter[1]} "
f"{lang('apt_disable')},{lang('apt_reboot')}") f"{lang('apt_disable')}")
await log(f"{lang('apt_disable')} {message.parameter[1]}.") await log(f"{lang('apt_disable')} {message.parameter[1]}.")
reload_all() reload_all()
else: else:

View File

@ -1,12 +1,16 @@
import contextlib import contextlib
from os import sep, remove, mkdir from os import sep, remove, mkdir
from os.path import exists from os.path import exists
from typing import List, Optional from typing import List, Optional, Union
from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.schedulers.asyncio import AsyncIOScheduler
from httpx import AsyncClient from httpx import AsyncClient
from pyrogram import Client from pyrogram import Client
from pyrogram.types import Message from pyrogram.types import Message
from pyromod.utils.conversation import Conversation
from pyromod.utils.errors import AlreadyInConversationError, TimeoutConversationError
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
# init folders # init folders
@ -31,6 +35,25 @@ def safe_remove(name: str) -> None:
class Client(Client): # noqa class Client(Client): # noqa
job: Optional[AsyncIOScheduler] = None job: Optional[AsyncIOScheduler] = None
async def listen(self, chat_id, filters=None, timeout=None) -> Optional[Message]:
return
async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs) -> Optional[Message]:
return
def cancel_listener(self, chat_id):
""" Cancel the conversation with the given chat_id. """
return
def cancel_all_listeners(self):
""" Cancel all conversations. """
return
def conversation(self, chat_id: Union[int, str],
once_timeout: int = 60, filters=None) -> Optional[Conversation]:
""" Initialize a conversation with the given chat_id. """
return
class Message(Message): # noqa class Message(Message): # noqa
arguments: str arguments: str

View File

@ -18,4 +18,4 @@ You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
from .filters import dice from .filters import dice

View File

@ -18,8 +18,11 @@ You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
import pyrogram import pyrogram
def dice(ctx, message): def dice(ctx, message):
return hasattr(message, 'dice') and message.dice return hasattr(message, 'dice') and message.dice
pyrogram.filters.dice = dice pyrogram.filters.dice = dice

View File

@ -17,4 +17,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
from .helpers import ikb, bki, ntb, btn, kb, kbtn, array_chunk, force_reply from .helpers import ikb, bki, ntb, btn, kb, kbtn, array_chunk, force_reply

View File

@ -1,19 +1,24 @@
from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup, KeyboardButton, ReplyKeyboardMarkup, ForceReply from pyrogram.types import InlineKeyboardButton, InlineKeyboardMarkup, KeyboardButton, ReplyKeyboardMarkup, ForceReply
def ikb(rows = []):
def ikb(rows=None):
if rows is None:
rows = []
lines = [] lines = []
for row in rows: for row in rows:
line = [] line = []
for button in row: for button in row:
button = btn(*button) # InlineKeyboardButton button = btn(*button) # InlineKeyboardButton
line.append(button) line.append(button)
lines.append(line) lines.append(line)
return InlineKeyboardMarkup(inline_keyboard=lines) return InlineKeyboardMarkup(inline_keyboard=lines)
#return {'inline_keyboard': lines} # return {'inline_keyboard': lines}
def btn(text, value, type = 'callback_data'):
def btn(text, value, type='callback_data'):
return InlineKeyboardButton(text, **{type: value}) return InlineKeyboardButton(text, **{type: value})
#return {'text': text, type: value} # return {'text': text, type: value}
# The inverse of above # The inverse of above
def bki(keyboard): def bki(keyboard):
@ -21,14 +26,16 @@ def bki(keyboard):
for row in keyboard.inline_keyboard: for row in keyboard.inline_keyboard:
line = [] line = []
for button in row: for button in row:
button = ntb(button) # btn() format button = ntb(button) # btn() format
line.append(button) line.append(button)
lines.append(line) lines.append(line)
return lines return lines
#return ikb() format # return ikb() format
def ntb(button): def ntb(button):
for btn_type in ['callback_data', 'url', 'switch_inline_query', 'switch_inline_query_current_chat', 'callback_game']: for btn_type in ['callback_data', 'url', 'switch_inline_query', 'switch_inline_query_current_chat',
'callback_game']:
value = getattr(button, btn_type) value = getattr(button, btn_type)
if value: if value:
break break
@ -36,9 +43,12 @@ def ntb(button):
if btn_type != 'callback_data': if btn_type != 'callback_data':
button.append(btn_type) button.append(btn_type)
return button return button
#return {'text': text, type: value} # return {'text': text, type: value}
def kb(rows = [], **kwargs):
def kb(rows=None, **kwargs):
if rows is None:
rows = []
lines = [] lines = []
for row in rows: for row in rows:
line = [] line = []
@ -48,16 +58,18 @@ def kb(rows = [], **kwargs):
button = KeyboardButton(button) button = KeyboardButton(button)
elif button_type == dict: elif button_type == dict:
button = KeyboardButton(**button) button = KeyboardButton(**button)
line.append(button) line.append(button)
lines.append(line) lines.append(line)
return ReplyKeyboardMarkup(keyboard=lines, **kwargs) return ReplyKeyboardMarkup(keyboard=lines, **kwargs)
kbtn = KeyboardButton kbtn = KeyboardButton
def force_reply(selective=True): def force_reply(selective=True):
return ForceReply(selective=selective) return ForceReply(selective=selective)
def array_chunk(input, size):
return [input[i:i+size] for i in range(0, len(input), size)]
def array_chunk(input_, size):
return [input_[i:i + size] for i in range(0, len(input_), size)]

View File

@ -20,7 +20,7 @@ along with pyromod. If not, see <https://www.gnu.org/licenses/>.
import asyncio import asyncio
import functools import functools
from typing import Optional, List from typing import Optional, List, Union
import pyrogram import pyrogram
@ -28,6 +28,8 @@ from pagermaid.single_utils import get_sudo_list, Message
from pagermaid.scheduler import add_delete_message_job from pagermaid.scheduler import add_delete_message_job
from ..utils import patch, patchable from ..utils import patch, patchable
from ..utils.conversation import Conversation
from ..utils.errors import TimeoutConversationError
class ListenerCanceled(Exception): class ListenerCanceled(Exception):
@ -59,7 +61,10 @@ class Client:
self.listening.update({ self.listening.update({
chat_id: {"future": future, "filters": filters} chat_id: {"future": future, "filters": filters}
}) })
return await asyncio.wait_for(future, timeout) try:
return await asyncio.wait_for(future, timeout)
except asyncio.exceptions.TimeoutError as e:
raise TimeoutConversationError() from e
@patchable @patchable
async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs): async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs):
@ -82,6 +87,15 @@ class Client:
listener['future'].set_exception(ListenerCanceled()) listener['future'].set_exception(ListenerCanceled())
self.clear_listener(chat_id, listener['future']) self.clear_listener(chat_id, listener['future'])
@patchable
def cancel_all_listener(self):
for chat_id in self.listening:
self.cancel_listener(chat_id)
@patchable
def conversation(self, chat_id: Union[int, str], once_timeout: int = 60, filters=None):
return Conversation(self, chat_id, once_timeout, filters)
@patch(pyrogram.handlers.message_handler.MessageHandler) @patch(pyrogram.handlers.message_handler.MessageHandler)
class MessageHandler: class MessageHandler:

View File

@ -17,4 +17,4 @@ GNU General Public License for more details.
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
from .pagination import Pagination from .pagination import Pagination

View File

@ -20,6 +20,7 @@ along with pyromod. If not, see <https://www.gnu.org/licenses/>.
import math import math
from ..helpers import array_chunk from ..helpers import array_chunk
class Pagination: class Pagination:
def __init__(self, objects, page_data=None, item_data=None, item_title=None): def __init__(self, objects, page_data=None, item_data=None, item_title=None):
default_page_callback = (lambda x: str(x)) default_page_callback = (lambda x: str(x))
@ -28,26 +29,25 @@ class Pagination:
self.page_data = page_data or default_page_callback self.page_data = page_data or default_page_callback
self.item_data = item_data or default_item_callback self.item_data = item_data or default_item_callback
self.item_title = item_title or default_item_callback self.item_title = item_title or default_item_callback
def create(self, page, lines=5, columns=1): def create(self, page, lines=5, columns=1):
quant_per_page = lines*columns quant_per_page = lines * columns
page = 1 if page <= 0 else page page = 1 if page <= 0 else page
offset = (page-1)*quant_per_page offset = (page - 1) * quant_per_page
stop = offset+quant_per_page stop = offset + quant_per_page
cutted = self.objects[offset:stop] cutted = self.objects[offset:stop]
total = len(self.objects) total = len(self.objects)
pages_range = [*range(1, math.ceil(total/quant_per_page)+1)] # each item is a page pages_range = [*range(1, math.ceil(total / quant_per_page) + 1)] # each item is a page
last_page = len(pages_range) last_page = len(pages_range)
nav = [] nav = []
if page <= 3: if page <= 3:
for n in [1,2,3]: for n in [1, 2, 3]:
if n not in pages_range: if n not in pages_range:
continue continue
text = f"· {n} ·" if n == page else n text = f"· {n} ·" if n == page else n
nav.append( (text, self.page_data(n)) ) nav.append((text, self.page_data(n)))
if last_page >= 4: if last_page >= 4:
nav.append( nav.append(
('4 ' if last_page > 5 else 4, self.page_data(4)) ('4 ' if last_page > 5 else 4, self.page_data(4))
@ -56,30 +56,29 @@ class Pagination:
nav.append( nav.append(
(f'{last_page} »' if last_page > 5 else last_page, self.page_data(last_page)) (f'{last_page} »' if last_page > 5 else last_page, self.page_data(last_page))
) )
elif page >= last_page-2: elif page >= last_page - 2:
nav.extend( nav.extend(
[ [
('« 1' if last_page > 5 else 1, self.page_data(1)), ('« 1' if last_page > 5 else 1, self.page_data(1)),
( (
f' {last_page-3}' if last_page > 5 else last_page - 3, f' {last_page - 3}' if last_page > 5 else last_page - 3,
self.page_data(last_page - 3), self.page_data(last_page - 3),
), ),
] ]
) )
for n in range(last_page-2, last_page+1): for n in range(last_page - 2, last_page + 1):
text = f"· {n} ·" if n == page else n text = f"· {n} ·" if n == page else n
nav.append( (text, self.page_data(n)) ) nav.append((text, self.page_data(n)))
else: else:
nav = [ nav = [
('« 1', self.page_data(1)), ('« 1', self.page_data(1)),
(f' {page-1}', self.page_data(page - 1)), (f' {page - 1}', self.page_data(page - 1)),
(f'· {page} ·', "noop"), (f'· {page} ·', "noop"),
(f'{page+1} ', self.page_data(page + 1)), (f'{page + 1} ', self.page_data(page + 1)),
(f'{last_page} »', self.page_data(last_page)), (f'{last_page} »', self.page_data(last_page)),
] ]
buttons = [ buttons = [
(self.item_title(item, page), self.item_data(item, page)) (self.item_title(item, page), self.item_data(item, page))
for item in cutted for item in cutted
@ -89,4 +88,4 @@ class Pagination:
if last_page > 1: if last_page > 1:
kb_lines.append(nav) kb_lines.append(nav)
return kb_lines return kb_lines

View File

@ -18,4 +18,4 @@ You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
from .utils import patch, patchable from .utils import patch, patchable

View File

@ -0,0 +1,89 @@
import asyncio
import functools
from typing import Union
from pyrogram.raw.types import InputPeerUser, InputPeerChat, InputPeerChannel
from pyromod.utils.errors import AlreadyInConversationError
def _checks_cancelled(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
if self._cancelled:
raise asyncio.CancelledError("The conversation was cancelled before")
return f(self, *args, **kwargs)
return wrapper
class Conversation:
def __init__(self, client, chat_id: Union[int, str],
once_timeout: int = 60, filters=None):
self._client = client
self._chat_id = chat_id
self._once_timeout = once_timeout
self._filters = filters
self._cancelled = False
@_checks_cancelled
async def send_message(self, *args, **kwargs):
return await self._client.send_message(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_media_group(self, *args, **kwargs):
return await self._client.send_media_group(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_photo(self, *args, **kwargs):
return await self._client.send_photo(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_document(self, *args, **kwargs):
return await self._client.send_document(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_sticker(self, *args, **kwargs):
return await self._client.send_sticker(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_voice(self, *args, **kwargs):
return await self._client.send_voice(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def send_video(self, *args, **kwargs):
return await self._client.send_video(self._chat_id, *args, **kwargs)
@_checks_cancelled
async def ask(self, text, filters=None, timeout=None, *args, **kwargs):
filters = filters or self._filters
timeout = timeout or self._once_timeout
return await self._client.ask(self._chat_id, text, filters=filters, timeout=timeout, *args, **kwargs)
@_checks_cancelled
async def get_response(self, filters=None, timeout=None):
filters = filters or self._filters
timeout = timeout or self._once_timeout
return await self._client.listen(self._chat_id, filters, timeout)
def mark_as_read(self, message=None):
return self._client.read_chat_history(self._chat_id, max_id=message.id if message else 0)
def cancel(self):
self._cancelled = True
self._client.cancel_listener(self._chat_id)
async def __aenter__(self):
self._peer_chat = await self._client.resolve_peer(self._chat_id)
if isinstance(self._peer_chat, InputPeerUser):
self._chat_id = self._peer_chat.user_id
elif isinstance(self._peer_chat, InputPeerChat):
self._chat_id = self._peer_chat.chat_id
elif isinstance(self._peer_chat, InputPeerChannel):
self._chat_id = self._peer_chat.channel_id
if self._client.listening.get(self._chat_id, False):
raise AlreadyInConversationError()
self._cancelled = False
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.cancel()

19
pyromod/utils/errors.py Normal file
View File

@ -0,0 +1,19 @@
class AlreadyInConversationError(Exception):
"""
Occurs when another exclusive conversation is opened in the same chat.
"""
def __init__(self):
super().__init__(
"Cannot open exclusive conversation in a "
"chat that already has one open conversation"
)
class TimeoutConversationError(Exception):
"""
Occurs when the conversation times out.
"""
def __init__(self):
super().__init__(
"Response read timed out"
)

View File

@ -18,18 +18,21 @@ You should have received a copy of the GNU General Public License
along with pyromod. If not, see <https://www.gnu.org/licenses/>. along with pyromod. If not, see <https://www.gnu.org/licenses/>.
""" """
def patch(obj): def patch(obj):
def is_patchable(item): def is_patchable(item):
return getattr(item[1], 'patchable', False) return getattr(item[1], 'patchable', False)
def wrapper(container): def wrapper(container):
for name,func in filter(is_patchable, container.__dict__.items()): for name, func in filter(is_patchable, container.__dict__.items()):
old = getattr(obj, name, None) old = getattr(obj, name, None)
setattr(obj, f'old{name}', old) setattr(obj, f'old{name}', old)
setattr(obj, name, func) setattr(obj, name, func)
return container return container
return wrapper return wrapper
def patchable(func): def patchable(func):
func.patchable = True func.patchable = True
return func return func