mirror of
https://github.com/Xtao-Labs/iShotaBot.git
synced 2024-11-21 14:48:23 +00:00
feat: predict face
This commit is contained in:
parent
966acd108e
commit
c6d04539ae
@ -30,3 +30,7 @@ mys_cookie = ABCD
|
|||||||
[bsky]
|
[bsky]
|
||||||
username = 111
|
username = 111
|
||||||
password = 11
|
password = 11
|
||||||
|
|
||||||
|
[predict]
|
||||||
|
url = ""
|
||||||
|
token = ""
|
||||||
|
@ -51,6 +51,9 @@ except ValueError:
|
|||||||
# bsky
|
# bsky
|
||||||
bsky_username = config.get("bsky", "username", fallback="")
|
bsky_username = config.get("bsky", "username", fallback="")
|
||||||
bsky_password = config.get("bsky", "password", fallback="")
|
bsky_password = config.get("bsky", "password", fallback="")
|
||||||
|
# predict
|
||||||
|
predict_url = config.get("predict", "url", fallback="")
|
||||||
|
predict_token = config.get("predict", "token", fallback="")
|
||||||
try:
|
try:
|
||||||
ipv6 = bool(strtobool(ipv6))
|
ipv6 = bool(strtobool(ipv6))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
85
defs/predict.py
Normal file
85
defs/predict.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import io
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from typing import Tuple, List, Optional, BinaryIO
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from defs.glover import predict_url, predict_token
|
||||||
|
|
||||||
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
headers = {"x-token": predict_token}
|
||||||
|
request = httpx.AsyncClient(timeout=60.0, headers=headers, verify=False)
|
||||||
|
|
||||||
|
|
||||||
|
class FacialAreaRegion(BaseModel):
|
||||||
|
x: float
|
||||||
|
y: float
|
||||||
|
w: float
|
||||||
|
h: float
|
||||||
|
left_eye: Tuple[int, int]
|
||||||
|
right_eye: Tuple[int, int]
|
||||||
|
confidence: float
|
||||||
|
|
||||||
|
|
||||||
|
class Result(BaseModel):
|
||||||
|
code: int
|
||||||
|
msg: str
|
||||||
|
faces: List[FacialAreaRegion] = []
|
||||||
|
|
||||||
|
|
||||||
|
class Face(BaseModel):
|
||||||
|
predict_time: float
|
||||||
|
draw_time: float
|
||||||
|
|
||||||
|
|
||||||
|
async def predict_photo(img_byte_arr: BinaryIO) -> Optional[Result]:
|
||||||
|
files = {"file": ("image.png", img_byte_arr, "image/png")}
|
||||||
|
try:
|
||||||
|
req = await request.post(predict_url, files=files, headers=headers)
|
||||||
|
return Result(**req.json())
|
||||||
|
except Exception:
|
||||||
|
traceback.print_exc()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def predict(file: BinaryIO) -> Tuple[Optional[Face], Optional[BinaryIO]]:
|
||||||
|
image = Image.open(file)
|
||||||
|
file.seek(0) # 重置指针到开始位置
|
||||||
|
time1 = time.time()
|
||||||
|
data = await predict_photo(file)
|
||||||
|
time2 = time.time()
|
||||||
|
if not data or not data.faces:
|
||||||
|
return None, None
|
||||||
|
for face in data.faces:
|
||||||
|
# 框出人脸
|
||||||
|
draw = ImageDraw.Draw(image)
|
||||||
|
x1, y1 = face.x, face.y
|
||||||
|
x2, y2 = x1 + face.w, y1 + face.h
|
||||||
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
|
||||||
|
# 画出眼睛
|
||||||
|
draw.ellipse(
|
||||||
|
[
|
||||||
|
face.left_eye[0] - 2,
|
||||||
|
face.left_eye[1] - 2,
|
||||||
|
face.left_eye[0] + 2,
|
||||||
|
face.left_eye[1] + 2,
|
||||||
|
],
|
||||||
|
fill="red",
|
||||||
|
)
|
||||||
|
draw.ellipse(
|
||||||
|
[
|
||||||
|
face.right_eye[0] - 2,
|
||||||
|
face.right_eye[1] - 2,
|
||||||
|
face.right_eye[0] + 2,
|
||||||
|
face.right_eye[1] + 2,
|
||||||
|
],
|
||||||
|
fill="red",
|
||||||
|
)
|
||||||
|
binary_io = io.BytesIO()
|
||||||
|
image.save(binary_io, "JPEG")
|
||||||
|
binary_io.seek(0)
|
||||||
|
time3 = time.time()
|
||||||
|
return Face(predict_time=time2 - time1, draw_time=time3 - time2), binary_io
|
27
modules/predict.py
Normal file
27
modules/predict.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from pyrogram import Client, filters
|
||||||
|
from pyrogram.types import Message
|
||||||
|
|
||||||
|
from defs.predict import predict
|
||||||
|
from init import bot
|
||||||
|
|
||||||
|
|
||||||
|
@bot.on_message(
|
||||||
|
filters.incoming & filters.command(["predict", f"predict@{bot.me.username}"])
|
||||||
|
)
|
||||||
|
async def predict_command(_: Client, message: Message):
|
||||||
|
r = message
|
||||||
|
if message.reply_to_message and message.reply_to_message.photo:
|
||||||
|
r = message.reply_to_message
|
||||||
|
if not r.photo:
|
||||||
|
return await message.reply("请发送/回复一张图片")
|
||||||
|
time1 = time.time()
|
||||||
|
file = await r.download(in_memory=True)
|
||||||
|
download_time = time.time()
|
||||||
|
face, image = await predict(file)
|
||||||
|
if face and image:
|
||||||
|
text = f"下载耗时: {download_time - time1:.2f}s\n预测耗时: {face.predict_time:.2f}s\n绘制耗时: {face.draw_time:.2f}s"
|
||||||
|
await message.reply_photo(image, caption=text)
|
||||||
|
else:
|
||||||
|
await message.reply("未检测到人脸")
|
Loading…
Reference in New Issue
Block a user