Refactor: Get in-game language from plane name

This commit is contained in:
LmeSzinc 2023-09-15 13:29:22 +08:00
parent 2c2c31cad7
commit 37e29838c5
8 changed files with 257 additions and 76 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@ -61,9 +61,12 @@ class Keyword:
def __bool__(self): def __bool__(self):
return True return True
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True): def _keywords_to_find(self, lang: str = None, ignore_punctuation=True):
if in_current_server: if lang is None:
match server.lang: lang = server.lang
if lang in server.VALID_LANG:
match lang:
case 'cn': case 'cn':
if ignore_punctuation: if ignore_punctuation:
return [self.cn_parsed] return [self.cn_parsed]
@ -122,11 +125,12 @@ class Keyword:
return name == keyword return name == keyword
@classmethod @classmethod
def find(cls, name, in_current_server=False, ignore_punctuation=True): def find(cls, name, lang: str = None, ignore_punctuation=True):
""" """
Args: Args:
name: Name in any server or instance id. name: Name in any server or instance id.
in_current_server: True to search the names from current server only. lang: Lang to find from
None to search the names from current server only.
ignore_punctuation: True to remove punctuations and turn into lowercase before searching. ignore_punctuation: True to remove punctuations and turn into lowercase before searching.
Returns: Returns:
@ -157,7 +161,7 @@ class Keyword:
instance: Keyword instance: Keyword
for instance in cls.instances.values(): for instance in cls.instances.values():
for keyword in instance._keywords_to_find( for keyword in instance._keywords_to_find(
in_current_server=in_current_server, ignore_punctuation=ignore_punctuation): lang=lang, ignore_punctuation=ignore_punctuation):
if cls._compare(name, keyword): if cls._compare(name, keyword):
return instance return instance

View File

