mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-30 17:43:32 +00:00
Add support for both sync and async filters (#437)
* support for both sync and async filters * Add whitespace for readability * moving to handler.check for coroutine function Ref: https://github.com/pyrogram/pyrogram/pull/437#discussion_r451626488 * add last line Co-authored-by: Dan <14043624+delivrance@users.noreply.github.com>
This commit is contained in:
parent
55fc4faf34
commit
ecab62ce84
@ -188,8 +188,9 @@ class Dispatcher:
|
||||
|
||||
if isinstance(handler, handler_type):
|
||||
try:
|
||||
if handler.check(parsed_update):
|
||||
if (await handler.check(parsed_update)):
|
||||
args = (parsed_update,)
|
||||
|
||||
except Exception as e:
|
||||
log.error(e, exc_info=True)
|
||||
continue
|
||||
|
@ -16,6 +16,9 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class Filter:
|
||||
def __call__(self, message):
|
||||
raise NotImplementedError
|
||||
@ -34,8 +37,13 @@ class InvertFilter(Filter):
|
||||
def __init__(self, base):
|
||||
self.base = base
|
||||
|
||||
def __call__(self, message):
|
||||
return not self.base(message)
|
||||
async def __call__(self, message):
|
||||
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||
x = await self.base(message)
|
||||
else:
|
||||
x = self.base(message)
|
||||
|
||||
return not x
|
||||
|
||||
|
||||
class AndFilter(Filter):
|
||||
@ -43,8 +51,18 @@ class AndFilter(Filter):
|
||||
self.base = base
|
||||
self.other = other
|
||||
|
||||
def __call__(self, message):
|
||||
return self.base(message) and self.other(message)
|
||||
async def __call__(self, message):
|
||||
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||
x = await self.base(message)
|
||||
else:
|
||||
x = self.base(message)
|
||||
|
||||
if asyncio.iscoroutinefunction(self.other.__call__):
|
||||
y = await self.other(message)
|
||||
else:
|
||||
y = self.other(message)
|
||||
|
||||
return x and y
|
||||
|
||||
|
||||
class OrFilter(Filter):
|
||||
@ -52,5 +70,15 @@ class OrFilter(Filter):
|
||||
self.base = base
|
||||
self.other = other
|
||||
|
||||
def __call__(self, message):
|
||||
return self.base(message) or self.other(message)
|
||||
async def __call__(self, message):
|
||||
if asyncio.iscoroutinefunction(self.base.__call__):
|
||||
x = await self.base(message)
|
||||
else:
|
||||
x = self.base(message)
|
||||
|
||||
if asyncio.iscoroutinefunction(self.other.__call__):
|
||||
y = await self.other(message)
|
||||
else:
|
||||
y = self.other(message)
|
||||
|
||||
return x or y
|
||||
|
@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler):
|
||||
def __init__(self, callback: callable, filters=None):
|
||||
super().__init__(callback, filters)
|
||||
|
||||
def check(self, messages):
|
||||
return super().check(messages[0])
|
||||
async def check(self, messages):
|
||||
return (await super().check(messages[0]))
|
||||
|
@ -16,14 +16,22 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class Handler:
|
||||
def __init__(self, callback: callable, filters=None):
|
||||
self.callback = callback
|
||||
self.filters = filters
|
||||
|
||||
def check(self, update):
|
||||
return (
|
||||
self.filters(update)
|
||||
if callable(self.filters)
|
||||
else True
|
||||
)
|
||||
async def check(self, update):
|
||||
|
||||
if callable(self.filters):
|
||||
if asyncio.iscoroutinefunction(self.filters.__call__):
|
||||
return (await self.filters(update))
|
||||
|
||||
else:
|
||||
return self.filters(update)
|
||||
|
||||
else:
|
||||
return True
|
||||
|
Loading…
Reference in New Issue
Block a user