Refactor: Migrate to pponnxcr

This commit is contained in:
LmeSzinc 2023-09-08 22:23:57 +08:00
parent d33128c57d
commit 16f4f061c1
18 changed files with 110 additions and 122 deletions

View File

@ -9,12 +9,12 @@ from tqdm import tqdm
from module.base.code_generator import CodeGenerator from module.base.code_generator import CodeGenerator
from module.base.utils import SelectedGrids, area_limit, area_pad, get_bbox, get_color, image_size, load_image from module.base.utils import SelectedGrids, area_limit, area_pad, get_bbox, get_color, image_size, load_image
from module.config.config_manual import ManualConfig as AzurLaneConfig from module.config.config_manual import ManualConfig as AzurLaneConfig
from module.config.server import VALID_SERVER from module.config.server import VALID_LANG
from module.config.utils import deep_get, deep_iter, deep_set, iter_folder from module.config.utils import deep_get, deep_iter, deep_set, iter_folder
from module.logger import logger from module.logger import logger
SHARE_SERVER = 'share' SHARE_SERVER = 'share'
ASSET_SERVER = [SHARE_SERVER] + VALID_SERVER ASSET_SERVER = [SHARE_SERVER] + VALID_LANG
class AssetsImage: class AssetsImage:
@ -217,7 +217,7 @@ def generate_code():
if has_share: if has_share:
servers = assets_data.keys() servers = assets_data.keys()
else: else:
servers = VALID_SERVER servers = VALID_LANG
for server in servers: for server in servers:
frames = list(assets_data.get(server, {}).values()) frames = list(assets_data.get(server, {}).values())
if len(frames) > 1: if len(frames) > 1:

View File

@ -3,7 +3,6 @@ from module.base.button import Button, ButtonWrapper, ClickButton, match_templat
from module.base.timer import Timer from module.base.timer import Timer
from module.base.utils import * from module.base.utils import *
from module.config.config import AzurLaneConfig from module.config.config import AzurLaneConfig
from module.config.server import set_server, to_package
from module.device.device import Device from module.device.device import Device
from module.logger import logger from module.logger import logger
@ -261,22 +260,11 @@ class ModuleBase:
self.device.image = value self.device.image = value
def set_server(self, server):
"""
For development.
Change server and affect globally,
including assets and server specific methods.
"""
package = to_package(server)
self.device.package = package
set_server(server)
logger.attr('Server', self.config.SERVER)
def set_lang(self, lang): def set_lang(self, lang):
""" """
For development. For development.
Change lang and affect globally, Change lang and affect globally,
including assets and server specific methods. including assets and server specific methods.
""" """
server_.server = lang server_.set_lang(lang)
logger.attr('Language', self.config.SERVER) logger.attr('Lang', self.config.LANG)

View File

@ -143,7 +143,7 @@ class ButtonWrapper(Resource):
@cached_property @cached_property
def buttons(self) -> t.List[Button]: def buttons(self) -> t.List[Button]:
for trial in [server.server, 'share', 'cn']: for trial in [server.lang, 'share', 'cn']:
assets = self.data_buttons.get(trial, None) assets = self.data_buttons.get(trial, None)
if assets is not None: if assets is not None:
if isinstance(assets, Button): if isinstance(assets, Button):
@ -151,7 +151,7 @@ class ButtonWrapper(Resource):
elif isinstance(assets, list): elif isinstance(assets, list):
return assets return assets
raise ScriptError(f'ButtonWrapper({self}) on server {server.server} has no fallback button') raise ScriptError(f'ButtonWrapper({self}) on server {server.lang} has no fallback button')
def match_color(self, image, threshold=10) -> bool: def match_color(self, image, threshold=10) -> bool:
for assets in self.buttons: for assets in self.buttons:

View File