@ -1,8 +1,9 @@
import re import re
import time import time
from datetime import timedelta from datetime import timedelta
from typing import Optional
import cv2 import numpy as np
from pponnxcr.predict_system import BoxedResult from pponnxcr.predict_system import BoxedResult
import module.config.server as server import module.config.server as server
@ -11,70 +12,33 @@ from module.base.decorator import cached_property
from module.base.utils import area_pad, corner2area, crop, float2str from module.base.utils import area_pad, corner2area, crop, float2str
from module.exception import ScriptError from module.exception import ScriptError
from module.logger import logger from module.logger import logger
from module.ocr.keyword import Keyword
from module.ocr.models import OCR_MODEL, TextSystem from module.ocr.models import OCR_MODEL, TextSystem
from module.ocr.utils import merge_buttons from module.ocr.utils import merge_buttons
def enlarge_canvas(image):
"""
Enlarge image into a square fill with black background. In the structure of PaddleOCR,
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
No longer needed, already included in pponnxcr.
"""
height, width = image.shape[:2]
length = int(max(width, height) // 32 * 32 + 32)
border = (0, length - height, 0, length - width)
if sum(border) > 0:
image = cv2.copyMakeBorder(image, *border, borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0))
return image
class OcrResultButton: class OcrResultButton:
def __init__(self, boxed_result: BoxedResult, keyword_classes: list): def __init__(self, boxed_result: BoxedResult, matched_keyword: Optional[Keyword]):
""" """
Args: Args:
boxed_result: BoxedResult from ppocr-onnx boxed_result: BoxedResult from ppocr-onnx
keyword_classes: List of Keyword classes matched_keyword: Keyword object or None
""" """
self.area = boxed_result.box self.area = boxed_result.box
self.search = area_pad(self.area, pad=-20) self.search = area_pad(self.area, pad=-20)
# self.color = # self.color =
self.button = boxed_result.box self.button = boxed_result.box
try: if matched_keyword is not None:
self.matched_keyword = self.match_keyword(boxed_result.ocr_text, keyword_classes) self.matched_keyword = matched_keyword
self.name = str(self.matched_keyword) self.name = str(matched_keyword)
except ScriptError: else:
self.matched_keyword = None self.matched_keyword = None
self.name = boxed_result.ocr_text self.name = boxed_result.ocr_text
self.text = boxed_result.ocr_text self.text = boxed_result.ocr_text
self.score = boxed_result.score self.score = boxed_result.score
@staticmethod
def match_keyword(ocr_text, keyword_classes):
"""
Args:
ocr_text (str):
keyword_classes: List of Keyword classes
Returns:
Keyword:
Raises:
ScriptError: If no keywords matched
"""
for keyword_class in keyword_classes:
try:
matched = keyword_class.find(ocr_text, in_current_server=True, ignore_punctuation=True)
return matched
except ScriptError:
continue
raise ScriptError
def __str__(self): def __str__(self):
return self.name return self.name
@ -89,6 +53,10 @@ class OcrResultButton:
def __bool__(self): def __bool__(self):
return True return True
@property
def is_keyword_matched(self) -> bool:
return self.matched_keyword is not None
class Ocr: class Ocr:
# Merge results with box distance <= thres # Merge results with box distance <= thres
@ -201,6 +169,127 @@ class Ocr:
text=str([result.ocr_text for result in results])) text=str([result.ocr_text for result in results]))
return results return results
def _match_result(
self,
result: str,
keyword_classes,
lang: str = None,
ignore_punctuation=True,
ignore_digit=True):
"""
Args:
result (str):
keyword_classes: A list of `Keyword` class or classes inherited `Keyword`
Returns:
If matched, return `Keyword` object or objects inherited `Keyword`
If not match, return None
"""
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
# Digits will be considered as the index of keyword
if ignore_digit:
if result.isdigit():
return None
# Try in current lang
for keyword_class in keyword_classes:
try:
matched = keyword_class.find(
result,
lang=lang,
ignore_punctuation=ignore_punctuation
)
return matched
except ScriptError:
continue
return None
def matched_single_line(
self,
image,
keyword_classes,
lang: str = None,
ignore_punctuation=True
) -> OcrResultButton:
"""
Args:
image: Image to detect
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
lang:
ignore_punctuation:
Returns:
OcrResultButton: Or None if it didn't matched known keywords.
"""
result = self.ocr_single_line(image)
result = self._match_result(
result,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
)
logger.attr(name=f'{self.name} matched',
text=result)
return result
def matched_multi_lines(
self,
image_list,
keyword_classes,
lang: str = None,
ignore_punctuation=True
) -> list[OcrResultButton]:
"""
Args:
image_list:
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
lang:
ignore_punctuation:
Returns:
List of matched OcrResultButton.
OCR result which didn't matched known keywords will be dropped.
"""
results = self.ocr_multi_lines(image_list)
results = [self._match_result(
result,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
) for result in results]
results = [result for result in results if result.is_keyword_matched]
logger.attr(name=f'{self.name} matched',
text=results)
return results
def _product_button(
self,
boxed_result: BoxedResult,
keyword_classes,
lang: str = None,
ignore_punctuation=True,
ignore_digit=True
) -> OcrResultButton:
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
matched_keyword = self._match_result(
boxed_result.ocr_text,
keyword_classes=keyword_classes,
lang=lang,
ignore_punctuation=ignore_punctuation,
ignore_digit=ignore_digit,
)
button = OcrResultButton(boxed_result, matched_keyword)
return button
def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]: def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]:
""" """
Args: Args:
@ -212,21 +301,11 @@ class Ocr:
List of matched OcrResultButton. List of matched OcrResultButton.
OCR result which didn't matched known keywords will be dropped. OCR result which didn't matched known keywords will be dropped.
""" """
if not isinstance(keyword_classes, list):
keyword_classes = [keyword_classes]
def is_valid(keyword):
# Digits will be considered as the index of keyword
if keyword.isdigit():
return False
return True
results = self.detect_and_ocr(image, direct_ocr=direct_ocr) results = self.detect_and_ocr(image, direct_ocr=direct_ocr)
results = [
OcrResultButton(result, keyword_classes) results = [self._product_button(result, keyword_classes) for result in results]
for result in results if is_valid(result.ocr_text) results = [result for result in results if result.is_keyword_matched]
]
results = [result for result in results if result.matched_keyword is not None]
logger.attr(name=f'{self.name} matched', logger.attr(name=f'{self.name} matched',
text=results) text=results)
return results return results

