342 lines
9.3 KiB
Go
342 lines
9.3 KiB
Go
package workers
|
||
|
||
import (
|
||
"context"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"fmt"
|
||
"go-bot/config"
|
||
"go-bot/tools"
|
||
"os"
|
||
"slices"
|
||
|
||
"io"
|
||
"log"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/imroc/req/v3"
|
||
openai "github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
func init() {
|
||
plugins := config.GetConfig()["PLUGINS"].([]interface{})
|
||
if slices.Contains(plugins, "ai") {
|
||
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
|
||
return &AI{
|
||
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
type AI struct {
|
||
*StdAns
|
||
}
|
||
|
||
func (a *AI) GetMsg() string {
|
||
if len(a.Parms) < 2 {
|
||
return "使用!ai xxx 向我提问吧"
|
||
}
|
||
|
||
ask := a.Parms[1]
|
||
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
|
||
return "不问问题你说个屁!"
|
||
}
|
||
|
||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig()
|
||
if OPENAI_API_KEY == "" {
|
||
return "OPENAI_API_KEY 未配置"
|
||
}
|
||
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
||
oaiconfig.BaseURL = OPENAI_BaseURL
|
||
|
||
client := openai.NewClientWithConfig(oaiconfig)
|
||
msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID)
|
||
if strings.ToLower(a.Parms[1]) != "models" {
|
||
|
||
// 如果非回复消息
|
||
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
|
||
messages := make([]openai.ChatCompletionMessage, 0)
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleSystem,
|
||
Content: PROMPT,
|
||
})
|
||
// println("messages0:", messages)
|
||
var key string
|
||
redisClient := tools.GetRedisClient()
|
||
// 如果redisClient不为空,则获取上下文
|
||
if redisClient != nil {
|
||
// 限制请求频率
|
||
if !tools.CheckRequestFrequency(a.GID, a.UID, 10) {
|
||
log.Printf("请求过于频繁。\n")
|
||
return "请求过于频繁。"
|
||
}
|
||
|
||
key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID)
|
||
// 获取上下文
|
||
message := openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||
}
|
||
// 序列化为 JSON 字符串
|
||
jsonMessage, err := json.Marshal(message)
|
||
if err != nil {
|
||
log.Println("序列化错误:", err)
|
||
log.Println("存储的 JSON 字符串:", string(jsonMessage))
|
||
|
||
return "序列化错误"
|
||
}
|
||
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||
// }
|
||
// println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:])
|
||
length := tools.GetListLength(key)
|
||
// if length > 0 {
|
||
for i := 0; i < int(length); i++ {
|
||
message, err := tools.GetListValue(key, int64(i))
|
||
if err != nil {
|
||
log.Println("获取上下文失败:", err)
|
||
return "获取上下文失败"
|
||
}
|
||
// log.Println("读取的 JSON 字符串:", message)
|
||
|
||
var msg openai.ChatCompletionMessage
|
||
err = json.Unmarshal([]byte(message), &msg)
|
||
if err != nil {
|
||
log.Println("反序列化错误:", err)
|
||
|
||
return "反序列化错误"
|
||
}
|
||
messages = append(messages, msg)
|
||
}
|
||
} else {
|
||
|
||
messages = append(messages, openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleUser,
|
||
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
||
})
|
||
|
||
}
|
||
// println("messages:", messages)
|
||
// for i, msg := range messages {
|
||
// fmt.Printf("消息 %d:\n", i+1)
|
||
// fmt.Printf(" 角色: %s\n", msg.Role)
|
||
// fmt.Printf(" 内容: %s\n", msg.Content)
|
||
// fmt.Println()
|
||
// }
|
||
resp, err := client.CreateChatCompletion(
|
||
context.Background(),
|
||
openai.ChatCompletionRequest{
|
||
Model: MODEL,
|
||
Stream: false,
|
||
Messages: messages,
|
||
},
|
||
)
|
||
if err != nil {
|
||
log.Println("ChatCompletion error: ", err)
|
||
tools.RemoveLastValueFromList(key)
|
||
return "请求失败"
|
||
}
|
||
// println(resp.Choices[0].Message.Content)
|
||
content := resp.Choices[0].Message.Content
|
||
if redisClient != nil {
|
||
|
||
message := openai.ChatCompletionMessage{
|
||
Role: openai.ChatMessageRoleAssistant,
|
||
Content: stripMarkdown(content)}
|
||
// 序列化为 JSON 字符串
|
||
jsonMessage, err := json.Marshal(message)
|
||
if err != nil {
|
||
log.Println("Assistant消息序列化错误:", err)
|
||
return "Assistant消息序列化错误"
|
||
}
|
||
tools.AddToContext(key, jsonMessage, CONTEXT)
|
||
|
||
}
|
||
return msg + stripMarkdown(content)
|
||
} else {
|
||
// 匹配回复消息
|
||
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
||
re := regexp.MustCompile(pattern)
|
||
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
|
||
var msgId string
|
||
if len(matches) > 0 {
|
||
|
||
msgId = matches[1]
|
||
// println("msgId:", msgId)
|
||
} else {
|
||
msgId = ""
|
||
log.Println("未找到回复消息")
|
||
return "未找到回复消息"
|
||
}
|
||
file, picUrl, fileSizeStr := a.GetHisMsg(msgId)
|
||
// println("file:", file, "picUrl:", picUrl, "fileSizeStr:", fileSizeStr)
|
||
|
||
if picUrl == "" {
|
||
log.Println("未找到文件信息")
|
||
return "未找到文件信息"
|
||
}
|
||
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
|
||
if err != nil {
|
||
fmt.Println("获取图片大小失败:", err)
|
||
return "获取图片大小失败"
|
||
}
|
||
if fileSize/1024/1024 > 5.0 {
|
||
log.Println("文件大小超过5M")
|
||
return "文件大小超过5M"
|
||
}
|
||
filePath := a.GetImage(file)
|
||
// println("filePath:", filePath)
|
||
if filePath == "" {
|
||
log.Println("获取图片失败")
|
||
// return "获取图片失败"
|
||
// 下载picUrl文件
|
||
|
||
}
|
||
base64Img := Image2Base64(filePath, picUrl)
|
||
if base64Img == "" {
|
||
log.Println("图片转换base64失败")
|
||
return "图片转换base64失败"
|
||
}
|
||
// 找到第一个空格的位置
|
||
firstSpaceIndex := strings.Index(a.RawMsg, " ")
|
||
|
||
// 找到最后一个空格的位置
|
||
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
|
||
// 调用图片分析
|
||
resp, err := client.CreateChatCompletion(
|
||
context.Background(),
|
||
openai.ChatCompletionRequest{
|
||
Model: MODEL,
|
||
Messages: []openai.ChatCompletionMessage{
|
||
{
|
||
Role: "user",
|
||
MultiContent: []openai.ChatMessagePart{
|
||
{
|
||
Type: openai.ChatMessagePartTypeText,
|
||
Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
|
||
}, {
|
||
Type: openai.ChatMessagePartTypeImageURL,
|
||
ImageURL: &openai.ChatMessageImageURL{
|
||
URL: base64Img,
|
||
Detail: openai.ImageURLDetailAuto,
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
},
|
||
)
|
||
if err != nil {
|
||
log.Println("ChatCompletion error: ", err)
|
||
return "请求失败,api可能不支持图片上传"
|
||
|
||
}
|
||
msg += resp.Choices[0].Message.Content
|
||
|
||
}
|
||
return stripMarkdown(msg)
|
||
}
|
||
models, err := client.ListModels(context.Background())
|
||
if err != nil {
|
||
log.Println("ListModels error: ", err)
|
||
return "查询支持模型失败!"
|
||
|
||
}
|
||
var modelList string
|
||
for _, model := range models.Models {
|
||
if MODEL == model.ID {
|
||
model.ID = model.ID + "\t(当前使用)"
|
||
}
|
||
modelList += model.ID + "\n"
|
||
}
|
||
return modelList
|
||
// return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
||
}
|
||
|
||
func stripMarkdown(text string) string {
|
||
|
||
// println("before:", text)
|
||
// 移除代码块,但保留代码内容
|
||
re := regexp.MustCompile("(?s)```(?:\\w+\\s*\n)?(.*?)```")
|
||
text = re.ReplaceAllStringFunc(text, func(match string) string {
|
||
submatches := re.FindStringSubmatch(match)
|
||
if len(submatches) > 1 {
|
||
return strings.TrimSpace(submatches[1])
|
||
}
|
||
return match
|
||
})
|
||
|
||
// 移除行内代码
|
||
re = regexp.MustCompile("`([^`]+)`")
|
||
text = re.ReplaceAllString(text, "$1")
|
||
|
||
// 移除标题
|
||
re = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`)
|
||
text = re.ReplaceAllString(text, "$1")
|
||
|
||
// 移除粗体和斜体
|
||
re = regexp.MustCompile(`(\*\*|__)(.+?)(\*\*|__)`)
|
||
text = re.ReplaceAllString(text, "$2")
|
||
re = regexp.MustCompile(`(\*|_)(.+?)(\*|_)`)
|
||
text = re.ReplaceAllString(text, "$2")
|
||
|
||
// 移除链接,保留链接文本
|
||
re = regexp.MustCompile(`\[([^\]]+)\]\([^\)]+\)`)
|
||
text = re.ReplaceAllString(text, "$1")
|
||
|
||
// 移除图片
|
||
re = regexp.MustCompile(`!\[([^\]]*)\]\([^\)]+\)`)
|
||
text = re.ReplaceAllString(text, "$1")
|
||
|
||
// 移除水平线
|
||
re = regexp.MustCompile(`(?m)^-{3,}|_{3,}|\*{3,}$`)
|
||
text = re.ReplaceAllString(text, "")
|
||
|
||
// 移除块引用
|
||
re = regexp.MustCompile(`(?m)^>\s+(.+)`)
|
||
text = re.ReplaceAllString(text, "$1")
|
||
|
||
// 移除列表标记
|
||
re = regexp.MustCompile(`(?m)^[ \t]*[\*\-+][ \t]+`)
|
||
text = re.ReplaceAllString(text, "")
|
||
re = regexp.MustCompile(`(?m)^[ \t]*\d+\.[ \t]+`)
|
||
text = re.ReplaceAllString(text, "")
|
||
|
||
// 移除多余的空行
|
||
re = regexp.MustCompile(`\n{3,}`)
|
||
text = re.ReplaceAllString(text, "\n\n")
|
||
// println("after:", text)
|
||
return strings.TrimSpace(text)
|
||
}
|
||
|
||
func Image2Base64(path string, picUrl string) string {
|
||
// 尝试从文件路径读取图片
|
||
file, err := os.Open(path)
|
||
if err == nil {
|
||
defer file.Close()
|
||
if data, err := io.ReadAll(file); err == nil {
|
||
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
||
}
|
||
}
|
||
|
||
// 如果文件路径不可用,则尝试从 URL 下载图片
|
||
|
||
client := req.C().
|
||
SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36").
|
||
SetTLSFingerprintChrome() // 模拟 Chrome 浏览器的 TLS 握手指纹,让网站相信这是 Chrome 浏览器在访问,予以通行。
|
||
|
||
resp, err := client.R().Get(picUrl)
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
|
||
defer resp.Body.Close()
|
||
|
||
if data, err := io.ReadAll(resp.Body); err == nil {
|
||
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
||
}
|
||
return ""
|
||
|
||
}
|