import time import cv2 import re from ppocronnx.predict_system import BoxedResult import module.config.server as server from module.base.button import ButtonWrapper from module.base.decorator import cached_property from module.base.utils import area_pad, corner2area, crop, float2str from module.exception import ScriptError from module.logger import logger from module.ocr.models import OCR_MODEL from module.ocr.ppocr import TextSystem from module.ocr.utils import merge_buttons def enlarge_canvas(image): """ Enlarge image into a square fill with black background. In the structure of PaddleOCR, image with w:h=1:1 is the best while 3:1 rectangles takes three times as long. Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32. """ height, width = image.shape[:2] length = int(max(width, height) // 32 * 32 + 32) border = (0, length - height, 0, length - width) if sum(border) > 0: image = cv2.copyMakeBorder(image, *border, borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0)) return image class OcrResultButton: def __init__(self, boxed_result: BoxedResult, keyword_classes: list): """ Args: boxed_result: BoxedResult from ppocr-onnx keyword_classes: List of Keyword classes """ self.area = boxed_result.box self.search = area_pad(self.area, pad=-20) # self.color = self.button = boxed_result.box try: self.matched_keyword = self.match_keyword(boxed_result.ocr_text, keyword_classes) self.name = str(self.matched_keyword) except ScriptError: self.matched_keyword = None self.name = boxed_result.ocr_text self.text = boxed_result.ocr_text self.score = boxed_result.score def match_keyword(self, ocr_text, keyword_classes): """ Args: ocr_text (str): keyword_classes: List of Keyword classes Returns: Keyword: Raises: ScriptError: If no keywords matched """ for keyword_class in keyword_classes: try: matched = keyword_class.find(ocr_text, in_current_server=True, ignore_punctuation=True) return matched except ScriptError: continue raise ScriptError def __str__(self): return self.name __repr__ = __str__ def __eq__(self, other): return str(self) == str(other) def __hash__(self): return hash(self.name) def __bool__(self): return True class Ocr: # Merge results with box distance <= thres merge_thres_x = 0 merge_thres_y = 0 def __init__(self, button: ButtonWrapper, lang=None, name=None): self.button: ButtonWrapper = button self.lang: str = lang if lang is not None else Ocr.server2lang() self.name: str = name if name is not None else button.name @classmethod def server2lang(cls, ser=None) -> str: if ser is None: ser = server.server match ser: case 'cn': return 'ch' case _: return 'ch' @cached_property def model(self) -> TextSystem: return OCR_MODEL.__getattribute__(self.lang) def pre_process(self, image): """ Args: image (np.ndarray): Shape (height, width, channel) Returns: np.ndarray: Shape (width, height) """ return image def after_process(self, result): """ Args: result (str): '第二行' Returns: str: """ if result.startswith('UID'): result = 'UID' return result def ocr_single_line(self, image): # pre process start_time = time.time() image = crop(image, self.button.area) image = self.pre_process(image) # ocr result, _ = self.model.ocr_single_line(image) # after proces result = self.after_process(result) logger.attr(name='%s %ss' % (self.name, float2str(time.time() - start_time)), text=str(result)) return result def detect_and_ocr(self, image, direct_ocr=False) -> list[BoxedResult]: """ Args: image: direct_ocr: True to ignore `button` attribute and feed the image to OCR model without cropping. Returns: """ # pre process start_time = time.time() if not direct_ocr: image = crop(image, self.button.area) image = self.pre_process(image) # ocr image = enlarge_canvas(image) results: list[BoxedResult] = self.model.detect_and_ocr(image) # after proces for result in results: if not direct_ocr: result.box += self.button.area[:2] result.box = tuple(corner2area(result.box)) results = merge_buttons(results, thres_x=self.merge_thres_x, thres_y=self.merge_thres_y) for result in results: result.ocr_text = self.after_process(result.ocr_text) logger.attr(name='%s %ss' % (self.name, float2str(time.time() - start_time)), text=str([result.ocr_text for result in results])) return results def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]: """ Args: image: Screenshot keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them. direct_ocr: True to ignore `button` attribute and feed the image to OCR model without cropping. Returns: List of matched OcrResultButton. OCR result which didn't matched known keywords will be dropped. """ if not isinstance(keyword_classes, list): keyword_classes = [keyword_classes] def is_valid(keyword): # Digits will be considered as the index of keyword if keyword.isdigit(): return False return True results = self.detect_and_ocr(image, direct_ocr=direct_ocr) results = [ OcrResultButton(result, keyword_classes) for result in results if is_valid(result.ocr_text) ] results = [result for result in results if result.matched_keyword is not None] logger.attr(name=f'{self.name} matched', text=results) return results class Digit(Ocr): def __init__(self, button: ButtonWrapper, lang='ch', name=None): super().__init__(button, lang=lang, name=name) def after_process(self, result) -> int: """ Returns: int: """ result = super().after_process(result) logger.attr(name=self.name, text=str(result)) res = re.search(r'(\d+)', result) if res: return int(res.group(1)) else: logger.warning(f'No digit found in {result}') return 0 class DigitCounter(Ocr): def __init__(self, button: ButtonWrapper, lang='ch', name=None): super().__init__(button, lang=lang, name=name) def after_process(self, result) -> tuple[int, int, int]: """ Do OCR on a counter, such as `14/15`, and returns 14, 1, 15 Returns: int: """ result = super().after_process(result) logger.attr(name=self.name, text=str(result)) res = re.search(r'(\d+)/(\d+)', result) if res: groups = [int(s) for s in res.groups()] current, total = int(groups[0]), int(groups[1]) # current = min(current, total) return current, total - current, total else: logger.warning(f'No digit counter found in {result}') return 0, 0, 0