Refactor: Get in-game language from plane name

This commit is contained in:
LmeSzinc 2023-09-15 13:29:22 +08:00
parent 2c2c31cad7
commit 37e29838c5
8 changed files with 257 additions and 76 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -61,9 +61,12 @@ class Keyword:
def __bool__(self):
return True
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True):
if in_current_server:
match server.lang:
def _keywords_to_find(self, lang: str = None, ignore_punctuation=True):
if lang is None:
lang = server.lang
if lang in server.VALID_LANG:
match lang:
case 'cn':
if ignore_punctuation:
return [self.cn_parsed]
@ -122,11 +125,12 @@ class Keyword:
return name == keyword
@classmethod
def find(cls, name, in_current_server=False, ignore_punctuation=True):
def find(cls, name, lang: str = None, ignore_punctuation=True):
"""
Args:
name: Name in any server or instance id.
in_current_server: True to search the names from current server only.
lang: Lang to find from
None to search the names from current server only.
ignore_punctuation: True to remove punctuations and turn into lowercase before searching.
Returns:
@ -157,7 +161,7 @@ class Keyword:
instance: Keyword
for instance in cls.instances.values():
for keyword in instance._keywords_to_find(
in_current_server=in_current_server, ignore_punctuation=ignore_punctuation):
lang=lang, ignore_punctuation=ignore_punctuation):
if cls._compare(name, keyword):
return instance

View File

@ -1,8 +1,9 @@
import re
import time
from datetime import timedelta
from typing import Optional
import cv2
import numpy as np
from pponnxcr.predict_system import BoxedResult
import module.config.server as server
@ -11,70 +12,33 @@ from module.base.decorator import cached_property
from module.base.utils import area_pad, corner2area, crop, float2str
from module.exception import ScriptError
from module.logger import logger
from module.ocr.keyword import Keyword
from module.ocr.models import OCR_MODEL, TextSystem
from module.ocr.utils import merge_buttons
def enlarge_canvas(image):
"""
Enlarge image into a square fill with black background. In the structure of PaddleOCR,
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
No longer needed, already included in pponnxcr.
"""
height, width = image.shape[:2]
length = int(max(width, height) // 32 * 32 + 32)
border = (0, length - height, 0, length - width)
if sum(border) > 0:
image = cv2.copyMakeBorder(image, *border, borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0))
return image
class OcrResultButton:
def __init__(self, boxed_result: BoxedResult, keyword_classes: list):
def __init__(self, boxed_result: BoxedResult, matched_keyword: Optional[Keyword]):
"""
Args:
boxed_result: BoxedResult from ppocr-onnx
keyword_classes: List of Keyword classes
matched_keyword: Keyword object or None
"""
self.area = boxed_result.box
self.search = area_pad(self.area, pad=-20)
# self.color =
self.button = boxed_result.box
try:
self.matched_keyword = self.match_keyword(boxed_result.ocr_text, keyword_classes)
self.name = str(self.matched_keyword)
except ScriptError:
if matched_keyword is not None:
self.matched_keyword = matched_keyword
self.name = str(matched_keyword)
else:
self.matched_keyword = None
self.name = boxed_result.ocr_text
self.text = boxed_result.ocr_text
self.score = boxed_result.score
@staticmethod
def match_keyword(ocr_text, keyword_classes):
"""
Args:
ocr_text (str):
keyword_classes: List of Keyword classes
Returns:
Keyword:
Raises:
ScriptError: If no keywords matched
"""
for keyword_class in keyword_classes:
try:
matched = keyword_class.find(ocr_text, in_current_server=True, ignore_punctuation=True)
return matched
except ScriptError:
continue
raise ScriptError
def __str__(self):
return self.name
@ -89,6 +53,10 @@ class OcrResultButton:
def __bool__(self):
return True
@property
def is_keyword_matched(self) -> bool:
return self.matched_keyword is not None
class Ocr:
# Merge results with box distance <= thres
@ -201,6 +169,127 @@ class Ocr:
text=str([result.ocr_text for result in results]))
return results
def _match_result(
self,
result: str,
keyword_classes,
lang: str = None,
ignore_punctuation=True,
ignore_digit=True):
"""
Args:
result (str):
keyword_classes: A list of `Keyword` class or classes inherited `Keyword`
Returns:
If matched, return `Keyword` object or objects inherited `Keyword`
If not match, return None
"""
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
# Digits will be considered as the index of keyword
if ignore_digit:
if result.isdigit():
return None
# Try in current lang
for keyword_class in keyword_classes:
try:
matched = keyword_class.find(
result,
lang=lang,
ignore_punctuation=ignore_punctuation
)
return matched
except ScriptError:
continue
return None
def matched_single_line(
self,
image,
keyword_classes,
lang: str = None,
ignore_punctuation=True
) -> OcrResultButton:
"""
Args:
image: Image to detect
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
lang:
ignore_punctuation:
Returns:
OcrResultButton: Or None if it didn't matched known keywords.
"""
result = self.ocr_single_line(image)
result = self._match_result(
result,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
)
logger.attr(name=f'{self.name} matched',
text=result)
return result
def matched_multi_lines(
self,
image_list,
keyword_classes,
lang: str = None,
ignore_punctuation=True
) -> list[OcrResultButton]:
"""
Args:
image_list:
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
lang:
ignore_punctuation:
Returns:
List of matched OcrResultButton.
OCR result which didn't matched known keywords will be dropped.
"""
results = self.ocr_multi_lines(image_list)
results = [self._match_result(
result,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
) for result in results]
results = [result for result in results if result.is_keyword_matched]
logger.attr(name=f'{self.name} matched',
text=results)
return results
def _product_button(
self,
boxed_result: BoxedResult,
keyword_classes,
lang: str = None,
ignore_punctuation=True,
ignore_digit=True
) -> OcrResultButton:
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
matched_keyword = self._match_result(
boxed_result.ocr_text,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
ignore_digit=ignore_digit,
)
button = OcrResultButton(boxed_result, matched_keyword)
return button
def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]:
"""
Args:
@ -212,21 +301,11 @@ class Ocr:
List of matched OcrResultButton.
OCR result which didn't matched known keywords will be dropped.
"""
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
def is_valid(keyword):
# Digits will be considered as the index of keyword
if keyword.isdigit():
return False
return True
results = self.detect_and_ocr(image, direct_ocr=direct_ocr)
results = [
OcrResultButton(result, keyword_classes)
for result in results if is_valid(result.ocr_text)
]
results = [result for result in results if result.matched_keyword is not None]
results = [self._product_button(result, keyword_classes) for result in results]
results = [result for result in results if result.is_keyword_matched]
logger.attr(name=f'{self.name} matched',
text=results)
return results

View File

@ -182,7 +182,7 @@ class DraggableList:
main.wait_until_stable(self.search_button, timer=Timer(
0, count=0), timeout=Timer(1.5, count=5))
skip_first_screenshot = True
if last_buttons == set(self.cur_buttons):
if self.cur_buttons and last_buttons == set(self.cur_buttons):
logger.warning(f'No more rows in {self}')
return False
last_buttons = set(self.cur_buttons)

View File

@ -0,0 +1,15 @@
from module.base.button import Button, ButtonWrapper
# This file was auto-generated, do not modify it manually. To generate:
# ``` python -m dev_tools.button_extract ```
OCR_MAP_NAME = ButtonWrapper(
name='OCR_MAP_NAME',
share=Button(
file='./assets/share/base/main_page/OCR_MAP_NAME.png',
area=(48, 15, 373, 32),
search=(28, 0, 393, 52),
color=(69, 72, 78),
button=(48, 15, 373, 32),
),
)

80
tasks/base/main_page.py Normal file
View File

@ -0,0 +1,80 @@
import re
from typing import Optional
import module.config.server as server
from module.base.base import ModuleBase
from module.config.server import VALID_LANG
from module.exception import RequestHumanTakeover, ScriptError
from module.logger import logger
from module.ocr.ocr import Ocr
from tasks.base.assets.assets_base_main_page import OCR_MAP_NAME
from tasks.base.page import Page, page_main
from tasks.map.keywords import KEYWORDS_MAP_PLANE, MapPlane
class OcrPlaneName(Ocr):
def after_process(self, result):
# RobotSettlement1
result = re.sub(r'\d+$', '', result)
return super().after_process(result)
class MainPage(ModuleBase):
# Same as BigmapPlane class
# Current plane
plane: MapPlane = KEYWORDS_MAP_PLANE.Herta_ParlorCar
_lang_checked = False
def check_lang_from_map_plane(self) -> Optional[str]:
logger.info('check_lang_from_map_plane')
lang_unknown = self.config.Emulator_GameLanguage == 'auto'
if lang_unknown:
lang_list = VALID_LANG
else:
# Try current lang first
lang_list = [server.lang] + [lang for lang in VALID_LANG if lang != server.lang]
for lang in lang_list:
logger.info(f'Try ocr in lang {lang}')
ocr = OcrPlaneName(OCR_MAP_NAME, lang=lang)
result = ocr.ocr_single_line(self.device.image)
keyword = ocr._match_result(result, keyword_classes=MapPlane, lang=lang)
if keyword is not None:
self.plane = keyword
logger.attr('CurrentPlane', self.plane)
logger.info(f'check_lang_from_map_plane matched lang: {lang}')
if lang_unknown or lang != server.lang:
self.config.Emulator_GameLanguage = lang
return lang
if lang_unknown:
logger.critical('Cannot detect in-game text language, please set it to 简体中文 or English')
raise RequestHumanTakeover
else:
logger.warning(f'Cannot detect in-game text language, assume current lang={server.lang} is correct')
return server.lang
def handle_lang_check(self, page: Page):
if MainPage._lang_checked:
return
if page != page_main:
return
self.check_lang_from_map_plane()
MainPage._lang_checked = True
def acquire_lang_checked(self):
if MainPage._lang_checked:
return
logger.info('acquire_lang_checked')
try:
self.ui_goto(page_main)
except AttributeError:
logger.critical('Method ui_goto() not found, class MainPage must be inherited by class UI')
raise ScriptError
self.handle_lang_check(page=page_main)

View File

@ -5,13 +5,14 @@ from module.exception import GameNotRunningError, GamePageUnknownError
from module.logger import logger
from module.ocr.ocr import Ocr
from tasks.base.assets.assets_base_page import CLOSE
from tasks.base.main_page import MainPage
from tasks.base.page import Page, page_main
from tasks.base.popup import PopupHandler
from tasks.combat.assets.assets_combat_finish import COMBAT_EXIT
from tasks.combat.assets.assets_combat_prepare import COMBAT_PREPARE
class UI(PopupHandler):
class UI(PopupHandler, MainPage):
ui_current: Page
ui_main_confirm_timer = Timer(0.2, count=0)
@ -124,6 +125,7 @@ class UI(PopupHandler):
continue
if self.appear(page.check_button, interval=5):
logger.info(f'Page switch: {page} -> {page.parent}')
self.handle_lang_check(page)
if self.ui_page_confirm(page):
logger.info(f'Page arrive confirm {page}')
button = page.links[page.parent]
@ -151,6 +153,7 @@ class UI(PopupHandler):
bool: If UI switched.
"""
logger.hr("UI ensure")
self.acquire_lang_checked()
self.ui_get_current_page(skip_first_screenshot=skip_first_screenshot)
if self.ui_current == destination:
logger.info("Already at %s" % destination)

