Openaichat 限制上下文长度,并支持获取已记录的聊天话题 (#110)

* 限制上下文长度,并支持获取已记录的聊天话题

* bump version

* 过滤AI输出的开头空符
This commit is contained in:
jiechus 2022-12-22 12:11:53 +08:00 committed by GitHub
parent a06a9faf0e
commit fef7d425b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 9 deletions

View File

@ -862,13 +862,13 @@
},
{
"name": "openaichat",
"version": "1.1",
"version": "1.2",
"section": "chat",
"maintainer": "jiechus",
"size": "4.857 kb",
"size": "5.131 kb",
"supported": true,
"des-short": "openaichat",
"des": "openaichat"
"des-short": "openaichat 使用 OpenAI Chat 聊天",
"des": "openaichat 使用 OpenAI Chat 聊天\n基于 text-davinci-003 模型,与 ChatGPT 的效果有些许不同\n代码参考了原先的 ChatGPT 插件"
},
{
"name": "copy_sticker_set",

View File

@ -1,5 +1,6 @@
import contextlib
import threading
import re
from collections import defaultdict
@ -22,7 +23,7 @@ async def get_chat_response(prompt: str) -> str:
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.6,
stop=[" Human:", " AI:"]
stop=["Human: ", "AI: "]
).choices[0].text
@ -54,6 +55,7 @@ def get_template() -> str:
def formatted_response(prompt: str, message: str) -> str:
if not get_template():
set_template(default_template)
message = re.sub(r'^\s+', r'', message)
try:
return get_template().format(prompt, message)
except Exception:
@ -67,6 +69,7 @@ chat_bot_help = "使用 OpenAI Chat 聊天\n" \
"代码参考了原先的 ChatGPT 插件\n\n" \
"参数:\n\n- 问题:询问 ai\n" \
"- reset重置聊天话题\n" \
"- thread获取已记录的聊天话题\n" \
"- set <api_key>:设置 OpenAI API Key获取 API Key https://beta.openai.com/account/api-keys \n" \
"- del删除 OpenAI API Key\n" \
"- template {set|get|reset} <template>: 设置/获取/重置回应模板。回应模板中的 {0} 将替换为问题,{1} 将替换为回答"
@ -107,7 +110,9 @@ async def chat_bot_func(message: Message):
elif message.arguments == "reset":
with contextlib.suppress(KeyError):
del chat_bot_session[from_id]
return await message.edit("已重置聊天状态。")
return await message.edit("已重置聊天话题。")
elif message.arguments == "thread":
return await message.edit(chat_bot_session.get(from_id, {}).get("chat_thread", "没有已记录的聊天话题。"))
elif message.arguments == "del":
if not get_api_key():
return await message.edit("没有设置 API Key。")
@ -119,12 +124,12 @@ async def chat_bot_func(message: Message):
with contextlib.suppress(Exception):
message: Message = await message.edit(formatted_response(message.arguments, "处理中..."))
try:
chat_thread = chat_bot_session[from_id]["chat_thread"] if chat_bot_session[from_id] else ""
prompt = f"{chat_thread}\n Human:{message.arguments}\n AI:"
chat_thread = chat_bot_session.get(from_id, {}).get("chat_thread", "")
prompt = f"{chat_thread}\nHuman: {message.arguments}\nAI: "[-3946:] # 4096 - 150(max_tokens)
msg = await get_chat_response(prompt)
chat_bot_session[from_id]["chat_thread"] = prompt + msg
except Exception as e:
msg = f"可能是 API Key 过期,请重新设置。\n{repr(e)}"
msg = f"可能是 API Key 过期或网络/输入错误,请重新设置。\n{repr(e)}"
if not msg:
msg = "无法获取到回复,可能是网络波动,请稍后再试。"
with contextlib.suppress(Exception):