Merge branch 'asyncio-dev'

This commit is contained in:
Dan 2020-08-22 07:59:45 +02:00
commit 2f0a1f4119
4 changed files with 52 additions and 15 deletions

View File

@ -188,8 +188,9 @@ class Dispatcher:
if isinstance(handler, handler_type): if isinstance(handler, handler_type):
try: try:
if handler.check(parsed_update): if (await handler.check(parsed_update)):
args = (parsed_update,) args = (parsed_update,)
except Exception as e: except Exception as e:
log.error(e, exc_info=True) log.error(e, exc_info=True)
continue continue

View File

@ -16,6 +16,9 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
class Filter: class Filter:
def __call__(self, message): def __call__(self, message):
raise NotImplementedError raise NotImplementedError
@ -34,8 +37,13 @@ class InvertFilter(Filter):
def __init__(self, base): def __init__(self, base):
self.base = base self.base = base
def __call__(self, message): async def __call__(self, message):
return not self.base(message) if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
return not x
class AndFilter(Filter): class AndFilter(Filter):
@ -43,8 +51,18 @@ class AndFilter(Filter):
self.base = base self.base = base
self.other = other self.other = other
def __call__(self, message): async def __call__(self, message):
return self.base(message) and self.other(message) if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
if asyncio.iscoroutinefunction(self.other.__call__):
y = await self.other(message)
else:
y = self.other(message)
return x and y
class OrFilter(Filter): class OrFilter(Filter):
@ -52,5 +70,15 @@ class OrFilter(Filter):
self.base = base self.base = base
self.other = other self.other = other
def __call__(self, message): async def __call__(self, message):
return self.base(message) or self.other(message) if asyncio.iscoroutinefunction(self.base.__call__):
x = await self.base(message)
else:
x = self.base(message)
if asyncio.iscoroutinefunction(self.other.__call__):
y = await self.other(message)
else:
y = self.other(message)
return x or y

View File

@ -46,5 +46,5 @@ class DeletedMessagesHandler(Handler):
def __init__(self, callback: callable, filters=None): def __init__(self, callback: callable, filters=None):
super().__init__(callback, filters) super().__init__(callback, filters)
def check(self, messages): async def check(self, messages):
return super().check(messages[0]) return (await super().check(messages[0]))

View File

@ -16,14 +16,22 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
class Handler: class Handler:
def __init__(self, callback: callable, filters=None): def __init__(self, callback: callable, filters=None):
self.callback = callback self.callback = callback
self.filters = filters self.filters = filters
def check(self, update): async def check(self, update):
return (
self.filters(update) if callable(self.filters):
if callable(self.filters) if asyncio.iscoroutinefunction(self.filters.__call__):
else True return (await self.filters(update))
)
else:
return self.filters(update)
else:
return True