refactor(ai): 优化上下文管理逻辑,添加请求频率限制和上下文持续时间检查
This commit is contained in:
parent
90e51bc485
commit
43af7b4623
3 changed files with 99 additions and 31 deletions
|
@ -57,12 +57,12 @@ func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加对话到上下文,超过5条则删除所有上下文
|
// 添加对话到上下文,超过5条则删除所有上下文
|
||||||
func addToContext(rdb *redis.Client, groupID string, qqID string, message string) {
|
func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) {
|
||||||
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
|
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
|
||||||
|
|
||||||
// 如果上下文超过5条,删除所有上下文
|
// 如果上下文超过5条,删除所有上下文
|
||||||
listLength := rdb.LLen(ctx, key).Val()
|
listLength := rdb.LLen(ctx, key).Val()
|
||||||
if listLength > 5 {
|
if listLength > context {
|
||||||
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
||||||
}
|
}
|
||||||
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
||||||
|
@ -85,23 +85,23 @@ func main() {
|
||||||
qqID := "67890"
|
qqID := "67890"
|
||||||
rdb := initRedis()
|
rdb := initRedis()
|
||||||
|
|
||||||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig()
|
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
|
||||||
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
||||||
oaiconfig.BaseURL = OPENAI_BaseURL
|
oaiconfig.BaseURL = OPENAI_BaseURL
|
||||||
client := openai.NewClientWithConfig(oaiconfig)
|
client := openai.NewClientWithConfig(oaiconfig)
|
||||||
|
|
||||||
messages := make([]openai.ChatCompletionMessage, 0)
|
messages := make([]openai.ChatCompletionMessage, 0)
|
||||||
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: PROMPT,
|
||||||
|
})
|
||||||
reader := bufio.NewReader(os.Stdin)
|
reader := bufio.NewReader(os.Stdin)
|
||||||
fmt.Println("Conversation")
|
fmt.Println("Conversation")
|
||||||
fmt.Println("---------------------")
|
fmt.Println("---------------------")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
||||||
fmt.Print("-> ")
|
fmt.Print("-> ")
|
||||||
// 检查请求频率
|
|
||||||
if !checkRequestFrequency(rdb, groupID, qqID) {
|
|
||||||
// fmt.Println("请求太频繁,请稍后再试。")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
text, _ := reader.ReadString('\n')
|
text, _ := reader.ReadString('\n')
|
||||||
// convert CRLF to LF
|
// convert CRLF to LF
|
||||||
text = strings.Replace(text, "\n", "", -1)
|
text = strings.Replace(text, "\n", "", -1)
|
||||||
|
@ -110,9 +110,13 @@ func main() {
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: text,
|
Content: text,
|
||||||
})
|
})
|
||||||
|
// 输出 messages 的内容
|
||||||
|
for i, msg := range messages {
|
||||||
|
fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content)
|
||||||
|
}
|
||||||
|
|
||||||
resp, err := client.CreateChatCompletion(
|
resp, err := client.CreateChatCompletion(
|
||||||
ctx,
|
context.Background(),
|
||||||
openai.ChatCompletionRequest{
|
openai.ChatCompletionRequest{
|
||||||
Model: MODEL,
|
Model: MODEL,
|
||||||
Messages: messages,
|
Messages: messages,
|
||||||
|
@ -130,10 +134,15 @@ func main() {
|
||||||
Content: content,
|
Content: content,
|
||||||
})
|
})
|
||||||
// 添加新消息到上下文
|
// 添加新消息到上下文
|
||||||
addToContext(rdb, groupID, qqID, text)
|
addToContext(rdb, groupID, qqID, text, CONTEXT)
|
||||||
|
|
||||||
fmt.Println(content)
|
fmt.Println(content)
|
||||||
fmt.Println("---------------------")
|
fmt.Println("---------------------")
|
||||||
|
// 检查请求频率
|
||||||
|
if !checkRequestFrequency(rdb, groupID, qqID) {
|
||||||
|
// fmt.Println("请求太频繁,请稍后再试。")
|
||||||
|
return
|
||||||
|
}
|
||||||
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
|
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"go-bot/config"
|
"go-bot/config"
|
||||||
"log"
|
"log"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
@ -49,18 +50,43 @@ func GetRedisClient() *redis.Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 添加对话到上下文,超过5条则删除所有上下文
|
// 添加对话到上下文,超过5条则删除所有上下文
|
||||||
func AddToContext(key string, message string, contextlenth int64) {
|
func AddToContext(key string, message []byte, contextlenth int64) {
|
||||||
// key := fmt.Sprintf("context:%s:%s:%s", worker, groupID, qqID)
|
// key := fmt.Sprintf("context:%s:%s:%s", worker, groupID, qqID)
|
||||||
|
// 检查上下文持续时间,如果大于一小时就删除所有列表
|
||||||
|
lastRequestKey := fmt.Sprintf("last_request:%s", key)
|
||||||
|
now := time.Now().Unix()
|
||||||
|
exists, err := CheckKeyExists(lastRequestKey)
|
||||||
|
if err != nil || !exists {
|
||||||
|
log.Println("检查键是否存在时出错:", err)
|
||||||
|
rdb.Del(ctx, key)
|
||||||
|
rdb.Set(ctx, lastRequestKey, fmt.Sprintf("%d", now), 60*time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
// 如果上下文超过5条,删除所有上下文
|
// 如果上下文超过5条,删除所有上下文
|
||||||
listLength := GetListLength(key)
|
listLength := GetListLength(key)
|
||||||
// log.Println("listLength:", listLength)
|
// log.Println("listLength:", listLength)
|
||||||
if listLength >= contextlenth {
|
if listLength >= contextlenth*2 {
|
||||||
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
||||||
}
|
}
|
||||||
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查请求频率,10秒内只能请求一次
|
||||||
|
func CheckRequestFrequency(GID string, UID string, interval int64) bool {
|
||||||
|
key := fmt.Sprintf("time_interval:%s:%s", GID, UID)
|
||||||
|
// lastRequest, err := rdb.Get(ctx, key).Int64()
|
||||||
|
exists, err := CheckKeyExists(key)
|
||||||
|
if err != nil || !exists {
|
||||||
|
// log.Println("检查键是否存在时出错:", err)
|
||||||
|
|
||||||
|
rdb.Set(ctx, key, fmt.Sprintf("%d", time.Now().Unix()), time.Duration(interval)*time.Second)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false // 频率超限,拒绝请求
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// getListLength 获取列表长度
|
// getListLength 获取列表长度
|
||||||
func GetListLength(key string) int64 {
|
func GetListLength(key string) int64 {
|
||||||
return rdb.LLen(ctx, key).Val()
|
return rdb.LLen(ctx, key).Val()
|
||||||
|
|
|
@ -3,6 +3,7 @@ package workers
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"go-bot/tools"
|
"go-bot/tools"
|
||||||
|
|
||||||
|
@ -62,38 +63,55 @@ func (a *AI) GetMsg() string {
|
||||||
redisClient := tools.GetRedisClient()
|
redisClient := tools.GetRedisClient()
|
||||||
// 如果redisClient不为空,则获取上下文
|
// 如果redisClient不为空,则获取上下文
|
||||||
if redisClient != nil {
|
if redisClient != nil {
|
||||||
|
// 限制请求频率
|
||||||
|
if !tools.CheckRequestFrequency(a.GID, a.UID, 10) {
|
||||||
|
log.Printf("请求过于频繁。\n")
|
||||||
|
return "请求过于频繁。"
|
||||||
|
}
|
||||||
|
|
||||||
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
|
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
|
||||||
// 获取上下文
|
// 获取上下文
|
||||||
|
message := openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||||||
|
}
|
||||||
|
// 序列化为 JSON 字符串
|
||||||
|
jsonMessage, err := json.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("序列化错误:", err)
|
||||||
|
log.Println("存储的 JSON 字符串:", string(jsonMessage))
|
||||||
|
|
||||||
tools.AddToContext(key, a.RawMsg[strings.Index(a.RawMsg, " ")+1:], CONTEXT)
|
return "序列化错误"
|
||||||
|
}
|
||||||
|
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||||||
// }
|
// }
|
||||||
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
|
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
|
||||||
length := tools.GetListLength(key)
|
length := tools.GetListLength(key)
|
||||||
if length > 0 {
|
// if length > 0 {
|
||||||
for i := 0; i < int(length); i++ {
|
for i := 0; i < int(length); i++ {
|
||||||
message, err := tools.GetListValue(key, int64(i))
|
message, err := tools.GetListValue(key, int64(i))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("获取上下文失败:", err)
|
log.Println("获取上下文失败:", err)
|
||||||
return "获取上下文失败"
|
return "获取上下文失败"
|
||||||
}
|
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
|
||||||
Role: openai.ChatMessageRoleUser,
|
|
||||||
Content: message,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
} else {
|
log.Println("读取的 JSON 字符串:", message)
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
|
||||||
Role: openai.ChatMessageRoleUser,
|
|
||||||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
|
var msg openai.ChatCompletionMessage
|
||||||
|
err = json.Unmarshal([]byte(message), &msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("反序列化错误:", err)
|
||||||
|
|
||||||
|
return "反序列化错误"
|
||||||
|
}
|
||||||
|
messages = append(messages, msg)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
messages = append(messages, openai.ChatCompletionMessage{
|
messages = append(messages, openai.ChatCompletionMessage{
|
||||||
Role: openai.ChatMessageRoleUser,
|
Role: openai.ChatMessageRoleUser,
|
||||||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
// println("messages:", messages)
|
// println("messages:", messages)
|
||||||
// for i, msg := range messages {
|
// for i, msg := range messages {
|
||||||
|
@ -116,7 +134,22 @@ func (a *AI) GetMsg() string {
|
||||||
return "请求失败"
|
return "请求失败"
|
||||||
}
|
}
|
||||||
// println(resp.Choices[0].Message.Content)
|
// println(resp.Choices[0].Message.Content)
|
||||||
return msg + stripMarkdown(resp.Choices[0].Message.Content)
|
content := resp.Choices[0].Message.Content
|
||||||
|
if redisClient != nil {
|
||||||
|
|
||||||
|
message := openai.ChatCompletionMessage{
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
Content: stripMarkdown(content)}
|
||||||
|
// 序列化为 JSON 字符串
|
||||||
|
jsonMessage, err := json.Marshal(message)
|
||||||
|
if err != nil {
|
||||||
|
log.Println("Assistant消息序列化错误:", err)
|
||||||
|
return "Assistant消息序列化错误"
|
||||||
|
}
|
||||||
|
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||||||
|
|
||||||
|
}
|
||||||
|
return msg + stripMarkdown(content)
|
||||||
} else {
|
} else {
|
||||||
// 匹配回复消息
|
// 匹配回复消息
|
||||||
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
||||||
|
|
Loading…
Reference in a new issue