From 4fc45014457b8a4dbceda6bf696bcbbc11a6aa41 Mon Sep 17 00:00:00 2001 From: Dan <14043624+delivrance@users.noreply.github.com> Date: Thu, 6 May 2021 13:02:26 +0200 Subject: [PATCH] Remove trailing username when adding args to Message.command Fixes #676 --- pyrogram/filters.py | 10 +++++----- tests/filters/test_command.py | 4 ++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pyrogram/filters.py b/pyrogram/filters.py index aaff1aa1..4a07413d 100644 --- a/pyrogram/filters.py +++ b/pyrogram/filters.py @@ -770,7 +770,7 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = " nonlocal username if username is None: - username = (await client.get_me()).username + username = (await client.get_me()).username or "" text = message.text or message.caption message.command = None @@ -778,8 +778,6 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = " if not text: return False - pattern = rf"^(?:{{cmd}}|{{cmd}}@{username})(?:\s|$)" if username else r"^{cmd}(?:\s|$)" - for prefix in flt.prefixes: if not text.startswith(prefix): continue @@ -787,17 +785,19 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = " without_prefix = text[len(prefix):] for cmd in flt.commands: - if not re.match(pattern.format(cmd=re.escape(cmd)), without_prefix, + if not re.match(rf"^(?:{cmd}(?:@?{username})?)(?:\s|$)", without_prefix, flags=re.IGNORECASE if not flt.case_sensitive else 0): continue + without_command = re.sub(rf"{cmd}(?:@?{username})?\s", "", without_prefix, count=1) + # match.groups are 1-indexed, group(1) is the quote, group(2) is the text # between the quotes, group(3) is unquoted, whitespace-split text # Remove the escape character from the arguments message.command = [cmd] + [ re.sub(r"\\([\"'])", r"\1", m.group(2) or m.group(3) or "") - for m in command_re.finditer(without_prefix[len(cmd):]) + for m in command_re.finditer(without_command) ] return True diff --git a/tests/filters/test_command.py b/tests/filters/test_command.py index a29ee8f6..0cd00e95 100644 --- a/tests/filters/test_command.py +++ b/tests/filters/test_command.py @@ -107,6 +107,10 @@ async def test_with_args(): await f(c, m) assert m.command == ["start"] + list("abc") + m = Message('/start@username a b c') + await f(c, m) + assert m.command == ["start"] + list("abc") + m = Message("/start 'a b' c") await f(c, m) assert m.command == ["start", "a b", "c"]