go-bot/workers/ai.go

342 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package workers
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"go-bot/config"
"go-bot/tools"
"os"
"slices"
"io"
"log"
"regexp"
"strconv"
"strings"
"github.com/imroc/req/v3"
openai "github.com/sashabaranov/go-openai"
)
func init() {
plugins := config.GetConfig()["PLUGINS"].([]interface{})
if slices.Contains(plugins, "ai") {
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
return &AI{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
}
})
}
}
type AI struct {
*StdAns
}
func (a *AI) GetMsg() string {
if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧"
}
ask := a.Parms[1]
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
return "不问问题你说个屁!"
}
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
if OPENAI_API_KEY == "" {
return "OPENAI_API_KEY 未配置"
}
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
oaiconfig.BaseURL = OPENAI_BaseURL
client := openai.NewClientWithConfig(oaiconfig)
msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID)
if strings.ToLower(a.Parms[1]) != "models" {
// 如果非回复消息
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
messages := make([]openai.ChatCompletionMessage, 0)
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleSystem,
Content: PROMPT,
})
// println("messages0:", messages)
var key string
redisClient := tools.GetRedisClient()
// 如果redisClient不为空则获取上下文
if redisClient != nil {
// 限制请求频率
if !tools.CheckRequestFrequency(a.GID, a.UID, 10) {
log.Printf("请求过于频繁。\n")
return "请求过于频繁。"
}
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
// 获取上下文
message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
}
// 序列化为 JSON 字符串
jsonMessage, err := json.Marshal(message)
if err != nil {
log.Println("序列化错误:", err)
log.Println("存储的 JSON 字符串:", string(jsonMessage))
return "序列化错误"
}
tools.AddToContext(key, jsonMessage, CONTEXT)
// }
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
length := tools.GetListLength(key)
// if length > 0 {
for i := 0; i < int(length); i++ {
message, err := tools.GetListValue(key, int64(i))
if err != nil {
log.Println("获取上下文失败:", err)
return "获取上下文失败"
}
// log.Println("读取的 JSON 字符串:", message)
var msg openai.ChatCompletionMessage
err = json.Unmarshal([]byte(message), &msg)
if err != nil {
log.Println("反序列化错误:", err)
return "反序列化错误"
}
messages = append(messages, msg)
}
} else {
messages = append(messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
})
}
// println("messages:", messages)
// for i, msg := range messages {
// fmt.Printf("消息 %d:\n", i+1)
// fmt.Printf(" 角色: %s\n", msg.Role)
// fmt.Printf(" 内容: %s\n", msg.Content)
// fmt.Println()
// }
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: MODEL,
Stream: false,
Messages: messages,
},
)
if err != nil {
log.Println("ChatCompletion error: ", err)
tools.RemoveLastValueFromList(key)
return "请求失败"
}
// println(resp.Choices[0].Message.Content)
content := resp.Choices[0].Message.Content
if redisClient != nil {
message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleAssistant,
Content: stripMarkdown(content)}
// 序列化为 JSON 字符串
jsonMessage, err := json.Marshal(message)
if err != nil {
log.Println("Assistant消息序列化错误:", err)
return "Assistant消息序列化错误"
}
tools.AddToContext(key, jsonMessage, CONTEXT)
}
return msg + stripMarkdown(content)
} else {
// 匹配回复消息
pattern := `^\[CQ:reply,id=(-?\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
var msgId string
if len(matches) > 0 {
msgId = matches[1]
// println("msgId:", msgId)
} else {
msgId = ""
log.Println("未找到回复消息")
return "未找到回复消息"
}
file, picUrl, fileSizeStr := a.GetHisMsg(msgId)
// println("file:", file, "picUrl:", picUrl, "fileSizeStr:", fileSizeStr)
if picUrl == "" {
log.Println("未找到文件信息")
return "未找到文件信息"
}
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
if err != nil {
fmt.Println("获取图片大小失败:", err)
return "获取图片大小失败"
}
if fileSize/1024/1024 > 5.0 {
log.Println("文件大小超过5M")
return "文件大小超过5M"
}
filePath := a.GetImage(file)
// println("filePath:", filePath)
if filePath == "" {
log.Println("获取图片失败")
// return "获取图片失败"
// 下载picUrl文件
}
base64Img := Image2Base64(filePath, picUrl)
if base64Img == "" {
log.Println("图片转换base64失败")
return "图片转换base64失败"
}
// 找到第一个空格的位置
firstSpaceIndex := strings.Index(a.RawMsg, " ")
// 找到最后一个空格的位置
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
// 调用图片分析
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: MODEL,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
}, {
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: base64Img,
Detail: openai.ImageURLDetailAuto,
},
},
},
},
},
},
)
if err != nil {
log.Println("ChatCompletion error: ", err)
return "请求失败,api可能不支持图片上传"
}
msg += resp.Choices[0].Message.Content
}
return stripMarkdown(msg)
}
models, err := client.ListModels(context.Background())
if err != nil {
log.Println("ListModels error: ", err)
return "查询支持模型失败!"
}
var modelList string
for _, model := range models.Models {
if MODEL == model.ID {
model.ID = model.ID + "\t(当前使用)"
}
modelList += model.ID + "\n"
}
return modelList
// return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
}
func stripMarkdown(text string) string {
// println("before:", text)
// 移除代码块,但保留代码内容
re := regexp.MustCompile("(?s)```(?:\\w+\\s*\n)?(.*?)```")
text = re.ReplaceAllStringFunc(text, func(match string) string {
submatches := re.FindStringSubmatch(match)
if len(submatches) > 1 {
return strings.TrimSpace(submatches[1])
}
return match
})
// 移除行内代码
re = regexp.MustCompile("`([^`]+)`")
text = re.ReplaceAllString(text, "$1")
// 移除标题
re = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`)
text = re.ReplaceAllString(text, "$1")
// 移除粗体和斜体
re = regexp.MustCompile(`(\*\*|__)(.+?)(\*\*|__)`)
text = re.ReplaceAllString(text, "$2")
re = regexp.MustCompile(`(\*|_)(.+?)(\*|_)`)
text = re.ReplaceAllString(text, "$2")
// 移除链接,保留链接文本
re = regexp.MustCompile(`\[([^\]]+)\]\([^\)]+\)`)
text = re.ReplaceAllString(text, "$1")
// 移除图片
re = regexp.MustCompile(`!\[([^\]]*)\]\([^\)]+\)`)
text = re.ReplaceAllString(text, "$1")
// 移除水平线
re = regexp.MustCompile(`(?m)^-{3,}|_{3,}|\*{3,}$`)
text = re.ReplaceAllString(text, "")
// 移除块引用
re = regexp.MustCompile(`(?m)^>\s+(.+)`)
text = re.ReplaceAllString(text, "$1")
// 移除列表标记
re = regexp.MustCompile(`(?m)^[ \t]*[\*\-+][ \t]+`)
text = re.ReplaceAllString(text, "")
re = regexp.MustCompile(`(?m)^[ \t]*\d+\.[ \t]+`)
text = re.ReplaceAllString(text, "")
// 移除多余的空行
re = regexp.MustCompile(`\n{3,}`)
text = re.ReplaceAllString(text, "\n\n")
// println("after:", text)
return strings.TrimSpace(text)
}
func Image2Base64(path string, picUrl string) string {
// 尝试从文件路径读取图片
file, err := os.Open(path)
if err == nil {
defer file.Close()
if data, err := io.ReadAll(file); err == nil {
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
}
}
// 如果文件路径不可用,则尝试从 URL 下载图片
client := req.C().
SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36").
SetTLSFingerprintChrome() // 模拟 Chrome 浏览器的 TLS 握手指纹,让网站相信这是 Chrome 浏览器在访问,予以通行。
resp, err := client.R().Get(picUrl)
if err != nil {
return ""
}
defer resp.Body.Close()
if data, err := io.ReadAll(resp.Body); err == nil {
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
}
return ""
}