@ -74,7 +74,7 @@ class AzurLaneConfig(ConfigUpdater, ManualConfig, GeneratedConfig, ConfigWatcher
super().__setattr__(key, value) super().__setattr__(key, value)
def __init__(self, config_name, task=None): def __init__(self, config_name, task=None):
logger.attr("Server", self.SERVER) logger.attr("Lang", self.LANG)
# This will read ./config/<config_name>.json # This will read ./config/<config_name>.json
self.config_name = config_name self.config_name = config_name
# Raw json data in yaml file. # Raw json data in yaml file.

View File

@ -3,8 +3,8 @@ import module.config.server as server
class ManualConfig: class ManualConfig:
@property @property
def SERVER(self): def LANG(self):
return server.server return server.lang
SCHEDULER_PRIORITY = """ SCHEDULER_PRIORITY = """
Restart Restart

View File

@ -2,9 +2,9 @@
This file stores server, such as 'cn', 'en'. This file stores server, such as 'cn', 'en'.
Use 'import module.config.server as server' to import, don't use 'from xxx import xxx'. Use 'import module.config.server as server' to import, don't use 'from xxx import xxx'.
""" """
server = 'cn' # Setting default to cn, will avoid errors when using dev_tools lang = 'cn' # Setting default to cn, will avoid errors when using dev_tools
VALID_SERVER = ['cn', ] VALID_LANG = ['cn', 'en']
VALID_PACKAGE = { VALID_PACKAGE = {
'com.miHoYo.hkrpg': 'cn', 'com.miHoYo.hkrpg': 'cn',
'com.HoYoverse.hkrpgoversea': 'oversea' 'com.HoYoverse.hkrpgoversea': 'oversea'
@ -14,16 +14,16 @@ VALID_CHANNEL_PACKAGE = {
} }
def set_server(package_or_server: str): def set_lang(lang_: str):
""" """
Change server and this will effect globally, Change language and this will affect globally,
including assets and server specific methods. including assets and language specific methods.
Args: Args:
package_or_server: package name or server. lang_: package name or server.
""" """
global server global lang
server = to_server(package_or_server) lang = lang_
from module.base.resource import release_resources from module.base.resource import release_resources
release_resources() release_resources()

View File

@ -395,7 +395,7 @@ def dict_to_kv(dictionary, allow_none=True):
def server_timezone() -> timedelta: def server_timezone() -> timedelta:
return SERVER_TO_TIMEZONE.get(server_.server, SERVER_TO_TIMEZONE['cn']) return SERVER_TO_TIMEZONE.get(server_.lang, SERVER_TO_TIMEZONE['cn'])
def server_time_offset() -> timedelta: def server_time_offset() -> timedelta:

View File

@ -109,7 +109,7 @@ class Connection(ConnectionAttr):
# else: # else:
# set_server(self.package) # set_server(self.package)
logger.attr('PackageName', self.package) logger.attr('PackageName', self.package)
logger.attr('Server', self.config.SERVER) logger.attr('Lang', self.config.LANG)
self.check_mumu_app_keep_alive() self.check_mumu_app_keep_alive()

View File

@ -62,7 +62,7 @@ class Keyword:
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True): def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True):
if in_current_server: if in_current_server:
match server.server: match server.lang:
case 'cn': case 'cn':
if ignore_punctuation: if ignore_punctuation:
return [self.cn_parsed] return [self.cn_parsed]

View File

@ -1,11 +1,70 @@
from pponnxcr import TextSystem
from module.base.decorator import cached_property from module.base.decorator import cached_property
from module.ocr.ppocr import TextSystem from module.exception import ScriptError
DIC_LANG_TO_MODEL = {
'cn': 'zhs',
'en': 'en',
'jp': 'ja',
'tw': 'zht',
}
def lang2model(lang: str) -> str:
"""
Args:
lang: In-game language name, defined in VALID_LANG
Returns:
str: Model name, defined in pponnxcr.utility
"""
return DIC_LANG_TO_MODEL.get(lang, lang)
def model2lang(model: str) -> str:
"""
Args:
model: Model name, defined in pponnxcr.utility
Returns:
str: In-game language name, defined in VALID_LANG
"""
for k, v in DIC_LANG_TO_MODEL.items():
if model == v:
return k
return model
class OcrModel: class OcrModel:
def get_by_model(self, model: str) -> TextSystem:
try:
return self.__getattribute__(model)
except AttributeError:
raise ScriptError(f'OCR model "{model}" does not exists')
def get_by_lang(self, lang: str) -> TextSystem:
try:
model = lang2model(lang)
return self.__getattribute__(model)
except AttributeError:
raise ScriptError(f'OCR model under lang "{lang}" does not exists')
@cached_property @cached_property
def ch(self): def zhs(self):
return TextSystem() return TextSystem('zhs')
@cached_property
def en(self):
return TextSystem('en')
@cached_property
def ja(self):
return TextSystem('zht')
@cached_property
def zht(self):
return TextSystem('zht')
OCR_MODEL = OcrModel() OCR_MODEL = OcrModel()

View File

