Remove trailing username when adding args to Message.command

Fixes #676
This commit is contained in:
Dan 2021-05-06 13:02:26 +02:00
parent 1d940b96a3
commit 4fc4501445
2 changed files with 9 additions and 5 deletions

View File

@ -770,7 +770,7 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
nonlocal username nonlocal username
if username is None: if username is None:
username = (await client.get_me()).username username = (await client.get_me()).username or ""
text = message.text or message.caption text = message.text or message.caption
message.command = None message.command = None
@ -778,8 +778,6 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
if not text: if not text:
return False return False
pattern = rf"^(?:{{cmd}}|{{cmd}}@{username})(?:\s|$)" if username else r"^{cmd}(?:\s|$)"
for prefix in flt.prefixes: for prefix in flt.prefixes:
if not text.startswith(prefix): if not text.startswith(prefix):
continue continue
@ -787,17 +785,19 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
without_prefix = text[len(prefix):] without_prefix = text[len(prefix):]
for cmd in flt.commands: for cmd in flt.commands:
if not re.match(pattern.format(cmd=re.escape(cmd)), without_prefix, if not re.match(rf"^(?:{cmd}(?:@?{username})?)(?:\s|$)", without_prefix,
flags=re.IGNORECASE if not flt.case_sensitive else 0): flags=re.IGNORECASE if not flt.case_sensitive else 0):
continue continue
without_command = re.sub(rf"{cmd}(?:@?{username})?\s", "", without_prefix, count=1)
# match.groups are 1-indexed, group(1) is the quote, group(2) is the text # match.groups are 1-indexed, group(1) is the quote, group(2) is the text
# between the quotes, group(3) is unquoted, whitespace-split text # between the quotes, group(3) is unquoted, whitespace-split text
# Remove the escape character from the arguments # Remove the escape character from the arguments
message.command = [cmd] + [ message.command = [cmd] + [
re.sub(r"\\([\"'])", r"\1", m.group(2) or m.group(3) or "") re.sub(r"\\([\"'])", r"\1", m.group(2) or m.group(3) or "")
for m in command_re.finditer(without_prefix[len(cmd):]) for m in command_re.finditer(without_command)
] ]
return True return True

View File

@ -107,6 +107,10 @@ async def test_with_args():
await f(c, m) await f(c, m)
assert m.command == ["start"] + list("abc") assert m.command == ["start"] + list("abc")
m = Message('/start@username a b c')
await f(c, m)
assert m.command == ["start"] + list("abc")
m = Message("/start 'a b' c") m = Message("/start 'a b' c")
await f(c, m) await f(c, m)
assert m.command == ["start", "a b", "c"] assert m.command == ["start", "a b", "c"]