go-bot/test/openai.go

139 lines
3.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"fmt"
"go-bot/tools"
"os"
"strings"
"time"
"github.com/redis/go-redis/v9"
"github.com/sashabaranov/go-openai"
)
var ctx = context.Background()
// 初始化 Redis 客户端
func initRedis() *redis.Client {
rdb := redis.NewClient(&redis.Options{
Addr: "localhost:6379", // Redis 地址
Password: "", // no password set
DB: 0, // use default DB
})
return rdb
}
// 检查请求频率10秒内只能请求一次
func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool {
key := fmt.Sprintf("last_request:%s:%s", groupID, qqID)
lastRequest, err := rdb.Get(ctx, key).Int64()
println("key:", key)
now := time.Now().Unix()
println("now:", now)
println("lastRequest:", lastRequest)
if err == redis.Nil {
// 键不存在,这是第一次请求
rdb.Set(ctx, key, now, 10*time.Second)
return true
} else if err != nil {
fmt.Println("获取上次请求时间时出错:", err)
return false
}
currentRequest := getContext(rdb, key)
fmt.Println("currentRequest:", currentRequest)
if now-lastRequest < 10 {
fmt.Printf("请求过于频繁。距离上次请求还有 %d 秒。\n", 10-(now-lastRequest))
return false // 频率超限,拒绝请求
}
// 更新最后请求时间,并设置 10 秒的过期时间
rdb.Set(ctx, key, now, 10*time.Second)
return true
}
// 添加对话到上下文超过5条则删除所有上下文
func addToContext(rdb *redis.Client, groupID string, qqID string, message string) {
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
// 如果上下文超过5条删除所有上下文
listLength := rdb.LLen(ctx, key).Val()
if listLength > 5 {
rdb.Del(ctx, key) // 删除该用户的所有上下文
}
rdb.RPush(ctx, key, message) // 添加新消息到列表
}
// 获取当前上下文
func getContext(rdb *redis.Client, key string) []string {
// key :=
context, err := rdb.LRange(ctx, key, 0, -1).Result()
if err != nil {
fmt.Println("Error fetching context:", err)
return nil
}
return context
}
func main() {
// ctx := context.Background()
groupID := "12345"
qqID := "67890"
rdb := initRedis()
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := tools.GetOAIConfig()
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
oaiconfig.BaseURL = OPENAI_BaseURL
client := openai.NewClientWithConfig(oaiconfig)
messages := make([]openai.ChatCompletionMessage, 0)
reader := bufio.NewReader(os.Stdin)
fmt.Println("Conversation")
fmt.Println("---------------------")
for {
fmt.Print("-> ")
// 检查请求频率
if !checkRequestFrequency(rdb, groupID, qqID) {
// fmt.Println("请求太频繁,请稍后再试。")
return
}
text, _ := reader.ReadString('\n')
// convert CRLF to LF
text = strings.Replace(text, "\n", "", -1)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: text,
})
resp, err := client.CreateChatCompletion(
ctx,
openai.ChatCompletionRequest{
Model: MODEL,
Messages: messages,
},
)
if err != nil {
fmt.Printf("ChatCompletion error: %v\n", err)
continue
}
content := resp.Choices[0].Message.Content
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: content,
})
// 添加新消息到上下文
addToContext(rdb, groupID, qqID, text)
fmt.Println(content)
fmt.Println("---------------------")
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
}
}