From ecab62ce84f9f2cb3eadd5d6664cf5b99c9ab66d Mon Sep 17 00:00:00 2001 From: Hasibul Kobir <46620128+HasibulKabir@users.noreply.github.com> Date: Fri, 21 Aug 2020 11:33:24 +0600 Subject: [PATCH] Add support for both sync and async filters (#437) * support for both sync and async filters * Add whitespace for readability * moving to handler.check for coroutine function Ref: https://github.com/pyrogram/pyrogram/pull/437#discussion_r451626488 * add last line Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com> --- pyrogram/client/ext/dispatcher.py | 3 +- pyrogram/client/filters/filter.py | 40 ++++++++++++++++--- .../handlers/deleted_messages_handler.py | 4 +- pyrogram/client/handlers/handler.py | 20 +++++++--- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/pyrogram/client/ext/dispatcher.py b/pyrogram/client/ext/dispatcher.py index 818bc60c..d8308944 100644 --- a/pyrogram/client/ext/dispatcher.py +++ b/pyrogram/client/ext/dispatcher.py @@ -188,8 +188,9 @@ class Dispatcher: if isinstance(handler, handler_type): try: - if handler.check(parsed_update): + if (await handler.check(parsed_update)): args = (parsed_update,) + except Exception as e: log.error(e, exc_info=True) continue diff --git a/pyrogram/client/filters/filter.py b/pyrogram/client/filters/filter.py index eb89b3c3..67067e03 100644 --- a/pyrogram/client/filters/filter.py +++ b/pyrogram/client/filters/filter.py @@ -16,6 +16,9 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio + + class Filter: def __call__(self, message): raise NotImplementedError @@ -34,8 +37,13 @@ class InvertFilter(Filter): def __init__(self, base): self.base = base - def __call__(self, message): - return not self.base(message) + async def __call__(self, message): + if asyncio.iscoroutinefunction(self.base.__call__): + x = await self.base(message) + else: + x = self.base(message) + + return not x class AndFilter(Filter): @@ -43,8 +51,18 @@ class AndFilter(Filter): self.base = base self.other = other - def __call__(self, message): - return self.base(message) and self.other(message) + async def __call__(self, message): + if asyncio.iscoroutinefunction(self.base.__call__): + x = await self.base(message) + else: + x = self.base(message) + + if asyncio.iscoroutinefunction(self.other.__call__): + y = await self.other(message) + else: + y = self.other(message) + + return x and y class OrFilter(Filter): @@ -52,5 +70,15 @@ class OrFilter(Filter): self.base = base self.other = other - def __call__(self, message): - return self.base(message) or self.other(message) + async def __call__(self, message): + if asyncio.iscoroutinefunction(self.base.__call__): + x = await self.base(message) + else: + x = self.base(message) + + if asyncio.iscoroutinefunction(self.other.__call__): + y = await self.other(message) + else: + y = self.other(message) + + return x or y diff --git a/pyrogram/client/handlers/deleted_messages_handler.py b/pyrogram/client/handlers/deleted_messages_handler.py index 7312ba90..03863541 100644 --- a/pyrogram/client/handlers/deleted_messages_handler.py +++ b/pyrogram/client/handlers/deleted_messages_handler.py @@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler): def __init__(self, callback: callable, filters=None): super().__init__(callback, filters) - def check(self, messages): - return super().check(messages[0]) + async def check(self, messages): + return (await super().check(messages[0])) diff --git a/pyrogram/client/handlers/handler.py b/pyrogram/client/handlers/handler.py index 0eb132d1..d50b069b 100644 --- a/pyrogram/client/handlers/handler.py +++ b/pyrogram/client/handlers/handler.py @@ -16,14 +16,22 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio + + class Handler: def __init__(self, callback: callable, filters=None): self.callback = callback self.filters = filters - def check(self, update): - return ( - self.filters(update) - if callable(self.filters) - else True - ) + async def check(self, update): + + if callable(self.filters): + if asyncio.iscoroutinefunction(self.filters.__call__): + return (await self.filters(update)) + + else: + return self.filters(update) + + else: + return True