View File

@ -22,7 +22,7 @@ class ForgottenHallStageOcr(Ocr):
raw = image.copy()
area = OCR_STAGE.area
image = crop(raw, area)
yellow = color_similarity_2d(image, color=(250, 201, 111))
yellow = color_similarity_2d(image, color=(255, 200, 112))
gray = color_similarity_2d(image, color=(100, 109, 134))
image = np.maximum(yellow, gray)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
@ -36,7 +36,7 @@ class ForgottenHallStageOcr(Ocr):
for cont in contours:
rect = cv2.boundingRect(cv2.convexHull(cont).astype(np.float32))
# Filter with rectangle width, usually to be 62~64
if not 62 - 10 < rect[2] < 62 + 10:
if not 62 - 10 < rect[2] < 65 + 10:
continue
rect = (rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3])
rect = area_offset(rect, offset=area[:2])
@ -52,16 +52,16 @@ class ForgottenHallStageOcr(Ocr):
boxes = self._find_number(image)
image_list = [crop(image, area) for area in boxes]
results = self.ocr_multi_lines(image_list)
boxed_results = [
results = [
BoxedResult(area_offset(boxes[index], (-50, 0)), image_list[index], text, score)
for index, (text, score) in enumerate(results)
]
results_buttons = [
OcrResultButton(result, keyword_classes)
for result in boxed_results
]
logger.attr(name=f'{self.name} matched', text=results_buttons)
return results_buttons
results = [self._product_button(result, keyword_classes, ignore_digit=False) for result in results]
results = [result for result in results if result.is_keyword_matched]
logger.attr(name=f'{self.name} matched', text=results)
return results
class DraggableStageList(DraggableList):