go-bot/test/openai.go

140 lines
3.5 KiB
Go
Raw Normal View History

package main
import (
"bufio"
"context"
"fmt"
"go-bot/tools"
"os"
"strings"
"time"
"github.com/redis/go-redis/v9"
"github.com/sashabaranov/go-openai"
)
var ctx = context.Background()
// 初始化 Redis 客户端
func initRedis() *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Redis 地址
Password: "", // no password set
DB: 0, // use default DB
})
return rdb
}
// 检查请求频率10秒内只能请求一次
func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool {
key := fmt.Sprintf("last_request:%s:%s", groupID, qqID)
lastRequest, err := rdb.Get(ctx, key).Int64()
println("key:", key)
now := time.Now().Unix()
println("now:", now)
println("lastRequest:", lastRequest)
if err == redis.Nil {
// 键不存在,这是第一次请求
rdb.Set(ctx, key, now, 10*time.Second)
return true
} else if err != nil {
fmt.Println("获取上次请求时间时出错:", err)
return false
}
currentRequest := getContext(rdb, key)
fmt.Println("currentRequest:", currentRequest)
if now-lastRequest < 10 {
fmt.Printf("请求过于频繁。距离上次请求还有 %d 秒。\n", 10-(now-lastRequest))
return false // 频率超限,拒绝请求
}
// 更新最后请求时间,并设置 10 秒的过期时间
rdb.Set(ctx, key, now, 10*time.Second)
return true
}
// 添加对话到上下文超过5条则删除所有上下文
func addToContext(rdb *redis.Client, groupID string, qqID string, message string) {
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
// 如果上下文超过5条删除所有上下文
listLength := rdb.LLen(ctx, key).Val()
if listLength > 5 {
rdb.Del(ctx, key) // 删除该用户的所有上下文
}
rdb.RPush(ctx, key, message) // 添加新消息到列表
}
// 获取当前上下文
func getContext(rdb *redis.Client, key string) []string {
// key :=
context, err := rdb.LRange(ctx, key, 0, -1).Result()
if err != nil {
fmt.Println("Error fetching context:", err)
return nil
}
return context
}
func main() {
// ctx := context.Background()
groupID := "12345"
qqID := "67890"
rdb := initRedis()
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig()
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
oaiconfig.BaseURL = OPENAI_BaseURL
client := openai.NewClientWithConfig(oaiconfig)
messages := make([]openai.ChatCompletionMessage, 0)
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")
for {
fmt.Print("-> ")
// 检查请求频率
if !checkRequestFrequency(rdb, groupID, qqID) {
// fmt.Println("请求太频繁,请稍后再试。")
return
}
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: text,
})
resp, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: MODEL,
Messages: messages,
},
)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
continue
}
content := resp.Choices[0].Message.Content
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: content,
})
// 添加新消息到上下文
addToContext(rdb, groupID, qqID, text)
fmt.Println(content)
fmt.Println("---------------------")
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
}
}