From ecab62ce84f9f2cb3eadd5d6664cf5b99c9ab66d Mon Sep 17 00:00:00 2001
From: Hasibul Kobir <46620128+HasibulKabir@users.noreply.github.com>
Date: Fri, 21 Aug 2020 11:33:24 +0600
Subject: [PATCH] Add support for both sync and async filters (#437)
* support for both sync and async filters
* Add whitespace for readability
* moving to handler.check for coroutine function
Ref: https://github.com/pyrogram/pyrogram/pull/437#discussion_r451626488
* add last line
Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
---
pyrogram/client/ext/dispatcher.py | 3 +-
pyrogram/client/filters/filter.py | 40 ++++++++++++++++---
.../handlers/deleted_messages_handler.py | 4 +-
pyrogram/client/handlers/handler.py | 20 +++++++---
4 files changed, 52 insertions(+), 15 deletions(-)
diff --git a/pyrogram/client/ext/dispatcher.py b/pyrogram/client/ext/dispatcher.py
index 818bc60c..d8308944 100644
--- a/pyrogram/client/ext/dispatcher.py
+++ b/pyrogram/client/ext/dispatcher.py
@@ -188,8 +188,9 @@ class Dispatcher:
if isinstance(handler, handler_type):
try:
- if handler.check(parsed_update):
+ if (await handler.check(parsed_update)):
args = (parsed_update,)
+
except Exception as e:
log.error(e, exc_info=True)
continue
diff --git a/pyrogram/client/filters/filter.py b/pyrogram/client/filters/filter.py
index eb89b3c3..67067e03 100644
--- a/pyrogram/client/filters/filter.py
+++ b/pyrogram/client/filters/filter.py
@@ -16,6 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
+import asyncio
+
+
class Filter:
def __call__(self, message):
raise NotImplementedError
@@ -34,8 +37,13 @@ class InvertFilter(Filter):
def __init__(self, base):
self.base = base
- def __call__(self, message):
- return not self.base(message)
+ async def __call__(self, message):
+ if asyncio.iscoroutinefunction(self.base.__call__):
+ x = await self.base(message)
+ else:
+ x = self.base(message)
+
+ return not x
class AndFilter(Filter):
@@ -43,8 +51,18 @@ class AndFilter(Filter):
self.base = base
self.other = other
- def __call__(self, message):
- return self.base(message) and self.other(message)
+ async def __call__(self, message):
+ if asyncio.iscoroutinefunction(self.base.__call__):
+ x = await self.base(message)
+ else:
+ x = self.base(message)
+
+ if asyncio.iscoroutinefunction(self.other.__call__):
+ y = await self.other(message)
+ else:
+ y = self.other(message)
+
+ return x and y
class OrFilter(Filter):
@@ -52,5 +70,15 @@ class OrFilter(Filter):
self.base = base
self.other = other
- def __call__(self, message):
- return self.base(message) or self.other(message)
+ async def __call__(self, message):
+ if asyncio.iscoroutinefunction(self.base.__call__):
+ x = await self.base(message)
+ else:
+ x = self.base(message)
+
+ if asyncio.iscoroutinefunction(self.other.__call__):
+ y = await self.other(message)
+ else:
+ y = self.other(message)
+
+ return x or y
diff --git a/pyrogram/client/handlers/deleted_messages_handler.py b/pyrogram/client/handlers/deleted_messages_handler.py
index 7312ba90..03863541 100644
--- a/pyrogram/client/handlers/deleted_messages_handler.py
+++ b/pyrogram/client/handlers/deleted_messages_handler.py
@@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler):
def __init__(self, callback: callable, filters=None):
super().__init__(callback, filters)
- def check(self, messages):
- return super().check(messages[0])
+ async def check(self, messages):
+ return (await super().check(messages[0]))
diff --git a/pyrogram/client/handlers/handler.py b/pyrogram/client/handlers/handler.py
index 0eb132d1..d50b069b 100644
--- a/pyrogram/client/handlers/handler.py
+++ b/pyrogram/client/handlers/handler.py
@@ -16,14 +16,22 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
+import asyncio
+
+
class Handler:
def __init__(self, callback: callable, filters=None):
self.callback = callback
self.filters = filters
- def check(self, update):
- return (
- self.filters(update)
- if callable(self.filters)
- else True
- )
+ async def check(self, update):
+
+ if callable(self.filters):
+ if asyncio.iscoroutinefunction(self.filters.__call__):
+ return (await self.filters(update))
+
+ else:
+ return self.filters(update)
+
+ else:
+ return True