Remove trailing username when adding args to Message.command
Fixes #676
This commit is contained in:
parent
1d940b96a3
commit
4fc4501445
@ -770,7 +770,7 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
|
||||
nonlocal username
|
||||
|
||||
if username is None:
|
||||
username = (await client.get_me()).username
|
||||
username = (await client.get_me()).username or ""
|
||||
|
||||
text = message.text or message.caption
|
||||
message.command = None
|
||||
@ -778,8 +778,6 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
|
||||
if not text:
|
||||
return False
|
||||
|
||||
pattern = rf"^(?:{{cmd}}|{{cmd}}@{username})(?:\s|$)" if username else r"^{cmd}(?:\s|$)"
|
||||
|
||||
for prefix in flt.prefixes:
|
||||
if not text.startswith(prefix):
|
||||
continue
|
||||
@ -787,17 +785,19 @@ def command(commands: Union[str, List[str]], prefixes: Union[str, List[str]] = "
|
||||
without_prefix = text[len(prefix):]
|
||||
|
||||
for cmd in flt.commands:
|
||||
if not re.match(pattern.format(cmd=re.escape(cmd)), without_prefix,
|
||||
if not re.match(rf"^(?:{cmd}(?:@?{username})?)(?:\s|$)", without_prefix,
|
||||
flags=re.IGNORECASE if not flt.case_sensitive else 0):
|
||||
continue
|
||||
|
||||
without_command = re.sub(rf"{cmd}(?:@?{username})?\s", "", without_prefix, count=1)
|
||||
|
||||
# match.groups are 1-indexed, group(1) is the quote, group(2) is the text
|
||||
# between the quotes, group(3) is unquoted, whitespace-split text
|
||||
|
||||
# Remove the escape character from the arguments
|
||||
message.command = [cmd] + [
|
||||
re.sub(r"\\([\"'])", r"\1", m.group(2) or m.group(3) or "")
|
||||
for m in command_re.finditer(without_prefix[len(cmd):])
|
||||
for m in command_re.finditer(without_command)
|
||||
]
|
||||
|
||||
return True
|
||||
|
@ -107,6 +107,10 @@ async def test_with_args():
|
||||
await f(c, m)
|
||||
assert m.command == ["start"] + list("abc")
|
||||
|
||||
m = Message('/start@username a b c')
|
||||
await f(c, m)
|
||||
assert m.command == ["start"] + list("abc")
|
||||
|
||||
m = Message("/start 'a b' c")
|
||||
await f(c, m)
|
||||
assert m.command == ["start", "a b", "c"]
|
||||
|
Loading…
Reference in New Issue
Block a user