@ -4,6 +4,7 @@ from datetime import timedelta
import cv2 import cv2
from ppocronnx.predict_system import BoxedResult from ppocronnx.predict_system import BoxedResult
from pponnxcr import TextSystem
import module.config.server as server import module.config.server as server
from module.base.button import ButtonWrapper from module.base.button import ButtonWrapper
@ -12,7 +13,6 @@ from module.base.utils import area_pad, corner2area, crop, float2str
from module.exception import ScriptError from module.exception import ScriptError
from module.logger import logger from module.logger import logger
from module.ocr.models import OCR_MODEL from module.ocr.models import OCR_MODEL
from module.ocr.ppocr import TextSystem
from module.ocr.utils import merge_buttons from module.ocr.utils import merge_buttons
@ -21,6 +21,8 @@ def enlarge_canvas(image):
Enlarge image into a square fill with black background. In the structure of PaddleOCR, Enlarge image into a square fill with black background. In the structure of PaddleOCR,
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long. image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32. Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
No longer needed, already included in pponnxcr.
""" """
height, width = image.shape[:2] height, width = image.shape[:2]
length = int(max(width, height) // 32 * 32 + 32) length = int(max(width, height) // 32 * 32 + 32)
@ -95,27 +97,24 @@ class Ocr:
merge_thres_y = 0 merge_thres_y = 0
def __init__(self, button: ButtonWrapper, lang=None, name=None): def __init__(self, button: ButtonWrapper, lang=None, name=None):
"""
Args:
button:
lang: If None, use in-game language
name: If None, use button.name
"""
if lang is None:
lang = server.lang
if name is None:
name = button.name
self.button: ButtonWrapper = button self.button: ButtonWrapper = button
self._lang = lang self.lang: str = lang
self.name: str = name if name is not None else button.name self.name: str = name
@classmethod
def server2lang(cls, ser=None) -> str:
if ser is None:
ser = server.server
match ser:
case 'cn':
return 'ch'
case _:
return 'ch'
@cached_property
def lang(self) -> str:
return self._lang if self._lang is not None else Ocr.server2lang()
@cached_property @cached_property
def model(self) -> TextSystem: def model(self) -> TextSystem:
return OCR_MODEL.__getattribute__(self.lang) return OCR_MODEL.get_by_lang(self.lang)
def pre_process(self, image): def pre_process(self, image):
""" """
@ -188,7 +187,7 @@ class Ocr:
image = crop(image, self.button.area) image = crop(image, self.button.area)
image = self.pre_process(image) image = self.pre_process(image)
# ocr # ocr
image = enlarge_canvas(image) # image = enlarge_canvas(image)
results: list[BoxedResult] = self.model.detect_and_ocr(image) results: list[BoxedResult] = self.model.detect_and_ocr(image)
# after proces # after proces
for result in results: for result in results:
@ -235,7 +234,7 @@ class Ocr:
class Digit(Ocr): class Digit(Ocr):
def __init__(self, button: ButtonWrapper, lang='ch', name=None): def __init__(self, button: ButtonWrapper, lang='en', name=None):
super().__init__(button, lang=lang, name=name) super().__init__(button, lang=lang, name=name)
def format_result(self, result) -> int: def format_result(self, result) -> int:
@ -255,7 +254,7 @@ class Digit(Ocr):
class DigitCounter(Ocr): class DigitCounter(Ocr):
def __init__(self, button: ButtonWrapper, lang='ch', name=None): def __init__(self, button: ButtonWrapper, lang='en', name=None):
super().__init__(button, lang=lang, name=name) super().__init__(button, lang=lang, name=name)
def format_result(self, result) -> tuple[int, int, int]: def format_result(self, result) -> tuple[int, int, int]:
@ -283,7 +282,7 @@ class Duration(Ocr):
@cached_property @cached_property
def timedelta_regex(self): def timedelta_regex(self):
regex_str = { regex_str = {
'ch': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?', 'cn': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?',
'en': r'\D*((?P<days>\d{1,2})d\s*)?((?P<hours>\d{1,2})h\s*)?((?P<minutes>\d{1,2})m\s*)?((?P<seconds>\d{1,2})s)?' 'en': r'\D*((?P<days>\d{1,2})d\s*)?((?P<hours>\d{1,2})h\s*)?((?P<minutes>\d{1,2})m\s*)?((?P<seconds>\d{1,2})s)?'
}[self.lang] }[self.lang]
return re.compile(regex_str) return re.compile(regex_str)

View File

@ -1,56 +0,0 @@
import ppocronnx.predict_system
class TextSystem(ppocronnx.predict_system.TextSystem):
def __init__(
self,
use_angle_cls=False,
box_thresh=0.6,
unclip_ratio=1.6,
rec_model_path=None,
det_model_path=None,
ort_providers=None
):
super().__init__(
use_angle_cls=use_angle_cls,
box_thresh=box_thresh,
unclip_ratio=unclip_ratio,
rec_model_path=rec_model_path,
det_model_path=det_model_path,
ort_providers=ort_providers
)
# def ocr_single_line(self, img):
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# return super().ocr_single_line(img)
#
# def detect_and_ocr(self, img: np.ndarray,**kwargs):
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# return super().detect_and_ocr(img, **kwargs)
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
# sorted_boxes() from PaddleOCR 2.6, newer and better than the one in ppocr-onnx
ppocronnx.predict_system.sorted_boxes = sorted_boxes

View File

@ -62,8 +62,6 @@ class AssignmentOcr(Ocr):
if matched is None: if matched is None:
return result return result
keyword_lang = self.lang keyword_lang = self.lang
if self.lang == 'ch':
keyword_lang = 'cn'
matched = getattr(KEYWORDS_ASSIGNMENT_ENTRY, matched.lastgroup) matched = getattr(KEYWORDS_ASSIGNMENT_ENTRY, matched.lastgroup)
matched = getattr(matched, keyword_lang) matched = getattr(matched, keyword_lang)
logger.attr(name=f'{self.name} after_process', logger.attr(name=f'{self.name} after_process',

View File

@ -86,7 +86,7 @@ class UI(PopupHandler):
logger.warning("Unknown ui page") logger.warning("Unknown ui page")
logger.attr("EMULATOR__SCREENSHOT_METHOD", self.config.Emulator_ScreenshotMethod) logger.attr("EMULATOR__SCREENSHOT_METHOD", self.config.Emulator_ScreenshotMethod)
logger.attr("EMULATOR__CONTROL_METHOD", self.config.Emulator_ControlMethod) logger.attr("EMULATOR__CONTROL_METHOD", self.config.Emulator_ControlMethod)
logger.attr("SERVER", self.config.SERVER) logger.attr("Lang", self.config.LANG)
logger.warning("Starting from current page is not supported") logger.warning("Starting from current page is not supported")
logger.warning(f"Supported page: {[str(page) for page in Page.iter_pages()]}") logger.warning(f"Supported page: {[str(page) for page in Page.iter_pages()]}")
logger.warning('Supported page: Any page with a "HOME" button on the upper-right') logger.warning('Supported page: Any page with a "HOME" button on the upper-right')

View File

@ -68,7 +68,7 @@ SWITCH_BATTLE_PASS_MISSION_TAB.add_state(
class BattlePassQuestOcr(Ocr): class BattlePassQuestOcr(Ocr):
def after_process(self, result): def after_process(self, result):
result = super().after_process(result) result = super().after_process(result)
if self.lang == 'ch': if self.lang == 'cn':
result = re.sub("[jJ]", "", result) result = re.sub("[jJ]", "", result)
return result return result

View File

@ -28,7 +28,7 @@ class DailyQuestOcr(Ocr):
def after_process(self, result): def after_process(self, result):
result = super().after_process(result) result = super().after_process(result)
if self.lang == 'ch': if self.lang == 'cn':
result = result.replace("J", "") result = result.replace("J", "")
result = result.replace(";", "") result = result.replace(";", "")
result = result.replace("", "") result = result.replace("", "")

View File

@ -59,7 +59,7 @@ class OcrDungeonNav(Ocr):
def after_process(self, result): def after_process(self, result):
result = super().after_process(result) result = super().after_process(result)
result = result.replace('#', '') result = result.replace('#', '')
if self.lang == 'ch': if self.lang == 'cn':
result = result.replace('萼喜', '') result = result.replace('萼喜', '')
result = result.replace('', '') # 凝带虚影 result = result.replace('', '') # 凝带虚影
return result return result
@ -68,7 +68,7 @@ class OcrDungeonNav(Ocr):
class OcrDungeonList(Ocr): class OcrDungeonList(Ocr):
def after_process(self, result): def after_process(self, result):
result = super().after_process(result) result = super().after_process(result)
if self.lang == 'ch': if self.lang == 'cn':
result = result.replace('', '') # 巽风之形 result = result.replace('', '') # 巽风之形
result = result.replace('皖A0', '50').replace('', '') result = result.replace('皖A0', '50').replace('', '')
return result return result

View File

@ -32,7 +32,7 @@ class OcrMapPlane(Ocr):
def after_process(self, result): def after_process(self, result):
result = super().after_process(result) result = super().after_process(result)
result = re.sub(r'[+→★“”,.,、。]', '', result).strip() result = re.sub(r'[+→★“”,.,、。]', '', result).strip()
if self.lang == 'ch': if self.lang == 'cn':
result = result.replace('迎星港', '迴星港') result = result.replace('迎星港', '迴星港')
if result == '星港': if result == '星港':
result = '迴星港' result = '迴星港'