From c6d04539ae7ad8222e6664e6cd3f1649a99fa8b4 Mon Sep 17 00:00:00 2001 From: xtaodada Date: Wed, 30 Oct 2024 00:05:59 +0800 Subject: [PATCH] feat: predict face --- config.gen.ini | 4 +++ defs/glover.py | 3 ++ defs/predict.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++ modules/predict.py | 27 +++++++++++++++ 4 files changed, 119 insertions(+) create mode 100644 defs/predict.py create mode 100644 modules/predict.py diff --git a/config.gen.ini b/config.gen.ini index 6d0faf7..bc27797 100644 --- a/config.gen.ini +++ b/config.gen.ini @@ -30,3 +30,7 @@ mys_cookie = ABCD [bsky] username = 111 password = 11 + +[predict] +url = "" +token = "" diff --git a/defs/glover.py b/defs/glover.py index 81b5b9a..7fff272 100644 --- a/defs/glover.py +++ b/defs/glover.py @@ -51,6 +51,9 @@ except ValueError: # bsky bsky_username = config.get("bsky", "username", fallback="") bsky_password = config.get("bsky", "password", fallback="") +# predict +predict_url = config.get("predict", "url", fallback="") +predict_token = config.get("predict", "token", fallback="") try: ipv6 = bool(strtobool(ipv6)) except ValueError: diff --git a/defs/predict.py b/defs/predict.py new file mode 100644 index 0000000..1b51f08 --- /dev/null +++ b/defs/predict.py @@ -0,0 +1,85 @@ +import io +import time +import traceback +from typing import Tuple, List, Optional, BinaryIO + +import httpx +from pydantic import BaseModel + +from defs.glover import predict_url, predict_token + +from PIL import Image, ImageDraw + +headers = {"x-token": predict_token} +request = httpx.AsyncClient(timeout=60.0, headers=headers, verify=False) + + +class FacialAreaRegion(BaseModel): + x: float + y: float + w: float + h: float + left_eye: Tuple[int, int] + right_eye: Tuple[int, int] + confidence: float + + +class Result(BaseModel): + code: int + msg: str + faces: List[FacialAreaRegion] = [] + + +class Face(BaseModel): + predict_time: float + draw_time: float + + +async def predict_photo(img_byte_arr: BinaryIO) -> Optional[Result]: + files = {"file": ("image.png", img_byte_arr, "image/png")} + try: + req = await request.post(predict_url, files=files, headers=headers) + return Result(**req.json()) + except Exception: + traceback.print_exc() + return None + + +async def predict(file: BinaryIO) -> Tuple[Optional[Face], Optional[BinaryIO]]: + image = Image.open(file) + file.seek(0) # 重置指针到开始位置 + time1 = time.time() + data = await predict_photo(file) + time2 = time.time() + if not data or not data.faces: + return None, None + for face in data.faces: + # 框出人脸 + draw = ImageDraw.Draw(image) + x1, y1 = face.x, face.y + x2, y2 = x1 + face.w, y1 + face.h + draw.rectangle([x1, y1, x2, y2], outline="red", width=2) + # 画出眼睛 + draw.ellipse( + [ + face.left_eye[0] - 2, + face.left_eye[1] - 2, + face.left_eye[0] + 2, + face.left_eye[1] + 2, + ], + fill="red", + ) + draw.ellipse( + [ + face.right_eye[0] - 2, + face.right_eye[1] - 2, + face.right_eye[0] + 2, + face.right_eye[1] + 2, + ], + fill="red", + ) + binary_io = io.BytesIO() + image.save(binary_io, "JPEG") + binary_io.seek(0) + time3 = time.time() + return Face(predict_time=time2 - time1, draw_time=time3 - time2), binary_io diff --git a/modules/predict.py b/modules/predict.py new file mode 100644 index 0000000..a263f08 --- /dev/null +++ b/modules/predict.py @@ -0,0 +1,27 @@ +import time + +from pyrogram import Client, filters +from pyrogram.types import Message + +from defs.predict import predict +from init import bot + + +@bot.on_message( + filters.incoming & filters.command(["predict", f"predict@{bot.me.username}"]) +) +async def predict_command(_: Client, message: Message): + r = message + if message.reply_to_message and message.reply_to_message.photo: + r = message.reply_to_message + if not r.photo: + return await message.reply("请发送/回复一张图片") + time1 = time.time() + file = await r.download(in_memory=True) + download_time = time.time() + face, image = await predict(file) + if face and image: + text = f"下载耗时: {download_time - time1:.2f}s\n预测耗时: {face.predict_time:.2f}s\n绘制耗时: {face.draw_time:.2f}s" + await message.reply_photo(image, caption=text) + else: + await message.reply("未检测到人脸")