mirror of
https://github.com/TeamPGM/pyrogram.git
synced 2024-11-16 20:59:29 +00:00
Revamp HTML and Markdown parsers to allow multiple nested entities
This commit is contained in:
parent
648f37cf6d
commit
e7c49c6a1b
@ -16,127 +16,110 @@
|
||||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import html
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from html.parser import HTMLParser
|
||||
|
||||
import pyrogram
|
||||
from pyrogram.api.types import (
|
||||
MessageEntityBold as Bold,
|
||||
MessageEntityItalic as Italic,
|
||||
MessageEntityCode as Code,
|
||||
MessageEntityTextUrl as Url,
|
||||
MessageEntityPre as Pre,
|
||||
MessageEntityUnderline as Underline,
|
||||
MessageEntityStrike as Strike,
|
||||
MessageEntityBlockquote as Blockquote,
|
||||
MessageEntityMentionName as MentionInvalid,
|
||||
InputMessageEntityMentionName as Mention,
|
||||
)
|
||||
from pyrogram.api import types
|
||||
from pyrogram.errors import PeerIdInvalid
|
||||
from . import utils
|
||||
|
||||
|
||||
class HTML:
|
||||
HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)</\1>")
|
||||
class Parser(HTMLParser):
|
||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||
|
||||
def __init__(self, client: "pyrogram.BaseClient"):
|
||||
super().__init__()
|
||||
|
||||
self.client = client
|
||||
|
||||
self.text = ""
|
||||
self.entities = []
|
||||
self.temp_entities = []
|
||||
self.tags = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
attrs = dict(attrs)
|
||||
extra = {}
|
||||
|
||||
if tag in ["b", "strong"]:
|
||||
entity = types.MessageEntityBold
|
||||
elif tag in ["i", "em"]:
|
||||
entity = types.MessageEntityItalic
|
||||
elif tag == "u":
|
||||
entity = types.MessageEntityUnderline
|
||||
elif tag in ["s", "del", "strike"]:
|
||||
entity = types.MessageEntityStrike
|
||||
elif tag == "blockquote":
|
||||
entity = types.MessageEntityBlockquote
|
||||
elif tag == "code":
|
||||
entity = types.MessageEntityCode
|
||||
elif tag == "pre":
|
||||
entity = types.MessageEntityPre
|
||||
extra["language"] = ""
|
||||
elif tag == "a":
|
||||
url = attrs.get("href", "")
|
||||
|
||||
mention = Parser.MENTION_RE.match(url)
|
||||
|
||||
if mention:
|
||||
user_id = int(mention.group(1))
|
||||
|
||||
try:
|
||||
user = self.client.resolve_peer(user_id)
|
||||
except PeerIdInvalid:
|
||||
entity = types.MessageEntityMentionName
|
||||
extra["user_id"] = user_id
|
||||
else:
|
||||
entity = types.InputMessageEntityMentionName
|
||||
extra["user_id"] = user
|
||||
else:
|
||||
entity = types.MessageEntityTextUrl
|
||||
extra["url"] = url
|
||||
else:
|
||||
return
|
||||
|
||||
self.tags.append(tag)
|
||||
self.temp_entities.append(entity(offset=len(self.text), length=0, **extra))
|
||||
|
||||
def handle_data(self, data):
|
||||
data = html.unescape(data)
|
||||
|
||||
for entity in self.temp_entities:
|
||||
entity.length += len(data)
|
||||
|
||||
self.text += data
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
start_tag = self.tags.pop()
|
||||
|
||||
if start_tag != tag:
|
||||
line, offset = self.getpos()
|
||||
offset += 1
|
||||
|
||||
raise ValueError("Expected end tag </{}>, but found </{}> at {}:{}".format(start_tag, tag, line, offset))
|
||||
|
||||
self.entities.append(self.temp_entities.pop())
|
||||
|
||||
def error(self, message):
|
||||
pass
|
||||
|
||||
|
||||
class HTML:
|
||||
def __init__(self, client: "pyrogram.BaseClient" = None):
|
||||
self.client = client
|
||||
|
||||
def parse(self, message: str):
|
||||
entities = []
|
||||
message = utils.add_surrogates(str(message or ""))
|
||||
offset = 0
|
||||
def parse(self, text: str):
|
||||
text = utils.add_surrogates(str(text or "").strip())
|
||||
|
||||
for match in self.HTML_RE.finditer(message):
|
||||
start = match.start() - offset
|
||||
style, url, body = match.group(1, 3, 4)
|
||||
parser = Parser(self.client)
|
||||
parser.feed(text)
|
||||
print(parser.entities)
|
||||
|
||||
if url:
|
||||
mention = self.MENTION_RE.match(url)
|
||||
|
||||
if mention:
|
||||
user_id = int(mention.group(1))
|
||||
|
||||
try:
|
||||
input_user = self.client.resolve_peer(user_id)
|
||||
except PeerIdInvalid:
|
||||
input_user = None
|
||||
|
||||
entity = (
|
||||
Mention(offset=start, length=len(body), user_id=input_user)
|
||||
if input_user else MentionInvalid(offset=start, length=len(body), user_id=user_id)
|
||||
)
|
||||
else:
|
||||
entity = Url(offset=start, length=len(body), url=url)
|
||||
else:
|
||||
if style == "b" or style == "strong":
|
||||
entity = Bold(offset=start, length=len(body))
|
||||
elif style == "i" or style == "em":
|
||||
entity = Italic(offset=start, length=len(body))
|
||||
elif style == "code":
|
||||
entity = Code(offset=start, length=len(body))
|
||||
elif style == "pre":
|
||||
entity = Pre(offset=start, length=len(body), language="")
|
||||
elif style == "u":
|
||||
entity = Underline(offset=start, length=len(body))
|
||||
elif style in ["strike", "s", "del"]:
|
||||
entity = Strike(offset=start, length=len(body))
|
||||
elif style == "blockquote":
|
||||
entity = Blockquote(offset=start, length=len(body))
|
||||
else:
|
||||
continue
|
||||
|
||||
entities.append(entity)
|
||||
message = message.replace(match.group(), body)
|
||||
offset += len(style) * 2 + 5 + (len(url) + 8 if url else 0)
|
||||
|
||||
# TODO: OrderedDict to be removed in Python3.6
|
||||
# TODO: OrderedDict to be removed in Python 3.6
|
||||
return OrderedDict([
|
||||
("message", utils.remove_surrogates(message)),
|
||||
("entities", entities)
|
||||
("message", utils.remove_surrogates(parser.text)),
|
||||
("entities", parser.entities)
|
||||
])
|
||||
|
||||
def unparse(self, message: str, entities: list):
|
||||
message = utils.add_surrogates(message).strip()
|
||||
offset = 0
|
||||
|
||||
for entity in entities:
|
||||
start = entity.offset + offset
|
||||
type = entity.type
|
||||
url = entity.url
|
||||
user = entity.user
|
||||
sub = message[start: start + entity.length]
|
||||
|
||||
if type == "bold":
|
||||
style = "b"
|
||||
elif type == "italic":
|
||||
style = "i"
|
||||
elif type == "code":
|
||||
style = "code"
|
||||
elif type == "pre":
|
||||
style = "pre"
|
||||
elif type == "underline":
|
||||
style = "u"
|
||||
elif type == "strike":
|
||||
style = "s"
|
||||
elif type == "blockquote":
|
||||
style = "blockquote"
|
||||
elif type == "text_link":
|
||||
offset += 15 + len(url)
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "<a href=\"{}\">{}</a>".format(url, sub), 1)
|
||||
continue
|
||||
elif type == "text_mention":
|
||||
offset += 28 + len(str(user.id))
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "<a href=\"tg://user?id={}\">{}</a>".format(user.id, sub), 1)
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
offset += len(style) * 2 + 5
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "<{0}>{1}</{0}>".format(style, sub), 1)
|
||||
|
||||
return utils.remove_surrogates(message)
|
||||
|
@ -17,22 +17,9 @@
|
||||
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
import pyrogram
|
||||
from pyrogram.api.types import (
|
||||
MessageEntityBold as Bold,
|
||||
MessageEntityItalic as Italic,
|
||||
MessageEntityCode as Code,
|
||||
MessageEntityTextUrl as Url,
|
||||
MessageEntityPre as Pre,
|
||||
MessageEntityUnderline as Underline,
|
||||
MessageEntityStrike as Strike,
|
||||
MessageEntityMentionName as MentionInvalid,
|
||||
InputMessageEntityMentionName as Mention
|
||||
)
|
||||
from pyrogram.errors import PeerIdInvalid
|
||||
from . import utils
|
||||
from .html import HTML
|
||||
|
||||
|
||||
class Markdown:
|
||||
@ -43,10 +30,10 @@ class Markdown:
|
||||
CODE_DELIMITER = "`"
|
||||
PRE_DELIMITER = "```"
|
||||
|
||||
MARKDOWN_RE = re.compile(r"({d})([\w\W]*?)\1|\[([^[]+?)\]\(([^(]+?)\)".format(
|
||||
MARKDOWN_RE = re.compile(r"({d})".format(
|
||||
d="|".join(
|
||||
["".join(i) for i in [
|
||||
["\{}".format(j) for j in i]
|
||||
[r"\{}".format(j) for j in i]
|
||||
for i in [
|
||||
PRE_DELIMITER,
|
||||
CODE_DELIMITER,
|
||||
@ -56,107 +43,56 @@ class Markdown:
|
||||
BOLD_DELIMITER
|
||||
]
|
||||
]]
|
||||
)
|
||||
))
|
||||
MENTION_RE = re.compile(r"tg://user\?id=(\d+)")
|
||||
)))
|
||||
|
||||
def __init__(self, client: "pyrogram.BaseClient" = None):
|
||||
self.client = client
|
||||
URL_RE = re.compile(r"\[([^[]+)]\(([^(]+)\)")
|
||||
|
||||
def parse(self, message: str):
|
||||
message = utils.add_surrogates(str(message or "")).strip()
|
||||
entities = []
|
||||
def __init__(self, client: "pyrogram.BaseClient"):
|
||||
self.html = HTML(client)
|
||||
|
||||
def parse(self, text: str):
|
||||
offset = 0
|
||||
delimiters = set()
|
||||
|
||||
for match in self.MARKDOWN_RE.finditer(message):
|
||||
start = match.start() - offset
|
||||
style, body, text, url = match.groups()
|
||||
for i, match in enumerate(re.finditer(Markdown.MARKDOWN_RE, text)):
|
||||
start, stop = match.span()
|
||||
delimiter = match.group(1)
|
||||
|
||||
if url:
|
||||
mention = self.MENTION_RE.match(url)
|
||||
|
||||
if mention:
|
||||
user_id = int(mention.group(1))
|
||||
|
||||
try:
|
||||
input_user = self.client.resolve_peer(user_id)
|
||||
except PeerIdInvalid:
|
||||
input_user = None
|
||||
|
||||
entity = (
|
||||
Mention(offset=start, length=len(text), user_id=input_user)
|
||||
if input_user else MentionInvalid(offset=start, length=len(text), user_id=user_id)
|
||||
)
|
||||
else:
|
||||
entity = Url(offset=start, length=len(text), url=url)
|
||||
|
||||
body = text
|
||||
offset += len(url) + 4
|
||||
else:
|
||||
if style == self.BOLD_DELIMITER:
|
||||
entity = Bold(offset=start, length=len(body))
|
||||
elif style == self.ITALIC_DELIMITER:
|
||||
entity = Italic(offset=start, length=len(body))
|
||||
elif style == self.UNDERLINE_DELIMITER:
|
||||
entity = Underline(offset=start, length=len(body))
|
||||
elif style == self.STRIKE_DELIMITER:
|
||||
entity = Strike(offset=start, length=len(body))
|
||||
elif style == self.CODE_DELIMITER:
|
||||
entity = Code(offset=start, length=len(body))
|
||||
elif style == self.PRE_DELIMITER:
|
||||
entity = Pre(offset=start, length=len(body), language="")
|
||||
else:
|
||||
continue
|
||||
|
||||
offset += len(style) * 2
|
||||
|
||||
entities.append(entity)
|
||||
message = message.replace(match.group(), body)
|
||||
|
||||
# TODO: OrderedDict to be removed in Python3.6
|
||||
return OrderedDict([
|
||||
("message", utils.remove_surrogates(message)),
|
||||
("entities", entities)
|
||||
])
|
||||
|
||||
def unparse(self, message: str, entities: list):
|
||||
message = utils.add_surrogates(message).strip()
|
||||
offset = 0
|
||||
|
||||
for entity in entities:
|
||||
start = entity.offset + offset
|
||||
type = entity.type
|
||||
url = entity.url
|
||||
user = entity.user
|
||||
sub = message[start: start + entity.length]
|
||||
|
||||
if type == "bold":
|
||||
style = self.BOLD_DELIMITER
|
||||
elif type == "italic":
|
||||
style = self.ITALIC_DELIMITER
|
||||
elif type == "underline":
|
||||
style = self.UNDERLINE_DELIMITER
|
||||
elif type == "strike":
|
||||
style = self.STRIKE_DELIMITER
|
||||
elif type == "code":
|
||||
style = self.CODE_DELIMITER
|
||||
elif type == "pre":
|
||||
style = self.PRE_DELIMITER
|
||||
elif type == "text_link":
|
||||
offset += 4 + len(url)
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "[{}]({})".format(sub, url), 1)
|
||||
continue
|
||||
elif type == "text_mention":
|
||||
offset += 17 + len(str(user.id))
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "[{}](tg://user?id={})".format(sub, user.id), 1)
|
||||
continue
|
||||
if delimiter == Markdown.BOLD_DELIMITER:
|
||||
tag = "b"
|
||||
elif delimiter == Markdown.ITALIC_DELIMITER:
|
||||
tag = "i"
|
||||
elif delimiter == Markdown.UNDERLINE_DELIMITER:
|
||||
tag = "u"
|
||||
elif delimiter == Markdown.STRIKE_DELIMITER:
|
||||
tag = "s"
|
||||
elif delimiter == Markdown.CODE_DELIMITER:
|
||||
tag = "code"
|
||||
elif delimiter == Markdown.PRE_DELIMITER:
|
||||
tag = "pre"
|
||||
else:
|
||||
continue
|
||||
|
||||
offset += len(style) * 2
|
||||
message = message[:start] + message[start:].replace(
|
||||
sub, "{0}{1}{0}".format(style, sub), 1)
|
||||
if delimiter not in delimiters:
|
||||
delimiters.add(delimiter)
|
||||
tag = "<{}>".format(tag)
|
||||
else:
|
||||
delimiters.remove(delimiter)
|
||||
tag = "</{}>".format(tag)
|
||||
|
||||
return utils.remove_surrogates(message)
|
||||
text = text[:start + offset] + tag + text[stop + offset:]
|
||||
|
||||
offset += len(tag) - len(delimiter)
|
||||
|
||||
offset = 0
|
||||
|
||||
for match in re.finditer(Markdown.URL_RE, text):
|
||||
start, stop = match.span()
|
||||
full = match.group(0)
|
||||
body, url = match.groups()
|
||||
replace = '<a href="{}">{}</a>'.format(url, body)
|
||||
|
||||
text = text[:start + offset] + replace + text[stop + offset:]
|
||||
offset += len(replace) - len(full)
|
||||
|
||||
return self.html.parse(text)
|
||||
|
Loading…
Reference in New Issue
Block a user