go-bot/workers/ai.go

227 lines
5.4 KiB
Go
Raw Normal View History

package workers
import (
"context"
2024-07-14 21:38:39 +08:00
"encoding/base64"
"fmt"
2024-07-14 21:38:39 +08:00
"io"
"os"
"log"
2024-07-14 21:38:39 +08:00
"regexp"
"strconv"
"strings"
openai "github.com/sashabaranov/go-openai"
)
func init() {
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
return &AI{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
}
})
}
type AI struct {
*StdAns
}
func (a *AI) GetMsg() string {
if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧"
}
ask := a.Parms[1]
2024-07-14 21:38:39 +08:00
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
return "不问问题你说个屁!"
}
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
if OPENAI_API_KEY == "" {
return "OPENAI_API_KEY 未配置"
}
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
oaiconfig.BaseURL = OPENAI_BaseURL
client := openai.NewClientWithConfig(oaiconfig)
msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID)
if strings.ToLower(a.Parms[1]) != "models" {
// OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
2024-07-14 21:38:39 +08:00
PROMPT, ok := cfg["PROMPT"].(string)
if !ok {
log.Println("PROMPT 未配置")
PROMPT = ""
}
// var requestBody map[string]interface{}
// 如果非回复消息
2024-07-14 21:38:39 +08:00
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: MODEL,
Stream: false,
Messages: []openai.ChatCompletionMessage{
{
Role: "system",
Content: PROMPT,
},
{
Role: "user",
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
},
2024-07-14 21:38:39 +08:00
},
},
)
if err != nil {
log.Println("ChatCompletion error: ", err)
return "请求失败"
2024-07-14 21:38:39 +08:00
}
// println(resp.Choices[0].Message.Content)
return msg + resp.Choices[0].Message.Content
2024-07-14 21:38:39 +08:00
} else {
// 匹配回复消息
2024-07-14 21:38:39 +08:00
pattern := `^\[CQ:reply,id=(-?\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
var msgId string
if len(matches) > 0 {
msgId = matches[1]
println("msgId:", msgId)
2024-07-14 21:38:39 +08:00
} else {
msgId = ""
log.Println("未找到回复消息")
return "未找到回复消息"
}
message := a.GetHisMsg(msgId)
// 正则表达式匹配 file 和 file_size 的值
re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`)
matches = re.FindStringSubmatch(message)
var file string
var fileSizeStr string
if len(matches) > 2 {
file = matches[1]
fileSizeStr = matches[2]
} else {
log.Println("未找到文件信息")
return "未找到文件信息"
}
// 将 fileSizeStr 转换为整数
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
if err != nil {
fmt.Println("获取图片大小失败:", err)
return "获取图片大小失败"
}
if fileSize/1024/1024 > 2.0 {
log.Println("文件大小超过2M")
return "文件大小超过2M"
2024-07-14 21:38:39 +08:00
}
filePath := a.GetImage(file)
// println("filePath:", filePath)
2024-07-14 21:38:39 +08:00
if filePath == "" {
log.Println("获取图片失败")
return "获取图片失败"
}
base64Img := Image2Base64(filePath)
if base64Img == "" {
log.Println("图片转换base64失败")
return "图片转换base64失败"
}
// 找到第一个空格的位置
firstSpaceIndex := strings.Index(a.RawMsg, " ")
// 找到最后一个空格的位置
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
// 调用图片分析
resp, err := client.CreateChatCompletion(
context.Background(),
openai.ChatCompletionRequest{
Model: MODEL,
Messages: []openai.ChatCompletionMessage{
{
Role: "user",
MultiContent: []openai.ChatMessagePart{
{
Type: openai.ChatMessagePartTypeText,
Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
}, {
Type: openai.ChatMessagePartTypeImageURL,
ImageURL: &openai.ChatMessageImageURL{
URL: base64Img,
Detail: openai.ImageURLDetailAuto,
},
2024-07-14 21:38:39 +08:00
},
},
},
},
},
)
if err != nil {
log.Println("ChatCompletion error: ", err)
return "请求失败"
2024-07-14 21:38:39 +08:00
}
msg += resp.Choices[0].Message.Content
2024-07-14 21:38:39 +08:00
}
return msg
}
models, err := client.ListModels(context.Background())
if err != nil {
log.Println("ListModels error: ", err)
return "查询支持模型失败!"
2024-07-14 21:38:39 +08:00
}
var modelList string
for _, model := range models.Models {
if MODEL == model.ID {
model.ID = model.ID + "(当前使用)"
}
modelList += model.ID + "\n"
}
return modelList
// return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
}
func getConfig() (string, string, string) {
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
if cfg["OPENAI_API_KEY"] != nil {
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
} else {
log.Println("OPENAI_API_KEY 未配置")
}
if cfg["OPENAI_BaseURL"] != nil {
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
} else {
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
OPENAI_BaseURL = "https://api.openai.com/v1"
}
if cfg["MODEL"] != nil {
MODEL = cfg["MODEL"].(string)
} else {
log.Println("模型 未配置,使用默认 gpt-4o 模型")
MODEL = "gpt-4o"
}
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
}
2024-07-14 21:38:39 +08:00
func Image2Base64(path string) string {
file, err := os.Open(path)
if err != nil {
return ""
}
2024-07-14 21:38:39 +08:00
defer file.Close()
2024-07-14 21:38:39 +08:00
if data, err := io.ReadAll(file); err == nil {
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
}
2024-07-14 21:38:39 +08:00
return ""
}