Refactor: Migrate to pponnxcr

This commit is contained in:
LmeSzinc 2023-09-08 22:23:57 +08:00
parent d33128c57d
commit 16f4f061c1
18 changed files with 110 additions and 122 deletions

View File

@ -9,12 +9,12 @@ from tqdm import tqdm
from module.base.code_generator import CodeGenerator
from module.base.utils import SelectedGrids, area_limit, area_pad, get_bbox, get_color, image_size, load_image
from module.config.config_manual import ManualConfig as AzurLaneConfig
from module.config.server import VALID_SERVER
from module.config.server import VALID_LANG
from module.config.utils import deep_get, deep_iter, deep_set, iter_folder
from module.logger import logger
SHARE_SERVER = 'share'
ASSET_SERVER = [SHARE_SERVER] + VALID_SERVER
ASSET_SERVER = [SHARE_SERVER] + VALID_LANG
class AssetsImage:
@ -217,7 +217,7 @@ def generate_code():
if has_share:
servers = assets_data.keys()
else:
servers = VALID_SERVER
servers = VALID_LANG
for server in servers:
frames = list(assets_data.get(server, {}).values())
if len(frames) > 1:

View File

@ -3,7 +3,6 @@ from module.base.button import Button, ButtonWrapper, ClickButton, match_templat
from module.base.timer import Timer
from module.base.utils import *
from module.config.config import AzurLaneConfig
from module.config.server import set_server, to_package
from module.device.device import Device
from module.logger import logger
@ -261,22 +260,11 @@ class ModuleBase:
self.device.image = value
def set_server(self, server):
"""
For development.
Change server and affect globally,
including assets and server specific methods.
"""
package = to_package(server)
self.device.package = package
set_server(server)
logger.attr('Server', self.config.SERVER)
def set_lang(self, lang):
"""
For development.
Change lang and affect globally,
including assets and server specific methods.
"""
server_.server = lang
logger.attr('Language', self.config.SERVER)
server_.set_lang(lang)
logger.attr('Lang', self.config.LANG)

View File

@ -143,7 +143,7 @@ class ButtonWrapper(Resource):
@cached_property
def buttons(self) -> t.List[Button]:
for trial in [server.server, 'share', 'cn']:
for trial in [server.lang, 'share', 'cn']:
assets = self.data_buttons.get(trial, None)
if assets is not None:
if isinstance(assets, Button):
@ -151,7 +151,7 @@ class ButtonWrapper(Resource):
elif isinstance(assets, list):
return assets
raise ScriptError(f'ButtonWrapper({self}) on server {server.server} has no fallback button')
raise ScriptError(f'ButtonWrapper({self}) on server {server.lang} has no fallback button')
def match_color(self, image, threshold=10) -> bool:
for assets in self.buttons:

View File

