2024-09-22 16:22:02 +08:00
|
|
|
|
package main
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bufio"
|
|
|
|
|
"context"
|
|
|
|
|
"fmt"
|
|
|
|
|
"go-bot/tools"
|
|
|
|
|
|
|
|
|
|
"os"
|
|
|
|
|
"strings"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
|
|
|
"github.com/sashabaranov/go-openai"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
var ctx = context.Background()
|
|
|
|
|
|
|
|
|
|
// 初始化 Redis 客户端
|
|
|
|
|
func initRedis() *redis.Client {
|
|
|
|
|
rdb := redis.NewClient(&redis.Options{
|
|
|
|
|
Addr: "localhost:6379", // Redis 地址
|
|
|
|
|
Password: "", // no password set
|
|
|
|
|
DB: 0, // use default DB
|
|
|
|
|
})
|
|
|
|
|
return rdb
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 检查请求频率,10秒内只能请求一次
|
|
|
|
|
func checkRequestFrequency(rdb *redis.Client, groupID string, qqID string) bool {
|
|
|
|
|
key := fmt.Sprintf("last_request:%s:%s", groupID, qqID)
|
|
|
|
|
lastRequest, err := rdb.Get(ctx, key).Int64()
|
|
|
|
|
println("key:", key)
|
|
|
|
|
now := time.Now().Unix()
|
|
|
|
|
println("now:", now)
|
|
|
|
|
println("lastRequest:", lastRequest)
|
|
|
|
|
|
|
|
|
|
if err == redis.Nil {
|
|
|
|
|
// 键不存在,这是第一次请求
|
|
|
|
|
rdb.Set(ctx, key, now, 10*time.Second)
|
|
|
|
|
|
|
|
|
|
return true
|
|
|
|
|
} else if err != nil {
|
|
|
|
|
fmt.Println("获取上次请求时间时出错:", err)
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
currentRequest := getContext(rdb, key)
|
|
|
|
|
fmt.Println("currentRequest:", currentRequest)
|
|
|
|
|
if now-lastRequest < 10 {
|
|
|
|
|
fmt.Printf("请求过于频繁。距离上次请求还有 %d 秒。\n", 10-(now-lastRequest))
|
|
|
|
|
return false // 频率超限,拒绝请求
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 更新最后请求时间,并设置 10 秒的过期时间
|
|
|
|
|
rdb.Set(ctx, key, now, 10*time.Second)
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 添加对话到上下文,超过5条则删除所有上下文
|
2024-09-22 20:05:10 +08:00
|
|
|
|
func addToContext(rdb *redis.Client, groupID string, qqID string, message string, context int64) {
|
2024-09-22 16:22:02 +08:00
|
|
|
|
key := fmt.Sprintf("context:%s:%s", groupID, qqID)
|
|
|
|
|
|
|
|
|
|
// 如果上下文超过5条,删除所有上下文
|
|
|
|
|
listLength := rdb.LLen(ctx, key).Val()
|
2024-09-22 20:05:10 +08:00
|
|
|
|
if listLength > context {
|
2024-09-22 16:22:02 +08:00
|
|
|
|
rdb.Del(ctx, key) // 删除该用户的所有上下文
|
|
|
|
|
}
|
|
|
|
|
rdb.RPush(ctx, key, message) // 添加新消息到列表
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 获取当前上下文
|
|
|
|
|
func getContext(rdb *redis.Client, key string) []string {
|
|
|
|
|
// key :=
|
|
|
|
|
context, err := rdb.LRange(ctx, key, 0, -1).Result()
|
|
|
|
|
if err != nil {
|
|
|
|
|
fmt.Println("Error fetching context:", err)
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
return context
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func main() {
|
|
|
|
|
// ctx := context.Background()
|
|
|
|
|
groupID := "12345"
|
|
|
|
|
qqID := "67890"
|
|
|
|
|
rdb := initRedis()
|
|
|
|
|
|
2024-09-22 20:05:10 +08:00
|
|
|
|
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
|
2024-09-22 16:22:02 +08:00
|
|
|
|
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
|
|
|
|
oaiconfig.BaseURL = OPENAI_BaseURL
|
|
|
|
|
client := openai.NewClientWithConfig(oaiconfig)
|
|
|
|
|
|
|
|
|
|
messages := make([]openai.ChatCompletionMessage, 0)
|
2024-09-22 20:05:10 +08:00
|
|
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
|
|
|
Role: openai.ChatMessageRoleUser,
|
|
|
|
|
Content: PROMPT,
|
|
|
|
|
})
|
2024-09-22 16:22:02 +08:00
|
|
|
|
reader := bufio.NewReader(os.Stdin)
|
|
|
|
|
fmt.Println("Conversation")
|
|
|
|
|
fmt.Println("---------------------")
|
|
|
|
|
|
|
|
|
|
for {
|
2024-09-22 20:05:10 +08:00
|
|
|
|
|
2024-09-22 16:22:02 +08:00
|
|
|
|
fmt.Print("-> ")
|
|
|
|
|
text, _ := reader.ReadString('\n')
|
|
|
|
|
// convert CRLF to LF
|
|
|
|
|
text = strings.Replace(text, "\n", "", -1)
|
|
|
|
|
|
|
|
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
|
|
|
Role: openai.ChatMessageRoleUser,
|
|
|
|
|
Content: text,
|
|
|
|
|
})
|
2024-09-22 20:05:10 +08:00
|
|
|
|
// 输出 messages 的内容
|
|
|
|
|
for i, msg := range messages {
|
|
|
|
|
fmt.Printf("Message %d: Role: %s, Content: %s\n", i, msg.Role, msg.Content)
|
|
|
|
|
}
|
2024-09-22 16:22:02 +08:00
|
|
|
|
|
|
|
|
|
resp, err := client.CreateChatCompletion(
|
2024-09-22 20:05:10 +08:00
|
|
|
|
context.Background(),
|
2024-09-22 16:22:02 +08:00
|
|
|
|
openai.ChatCompletionRequest{
|
|
|
|
|
Model: MODEL,
|
|
|
|
|
Messages: messages,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
fmt.Printf("ChatCompletion error: %v\n", err)
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
content := resp.Choices[0].Message.Content
|
|
|
|
|
messages = append(messages, openai.ChatCompletionMessage{
|
|
|
|
|
Role: openai.ChatMessageRoleAssistant,
|
|
|
|
|
Content: content,
|
|
|
|
|
})
|
|
|
|
|
// 添加新消息到上下文
|
2024-09-22 20:05:10 +08:00
|
|
|
|
addToContext(rdb, groupID, qqID, text, CONTEXT)
|
2024-09-22 16:22:02 +08:00
|
|
|
|
|
|
|
|
|
fmt.Println(content)
|
|
|
|
|
fmt.Println("---------------------")
|
2024-09-22 20:05:10 +08:00
|
|
|
|
// 检查请求频率
|
|
|
|
|
if !checkRequestFrequency(rdb, groupID, qqID) {
|
|
|
|
|
// fmt.Println("请求太频繁,请稍后再试。")
|
|
|
|
|
return
|
|
|
|
|
}
|
2024-09-22 16:22:02 +08:00
|
|
|
|
fmt.Println(getContext(rdb, "context:"+groupID+":"+qqID))
|
|
|
|
|
}
|
|
|
|
|
}
|