mirror of
https://github.com/LmeSzinc/StarRailCopilot.git
synced 2024-11-25 10:01:10 +00:00
Refactor: Migrate to pponnxcr
This commit is contained in:
parent
d33128c57d
commit
16f4f061c1
@ -9,12 +9,12 @@ from tqdm import tqdm
|
||||
from module.base.code_generator import CodeGenerator
|
||||
from module.base.utils import SelectedGrids, area_limit, area_pad, get_bbox, get_color, image_size, load_image
|
||||
from module.config.config_manual import ManualConfig as AzurLaneConfig
|
||||
from module.config.server import VALID_SERVER
|
||||
from module.config.server import VALID_LANG
|
||||
from module.config.utils import deep_get, deep_iter, deep_set, iter_folder
|
||||
from module.logger import logger
|
||||
|
||||
SHARE_SERVER = 'share'
|
||||
ASSET_SERVER = [SHARE_SERVER] + VALID_SERVER
|
||||
ASSET_SERVER = [SHARE_SERVER] + VALID_LANG
|
||||
|
||||
|
||||
class AssetsImage:
|
||||
@ -217,7 +217,7 @@ def generate_code():
|
||||
if has_share:
|
||||
servers = assets_data.keys()
|
||||
else:
|
||||
servers = VALID_SERVER
|
||||
servers = VALID_LANG
|
||||
for server in servers:
|
||||
frames = list(assets_data.get(server, {}).values())
|
||||
if len(frames) > 1:
|
||||
|
@ -3,7 +3,6 @@ from module.base.button import Button, ButtonWrapper, ClickButton, match_templat
|
||||
from module.base.timer import Timer
|
||||
from module.base.utils import *
|
||||
from module.config.config import AzurLaneConfig
|
||||
from module.config.server import set_server, to_package
|
||||
from module.device.device import Device
|
||||
from module.logger import logger
|
||||
|
||||
@ -261,22 +260,11 @@ class ModuleBase:
|
||||
|
||||
self.device.image = value
|
||||
|
||||
def set_server(self, server):
|
||||
"""
|
||||
For development.
|
||||
Change server and affect globally,
|
||||
including assets and server specific methods.
|
||||
"""
|
||||
package = to_package(server)
|
||||
self.device.package = package
|
||||
set_server(server)
|
||||
logger.attr('Server', self.config.SERVER)
|
||||
|
||||
def set_lang(self, lang):
|
||||
"""
|
||||
For development.
|
||||
Change lang and affect globally,
|
||||
including assets and server specific methods.
|
||||
"""
|
||||
server_.server = lang
|
||||
logger.attr('Language', self.config.SERVER)
|
||||
server_.set_lang(lang)
|
||||
logger.attr('Lang', self.config.LANG)
|
||||
|
@ -143,7 +143,7 @@ class ButtonWrapper(Resource):
|
||||
|
||||
@cached_property
|
||||
def buttons(self) -> t.List[Button]:
|
||||
for trial in [server.server, 'share', 'cn']:
|
||||
for trial in [server.lang, 'share', 'cn']:
|
||||
assets = self.data_buttons.get(trial, None)
|
||||
if assets is not None:
|
||||
if isinstance(assets, Button):
|
||||
@ -151,7 +151,7 @@ class ButtonWrapper(Resource):
|
||||
elif isinstance(assets, list):
|
||||
return assets
|
||||
|
||||
raise ScriptError(f'ButtonWrapper({self}) on server {server.server} has no fallback button')
|
||||
raise ScriptError(f'ButtonWrapper({self}) on server {server.lang} has no fallback button')
|
||||
|
||||
def match_color(self, image, threshold=10) -> bool:
|
||||
for assets in self.buttons:
|
||||
|
@ -74,7 +74,7 @@ class AzurLaneConfig(ConfigUpdater, ManualConfig, GeneratedConfig, ConfigWatcher
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __init__(self, config_name, task=None):
|
||||
logger.attr("Server", self.SERVER)
|
||||
logger.attr("Lang", self.LANG)
|
||||
# This will read ./config/<config_name>.json
|
||||
self.config_name = config_name
|
||||
# Raw json data in yaml file.
|
||||
|
@ -3,8 +3,8 @@ import module.config.server as server
|
||||
|
||||
class ManualConfig:
|
||||
@property
|
||||
def SERVER(self):
|
||||
return server.server
|
||||
def LANG(self):
|
||||
return server.lang
|
||||
|
||||
SCHEDULER_PRIORITY = """
|
||||
Restart
|
||||
|
@ -2,9 +2,9 @@
|
||||
This file stores server, such as 'cn', 'en'.
|
||||
Use 'import module.config.server as server' to import, don't use 'from xxx import xxx'.
|
||||
"""
|
||||
server = 'cn' # Setting default to cn, will avoid errors when using dev_tools
|
||||
lang = 'cn' # Setting default to cn, will avoid errors when using dev_tools
|
||||
|
||||
VALID_SERVER = ['cn', ]
|
||||
VALID_LANG = ['cn', 'en']
|
||||
VALID_PACKAGE = {
|
||||
'com.miHoYo.hkrpg': 'cn',
|
||||
'com.HoYoverse.hkrpgoversea': 'oversea'
|
||||
@ -14,16 +14,16 @@ VALID_CHANNEL_PACKAGE = {
|
||||
}
|
||||
|
||||
|
||||
def set_server(package_or_server: str):
|
||||
def set_lang(lang_: str):
|
||||
"""
|
||||
Change server and this will effect globally,
|
||||
including assets and server specific methods.
|
||||
Change language and this will affect globally,
|
||||
including assets and language specific methods.
|
||||
|
||||
Args:
|
||||
package_or_server: package name or server.
|
||||
lang_: package name or server.
|
||||
"""
|
||||
global server
|
||||
server = to_server(package_or_server)
|
||||
global lang
|
||||
lang = lang_
|
||||
|
||||
from module.base.resource import release_resources
|
||||
release_resources()
|
||||
|
@ -395,7 +395,7 @@ def dict_to_kv(dictionary, allow_none=True):
|
||||
|
||||
|
||||
def server_timezone() -> timedelta:
|
||||
return SERVER_TO_TIMEZONE.get(server_.server, SERVER_TO_TIMEZONE['cn'])
|
||||
return SERVER_TO_TIMEZONE.get(server_.lang, SERVER_TO_TIMEZONE['cn'])
|
||||
|
||||
|
||||
def server_time_offset() -> timedelta:
|
||||
|
@ -109,7 +109,7 @@ class Connection(ConnectionAttr):
|
||||
# else:
|
||||
# set_server(self.package)
|
||||
logger.attr('PackageName', self.package)
|
||||
logger.attr('Server', self.config.SERVER)
|
||||
logger.attr('Lang', self.config.LANG)
|
||||
|
||||
self.check_mumu_app_keep_alive()
|
||||
|
||||
|
@ -62,7 +62,7 @@ class Keyword:
|
||||
|
||||
def _keywords_to_find(self, in_current_server=False, ignore_punctuation=True):
|
||||
if in_current_server:
|
||||
match server.server:
|
||||
match server.lang:
|
||||
case 'cn':
|
||||
if ignore_punctuation:
|
||||
return [self.cn_parsed]
|
||||
|
@ -1,11 +1,70 @@
|
||||
from pponnxcr import TextSystem
|
||||
|
||||
from module.base.decorator import cached_property
|
||||
from module.ocr.ppocr import TextSystem
|
||||
from module.exception import ScriptError
|
||||
|
||||
DIC_LANG_TO_MODEL = {
|
||||
'cn': 'zhs',
|
||||
'en': 'en',
|
||||
'jp': 'ja',
|
||||
'tw': 'zht',
|
||||
}
|
||||
|
||||
|
||||
def lang2model(lang: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
lang: In-game language name, defined in VALID_LANG
|
||||
|
||||
Returns:
|
||||
str: Model name, defined in pponnxcr.utility
|
||||
"""
|
||||
return DIC_LANG_TO_MODEL.get(lang, lang)
|
||||
|
||||
|
||||
def model2lang(model: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
model: Model name, defined in pponnxcr.utility
|
||||
|
||||
Returns:
|
||||
str: In-game language name, defined in VALID_LANG
|
||||
"""
|
||||
for k, v in DIC_LANG_TO_MODEL.items():
|
||||
if model == v:
|
||||
return k
|
||||
return model
|
||||
|
||||
|
||||
class OcrModel:
|
||||
def get_by_model(self, model: str) -> TextSystem:
|
||||
try:
|
||||
return self.__getattribute__(model)
|
||||
except AttributeError:
|
||||
raise ScriptError(f'OCR model "{model}" does not exists')
|
||||
|
||||
def get_by_lang(self, lang: str) -> TextSystem:
|
||||
try:
|
||||
model = lang2model(lang)
|
||||
return self.__getattribute__(model)
|
||||
except AttributeError:
|
||||
raise ScriptError(f'OCR model under lang "{lang}" does not exists')
|
||||
|
||||
@cached_property
|
||||
def ch(self):
|
||||
return TextSystem()
|
||||
def zhs(self):
|
||||
return TextSystem('zhs')
|
||||
|
||||
@cached_property
|
||||
def en(self):
|
||||
return TextSystem('en')
|
||||
|
||||
@cached_property
|
||||
def ja(self):
|
||||
return TextSystem('zht')
|
||||
|
||||
@cached_property
|
||||
def zht(self):
|
||||
return TextSystem('zht')
|
||||
|
||||
|
||||
OCR_MODEL = OcrModel()
|
||||
|
@ -4,6 +4,7 @@ from datetime import timedelta
|
||||
|
||||
import cv2
|
||||
from ppocronnx.predict_system import BoxedResult
|
||||
from pponnxcr import TextSystem
|
||||
|
||||
import module.config.server as server
|
||||
from module.base.button import ButtonWrapper
|
||||
@ -12,7 +13,6 @@ from module.base.utils import area_pad, corner2area, crop, float2str
|
||||
from module.exception import ScriptError
|
||||
from module.logger import logger
|
||||
from module.ocr.models import OCR_MODEL
|
||||
from module.ocr.ppocr import TextSystem
|
||||
from module.ocr.utils import merge_buttons
|
||||
|
||||
|
||||
@ -21,6 +21,8 @@ def enlarge_canvas(image):
|
||||
Enlarge image into a square fill with black background. In the structure of PaddleOCR,
|
||||
image with w:h=1:1 is the best while 3:1 rectangles takes three times as long.
|
||||
Also enlarge into the integer multiple of 32 cause PaddleOCR will downscale images to 1/32.
|
||||
|
||||
No longer needed, already included in pponnxcr.
|
||||
"""
|
||||
height, width = image.shape[:2]
|
||||
length = int(max(width, height) // 32 * 32 + 32)
|
||||
@ -95,27 +97,24 @@ class Ocr:
|
||||
merge_thres_y = 0
|
||||
|
||||
def __init__(self, button: ButtonWrapper, lang=None, name=None):
|
||||
"""
|
||||
Args:
|
||||
button:
|
||||
lang: If None, use in-game language
|
||||
name: If None, use button.name
|
||||
"""
|
||||
if lang is None:
|
||||
lang = server.lang
|
||||
if name is None:
|
||||
name = button.name
|
||||
|
||||
self.button: ButtonWrapper = button
|
||||
self._lang = lang
|
||||
self.name: str = name if name is not None else button.name
|
||||
|
||||
@classmethod
|
||||
def server2lang(cls, ser=None) -> str:
|
||||
if ser is None:
|
||||
ser = server.server
|
||||
match ser:
|
||||
case 'cn':
|
||||
return 'ch'
|
||||
case _:
|
||||
return 'ch'
|
||||
|
||||
@cached_property
|
||||
def lang(self) -> str:
|
||||
return self._lang if self._lang is not None else Ocr.server2lang()
|
||||
self.lang: str = lang
|
||||
self.name: str = name
|
||||
|
||||
@cached_property
|
||||
def model(self) -> TextSystem:
|
||||
return OCR_MODEL.__getattribute__(self.lang)
|
||||
return OCR_MODEL.get_by_lang(self.lang)
|
||||
|
||||
def pre_process(self, image):
|
||||
"""
|
||||
@ -188,7 +187,7 @@ class Ocr:
|
||||
image = crop(image, self.button.area)
|
||||
image = self.pre_process(image)
|
||||
# ocr
|
||||
image = enlarge_canvas(image)
|
||||
# image = enlarge_canvas(image)
|
||||
results: list[BoxedResult] = self.model.detect_and_ocr(image)
|
||||
# after proces
|
||||
for result in results:
|
||||
@ -235,7 +234,7 @@ class Ocr:
|
||||
|
||||
|
||||
class Digit(Ocr):
|
||||
def __init__(self, button: ButtonWrapper, lang='ch', name=None):
|
||||
def __init__(self, button: ButtonWrapper, lang='en', name=None):
|
||||
super().__init__(button, lang=lang, name=name)
|
||||
|
||||
def format_result(self, result) -> int:
|
||||
@ -255,7 +254,7 @@ class Digit(Ocr):
|
||||
|
||||
|
||||
class DigitCounter(Ocr):
|
||||
def __init__(self, button: ButtonWrapper, lang='ch', name=None):
|
||||
def __init__(self, button: ButtonWrapper, lang='en', name=None):
|
||||
super().__init__(button, lang=lang, name=name)
|
||||
|
||||
def format_result(self, result) -> tuple[int, int, int]:
|
||||
@ -283,7 +282,7 @@ class Duration(Ocr):
|
||||
@cached_property
|
||||
def timedelta_regex(self):
|
||||
regex_str = {
|
||||
'ch': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?',
|
||||
'cn': r'\D*((?P<days>\d{1,2})天)?((?P<hours>\d{1,2})小时)?((?P<minutes>\d{1,2})分钟)?((?P<seconds>\d{1,2})秒)?',
|
||||
'en': r'\D*((?P<days>\d{1,2})d\s*)?((?P<hours>\d{1,2})h\s*)?((?P<minutes>\d{1,2})m\s*)?((?P<seconds>\d{1,2})s)?'
|
||||
}[self.lang]
|
||||
return re.compile(regex_str)
|
||||
|
@ -1,56 +0,0 @@
|
||||
import ppocronnx.predict_system
|
||||
|
||||
|
||||
class TextSystem(ppocronnx.predict_system.TextSystem):
|
||||
def __init__(
|
||||
self,
|
||||
use_angle_cls=False,
|
||||
box_thresh=0.6,
|
||||
unclip_ratio=1.6,
|
||||
rec_model_path=None,
|
||||
det_model_path=None,
|
||||
ort_providers=None
|
||||
):
|
||||
super().__init__(
|
||||
use_angle_cls=use_angle_cls,
|
||||
box_thresh=box_thresh,
|
||||
unclip_ratio=unclip_ratio,
|
||||
rec_model_path=rec_model_path,
|
||||
det_model_path=det_model_path,
|
||||
ort_providers=ort_providers
|
||||
)
|
||||
|
||||
# def ocr_single_line(self, img):
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
# return super().ocr_single_line(img)
|
||||
#
|
||||
# def detect_and_ocr(self, img: np.ndarray,**kwargs):
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
# return super().detect_and_ocr(img, **kwargs)
|
||||
|
||||
|
||||
def sorted_boxes(dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
# sorted_boxes() from PaddleOCR 2.6, newer and better than the one in ppocr-onnx
|
||||
ppocronnx.predict_system.sorted_boxes = sorted_boxes
|
@ -62,8 +62,6 @@ class AssignmentOcr(Ocr):
|
||||
if matched is None:
|
||||
return result
|
||||
keyword_lang = self.lang
|
||||
if self.lang == 'ch':
|
||||
keyword_lang = 'cn'
|
||||
matched = getattr(KEYWORDS_ASSIGNMENT_ENTRY, matched.lastgroup)
|
||||
matched = getattr(matched, keyword_lang)
|
||||
logger.attr(name=f'{self.name} after_process',
|
||||
|
@ -86,7 +86,7 @@ class UI(PopupHandler):
|
||||
logger.warning("Unknown ui page")
|
||||
logger.attr("EMULATOR__SCREENSHOT_METHOD", self.config.Emulator_ScreenshotMethod)
|
||||
logger.attr("EMULATOR__CONTROL_METHOD", self.config.Emulator_ControlMethod)
|
||||
logger.attr("SERVER", self.config.SERVER)
|
||||
logger.attr("Lang", self.config.LANG)
|
||||
logger.warning("Starting from current page is not supported")
|
||||
logger.warning(f"Supported page: {[str(page) for page in Page.iter_pages()]}")
|
||||
logger.warning('Supported page: Any page with a "HOME" button on the upper-right')
|
||||
|
@ -68,7 +68,7 @@ SWITCH_BATTLE_PASS_MISSION_TAB.add_state(
|
||||
class BattlePassQuestOcr(Ocr):
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
if self.lang == 'ch':
|
||||
if self.lang == 'cn':
|
||||
result = re.sub("[jJ]", "」", result)
|
||||
return result
|
||||
|
||||
|
@ -28,7 +28,7 @@ class DailyQuestOcr(Ocr):
|
||||
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
if self.lang == 'ch':
|
||||
if self.lang == 'cn':
|
||||
result = result.replace("J", "」")
|
||||
result = result.replace(";", "」")
|
||||
result = result.replace("了", "」")
|
||||
|
@ -59,7 +59,7 @@ class OcrDungeonNav(Ocr):
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
result = result.replace('#', '')
|
||||
if self.lang == 'ch':
|
||||
if self.lang == 'cn':
|
||||
result = result.replace('萼喜', '萼')
|
||||
result = result.replace('带', '滞') # 凝带虚影
|
||||
return result
|
||||
@ -68,7 +68,7 @@ class OcrDungeonNav(Ocr):
|
||||
class OcrDungeonList(Ocr):
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
if self.lang == 'ch':
|
||||
if self.lang == 'cn':
|
||||
result = result.replace('翼', '巽') # 巽风之形
|
||||
result = result.replace('皖A0', '50').replace('皖', '')
|
||||
return result
|
||||
|
@ -32,7 +32,7 @@ class OcrMapPlane(Ocr):
|
||||
def after_process(self, result):
|
||||
result = super().after_process(result)
|
||||
result = re.sub(r'[+→★“”,.,、。]', '', result).strip()
|
||||
if self.lang == 'ch':
|
||||
if self.lang == 'cn':
|
||||
result = result.replace('迎星港', '迴星港')
|
||||
if result == '星港':
|
||||
result = '迴星港'
|
||||
|
Loading…
Reference in New Issue
Block a user