Add support for both sync and async filters (#437)

* support for both sync and async filters

* Add whitespace for readability

* moving to handler.check for coroutine function

Ref: https://github.com/pyrogram/pyrogram/pull/437#discussion_r451626488

* add last line

Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
This commit is contained in:
Hasibul Kobir 2020-08-21 11:33:24 +06:00 committed by GitHub
parent 55fc4faf34
commit ecab62ce84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 52 additions and 15 deletions

View File

@ -188,8 +188,9 @@ class Dispatcher:
if isinstance(handler, handler_type):
try:
if handler.check(parsed_update):
if (await handler.check(parsed_update)):
args = (parsed_update,)
except Exception as e:
log.error(e, exc_info=True)
continue

View File

@ -16,6 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
class Filter:
def __call__(self, message):
raise NotImplementedError
@ -34,8 +37,13 @@ class InvertFilter(Filter):
def __init__(self, base):
self.base = base
def __call__(self, message):
return not self.base(message)
async def __call__(self, message):
if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
return not x
class AndFilter(Filter):
@ -43,8 +51,18 @@ class AndFilter(Filter):
self.base = base
self.other = other
def __call__(self, message):
return self.base(message) and self.other(message)
async def __call__(self, message):
if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
if asyncio.iscoroutinefunction(self.other.__call__):
y = await self.other(message)
else:
y = self.other(message)
return x and y
class OrFilter(Filter):
@ -52,5 +70,15 @@ class OrFilter(Filter):
self.base = base
self.other = other
def __call__(self, message):
return self.base(message) or self.other(message)
async def __call__(self, message):
if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
if asyncio.iscoroutinefunction(self.other.__call__):
y = await self.other(message)
else:
y = self.other(message)
return x or y

View File

@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler):
def __init__(self, callback: callable, filters=None):
super().__init__(callback, filters)
def check(self, messages):
return super().check(messages[0])
async def check(self, messages):
return (await super().check(messages[0]))

View File

@ -16,14 +16,22 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
class Handler:
def __init__(self, callback: callable, filters=None):
self.callback = callback
self.filters = filters
def check(self, update):
return (
self.filters(update)
if callable(self.filters)
else True
)
async def check(self, update):
if callable(self.filters):
if asyncio.iscoroutinefunction(self.filters.__call__):
return (await self.filters(update))
else:
return self.filters(update)
else:
return True