feat: predict face

This commit is contained in:
xtaodada 2024-10-30 00:05:59 +08:00
parent 966acd108e
commit c6d04539ae
Signed by: xtaodada
GPG Key ID: 4CBB3F4FA8C85659
4 changed files with 119 additions and 0 deletions

View File

@ -30,3 +30,7 @@ mys_cookie = ABCD
[bsky] [bsky]
username = 111 username = 111
password = 11 password = 11
[predict]
url = ""
token = ""

View File

@ -51,6 +51,9 @@ except ValueError:
# bsky # bsky
bsky_username = config.get("bsky", "username", fallback="") bsky_username = config.get("bsky", "username", fallback="")
bsky_password = config.get("bsky", "password", fallback="") bsky_password = config.get("bsky", "password", fallback="")
# predict
predict_url = config.get("predict", "url", fallback="")
predict_token = config.get("predict", "token", fallback="")
try: try:
ipv6 = bool(strtobool(ipv6)) ipv6 = bool(strtobool(ipv6))
except ValueError: except ValueError:

85
defs/predict.py Normal file
View File

@ -0,0 +1,85 @@
import io
import time
import traceback
from typing import Tuple, List, Optional, BinaryIO
import httpx
from pydantic import BaseModel
from defs.glover import predict_url, predict_token
from PIL import Image, ImageDraw
headers = {"x-token": predict_token}
request = httpx.AsyncClient(timeout=60.0, headers=headers, verify=False)
class FacialAreaRegion(BaseModel):
x: float
y: float
w: float
h: float
left_eye: Tuple[int, int]
right_eye: Tuple[int, int]
confidence: float
class Result(BaseModel):
code: int
msg: str
faces: List[FacialAreaRegion] = []
class Face(BaseModel):
predict_time: float
draw_time: float
async def predict_photo(img_byte_arr: BinaryIO) -> Optional[Result]:
files = {"file": ("image.png", img_byte_arr, "image/png")}
try:
req = await request.post(predict_url, files=files, headers=headers)
return Result(**req.json())
except Exception:
traceback.print_exc()
return None
async def predict(file: BinaryIO) -> Tuple[Optional[Face], Optional[BinaryIO]]:
image = Image.open(file)
file.seek(0) # 重置指针到开始位置
time1 = time.time()
data = await predict_photo(file)
time2 = time.time()
if not data or not data.faces:
return None, None
for face in data.faces:
# 框出人脸
draw = ImageDraw.Draw(image)
x1, y1 = face.x, face.y
x2, y2 = x1 + face.w, y1 + face.h
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
# 画出眼睛
draw.ellipse(
[
face.left_eye[0] - 2,
face.left_eye[1] - 2,
face.left_eye[0] + 2,
face.left_eye[1] + 2,
],
fill="red",
)
draw.ellipse(
[
face.right_eye[0] - 2,
face.right_eye[1] - 2,
face.right_eye[0] + 2,
face.right_eye[1] + 2,
],
fill="red",
)
binary_io = io.BytesIO()
image.save(binary_io, "JPEG")
binary_io.seek(0)
time3 = time.time()
return Face(predict_time=time2 - time1, draw_time=time3 - time2), binary_io

27
modules/predict.py Normal file
View File

@ -0,0 +1,27 @@
import time
from pyrogram import Client, filters
from pyrogram.types import Message
from defs.predict import predict
from init import bot
@bot.on_message(
filters.incoming & filters.command(["predict", f"predict@{bot.me.username}"])
)
async def predict_command(_: Client, message: Message):
r = message
if message.reply_to_message and message.reply_to_message.photo:
r = message.reply_to_message
if not r.photo:
return await message.reply("请发送/回复一张图片")
time1 = time.time()
file = await r.download(in_memory=True)
download_time = time.time()
face, image = await predict(file)
if face and image:
text = f"下载耗时: {download_time - time1:.2f}s\n预测耗时: {face.predict_time:.2f}s\n绘制耗时: {face.draw_time:.2f}s"
await message.reply_photo(image, caption=text)
else:
await message.reply("未检测到人脸")