diff --git a/pyrogram/client/style/html.py b/pyrogram/client/style/html.py index b42114a8..82921f4c 100644 --- a/pyrogram/client/style/html.py +++ b/pyrogram/client/style/html.py @@ -16,127 +16,110 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import html import re from collections import OrderedDict +from html.parser import HTMLParser import pyrogram -from pyrogram.api.types import ( - MessageEntityBold as Bold, - MessageEntityItalic as Italic, - MessageEntityCode as Code, - MessageEntityTextUrl as Url, - MessageEntityPre as Pre, - MessageEntityUnderline as Underline, - MessageEntityStrike as Strike, - MessageEntityBlockquote as Blockquote, - MessageEntityMentionName as MentionInvalid, - InputMessageEntityMentionName as Mention, -) +from pyrogram.api import types from pyrogram.errors import PeerIdInvalid from . import utils -class HTML: - HTML_RE = re.compile(r"<(\w+)(?: href=([\"'])([^<]+)\2)?>([^>]+)") +class Parser(HTMLParser): MENTION_RE = re.compile(r"tg://user\?id=(\d+)") + def __init__(self, client: "pyrogram.BaseClient"): + super().__init__() + + self.client = client + + self.text = "" + self.entities = [] + self.temp_entities = [] + self.tags = [] + + def handle_starttag(self, tag, attrs): + attrs = dict(attrs) + extra = {} + + if tag in ["b", "strong"]: + entity = types.MessageEntityBold + elif tag in ["i", "em"]: + entity = types.MessageEntityItalic + elif tag == "u": + entity = types.MessageEntityUnderline + elif tag in ["s", "del", "strike"]: + entity = types.MessageEntityStrike + elif tag == "blockquote": + entity = types.MessageEntityBlockquote + elif tag == "code": + entity = types.MessageEntityCode + elif tag == "pre": + entity = types.MessageEntityPre + extra["language"] = "" + elif tag == "a": + url = attrs.get("href", "") + + mention = Parser.MENTION_RE.match(url) + + if mention: + user_id = int(mention.group(1)) + + try: + user = self.client.resolve_peer(user_id) + except PeerIdInvalid: + entity = types.MessageEntityMentionName + extra["user_id"] = user_id + else: + entity = types.InputMessageEntityMentionName + extra["user_id"] = user + else: + entity = types.MessageEntityTextUrl + extra["url"] = url + else: + return + + self.tags.append(tag) + self.temp_entities.append(entity(offset=len(self.text), length=0, **extra)) + + def handle_data(self, data): + data = html.unescape(data) + + for entity in self.temp_entities: + entity.length += len(data) + + self.text += data + + def handle_endtag(self, tag): + start_tag = self.tags.pop() + + if start_tag != tag: + line, offset = self.getpos() + offset += 1 + + raise ValueError("Expected end tag , but found at {}:{}".format(start_tag, tag, line, offset)) + + self.entities.append(self.temp_entities.pop()) + + def error(self, message): + pass + + +class HTML: def __init__(self, client: "pyrogram.BaseClient" = None): self.client = client - def parse(self, message: str): - entities = [] - message = utils.add_surrogates(str(message or "")) - offset = 0 + def parse(self, text: str): + text = utils.add_surrogates(str(text or "").strip()) - for match in self.HTML_RE.finditer(message): - start = match.start() - offset - style, url, body = match.group(1, 3, 4) + parser = Parser(self.client) + parser.feed(text) + print(parser.entities) - if url: - mention = self.MENTION_RE.match(url) - - if mention: - user_id = int(mention.group(1)) - - try: - input_user = self.client.resolve_peer(user_id) - except PeerIdInvalid: - input_user = None - - entity = ( - Mention(offset=start, length=len(body), user_id=input_user) - if input_user else MentionInvalid(offset=start, length=len(body), user_id=user_id) - ) - else: - entity = Url(offset=start, length=len(body), url=url) - else: - if style == "b" or style == "strong": - entity = Bold(offset=start, length=len(body)) - elif style == "i" or style == "em": - entity = Italic(offset=start, length=len(body)) - elif style == "code": - entity = Code(offset=start, length=len(body)) - elif style == "pre": - entity = Pre(offset=start, length=len(body), language="") - elif style == "u": - entity = Underline(offset=start, length=len(body)) - elif style in ["strike", "s", "del"]: - entity = Strike(offset=start, length=len(body)) - elif style == "blockquote": - entity = Blockquote(offset=start, length=len(body)) - else: - continue - - entities.append(entity) - message = message.replace(match.group(), body) - offset += len(style) * 2 + 5 + (len(url) + 8 if url else 0) - - # TODO: OrderedDict to be removed in Python3.6 + # TODO: OrderedDict to be removed in Python 3.6 return OrderedDict([ - ("message", utils.remove_surrogates(message)), - ("entities", entities) + ("message", utils.remove_surrogates(parser.text)), + ("entities", parser.entities) ]) - - def unparse(self, message: str, entities: list): - message = utils.add_surrogates(message).strip() - offset = 0 - - for entity in entities: - start = entity.offset + offset - type = entity.type - url = entity.url - user = entity.user - sub = message[start: start + entity.length] - - if type == "bold": - style = "b" - elif type == "italic": - style = "i" - elif type == "code": - style = "code" - elif type == "pre": - style = "pre" - elif type == "underline": - style = "u" - elif type == "strike": - style = "s" - elif type == "blockquote": - style = "blockquote" - elif type == "text_link": - offset += 15 + len(url) - message = message[:start] + message[start:].replace( - sub, "{}".format(url, sub), 1) - continue - elif type == "text_mention": - offset += 28 + len(str(user.id)) - message = message[:start] + message[start:].replace( - sub, "{}".format(user.id, sub), 1) - continue - else: - continue - - offset += len(style) * 2 + 5 - message = message[:start] + message[start:].replace( - sub, "<{0}>{1}".format(style, sub), 1) - - return utils.remove_surrogates(message) diff --git a/pyrogram/client/style/markdown.py b/pyrogram/client/style/markdown.py index 9dded1f3..26effe5c 100644 --- a/pyrogram/client/style/markdown.py +++ b/pyrogram/client/style/markdown.py @@ -17,22 +17,9 @@ # along with Pyrogram. If not, see . import re -from collections import OrderedDict import pyrogram -from pyrogram.api.types import ( - MessageEntityBold as Bold, - MessageEntityItalic as Italic, - MessageEntityCode as Code, - MessageEntityTextUrl as Url, - MessageEntityPre as Pre, - MessageEntityUnderline as Underline, - MessageEntityStrike as Strike, - MessageEntityMentionName as MentionInvalid, - InputMessageEntityMentionName as Mention -) -from pyrogram.errors import PeerIdInvalid -from . import utils +from .html import HTML class Markdown: @@ -43,10 +30,10 @@ class Markdown: CODE_DELIMITER = "`" PRE_DELIMITER = "```" - MARKDOWN_RE = re.compile(r"({d})([\w\W]*?)\1|\[([^[]+?)\]\(([^(]+?)\)".format( + MARKDOWN_RE = re.compile(r"({d})".format( d="|".join( ["".join(i) for i in [ - ["\{}".format(j) for j in i] + [r"\{}".format(j) for j in i] for i in [ PRE_DELIMITER, CODE_DELIMITER, @@ -56,107 +43,56 @@ class Markdown: BOLD_DELIMITER ] ]] - ) - )) - MENTION_RE = re.compile(r"tg://user\?id=(\d+)") + ))) - def __init__(self, client: "pyrogram.BaseClient" = None): - self.client = client + URL_RE = re.compile(r"\[([^[]+)]\(([^(]+)\)") - def parse(self, message: str): - message = utils.add_surrogates(str(message or "")).strip() - entities = [] + def __init__(self, client: "pyrogram.BaseClient"): + self.html = HTML(client) + + def parse(self, text: str): offset = 0 + delimiters = set() - for match in self.MARKDOWN_RE.finditer(message): - start = match.start() - offset - style, body, text, url = match.groups() + for i, match in enumerate(re.finditer(Markdown.MARKDOWN_RE, text)): + start, stop = match.span() + delimiter = match.group(1) - if url: - mention = self.MENTION_RE.match(url) - - if mention: - user_id = int(mention.group(1)) - - try: - input_user = self.client.resolve_peer(user_id) - except PeerIdInvalid: - input_user = None - - entity = ( - Mention(offset=start, length=len(text), user_id=input_user) - if input_user else MentionInvalid(offset=start, length=len(text), user_id=user_id) - ) - else: - entity = Url(offset=start, length=len(text), url=url) - - body = text - offset += len(url) + 4 - else: - if style == self.BOLD_DELIMITER: - entity = Bold(offset=start, length=len(body)) - elif style == self.ITALIC_DELIMITER: - entity = Italic(offset=start, length=len(body)) - elif style == self.UNDERLINE_DELIMITER: - entity = Underline(offset=start, length=len(body)) - elif style == self.STRIKE_DELIMITER: - entity = Strike(offset=start, length=len(body)) - elif style == self.CODE_DELIMITER: - entity = Code(offset=start, length=len(body)) - elif style == self.PRE_DELIMITER: - entity = Pre(offset=start, length=len(body), language="") - else: - continue - - offset += len(style) * 2 - - entities.append(entity) - message = message.replace(match.group(), body) - - # TODO: OrderedDict to be removed in Python3.6 - return OrderedDict([ - ("message", utils.remove_surrogates(message)), - ("entities", entities) - ]) - - def unparse(self, message: str, entities: list): - message = utils.add_surrogates(message).strip() - offset = 0 - - for entity in entities: - start = entity.offset + offset - type = entity.type - url = entity.url - user = entity.user - sub = message[start: start + entity.length] - - if type == "bold": - style = self.BOLD_DELIMITER - elif type == "italic": - style = self.ITALIC_DELIMITER - elif type == "underline": - style = self.UNDERLINE_DELIMITER - elif type == "strike": - style = self.STRIKE_DELIMITER - elif type == "code": - style = self.CODE_DELIMITER - elif type == "pre": - style = self.PRE_DELIMITER - elif type == "text_link": - offset += 4 + len(url) - message = message[:start] + message[start:].replace( - sub, "[{}]({})".format(sub, url), 1) - continue - elif type == "text_mention": - offset += 17 + len(str(user.id)) - message = message[:start] + message[start:].replace( - sub, "[{}](tg://user?id={})".format(sub, user.id), 1) - continue + if delimiter == Markdown.BOLD_DELIMITER: + tag = "b" + elif delimiter == Markdown.ITALIC_DELIMITER: + tag = "i" + elif delimiter == Markdown.UNDERLINE_DELIMITER: + tag = "u" + elif delimiter == Markdown.STRIKE_DELIMITER: + tag = "s" + elif delimiter == Markdown.CODE_DELIMITER: + tag = "code" + elif delimiter == Markdown.PRE_DELIMITER: + tag = "pre" else: continue - offset += len(style) * 2 - message = message[:start] + message[start:].replace( - sub, "{0}{1}{0}".format(style, sub), 1) + if delimiter not in delimiters: + delimiters.add(delimiter) + tag = "<{}>".format(tag) + else: + delimiters.remove(delimiter) + tag = "".format(tag) - return utils.remove_surrogates(message) + text = text[:start + offset] + tag + text[stop + offset:] + + offset += len(tag) - len(delimiter) + + offset = 0 + + for match in re.finditer(Markdown.URL_RE, text): + start, stop = match.span() + full = match.group(0) + body, url = match.groups() + replace = '{}'.format(url, body) + + text = text[:start + offset] + replace + text[stop + offset:] + offset += len(replace) - len(full) + + return self.html.parse(text)