mirror of
https://github.com/LmeSzinc/StarRailCopilot.git
synced 2024-11-16 06:25:24 +00:00
Refactor: Get in-game language from plane name
This commit is contained in:
parent
2c2c31cad7
commit
37e29838c5
BIN
assets/share/base/main_page/OCR_MAP_NAME.png
Normal file
BIN
assets/share/base/main_page/OCR_MAP_NAME.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
@ -61,9 +61,12 @@ class Keyword:
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True):
|
||||
if in_current_server:
|
||||
match server.lang:
|
||||
def _keywords_to_find(self, lang: str = None, ignore_punctuation=True):
|
||||
if lang is None:
|
||||
lang = server.lang
|
||||
|
||||
if lang in server.VALID_LANG:
|
||||
match lang:
|
||||
case 'cn':
|
||||
if ignore_punctuation:
|
||||
return [self.cn_parsed]
|
||||
@ -122,11 +125,12 @@ class Keyword:
|
||||
return name == keyword
|
||||
|
||||
@classmethod
|
||||
def find(cls, name, in_current_server=False, ignore_punctuation=True):
|
||||
def find(cls, name, lang: str = None, ignore_punctuation=True):
|
||||
"""
|
||||
Args:
|
||||
name: Name in any server or instance id.
|
||||
in_current_server: True to search the names from current server only.
|
||||
lang: Lang to find from
|
||||
None to search the names from current server only.
|
||||
ignore_punctuation: True to remove punctuations and turn into lowercase before searching.
|
||||
|
||||
Returns:
|
||||
@ -157,7 +161,7 @@ class Keyword:
|
||||
instance: Keyword
|
||||
for instance in cls.instances.values():
|
||||
for keyword in instance._keywords_to_find(
|
||||
in_current_server=in_current_server, ignore_punctuation=ignore_punctuation):
|
||||
lang=lang, ignore_punctuation=ignore_punctuation):
|
||||
if cls._compare(name, keyword):
|
||||
return instance
|
||||
|
||||
|
@ -1,8 +1,9 @@
|
||||
import re
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from pponnxcr.predict_system import BoxedResult
|
||||
|
||||
import module.config.server as server
|
||||
@ -11,70 +12,33 @@ from module.base.decorator import cached_property
|
||||
from module.base.utils import area_pad, corner2area, crop, float2str
|
||||
from module.exception import ScriptError
|
||||
from module.logger import logger
|
||||
from module.ocr.keyword import Keyword
|
||||
from module.ocr.models import OCR_MODEL, TextSystem
|
||||
from module.ocr.utils import merge_buttons
|
||||
|
||||
|
||||
def enlarge_canvas(image):
|
||||
"""
|
||||
Enlarge image into a square fill with black background. In the structure of PaddleOCR,
|
||||
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
|
||||
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
|
||||
|
||||
No longer needed, already included in pponnxcr.
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
length = int(max(width, height) // 32 * 32 + 32)
|
||||
border = (0, length - height, 0, length - width)
|
||||
if sum(border) > 0:
|
||||
image = cv2.copyMakeBorder(image, *border, borderType=cv2.BORDER_CONSTANT, value=(0, 0, 0))
|
||||
return image
|
||||
|
||||
|
||||
class OcrResultButton:
|
||||
def __init__(self, boxed_result: BoxedResult, keyword_classes: list):
|
||||
def __init__(self, boxed_result: BoxedResult, matched_keyword: Optional[Keyword]):
|
||||
"""
|
||||
Args:
|
||||
boxed_result: BoxedResult from ppocr-onnx
|
||||
keyword_classes: List of Keyword classes
|
||||
matched_keyword: Keyword object or None
|
||||
"""
|
||||
self.area = boxed_result.box
|
||||
self.search = area_pad(self.area, pad=-20)
|
||||
# self.color =
|
||||
self.button = boxed_result.box
|
||||
|
||||
try:
|
||||
self.matched_keyword = self.match_keyword(boxed_result.ocr_text, keyword_classes)
|
||||
self.name = str(self.matched_keyword)
|
||||
except ScriptError:
|
||||
if matched_keyword is not None:
|
||||
self.matched_keyword = matched_keyword
|
||||
self.name = str(matched_keyword)
|
||||
else:
|
||||
self.matched_keyword = None
|
||||
self.name = boxed_result.ocr_text
|
||||
|
||||
self.text = boxed_result.ocr_text
|
||||
self.score = boxed_result.score
|
||||
|
||||
@staticmethod
|
||||
def match_keyword(ocr_text, keyword_classes):
|
||||
"""
|
||||
Args:
|
||||
ocr_text (str):
|
||||
keyword_classes: List of Keyword classes
|
||||
|
||||
Returns:
|
||||
Keyword:
|
||||
|
||||
Raises:
|
||||
ScriptError: If no keywords matched
|
||||
"""
|
||||
for keyword_class in keyword_classes:
|
||||
try:
|
||||
matched = keyword_class.find(ocr_text, in_current_server=True, ignore_punctuation=True)
|
||||
return matched
|
||||
except ScriptError:
|
||||
continue
|
||||
|
||||
raise ScriptError
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@ -89,6 +53,10 @@ class OcrResultButton:
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_keyword_matched(self) -> bool:
|
||||
return self.matched_keyword is not None
|
||||
|
||||
|
||||
class Ocr:
|
||||
# Merge results with box distance <= thres
|
||||
@ -201,6 +169,127 @@ class Ocr:
|
||||
text=str([result.ocr_text for result in results]))
|
||||
return results
|
||||
|
||||
def _match_result(
|
||||
self,
|
||||
result: str,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True,
|
||||
ignore_digit=True):
|
||||
"""
|
||||
Args:
|
||||
result (str):
|
||||
keyword_classes: A list of `Keyword` class or classes inherited `Keyword`
|
||||
|
||||
Returns:
|
||||
If matched, return `Keyword` object or objects inherited `Keyword`
|
||||
If not match, return None
|
||||
"""
|
||||
if not isinstance(keyword_classes, list):
|
||||
keyword_classes = [keyword_classes]
|
||||
|
||||
# Digits will be considered as the index of keyword
|
||||
if ignore_digit:
|
||||
if result.isdigit():
|
||||
return None
|
||||
|
||||
# Try in current lang
|
||||
for keyword_class in keyword_classes:
|
||||
try:
|
||||
matched = keyword_class.find(
|
||||
result,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation
|
||||
)
|
||||
return matched
|
||||
except ScriptError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def matched_single_line(
|
||||
self,
|
||||
image,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True
|
||||
) -> OcrResultButton:
|
||||
"""
|
||||
Args:
|
||||
image: Image to detect
|
||||
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
|
||||
lang:
|
||||
ignore_punctuation:
|
||||
|
||||
Returns:
|
||||
OcrResultButton: Or None if it didn't matched known keywords.
|
||||
"""
|
||||
result = self.ocr_single_line(image)
|
||||
|
||||
result = self._match_result(
|
||||
result,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
)
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=result)
|
||||
return result
|
||||
|
||||
def matched_multi_lines(
|
||||
self,
|
||||
image_list,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True
|
||||
) -> list[OcrResultButton]:
|
||||
"""
|
||||
Args:
|
||||
image_list:
|
||||
keyword_classes: `Keyword` class or classes inherited `Keyword`, or a list of them.
|
||||
lang:
|
||||
ignore_punctuation:
|
||||
|
||||
Returns:
|
||||
List of matched OcrResultButton.
|
||||
OCR result which didn't matched known keywords will be dropped.
|
||||
"""
|
||||
results = self.ocr_multi_lines(image_list)
|
||||
|
||||
results = [self._match_result(
|
||||
result,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
) for result in results]
|
||||
results = [result for result in results if result.is_keyword_matched]
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=results)
|
||||
return results
|
||||
|
||||
def _product_button(
|
||||
self,
|
||||
boxed_result: BoxedResult,
|
||||
keyword_classes,
|
||||
lang: str = None,
|
||||
ignore_punctuation=True,
|
||||
ignore_digit=True
|
||||
) -> OcrResultButton:
|
||||
if not isinstance(keyword_classes, list):
|
||||
keyword_classes = [keyword_classes]
|
||||
|
||||
matched_keyword = self._match_result(
|
||||
boxed_result.ocr_text,
|
||||
keyword_classes=keyword_classes,
|
||||
lang=lang,
|
||||
ignore_punctuation=ignore_punctuation,
|
||||
ignore_digit=ignore_digit,
|
||||
)
|
||||
button = OcrResultButton(boxed_result, matched_keyword)
|
||||
return button
|
||||
|
||||
def matched_ocr(self, image, keyword_classes, direct_ocr=False) -> list[OcrResultButton]:
|
||||
"""
|
||||
Args:
|
||||
@ -212,21 +301,11 @@ class Ocr:
|
||||
List of matched OcrResultButton.
|
||||
OCR result which didn't matched known keywords will be dropped.
|
||||
"""
|
||||
if not isinstance(keyword_classes, list):
|
||||
keyword_classes = [keyword_classes]
|
||||
|
||||
def is_valid(keyword):
|
||||
# Digits will be considered as the index of keyword
|
||||
if keyword.isdigit():
|
||||
return False
|
||||
return True
|
||||
|
||||
results = self.detect_and_ocr(image, direct_ocr=direct_ocr)
|
||||
results = [
|
||||
OcrResultButton(result, keyword_classes)
|
||||
for result in results if is_valid(result.ocr_text)
|
||||
]
|
||||
results = [result for result in results if result.matched_keyword is not None]
|
||||
|
||||
results = [self._product_button(result, keyword_classes) for result in results]
|
||||
results = [result for result in results if result.is_keyword_matched]
|
||||
|
||||
logger.attr(name=f'{self.name} matched',
|
||||
text=results)
|
||||
return results
|
||||
|
@ -182,7 +182,7 @@ class DraggableList:
|
||||
main.wait_until_stable(self.search_button, timer=Timer(
|
||||
0, count=0), timeout=Timer(1.5, count=5))
|
||||
skip_first_screenshot = True
|
||||
if last_buttons == set(self.cur_buttons):
|
||||
if self.cur_buttons and last_buttons == set(self.cur_buttons):
|
||||
logger.warning(f'No more rows in {self}')
|
||||
return False
|
||||
last_buttons = set(self.cur_buttons)
|
||||
|
15
tasks/base/assets/assets_base_main_page.py
Normal file
15
tasks/base/assets/assets_base_main_page.py
Normal file
@ -0,0 +1,15 @@
|
||||
from module.base.button import Button, ButtonWrapper
|
||||
|
||||
# This file was auto-generated, do not modify it manually. To generate:
|
||||
# ``` python -m dev_tools.button_extract ```
|
||||
|
||||
OCR_MAP_NAME = ButtonWrapper(
|
||||
name='OCR_MAP_NAME',
|
||||
share=Button(
|
||||
file='./assets/share/base/main_page/OCR_MAP_NAME.png',
|
||||
area=(48, 15, 373, 32),
|
||||
search=(28, 0, 393, 52),
|
||||
color=(69, 72, 78),
|
||||
button=(48, 15, 373, 32),
|
||||
),
|
||||
)
|
80
tasks/base/main_page.py
Normal file
80
tasks/base/main_page.py
Normal file
@ -0,0 +1,80 @@
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import module.config.server as server
|
||||
from module.base.base import ModuleBase
|
||||
from module.config.server import VALID_LANG
|
||||
from module.exception import RequestHumanTakeover, ScriptError
|
||||
from module.logger import logger
|
||||
from module.ocr.ocr import Ocr
|
||||
from tasks.base.assets.assets_base_main_page import OCR_MAP_NAME
|
||||
from tasks.base.page import Page, page_main
|
||||
from tasks.map.keywords import KEYWORDS_MAP_PLANE, MapPlane
|
||||
|
||||
|
||||
class OcrPlaneName(Ocr):
|
||||
def after_process(self, result):
|
||||
# RobotSettlement1
|
||||
result = re.sub(r'\d+$', '', result)
|
||||
|
||||
return super().after_process(result)
|
||||
|
||||
|
||||
class MainPage(ModuleBase):
|
||||
# Same as BigmapPlane class
|
||||
# Current plane
|
||||
plane: MapPlane = KEYWORDS_MAP_PLANE.Herta_ParlorCar
|
||||
|
||||
_lang_checked = False
|
||||
|
||||
def check_lang_from_map_plane(self) -> Optional[str]:
|
||||
logger.info('check_lang_from_map_plane')
|
||||
lang_unknown = self.config.Emulator_GameLanguage == 'auto'
|
||||
|
||||
if lang_unknown:
|
||||
lang_list = VALID_LANG
|
||||
else:
|
||||
# Try current lang first
|
||||
lang_list = [server.lang] + [lang for lang in VALID_LANG if lang != server.lang]
|
||||
|
||||
for lang in lang_list:
|
||||
logger.info(f'Try ocr in lang {lang}')
|
||||
ocr = OcrPlaneName(OCR_MAP_NAME, lang=lang)
|
||||
result = ocr.ocr_single_line(self.device.image)
|
||||
keyword = ocr._match_result(result, keyword_classes=MapPlane, lang=lang)
|
||||
if keyword is not None:
|
||||
self.plane = keyword
|
||||
logger.attr('CurrentPlane', self.plane)
|
||||
logger.info(f'check_lang_from_map_plane matched lang: {lang}')
|
||||
if lang_unknown or lang != server.lang:
|
||||
self.config.Emulator_GameLanguage = lang
|
||||
return lang
|
||||
|
||||
if lang_unknown:
|
||||
logger.critical('Cannot detect in-game text language, please set it to 简体中文 or English')
|
||||
raise RequestHumanTakeover
|
||||
else:
|
||||
logger.warning(f'Cannot detect in-game text language, assume current lang={server.lang} is correct')
|
||||
return server.lang
|
||||
|
||||
def handle_lang_check(self, page: Page):
|
||||
if MainPage._lang_checked:
|
||||
return
|
||||
if page != page_main:
|
||||
return
|
||||
|
||||
self.check_lang_from_map_plane()
|
||||
MainPage._lang_checked = True
|
||||
|
||||
def acquire_lang_checked(self):
|
||||
if MainPage._lang_checked:
|
||||
return
|
||||
|
||||
logger.info('acquire_lang_checked')
|
||||
try:
|
||||
self.ui_goto(page_main)
|
||||
except AttributeError:
|
||||
logger.critical('Method ui_goto() not found, class MainPage must be inherited by class UI')
|
||||
raise ScriptError
|
||||
|
||||
self.handle_lang_check(page=page_main)
|
@ -5,13 +5,14 @@ from module.exception import GameNotRunningError, GamePageUnknownError
|
||||
from module.logger import logger
|
||||
from module.ocr.ocr import Ocr
|
||||
from tasks.base.assets.assets_base_page import CLOSE
|
||||
from tasks.base.main_page import MainPage
|
||||
from tasks.base.page import Page, page_main
|
||||
from tasks.base.popup import PopupHandler
|
||||
from tasks.combat.assets.assets_combat_finish import COMBAT_EXIT
|
||||
from tasks.combat.assets.assets_combat_prepare import COMBAT_PREPARE
|
||||
|
||||
|
||||
class UI(PopupHandler):
|
||||
class UI(PopupHandler, MainPage):
|
||||
ui_current: Page
|
||||
ui_main_confirm_timer = Timer(0.2, count=0)
|
||||
|
||||
@ -124,6 +125,7 @@ class UI(PopupHandler):
|
||||
continue
|
||||
if self.appear(page.check_button, interval=5):
|
||||
logger.info(f'Page switch: {page} -> {page.parent}')
|
||||
self.handle_lang_check(page)
|
||||
if self.ui_page_confirm(page):
|
||||
logger.info(f'Page arrive confirm {page}')
|
||||
button = page.links[page.parent]
|
||||
@ -151,6 +153,7 @@ class UI(PopupHandler):
|
||||
bool: If UI switched.
|
||||
"""
|
||||
logger.hr("UI ensure")
|
||||
self.acquire_lang_checked()
|
||||
self.ui_get_current_page(skip_first_screenshot=skip_first_screenshot)
|
||||
if self.ui_current == destination:
|
||||
logger.info("Already at %s" % destination)
|
||||
|
@ -22,7 +22,7 @@ class ForgottenHallStageOcr(Ocr):
|
||||
raw = image.copy()
|
||||
area = OCR_STAGE.area
|
||||
image = crop(raw, area)
|
||||
yellow = color_similarity_2d(image, color=(250, 201, 111))
|
||||
yellow = color_similarity_2d(image, color=(255, 200, 112))
|
||||
gray = color_similarity_2d(image, color=(100, 109, 134))
|
||||
image = np.maximum(yellow, gray)
|
||||
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
|
||||
@ -36,7 +36,7 @@ class ForgottenHallStageOcr(Ocr):
|
||||
for cont in contours:
|
||||
rect = cv2.boundingRect(cv2.convexHull(cont).astype(np.float32))
|
||||
# Filter with rectangle width, usually to be 62~64
|
||||
if not 62 - 10 < rect[2] < 62 + 10:
|
||||
if not 62 - 10 < rect[2] < 65 + 10:
|
||||
continue
|
||||
rect = (rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3])
|
||||
rect = area_offset(rect, offset=area[:2])
|
||||
@ -52,16 +52,16 @@ class ForgottenHallStageOcr(Ocr):
|
||||
boxes = self._find_number(image)
|
||||
image_list = [crop(image, area) for area in boxes]
|
||||
results = self.ocr_multi_lines(image_list)
|
||||
boxed_results = [
|
||||
results = [
|
||||
BoxedResult(area_offset(boxes[index], (-50, 0)), image_list[index], text, score)
|
||||
for index, (text, score) in enumerate(results)
|
||||
]
|
||||
results_buttons = [
|
||||
OcrResultButton(result, keyword_classes)
|
||||
for result in boxed_results
|
||||
]
|
||||
logger.attr(name=f'{self.name} matched', text=results_buttons)
|
||||
return results_buttons
|
||||
|
||||
results = [self._product_button(result, keyword_classes, ignore_digit=False) for result in results]
|
||||
results = [result for result in results if result.is_keyword_matched]
|
||||
|
||||
logger.attr(name=f'{self.name} matched', text=results)
|
||||
return results
|
||||
|
||||
|
||||
class DraggableStageList(DraggableList):
|
||||
|
Loading…
Reference in New Issue
Block a user