View File

@ -182,7 +182,7 @@ class DraggableList:
main.wait_until_stable(self.search_button, timer=Timer( main.wait_until_stable(self.search_button, timer=Timer(
0, count=0), timeout=Timer(1.5, count=5)) 0, count=0), timeout=Timer(1.5, count=5))
skip_first_screenshot = True skip_first_screenshot = True
if last_buttons == set(self.cur_buttons): if self.cur_buttons and last_buttons == set(self.cur_buttons):
logger.warning(f'No more rows in {self}') logger.warning(f'No more rows in {self}')
return False return False
last_buttons = set(self.cur_buttons) last_buttons = set(self.cur_buttons)

View File

@ -0,0 +1,15 @@
from module.base.button import Button, ButtonWrapper
# This file was auto-generated, do not modify it manually. To generate:
# ``` python -m dev_tools.button_extract ```
OCR_MAP_NAME = ButtonWrapper(
name='OCR_MAP_NAME',
share=Button(
file='./assets/share/base/main_page/OCR_MAP_NAME.png',
area=(48, 15, 373, 32),
search=(28, 0, 393, 52),
color=(69, 72, 78),
button=(48, 15, 373, 32),
),
)

80
tasks/base/main_page.py Normal file
View File

@ -0,0 +1,80 @@
import re
from typing import Optional
import module.config.server as server
from module.base.base import ModuleBase
from module.config.server import VALID_LANG
from module.exception import RequestHumanTakeover, ScriptError
from module.logger import logger
from module.ocr.ocr import Ocr
from tasks.base.assets.assets_base_main_page import OCR_MAP_NAME
from tasks.base.page import Page, page_main
from tasks.map.keywords import KEYWORDS_MAP_PLANE, MapPlane
class OcrPlaneName(Ocr):
def after_process(self, result):
# RobotSettlement1
result = re.sub(r'\d+$', '', result)
return super().after_process(result)
class MainPage(ModuleBase):
# Same as BigmapPlane class
# Current plane
plane: MapPlane = KEYWORDS_MAP_PLANE.Herta_ParlorCar
_lang_checked = False
def check_lang_from_map_plane(self) -> Optional[str]:
logger.info('check_lang_from_map_plane')
lang_unknown = self.config.Emulator_GameLanguage == 'auto'
if lang_unknown:
lang_list = VALID_LANG
else:
# Try current lang first
lang_list = [server.lang] + [lang for lang in VALID_LANG if lang != server.lang]
for lang in lang_list:
logger.info(f'Try ocr in lang {lang}')
ocr = OcrPlaneName(OCR_MAP_NAME, lang=lang)
result = ocr.ocr_single_line(self.device.image)
keyword = ocr._match_result(result, keyword_classes=MapPlane, lang=lang)
if keyword is not None:
self.plane = keyword
logger.attr('CurrentPlane', self.plane)
logger.info(f'check_lang_from_map_plane matched lang: {lang}')
if lang_unknown or lang != server.lang:
self.config.Emulator_GameLanguage = lang
return lang
if lang_unknown:
logger.critical('Cannot detect in-game text language, please set it to 简体中文 or English')
raise RequestHumanTakeover
else:
logger.warning(f'Cannot detect in-game text language, assume current lang={server.lang} is correct')
return server.lang
def handle_lang_check(self, page: Page):
if MainPage._lang_checked:
return
if page != page_main:
return
self.check_lang_from_map_plane()
MainPage._lang_checked = True
def acquire_lang_checked(self):
if MainPage._lang_checked:
return
logger.info('acquire_lang_checked')
try:
self.ui_goto(page_main)
except AttributeError:
logger.critical('Method ui_goto() not found, class MainPage must be inherited by class UI')
raise ScriptError
self.handle_lang_check(page=page_main)

View File