@ -74,7 +74,7 @@ class AzurLaneConfig(ConfigUpdater, ManualConfig, GeneratedConfig, ConfigWatcher
super().__setattr__(key, value)
def __init__(self, config_name, task=None):
logger.attr("Server", self.SERVER)
logger.attr("Lang", self.LANG)
# This will read ./config/<config_name>.json
self.config_name = config_name
# Raw json data in yaml file.

View File

@ -3,8 +3,8 @@ import module.config.server as server
class ManualConfig:
@property
def SERVER(self):
return server.server
def LANG(self):
return server.lang
SCHEDULER_PRIORITY = """
Restart

View File

@ -2,9 +2,9 @@
This file stores server, such as 'cn', 'en'.
Use 'import module.config.server as server' to import, don't use 'from xxx import xxx'.
"""
server = 'cn' # Setting default to cn, will avoid errors when using dev_tools
lang = 'cn' # Setting default to cn, will avoid errors when using dev_tools
VALID_SERVER = ['cn', ]
VALID_LANG = ['cn', 'en']
VALID_PACKAGE = {
'com.miHoYo.hkrpg': 'cn',
'com.HoYoverse.hkrpgoversea': 'oversea'
@ -14,16 +14,16 @@ VALID_CHANNEL_PACKAGE = {
}
def set_server(package_or_server: str):
def set_lang(lang_: str):
"""
Change server and this will effect globally,
including assets and server specific methods.
Change language and this will affect globally,
including assets and language specific methods.
Args:
package_or_server: package name or server.
lang_: package name or server.
"""
global server
server = to_server(package_or_server)
global lang
lang = lang_
from module.base.resource import release_resources
release_resources()

View File

@ -395,7 +395,7 @@ def dict_to_kv(dictionary, allow_none=True):
def server_timezone() -> timedelta:
return SERVER_TO_TIMEZONE.get(server_.server, SERVER_TO_TIMEZONE['cn'])
return SERVER_TO_TIMEZONE.get(server_.lang, SERVER_TO_TIMEZONE['cn'])
def server_time_offset() -> timedelta:

View File

@ -109,7 +109,7 @@ class Connection(ConnectionAttr):
# else:
# set_server(self.package)
logger.attr('PackageName', self.package)
logger.attr('Server', self.config.SERVER)
logger.attr('Lang', self.config.LANG)
self.check_mumu_app_keep_alive()

View File

@ -62,7 +62,7 @@ class Keyword:
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True):
if in_current_server:
match server.server:
match server.lang:
case 'cn':
if ignore_punctuation:
return [self.cn_parsed]

View File

@ -1,11 +1,70 @@
from pponnxcr import TextSystem
from module.base.decorator import cached_property
from module.ocr.ppocr import TextSystem
from module.exception import ScriptError
DIC_LANG_TO_MODEL = {
'cn': 'zhs',
'en': 'en',
'jp': 'ja',
'tw': 'zht',
}
def lang2model(lang: str) -> str:
"""
Args:
lang: In-game language name, defined in VALID_LANG
Returns:
str: Model name, defined in pponnxcr.utility
"""
return DIC_LANG_TO_MODEL.get(lang, lang)
def model2lang(model: str) -> str:
"""
Args:
model: Model name, defined in pponnxcr.utility
Returns:
str: In-game language name, defined in VALID_LANG
"""
for k, v in DIC_LANG_TO_MODEL.items():
if model == v:
return k
return model
class OcrModel:
def get_by_model(self, model: str) -> TextSystem:
try:
return self.__getattribute__(model)
except AttributeError:
raise ScriptError(f'OCR model "{model}" does not exists')
def get_by_lang(self, lang: str) -> TextSystem:
try:
model = lang2model(lang)
return self.__getattribute__(model)
except AttributeError:
raise ScriptError(f'OCR model under lang "{lang}" does not exists')
@cached_property
def ch(self):
return TextSystem()
def zhs(self):
return TextSystem('zhs')
@cached_property
def en(self):
return TextSystem('en')
@cached_property
def ja(self):
return TextSystem('zht')
@cached_property
def zht(self):
return TextSystem('zht')
OCR_MODEL = OcrModel()

View File

@ -4,6 +4,7 @@ from datetime import timedelta
import cv2
from ppocronnx.predict_system import BoxedResult
from pponnxcr import TextSystem
import module.config.server as server
from module.base.button import ButtonWrapper
@ -12,7 +13,6 @@ from module.base.utils import area_pad, corner2area, crop, float2str
from module.exception import ScriptError
from module.logger import logger
from module.ocr.models import OCR_MODEL
from module.ocr.ppocr import TextSystem
from module.ocr.utils import merge_buttons
@ -21,6 +21,8 @@ def enlarge_canvas(image):
Enlarge image into a square fill with black background. In the structure of PaddleOCR,
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
No longer needed, already included in pponnxcr.
"""
height, width = image.shape[:2]
length = int(max(width, height) // 32 * 32 + 32)
@ -95,27 +97,24 @@ class Ocr:
merge_thres_y = 0
def __init__(self, button: ButtonWrapper, lang=None, name=None):
"""
Args:
button:
lang: If None, use in-game language
name: If None, use button.name
"""
if lang is None:
lang = server.lang
if name is None:
name = button.name
self.button: ButtonWrapper = button
self._lang = lang
self.name: str = name if name is not None else button.name
@classmethod
def server2lang(cls, ser=None) -> str:
if ser is None:
ser = server.server
match ser:
case 'cn':
return 'ch'
case _:
return 'ch'
@cached_property
def lang(self) -> str:
return self._lang if self._lang is not None else Ocr.server2lang()
self.lang: str = lang
self.name: str = name
@cached_property
def model(self) -> TextSystem:
return OCR_MODEL.__getattribute__(self.lang)
return OCR_MODEL.get_by_lang(self.lang)
def pre_process(self, image):
"""
@ -188,7 +187,7 @@ class Ocr:
image = crop(image, self.button.area)
image = self.pre_process(image)
# ocr
image = enlarge_canvas(image)
# image = enlarge_canvas(image)
results: list[BoxedResult] = self.model.detect_and_ocr(image)
# after proces
for result in results:
@ -235,7 +234,7 @@ class Ocr:
class Digit(Ocr):
def __init__(self, button: ButtonWrapper, lang='ch', name=None):
def __init__(self, button: ButtonWrapper, lang='en', name=None):
super().__init__(button, lang=lang, name=name)
def format_result(self, result) -> int:
@ -255,7 +254,7 @@ class Digit(Ocr):
class DigitCounter(Ocr):
def __init__(self, button: ButtonWrapper, lang='ch', name=None):
def __init__(self, button: ButtonWrapper, lang='en', name=None):
super().__init__(button, lang=lang, name=name)
def format_result(self, result) -> tuple[int, int, int]:
@ -283,7 +282,7 @@ class Duration(Ocr):
@cached_property
def timedelta_regex(self):
regex_str = {
'ch': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?',
'cn': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?',
'en': r'\D*((?P<days>\d{1,2})d\s*)?((?P<hours>\d{1,2})h\s*)?((?P<minutes>\d{1,2})m\s*)?((?P<seconds>\d{1,2})s)?'
}[self.lang]
return re.compile(regex_str)

View File

@ -1,56 +0,0 @@
import ppocronnx.predict_system
class TextSystem(ppocronnx.predict_system.TextSystem):
def __init__(
self,
use_angle_cls=False,
box_thresh=0.6,
unclip_ratio=1.6,
rec_model_path=None,
det_model_path=None,
ort_providers=None
):
super().__init__(
use_angle_cls=use_angle_cls,
box_thresh=box_thresh,
unclip_ratio=unclip_ratio,
rec_model_path=rec_model_path,
det_model_path=det_model_path,
ort_providers=ort_providers
)
# def ocr_single_line(self, img):
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# return super().ocr_single_line(img)
#
# def detect_and_ocr(self, img: np.ndarray,**kwargs):
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
# return super().detect_and_ocr(img, **kwargs)
def sorted_boxes(dt_boxes):
"""
Sort text boxes in order from top to bottom, left to right
args:
dt_boxes(array):detected text boxes with shape [4, 2]
return:
sorted boxes(array) with shape [4, 2]
"""
num_boxes = dt_boxes.shape[0]
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
_boxes = list(sorted_boxes)
for i in range(num_boxes - 1):
for j in range(i, -1, -1):
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
tmp = _boxes[j]
_boxes[j] = _boxes[j + 1]
_boxes[j + 1] = tmp
else:
break
return _boxes
# sorted_boxes() from PaddleOCR 2.6, newer and better than the one in ppocr-onnx
ppocronnx.predict_system.sorted_boxes = sorted_boxes

View File

@ -62,8 +62,6 @@ class AssignmentOcr(Ocr):
if matched is None:
return result
keyword_lang = self.lang
if self.lang == 'ch':
keyword_lang = 'cn'
matched = getattr(KEYWORDS_ASSIGNMENT_ENTRY, matched.lastgroup)
matched = getattr(matched, keyword_lang)
logger.attr(name=f'{self.name} after_process',

View File

@ -86,7 +86,7 @@ class UI(PopupHandler):
logger.warning("Unknown ui page")
logger.attr("EMULATOR__SCREENSHOT_METHOD", self.config.Emulator_ScreenshotMethod)
logger.attr("EMULATOR__CONTROL_METHOD", self.config.Emulator_ControlMethod)
logger.attr("SERVER", self.config.SERVER)
logger.attr("Lang", self.config.LANG)
logger.warning("Starting from current page is not supported")
logger.warning(f"Supported page: {[str(page) for page in Page.iter_pages()]}")
logger.warning('Supported page: Any page with a "HOME" button on the upper-right')

View File

@ -68,7 +68,7 @@ SWITCH_BATTLE_PASS_MISSION_TAB.add_state(
class BattlePassQuestOcr(Ocr):
def after_process(self, result):
result = super().after_process(result)
if self.lang == 'ch':
if self.lang == 'cn':
result = re.sub("[jJ]", "", result)
return result

View File

@ -28,7 +28,7 @@ class DailyQuestOcr(Ocr):
def after_process(self, result):
result = super().after_process(result)
if self.lang == 'ch':
if self.lang == 'cn':
result = result.replace("J", "")
result = result.replace(";", "")
result = result.replace("", "")

View File

@ -59,7 +59,7 @@ class OcrDungeonNav(Ocr):
def after_process(self, result):
result = super().after_process(result)
result = result.replace('#', '')
if self.lang == 'ch':
if self.lang == 'cn':
result = result.replace('萼喜', '')
result = result.replace('', '') # 凝带虚影
return result
@ -68,7 +68,7 @@ class OcrDungeonNav(Ocr):
class OcrDungeonList(Ocr):
def after_process(self, result):
result = super().after_process(result)
if self.lang == 'ch':
if self.lang == 'cn':
result = result.replace('', '') # 巽风之形
result = result.replace('皖A0', '50').replace('', '')
return result

View File

@ -32,7 +32,7 @@ class OcrMapPlane(Ocr):
def after_process(self, result):
result = super().after_process(result)
result = re.sub(r'[+→★“”,.,、。]', '', result).strip()
if self.lang == 'ch':
if self.lang == 'cn':
result = result.replace('迎星港', '迴星港')
if result == '星港':
result = '迴星港'