refactor(ai): 优化上下文管理逻辑,添加请求频率限制和上下文持续时间检查
This commit is contained in:
parent
90e51bc485
commit
43af7b4623
3 changed files with 99 additions and 31 deletions
|
@ -57,12 +57,12 @@ func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool
|
|||
}
|
||||
|
||||
// 添加对话到上下文,超过5条则删除所有上下文
|
||||
func addToContext(rdb *redis.Client, groupID string, qqID string, message string) {
|
||||
func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) {
|
||||
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
|
||||
|
||||
// 如果上下文超过5条,删除所有上下文
|
||||
listLength := rdb.LLen(ctx, key).Val()
|
||||
if listLength > 5 {
|
||||
if listLength > context {
|
||||
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
||||
}
|
||||
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
||||
|
@ -85,23 +85,23 @@ func main() {
|
|||
qqID := "67890"
|
||||
rdb := initRedis()
|
||||
|
||||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig()
|
||||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
|
||||
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
||||
oaiconfig.BaseURL = OPENAI_BaseURL
|
||||
client := openai.NewClientWithConfig(oaiconfig)
|
||||
|
||||
messages := make([]openai.ChatCompletionMessage, 0)
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: PROMPT,
|
||||
})
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
fmt.Println("Conversation")
|
||||
fmt.Println("---------------------")
|
||||
|
||||
for {
|
||||
|
||||
fmt.Print("-> ")
|
||||
// 检查请求频率
|
||||
if !checkRequestFrequency(rdb, groupID, qqID) {
|
||||
// fmt.Println("请求太频繁,请稍后再试。")
|
||||
return
|
||||
}
|
||||
text, _ := reader.ReadString('\n')
|
||||
// convert CRLF to LF
|
||||
text = strings.Replace(text, "\n", "", -1)
|
||||
|
@ -110,9 +110,13 @@ func main() {
|
|||
Role: openai.ChatMessageRoleUser,
|
||||
Content: text,
|
||||
})
|
||||
// 输出 messages 的内容
|
||||
for i, msg := range messages {
|
||||
fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content)
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
ctx,
|
||||
context.Background(),
|
||||
openai.ChatCompletionRequest{
|
||||
Model: MODEL,
|
||||
Messages: messages,
|
||||
|
@ -130,10 +134,15 @@ func main() {
|
|||
Content: content,
|
||||
})
|
||||
// 添加新消息到上下文
|
||||
addToContext(rdb, groupID, qqID, text)
|
||||
addToContext(rdb, groupID, qqID, text, CONTEXT)
|
||||
|
||||
fmt.Println(content)
|
||||
fmt.Println("---------------------")
|
||||
// 检查请求频率
|
||||
if !checkRequestFrequency(rdb, groupID, qqID) {
|
||||
// fmt.Println("请求太频繁,请稍后再试。")
|
||||
return
|
||||
}
|
||||
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"go-bot/config"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
@ -49,18 +50,43 @@ func GetRedisClient() *redis.Client {
|
|||
}
|
||||
|
||||
// 添加对话到上下文,超过5条则删除所有上下文
|
||||
func AddToContext(key string, message string, contextlenth int64) {
|
||||
func AddToContext(key string, message []byte, contextlenth int64) {
|
||||
// key := fmt.Sprintf("context:%s:%s:%s", worker, groupID, qqID)
|
||||
// 检查上下文持续时间,如果大于一小时就删除所有列表
|
||||
lastRequestKey := fmt.Sprintf("last_request:%s", key)
|
||||
now := time.Now().Unix()
|
||||
exists, err := CheckKeyExists(lastRequestKey)
|
||||
if err != nil || !exists {
|
||||
log.Println("检查键是否存在时出错:", err)
|
||||
rdb.Del(ctx, key)
|
||||
rdb.Set(ctx, lastRequestKey, fmt.Sprintf("%d", now), 60*time.Minute)
|
||||
}
|
||||
|
||||
// 如果上下文超过5条,删除所有上下文
|
||||
listLength := GetListLength(key)
|
||||
// log.Println("listLength:", listLength)
|
||||
if listLength >= contextlenth {
|
||||
if listLength >= contextlenth*2 {
|
||||
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
||||
}
|
||||
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
||||
}
|
||||
|
||||
// 检查请求频率,10秒内只能请求一次
|
||||
func CheckRequestFrequency(GID string, UID string, interval int64) bool {
|
||||
key := fmt.Sprintf("time_interval:%s:%s", GID, UID)
|
||||
// lastRequest, err := rdb.Get(ctx, key).Int64()
|
||||
exists, err := CheckKeyExists(key)
|
||||
if err != nil || !exists {
|
||||
// log.Println("检查键是否存在时出错:", err)
|
||||
|
||||
rdb.Set(ctx, key, fmt.Sprintf("%d", time.Now().Unix()), time.Duration(interval)*time.Second)
|
||||
return true
|
||||
}
|
||||
|
||||
return false // 频率超限,拒绝请求
|
||||
|
||||
}
|
||||
|
||||
// getListLength 获取列表长度
|
||||
func GetListLength(key string) int64 {
|
||||
return rdb.LLen(ctx, key).Val()
|
||||
|
|
|
@ -3,6 +3,7 @@ package workers
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"go-bot/tools"
|
||||
|
||||
|
@ -62,38 +63,55 @@ func (a *AI) GetMsg() string {
|
|||
redisClient := tools.GetRedisClient()
|
||||
// 如果redisClient不为空,则获取上下文
|
||||
if redisClient != nil {
|
||||
// 限制请求频率
|
||||
if !tools.CheckRequestFrequency(a.GID, a.UID, 10) {
|
||||
log.Printf("请求过于频繁。\n")
|
||||
return "请求过于频繁。"
|
||||
}
|
||||
|
||||
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
|
||||
// 获取上下文
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||||
}
|
||||
// 序列化为 JSON 字符串
|
||||
jsonMessage, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
log.Println("序列化错误:", err)
|
||||
log.Println("存储的 JSON 字符串:", string(jsonMessage))
|
||||
|
||||
tools.AddToContext(key, a.RawMsg[strings.Index(a.RawMsg, " ")+1:], CONTEXT)
|
||||
return "序列化错误"
|
||||
}
|
||||
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||||
// }
|
||||
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
|
||||
length := tools.GetListLength(key)
|
||||
if length > 0 {
|
||||
// if length > 0 {
|
||||
for i := 0; i < int(length); i++ {
|
||||
message, err := tools.GetListValue(key, int64(i))
|
||||
if err != nil {
|
||||
log.Println("获取上下文失败:", err)
|
||||
return "获取上下文失败"
|
||||
}
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: message,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||||
})
|
||||
}
|
||||
log.Println("读取的 JSON 字符串:", message)
|
||||
|
||||
var msg openai.ChatCompletionMessage
|
||||
err = json.Unmarshal([]byte(message), &msg)
|
||||
if err != nil {
|
||||
log.Println("反序列化错误:", err)
|
||||
|
||||
return "反序列化错误"
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
} else {
|
||||
|
||||
messages = append(messages, openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||||
})
|
||||
|
||||
}
|
||||
// println("messages:", messages)
|
||||
// for i, msg := range messages {
|
||||
|
@ -116,7 +134,22 @@ func (a *AI) GetMsg() string {
|
|||
return "请求失败"
|
||||
}
|
||||
// println(resp.Choices[0].Message.Content)
|
||||
return msg + stripMarkdown(resp.Choices[0].Message.Content)
|
||||
content := resp.Choices[0].Message.Content
|
||||
if redisClient != nil {
|
||||
|
||||
message := openai.ChatCompletionMessage{
|
||||
Role: openai.ChatMessageRoleAssistant,
|
||||
Content: stripMarkdown(content)}
|
||||
// 序列化为 JSON 字符串
|
||||
jsonMessage, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
log.Println("Assistant消息序列化错误:", err)
|
||||
return "Assistant消息序列化错误"
|
||||
}
|
||||
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||||
|
||||
}
|
||||
return msg + stripMarkdown(content)
|
||||
} else {
|
||||
// 匹配回复消息
|
||||
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
||||
|
|
Loading…
Reference in a new issue