@ -5,13 +5,14 @@ from module.exception import GameNotRunningError, GamePageUnknownError
from module.logger import logger from module.logger import logger
from module.ocr.ocr import Ocr from module.ocr.ocr import Ocr
from tasks.base.assets.assets_base_page import CLOSE from tasks.base.assets.assets_base_page import CLOSE
from tasks.base.main_page import MainPage
from tasks.base.page import Page, page_main from tasks.base.page import Page, page_main
from tasks.base.popup import PopupHandler from tasks.base.popup import PopupHandler
from tasks.combat.assets.assets_combat_finish import COMBAT_EXIT from tasks.combat.assets.assets_combat_finish import COMBAT_EXIT
from tasks.combat.assets.assets_combat_prepare import COMBAT_PREPARE from tasks.combat.assets.assets_combat_prepare import COMBAT_PREPARE
class UI(PopupHandler): class UI(PopupHandler, MainPage):
ui_current: Page ui_current: Page
ui_main_confirm_timer = Timer(0.2, count=0) ui_main_confirm_timer = Timer(0.2, count=0)
@ -124,6 +125,7 @@ class UI(PopupHandler):
continue continue
if self.appear(page.check_button, interval=5): if self.appear(page.check_button, interval=5):
logger.info(f'Page switch: {page} -> {page.parent}') logger.info(f'Page switch: {page} -> {page.parent}')
self.handle_lang_check(page)
if self.ui_page_confirm(page): if self.ui_page_confirm(page):
logger.info(f'Page arrive confirm {page}') logger.info(f'Page arrive confirm {page}')
button = page.links[page.parent] button = page.links[page.parent]
@ -151,6 +153,7 @@ class UI(PopupHandler):
bool: If UI switched. bool: If UI switched.
""" """
logger.hr("UI ensure") logger.hr("UI ensure")
self.acquire_lang_checked()
self.ui_get_current_page(skip_first_screenshot=skip_first_screenshot) self.ui_get_current_page(skip_first_screenshot=skip_first_screenshot)
if self.ui_current == destination: if self.ui_current == destination:
logger.info("Already at %s" % destination) logger.info("Already at %s" % destination)

View File

@ -22,7 +22,7 @@ class ForgottenHallStageOcr(Ocr):
raw = image.copy() raw = image.copy()
area = OCR_STAGE.area area = OCR_STAGE.area
image = crop(raw, area) image = crop(raw, area)
yellow = color_similarity_2d(image, color=(250, 201, 111)) yellow = color_similarity_2d(image, color=(255, 200, 112))
gray = color_similarity_2d(image, color=(100, 109, 134)) gray = color_similarity_2d(image, color=(100, 109, 134))
image = np.maximum(yellow, gray) image = np.maximum(yellow, gray)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
@ -36,7 +36,7 @@ class ForgottenHallStageOcr(Ocr):
for cont in contours: for cont in contours:
rect = cv2.boundingRect(cv2.convexHull(cont).astype(np.float32)) rect = cv2.boundingRect(cv2.convexHull(cont).astype(np.float32))
# Filter with rectangle width, usually to be 62~64 # Filter with rectangle width, usually to be 62~64
if not 62 - 10 < rect[2] < 62 + 10: if not 62 - 10 < rect[2] < 65 + 10:
continue continue
rect = (rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3]) rect = (rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3])
rect = area_offset(rect, offset=area[:2]) rect = area_offset(rect, offset=area[:2])
@ -52,16 +52,16 @@ class ForgottenHallStageOcr(Ocr):
boxes = self._find_number(image) boxes = self._find_number(image)
image_list = [crop(image, area) for area in boxes] image_list = [crop(image, area) for area in boxes]
results = self.ocr_multi_lines(image_list) results = self.ocr_multi_lines(image_list)
boxed_results = [ results = [
BoxedResult(area_offset(boxes[index], (-50, 0)), image_list[index], text, score) BoxedResult(area_offset(boxes[index], (-50, 0)), image_list[index], text, score)
for index, (text, score) in enumerate(results) for index, (text, score) in enumerate(results)
] ]
results_buttons = [
OcrResultButton(result, keyword_classes) results = [self._product_button(result, keyword_classes, ignore_digit=False) for result in results]
for result in boxed_results results = [result for result in results if result.is_keyword_matched]
]
logger.attr(name=f'{self.name} matched', text=results_buttons) logger.attr(name=f'{self.name} matched', text=results)
return results_buttons return results
class DraggableStageList(DraggableList): class DraggableStageList(DraggableList):