refactor(ai): 优化上下文管理逻辑,添加请求频率限制和上下文持续时间检查

This commit is contained in:
liyp 2024-09-22 20:05:10 +08:00
parent 90e51bc485
commit 43af7b4623
3 changed files with 99 additions and 31 deletions

View file

@ -57,12 +57,12 @@ func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool
}
// 添加对话到上下文超过5条则删除所有上下文
func addToContext(rdb *redis.Client, groupID string, qqID string, message string) {
func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) {
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
// 如果上下文超过5条删除所有上下文
listLength := rdb.LLen(ctx, key).Val()
if listLength > 5 {
if listLength > context {
rdb.Del(ctx, key) // 删除该用户的所有上下文
}
rdb.RPush(ctx, key, message) // 添加新消息到列表
@ -85,23 +85,23 @@ func main() {
qqID := "67890"
rdb := initRedis()
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig()
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
oaiconfig.BaseURL = OPENAI_BaseURL
client := openai.NewClientWithConfig(oaiconfig)
messages := make([]openai.ChatCompletionMessage, 0)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: PROMPT,
})
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")
for {
fmt.Print("-> ")
// 检查请求频率
if !checkRequestFrequency(rdb, groupID, qqID) {
// fmt.Println("请求太频繁,请稍后再试。")
return
}
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
@ -110,9 +110,13 @@ func main() {
Role: openai.ChatMessageRoleUser,
Content: text,
})
// 输出 messages 的内容
for i, msg := range messages {
fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content)
}
resp, err := client.CreateChatCompletion(
ctx,
context.Background(),
openai.ChatCompletionRequest{
Model: MODEL,
Messages: messages,
@ -130,10 +134,15 @@ func main() {
Content: content,
})
// 添加新消息到上下文
addToContext(rdb, groupID, qqID, text)
addToContext(rdb, groupID, qqID, text, CONTEXT)
fmt.Println(content)
fmt.Println("---------------------")
// 检查请求频率
if !checkRequestFrequency(rdb, groupID, qqID) {
// fmt.Println("请求太频繁,请稍后再试。")
return
}
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
}
}

View file

@ -6,6 +6,7 @@ import (
"go-bot/config"
"log"
"sync"
"time"
"github.com/go-redis/redis/v8"
)
@ -49,18 +50,43 @@ func GetRedisClient() *redis.Client {
}
// 添加对话到上下文超过5条则删除所有上下文
func AddToContext(key string, message string, contextlenth int64) {
func AddToContext(key string, message []byte, contextlenth int64) {
// key := fmt.Sprintf("context:%s:%s:%s", worker, groupID, qqID)
// 检查上下文持续时间,如果大于一小时就删除所有列表
lastRequestKey := fmt.Sprintf("last_request:%s", key)
now := time.Now().Unix()
exists, err := CheckKeyExists(lastRequestKey)
if err != nil || !exists {
log.Println("检查键是否存在时出错:", err)
rdb.Del(ctx, key)
rdb.Set(ctx, lastRequestKey, fmt.Sprintf("%d", now), 60*time.Minute)
}
// 如果上下文超过5条删除所有上下文
listLength := GetListLength(key)
// log.Println("listLength:", listLength)
if listLength >= contextlenth {
if listLength >= contextlenth*2 {
rdb.Del(ctx, key) // 删除该用户的所有上下文
}
rdb.RPush(ctx, key, message) // 添加新消息到列表
}
// 检查请求频率10秒内只能请求一次
func CheckRequestFrequency(GID string, UID string, interval int64) bool {
key := fmt.Sprintf("time_interval:%s:%s", GID, UID)
// lastRequest, err := rdb.Get(ctx, key).Int64()
exists, err := CheckKeyExists(key)
if err != nil || !exists {
// log.Println("检查键是否存在时出错:", err)
rdb.Set(ctx, key, fmt.Sprintf("%d", time.Now().Unix()), time.Duration(interval)*time.Second)
return true
}
return false // 频率超限,拒绝请求
}
// getListLength 获取列表长度
func GetListLength(key string) int64 {
return rdb.LLen(ctx, key).Val()

View file

@ -3,6 +3,7 @@ package workers
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"go-bot/tools"
@ -62,38 +63,55 @@ func (a *AI) GetMsg() string {
redisClient := tools.GetRedisClient()
// 如果redisClient不为空则获取上下文
if redisClient != nil {
// 限制请求频率
if !tools.CheckRequestFrequency(a.GID, a.UID, 10) {
log.Printf("请求过于频繁。\n")
return "请求过于频繁。"
}
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
// 获取上下文
message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
}
// 序列化为 JSON 字符串
jsonMessage, err := json.Marshal(message)
if err != nil {
log.Println("序列化错误:", err)
log.Println("存储的 JSON 字符串:", string(jsonMessage))
tools.AddToContext(key, a.RawMsg[strings.Index(a.RawMsg, " ")+1:], CONTEXT)
return "序列化错误"
}
tools.AddToContext(key, jsonMessage, CONTEXT)
// }
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
length := tools.GetListLength(key)
if length > 0 {
// if length > 0 {
for i := 0; i < int(length); i++ {
message, err := tools.GetListValue(key, int64(i))
if err != nil {
log.Println("获取上下文失败:", err)
return "获取上下文失败"
}
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: message,
})
}
} else {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
})
}
log.Println("读取的 JSON 字符串:", message)
var msg openai.ChatCompletionMessage
err = json.Unmarshal([]byte(message), &msg)
if err != nil {
log.Println("反序列化错误:", err)
return "反序列化错误"
}
messages = append(messages, msg)
}
} else {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
})
}
// println("messages:", messages)
// for i, msg := range messages {
@ -116,7 +134,22 @@ func (a *AI) GetMsg() string {
return "请求失败"
}
// println(resp.Choices[0].Message.Content)
return msg + stripMarkdown(resp.Choices[0].Message.Content)
content := resp.Choices[0].Message.Content
if redisClient != nil {
message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: stripMarkdown(content)}
// 序列化为 JSON 字符串
jsonMessage, err := json.Marshal(message)
if err != nil {
log.Println("Assistant消息序列化错误:", err)
return "Assistant消息序列化错误"
}
tools.AddToContext(key, jsonMessage, CONTEXT)
}
return msg + stripMarkdown(content)
} else {
// 匹配回复消息
pattern := `^\[CQ:reply,id=(-?\d+)\]`