Add support for both sync and async filters (#437)
* support for both sync and async filters * Add whitespace for readability * moving to handler.check for coroutine function Ref: https://github.com/pyrogram/pyrogram/pull/437#discussion_r451626488 * add last line Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
This commit is contained in:
parent
55fc4faf34
commit
ecab62ce84
@ -188,8 +188,9 @@ class Dispatcher:
|
|||||||
|
|
||||||
if isinstance(handler, handler_type):
|
if isinstance(handler, handler_type):
|
||||||
try:
|
try:
|
||||||
if handler.check(parsed_update):
|
if (await handler.check(parsed_update)):
|
||||||
args = (parsed_update,)
|
args = (parsed_update,)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error(e, exc_info=True)
|
log.error(e, exc_info=True)
|
||||||
continue
|
continue
|
||||||
|
@ -16,6 +16,9 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
def __call__(self, message):
|
def __call__(self, message):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -34,8 +37,13 @@ class InvertFilter(Filter):
|
|||||||
def __init__(self, base):
|
def __init__(self, base):
|
||||||
self.base = base
|
self.base = base
|
||||||
|
|
||||||
def __call__(self, message):
|
async def __call__(self, message):
|
||||||
return not self.base(message)
|
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||||
|
x = await self.base(message)
|
||||||
|
else:
|
||||||
|
x = self.base(message)
|
||||||
|
|
||||||
|
return not x
|
||||||
|
|
||||||
|
|
||||||
class AndFilter(Filter):
|
class AndFilter(Filter):
|
||||||
@ -43,8 +51,18 @@ class AndFilter(Filter):
|
|||||||
self.base = base
|
self.base = base
|
||||||
self.other = other
|
self.other = other
|
||||||
|
|
||||||
def __call__(self, message):
|
async def __call__(self, message):
|
||||||
return self.base(message) and self.other(message)
|
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||||
|
x = await self.base(message)
|
||||||
|
else:
|
||||||
|
x = self.base(message)
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(self.other.__call__):
|
||||||
|
y = await self.other(message)
|
||||||
|
else:
|
||||||
|
y = self.other(message)
|
||||||
|
|
||||||
|
return x and y
|
||||||
|
|
||||||
|
|
||||||
class OrFilter(Filter):
|
class OrFilter(Filter):
|
||||||
@ -52,5 +70,15 @@ class OrFilter(Filter):
|
|||||||
self.base = base
|
self.base = base
|
||||||
self.other = other
|
self.other = other
|
||||||
|
|
||||||
def __call__(self, message):
|
async def __call__(self, message):
|
||||||
return self.base(message) or self.other(message)
|
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||||
|
x = await self.base(message)
|
||||||
|
else:
|
||||||
|
x = self.base(message)
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(self.other.__call__):
|
||||||
|
y = await self.other(message)
|
||||||
|
else:
|
||||||
|
y = self.other(message)
|
||||||
|
|
||||||
|
return x or y
|
||||||
|
@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler):
|
|||||||
def __init__(self, callback: callable, filters=None):
|
def __init__(self, callback: callable, filters=None):
|
||||||
super().__init__(callback, filters)
|
super().__init__(callback, filters)
|
||||||
|
|
||||||
def check(self, messages):
|
async def check(self, messages):
|
||||||
return super().check(messages[0])
|
return (await super().check(messages[0]))
|
||||||
|
@ -16,14 +16,22 @@
|
|||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
def __init__(self, callback: callable, filters=None):
|
def __init__(self, callback: callable, filters=None):
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.filters = filters
|
self.filters = filters
|
||||||
|
|
||||||
def check(self, update):
|
async def check(self, update):
|
||||||
return (
|
|
||||||
self.filters(update)
|
if callable(self.filters):
|
||||||
if callable(self.filters)
|
if asyncio.iscoroutinefunction(self.filters.__call__):
|
||||||
else True
|
return (await self.filters(update))
|
||||||
)
|
|
||||||
|
else:
|
||||||
|
return self.filters(update)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user