148 lines
3.8 KiB
Go
148 lines
3.8 KiB
Go
package main
|
||
|
||
import (
|
||
"bufio"
|
||
"context"
|
||
"fmt"
|
||
"go-bot/tools"
|
||
|
||
"os"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/redis/go-redis/v9"
|
||
"github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
var ctx = context.Background()
|
||
|
||
// 初始化 Redis 客户端
|
||
func initRedis() *redis.Client {
|
||
rdb := redis.NewClient(&redis.Options{
|
||
Addr: "localhost:6379", // Redis 地址
|
||
Password: "", // no password set
|
||
DB: 0, // use default DB
|
||
})
|
||
return rdb
|
||
}
|
||
|
||
// 检查请求频率,10秒内只能请求一次
|
||
func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool {
|
||
key := fmt.Sprintf("last_request:%s:%s", groupID, qqID)
|
||
lastRequest, err := rdb.Get(ctx, key).Int64()
|
||
println("key:", key)
|
||
now := time.Now().Unix()
|
||
println("now:", now)
|
||
println("lastRequest:", lastRequest)
|
||
|
||
if err == redis.Nil {
|
||
// 键不存在,这是第一次请求
|
||
rdb.Set(ctx, key, now, 10*time.Second)
|
||
|
||
return true
|
||
} else if err != nil {
|
||
fmt.Println("获取上次请求时间时出错:", err)
|
||
return false
|
||
}
|
||
currentRequest := getContext(rdb, key)
|
||
fmt.Println("currentRequest:", currentRequest)
|
||
if now-lastRequest < 10 {
|
||
fmt.Printf("请求过于频繁。距离上次请求还有 %d 秒。\n", 10-(now-lastRequest))
|
||
return false // 频率超限,拒绝请求
|
||
}
|
||
|
||
// 更新最后请求时间,并设置 10 秒的过期时间
|
||
rdb.Set(ctx, key, now, 10*time.Second)
|
||
return true
|
||
}
|
||
|
||
// 添加对话到上下文,超过5条则删除所有上下文
|
||
func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) {
|
||
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
|
||
|
||
// 如果上下文超过5条,删除所有上下文
|
||
listLength := rdb.LLen(ctx, key).Val()
|
||
if listLength > context {
|
||
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
||
}
|
||
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
||
}
|
||
|
||
// 获取当前上下文
|
||
func getContext(rdb *redis.Client, key string) []string {
|
||
// key :=
|
||
context, err := rdb.LRange(ctx, key, 0, -1).Result()
|
||
if err != nil {
|
||
fmt.Println("Error fetching context:", err)
|
||
return nil
|
||
}
|
||
return context
|
||
}
|
||
|
||
func main() {
|
||
// ctx := context.Background()
|
||
groupID := "12345"
|
||
qqID := "67890"
|
||
rdb := initRedis()
|
||
|
||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
|
||
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
||
oaiconfig.BaseURL = OPENAI_BaseURL
|
||
client := openai.NewClientWithConfig(oaiconfig)
|
||
|
||
messages := make([]openai.ChatCompletionMessage, 0)
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: PROMPT,
|
||
})
|
||
reader := bufio.NewReader(os.Stdin)
|
||
fmt.Println("Conversation")
|
||
fmt.Println("---------------------")
|
||
|
||
for {
|
||
|
||
fmt.Print("-> ")
|
||
text, _ := reader.ReadString('\n')
|
||
// convert CRLF to LF
|
||
text = strings.Replace(text, "\n", "", -1)
|
||
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: text,
|
||
})
|
||
// 输出 messages 的内容
|
||
for i, msg := range messages {
|
||
fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content)
|
||
}
|
||
|
||
resp, err := client.CreateChatCompletion(
|
||
context.Background(),
|
||
openai.ChatCompletionRequest{
|
||
Model: MODEL,
|
||
Messages: messages,
|
||
},
|
||
)
|
||
|
||
if err != nil {
|
||
fmt.Printf("ChatCompletion error: %v\n", err)
|
||
continue
|
||
}
|
||
|
||
content := resp.Choices[0].Message.Content
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleAssistant,
|
||
Content: content,
|
||
})
|
||
// 添加新消息到上下文
|
||
addToContext(rdb, groupID, qqID, text, CONTEXT)
|
||
|
||
fmt.Println(content)
|
||
fmt.Println("---------------------")
|
||
// 检查请求频率
|
||
if !checkRequestFrequency(rdb, groupID, qqID) {
|
||
// fmt.Println("请求太频繁,请稍后再试。")
|
||
return
|
||
}
|
||
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
|
||
}
|
||
}
|