2024-06-30 21:56:34 +08:00
|
|
|
package workers
|
|
|
|
|
|
|
|
import (
|
2024-08-31 16:31:06 +08:00
|
|
|
"context"
|
2024-07-14 21:38:39 +08:00
|
|
|
"encoding/base64"
|
2024-07-02 18:16:26 +08:00
|
|
|
"fmt"
|
2024-07-14 21:38:39 +08:00
|
|
|
"io"
|
2024-09-01 14:28:52 +08:00
|
|
|
"net/http"
|
2024-07-14 21:38:39 +08:00
|
|
|
"os"
|
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
"log"
|
2024-07-14 21:38:39 +08:00
|
|
|
"regexp"
|
|
|
|
"strconv"
|
2024-06-30 23:34:00 +08:00
|
|
|
"strings"
|
2024-06-30 21:56:34 +08:00
|
|
|
|
2024-08-31 16:31:06 +08:00
|
|
|
openai "github.com/sashabaranov/go-openai"
|
2024-06-30 21:56:34 +08:00
|
|
|
)
|
|
|
|
|
2024-07-05 22:33:55 +08:00
|
|
|
func init() {
|
|
|
|
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
|
|
|
|
return &AI{
|
|
|
|
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
type AI struct {
|
|
|
|
*StdAns
|
|
|
|
}
|
|
|
|
|
|
|
|
func (a *AI) GetMsg() string {
|
|
|
|
if len(a.Parms) < 2 {
|
|
|
|
return "使用!ai xxx 向我提问吧"
|
|
|
|
}
|
2024-07-01 21:04:22 +08:00
|
|
|
|
2024-07-06 19:18:59 +08:00
|
|
|
ask := a.Parms[1]
|
2024-07-14 21:38:39 +08:00
|
|
|
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
|
2024-06-30 23:34:00 +08:00
|
|
|
return "不问问题你说个屁!"
|
|
|
|
}
|
2024-07-06 19:18:59 +08:00
|
|
|
|
|
|
|
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
|
|
|
|
if OPENAI_API_KEY == "" {
|
|
|
|
return "OPENAI_API_KEY 未配置"
|
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
oaiconfig := openai.DefaultConfig(OPENAI_API_KEY)
|
|
|
|
oaiconfig.BaseURL = OPENAI_BaseURL
|
2024-07-06 19:18:59 +08:00
|
|
|
|
2024-08-31 16:31:06 +08:00
|
|
|
client := openai.NewClientWithConfig(oaiconfig)
|
|
|
|
msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID)
|
|
|
|
if strings.ToLower(a.Parms[1]) != "models" {
|
|
|
|
// OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
2024-07-14 21:38:39 +08:00
|
|
|
PROMPT, ok := cfg["PROMPT"].(string)
|
|
|
|
if !ok {
|
|
|
|
log.Println("PROMPT 未配置")
|
|
|
|
PROMPT = ""
|
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
// var requestBody map[string]interface{}
|
|
|
|
// 如果非回复消息
|
2024-07-14 21:38:39 +08:00
|
|
|
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
|
2024-08-31 16:31:06 +08:00
|
|
|
resp, err := client.CreateChatCompletion(
|
|
|
|
context.Background(),
|
|
|
|
openai.ChatCompletionRequest{
|
|
|
|
Model: MODEL,
|
|
|
|
Stream: false,
|
|
|
|
Messages: []openai.ChatCompletionMessage{
|
|
|
|
{
|
|
|
|
Role: "system",
|
|
|
|
Content: PROMPT,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
Role: "user",
|
|
|
|
Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:],
|
|
|
|
},
|
2024-07-14 21:38:39 +08:00
|
|
|
},
|
|
|
|
},
|
2024-08-31 16:31:06 +08:00
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("ChatCompletion error: ", err)
|
|
|
|
return "请求失败"
|
2024-07-14 21:38:39 +08:00
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
// println(resp.Choices[0].Message.Content)
|
|
|
|
return msg + resp.Choices[0].Message.Content
|
2024-07-14 21:38:39 +08:00
|
|
|
} else {
|
2024-08-31 16:31:06 +08:00
|
|
|
// 匹配回复消息
|
2024-07-14 21:38:39 +08:00
|
|
|
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
|
|
|
re := regexp.MustCompile(pattern)
|
|
|
|
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
|
|
|
|
var msgId string
|
|
|
|
if len(matches) > 0 {
|
|
|
|
|
|
|
|
msgId = matches[1]
|
2024-08-31 19:39:19 +08:00
|
|
|
println("msgId:", msgId)
|
2024-07-14 21:38:39 +08:00
|
|
|
} else {
|
|
|
|
msgId = ""
|
|
|
|
log.Println("未找到回复消息")
|
|
|
|
return "未找到回复消息"
|
|
|
|
}
|
2024-09-01 14:28:52 +08:00
|
|
|
file, picUrl, fileSizeStr := a.GetHisMsg(msgId)
|
2024-09-01 16:36:05 +08:00
|
|
|
println("file:", file, "picUrl:", picUrl, "fileSizeStr:", fileSizeStr)
|
2024-07-14 21:38:39 +08:00
|
|
|
// 正则表达式匹配 file 和 file_size 的值
|
2024-09-01 14:28:52 +08:00
|
|
|
// re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`)
|
|
|
|
// matches = re.FindStringSubmatch(message)
|
|
|
|
// var file string
|
|
|
|
// var fileSizeStr string
|
|
|
|
// if len(matches) > 2 {
|
|
|
|
// file = matches[1]
|
|
|
|
// fileSizeStr = matches[2]
|
|
|
|
// } else {
|
|
|
|
// log.Println("未找到文件信息")
|
|
|
|
// return "未找到文件信息"
|
|
|
|
// }
|
|
|
|
// 将 fileSizeStr 转换为整数
|
|
|
|
if picUrl == "" {
|
2024-07-14 21:38:39 +08:00
|
|
|
log.Println("未找到文件信息")
|
|
|
|
return "未找到文件信息"
|
|
|
|
}
|
|
|
|
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
|
|
|
|
if err != nil {
|
|
|
|
fmt.Println("获取图片大小失败:", err)
|
|
|
|
return "获取图片大小失败"
|
|
|
|
}
|
2024-09-01 14:28:52 +08:00
|
|
|
if fileSize/1024/1024 > 5.0 {
|
|
|
|
log.Println("文件大小超过5M")
|
|
|
|
return "文件大小超过5M"
|
2024-07-14 21:38:39 +08:00
|
|
|
}
|
|
|
|
filePath := a.GetImage(file)
|
2024-07-14 21:55:02 +08:00
|
|
|
// println("filePath:", filePath)
|
2024-07-14 21:38:39 +08:00
|
|
|
if filePath == "" {
|
|
|
|
log.Println("获取图片失败")
|
|
|
|
return "获取图片失败"
|
|
|
|
}
|
2024-09-01 14:28:52 +08:00
|
|
|
base64Img := Image2Base64(filePath, picUrl)
|
2024-07-14 21:38:39 +08:00
|
|
|
if base64Img == "" {
|
|
|
|
log.Println("图片转换base64失败")
|
|
|
|
return "图片转换base64失败"
|
|
|
|
}
|
|
|
|
// 找到第一个空格的位置
|
|
|
|
firstSpaceIndex := strings.Index(a.RawMsg, " ")
|
|
|
|
|
|
|
|
// 找到最后一个空格的位置
|
|
|
|
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
|
2024-08-31 16:31:06 +08:00
|
|
|
// 调用图片分析
|
|
|
|
resp, err := client.CreateChatCompletion(
|
|
|
|
context.Background(),
|
|
|
|
openai.ChatCompletionRequest{
|
|
|
|
Model: MODEL,
|
|
|
|
Messages: []openai.ChatCompletionMessage{
|
|
|
|
{
|
|
|
|
Role: "user",
|
|
|
|
MultiContent: []openai.ChatMessagePart{
|
|
|
|
{
|
|
|
|
Type: openai.ChatMessagePartTypeText,
|
|
|
|
Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
|
|
|
|
}, {
|
|
|
|
Type: openai.ChatMessagePartTypeImageURL,
|
|
|
|
ImageURL: &openai.ChatMessageImageURL{
|
|
|
|
URL: base64Img,
|
|
|
|
Detail: openai.ImageURLDetailAuto,
|
|
|
|
},
|
2024-07-14 21:38:39 +08:00
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
2024-08-31 16:31:06 +08:00
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
log.Println("ChatCompletion error: ", err)
|
|
|
|
return "请求失败"
|
2024-07-14 21:38:39 +08:00
|
|
|
|
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
msg += resp.Choices[0].Message.Content
|
2024-07-14 21:38:39 +08:00
|
|
|
|
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
return msg
|
|
|
|
}
|
|
|
|
models, err := client.ListModels(context.Background())
|
|
|
|
if err != nil {
|
|
|
|
log.Println("ListModels error: ", err)
|
|
|
|
return "查询支持模型失败!"
|
2024-07-14 21:38:39 +08:00
|
|
|
|
2024-07-06 19:18:59 +08:00
|
|
|
}
|
2024-08-31 16:31:06 +08:00
|
|
|
var modelList string
|
|
|
|
for _, model := range models.Models {
|
|
|
|
if MODEL == model.ID {
|
|
|
|
model.ID = model.ID + "(当前使用)"
|
|
|
|
}
|
|
|
|
modelList += model.ID + "\n"
|
|
|
|
}
|
|
|
|
return modelList
|
|
|
|
// return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
2024-07-06 19:18:59 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
func getConfig() (string, string, string) {
|
|
|
|
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
|
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
if cfg["OPENAI_API_KEY"] != nil {
|
|
|
|
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
|
|
|
|
} else {
|
|
|
|
log.Println("OPENAI_API_KEY 未配置")
|
|
|
|
}
|
2024-07-06 19:18:59 +08:00
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
if cfg["OPENAI_BaseURL"] != nil {
|
|
|
|
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
|
|
|
|
} else {
|
|
|
|
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
|
2024-06-30 23:34:00 +08:00
|
|
|
OPENAI_BaseURL = "https://api.openai.com/v1"
|
2024-06-30 21:56:34 +08:00
|
|
|
}
|
2024-07-06 19:18:59 +08:00
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
if cfg["MODEL"] != nil {
|
|
|
|
MODEL = cfg["MODEL"].(string)
|
|
|
|
} else {
|
2024-07-06 22:35:01 +08:00
|
|
|
log.Println("模型 未配置,使用默认 gpt-4o 模型")
|
|
|
|
MODEL = "gpt-4o"
|
2024-06-30 21:56:34 +08:00
|
|
|
}
|
2024-06-30 23:34:00 +08:00
|
|
|
|
2024-07-06 19:18:59 +08:00
|
|
|
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
|
|
|
|
}
|
2024-07-01 10:05:42 +08:00
|
|
|
|
2024-09-01 14:28:52 +08:00
|
|
|
func Image2Base64(path string, picUrl string) string {
|
|
|
|
// 尝试从文件路径读取图片
|
2024-07-14 21:38:39 +08:00
|
|
|
file, err := os.Open(path)
|
2024-09-01 14:28:52 +08:00
|
|
|
if err == nil {
|
|
|
|
defer file.Close()
|
|
|
|
if data, err := io.ReadAll(file); err == nil {
|
|
|
|
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 如果文件路径不可用,则尝试从 URL 下载图片
|
|
|
|
resp, err := http.Get(picUrl)
|
2024-07-14 21:38:39 +08:00
|
|
|
if err != nil {
|
|
|
|
return ""
|
2024-07-06 19:18:59 +08:00
|
|
|
}
|
2024-09-01 14:28:52 +08:00
|
|
|
defer resp.Body.Close()
|
2024-07-06 19:18:59 +08:00
|
|
|
|
2024-09-01 14:28:52 +08:00
|
|
|
if data, err := io.ReadAll(resp.Body); err == nil {
|
2024-07-14 21:38:39 +08:00
|
|
|
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
2024-07-06 19:18:59 +08:00
|
|
|
}
|
2024-07-14 21:38:39 +08:00
|
|
|
return ""
|
2024-07-13 18:00:03 +08:00
|
|
|
|
2024-06-30 21:56:34 +08:00
|
|
|
}
|