From 43af7b4623346f1a7911fdd0a9211893eecec4f3 Mon Sep 17 00:00:00 2001 From: liyp Date: Sun, 22 Sep 2024 20:05:10 +0800 Subject: [PATCH] =?UTF-8?q?refactor(ai):=20=E4=BC=98=E5=8C=96=E4=B8=8A?= =?UTF-8?q?=E4=B8=8B=E6=96=87=E7=AE=A1=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=AF=B7=E6=B1=82=E9=A2=91=E7=8E=87=E9=99=90?= =?UTF-8?q?=E5=88=B6=E5=92=8C=E4=B8=8A=E4=B8=8B=E6=96=87=E6=8C=81=E7=BB=AD?= =?UTF-8?q?=E6=97=B6=E9=97=B4=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/openai.go | 29 +++++++++++------- tools/redisClient.go | 30 +++++++++++++++++-- workers/ai.go | 71 ++++++++++++++++++++++++++++++++------------ 3 files changed, 99 insertions(+), 31 deletions(-) diff --git a/test/openai.go b/test/openai.go index 1c5ed03..84bf081 100644 --- a/test/openai.go +++ b/test/openai.go @@ -57,12 +57,12 @@ func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool } // 添加对话到上下文,超过5条则删除所有上下文 -func addToContext(rdb *redis.Client, groupID string, qqID string, message string) { +func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) { key := fmt.Sprintf("context:%s:%s", groupID, qqID) // 如果上下文超过5条,删除所有上下文 listLength := rdb.LLen(ctx, key).Val() - if listLength > 5 { + if listLength > context { rdb.Del(ctx, key) // 删除该用户的所有上下文 } rdb.RPush(ctx, key, message) // 添加新消息到列表 @@ -85,23 +85,23 @@ func main() { qqID := "67890" rdb := initRedis() - OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig() + OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig() oaiconfig := openai.DefaultConfig(OPENAI_API_KEY) oaiconfig.BaseURL = OPENAI_BaseURL client := openai.NewClientWithConfig(oaiconfig) messages := make([]openai.ChatCompletionMessage, 0) + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: PROMPT, + }) reader := bufio.NewReader(os.Stdin) fmt.Println("Conversation") fmt.Println("---------------------") for { + fmt.Print("-> ") - // 检查请求频率 - if !checkRequestFrequency(rdb, groupID, qqID) { - // fmt.Println("请求太频繁,请稍后再试。") - return - } text, _ := reader.ReadString('\n') // convert CRLF to LF text = strings.Replace(text, "\n", "", -1) @@ -110,9 +110,13 @@ func main() { Role: openai.ChatMessageRoleUser, Content: text, }) + // 输出 messages 的内容 + for i, msg := range messages { + fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content) + } resp, err := client.CreateChatCompletion( - ctx, + context.Background(), openai.ChatCompletionRequest{ Model: MODEL, Messages: messages, @@ -130,10 +134,15 @@ func main() { Content: content, }) // 添加新消息到上下文 - addToContext(rdb, groupID, qqID, text) + addToContext(rdb, groupID, qqID, text, CONTEXT) fmt.Println(content) fmt.Println("---------------------") + // 检查请求频率 + if !checkRequestFrequency(rdb, groupID, qqID) { + // fmt.Println("请求太频繁,请稍后再试。") + return + } fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID)) } } diff --git a/tools/redisClient.go b/tools/redisClient.go index 6de3fc0..47a720d 100644 --- a/tools/redisClient.go +++ b/tools/redisClient.go @@ -6,6 +6,7 @@ import ( "go-bot/config" "log" "sync" + "time" "github.com/go-redis/redis/v8" ) @@ -49,18 +50,43 @@ func GetRedisClient() *redis.Client { } // 添加对话到上下文,超过5条则删除所有上下文 -func AddToContext(key string, message string, contextlenth int64) { +func AddToContext(key string, message []byte, contextlenth int64) { // key := fmt.Sprintf("context:%s:%s:%s", worker, groupID, qqID) + // 检查上下文持续时间,如果大于一小时就删除所有列表 + lastRequestKey := fmt.Sprintf("last_request:%s", key) + now := time.Now().Unix() + exists, err := CheckKeyExists(lastRequestKey) + if err != nil || !exists { + log.Println("检查键是否存在时出错:", err) + rdb.Del(ctx, key) + rdb.Set(ctx, lastRequestKey, fmt.Sprintf("%d", now), 60*time.Minute) + } // 如果上下文超过5条,删除所有上下文 listLength := GetListLength(key) // log.Println("listLength:", listLength) - if listLength >= contextlenth { + if listLength >= contextlenth*2 { rdb.Del(ctx, key) // 删除该用户的所有上下文 } rdb.RPush(ctx, key, message) // 添加新消息到列表 } +// 检查请求频率,10秒内只能请求一次 +func CheckRequestFrequency(GID string, UID string, interval int64) bool { + key := fmt.Sprintf("time_interval:%s:%s", GID, UID) + // lastRequest, err := rdb.Get(ctx, key).Int64() + exists, err := CheckKeyExists(key) + if err != nil || !exists { + // log.Println("检查键是否存在时出错:", err) + + rdb.Set(ctx, key, fmt.Sprintf("%d", time.Now().Unix()), time.Duration(interval)*time.Second) + return true + } + + return false // 频率超限,拒绝请求 + +} + // getListLength 获取列表长度 func GetListLength(key string) int64 { return rdb.LLen(ctx, key).Val() diff --git a/workers/ai.go b/workers/ai.go index 381761d..eeb2f6a 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -3,6 +3,7 @@ package workers import ( "context" "encoding/base64" + "encoding/json" "fmt" "go-bot/tools" @@ -62,38 +63,55 @@ func (a *AI) GetMsg() string { redisClient := tools.GetRedisClient() // 如果redisClient不为空,则获取上下文 if redisClient != nil { + // 限制请求频率 + if !tools.CheckRequestFrequency(a.GID, a.UID, 10) { + log.Printf("请求过于频繁。\n") + return "请求过于频繁。" + } key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID) // 获取上下文 + message := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], + } + // 序列化为 JSON 字符串 + jsonMessage, err := json.Marshal(message) + if err != nil { + log.Println("序列化错误:", err) + log.Println("存储的 JSON 字符串:", string(jsonMessage)) - tools.AddToContext(key, a.RawMsg[strings.Index(a.RawMsg, " ")+1:], CONTEXT) + return "序列化错误" + } + tools.AddToContext(key, jsonMessage, CONTEXT) // } // println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:]) length := tools.GetListLength(key) - if length > 0 { - for i := 0; i < int(length); i++ { - message, err := tools.GetListValue(key, int64(i)) - if err != nil { - log.Println("获取上下文失败:", err) - return "获取上下文失败" - } - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: message, - }) + // if length > 0 { + for i := 0; i < int(length); i++ { + message, err := tools.GetListValue(key, int64(i)) + if err != nil { + log.Println("获取上下文失败:", err) + return "获取上下文失败" } - } else { - messages = append(messages, openai.ChatCompletionMessage{ - Role: openai.ChatMessageRoleUser, - Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], - }) - } + log.Println("读取的 JSON 字符串:", message) + var msg openai.ChatCompletionMessage + err = json.Unmarshal([]byte(message), &msg) + if err != nil { + log.Println("反序列化错误:", err) + + return "反序列化错误" + } + messages = append(messages, msg) + } } else { + messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], }) + } // println("messages:", messages) // for i, msg := range messages { @@ -116,7 +134,22 @@ func (a *AI) GetMsg() string { return "请求失败" } // println(resp.Choices[0].Message.Content) - return msg + stripMarkdown(resp.Choices[0].Message.Content) + content := resp.Choices[0].Message.Content + if redisClient != nil { + + message := openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleAssistant, + Content: stripMarkdown(content)} + // 序列化为 JSON 字符串 + jsonMessage, err := json.Marshal(message) + if err != nil { + log.Println("Assistant消息序列化错误:", err) + return "Assistant消息序列化错误" + } + tools.AddToContext(key, jsonMessage, CONTEXT) + + } + return msg + stripMarkdown(content) } else { // 匹配回复消息 pattern := `^\[CQ:reply,id=(-?\d+)\]`