go-bot/workers/ai.go
liyp c77378d825 refactor(ai): 移除调试打印语句
在GetMsg函数中,移除了用于调试的println语句,以保持代码的清洁和生产就绪状态。
2024-07-14 21:55:02 +08:00

269 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package workers
import (
"encoding/base64"
"fmt"
"io"
"os"
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/parnurzeal/gorequest"
)
func init() {
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
return &AI{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
}
})
}
type AI struct {
*StdAns
}
func (a *AI) GetMsg() string {
if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧"
}
ask := a.Parms[1]
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
return "不问问题你说个屁!"
}
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
if OPENAI_API_KEY == "" {
return "OPENAI_API_KEY 未配置"
}
if strings.ToLower(a.Parms[1]) == "models" {
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
} else {
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
PROMPT, ok := cfg["PROMPT"].(string)
if !ok {
log.Println("PROMPT 未配置")
PROMPT = ""
}
var requestBody map[string]interface{}
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
requestBody = map[string]interface{}{
"model": MODEL,
"stream": false,
"messages": []map[string]string{
{
"role": "system",
"content": PROMPT,
},
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]},
},
"temperature": 0.7,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_p": 1,
}
} else {
pattern := `^\[CQ:reply,id=(-?\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
var msgId string
if len(matches) > 0 {
msgId = matches[1]
} else {
msgId = ""
log.Println("未找到回复消息")
return "未找到回复消息"
}
message := a.GetHisMsg(msgId)
// 正则表达式匹配 file 和 file_size 的值
re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`)
matches = re.FindStringSubmatch(message)
var file string
var fileSizeStr string
if len(matches) > 2 {
file = matches[1]
fileSizeStr = matches[2]
} else {
log.Println("未找到文件信息")
return "未找到文件信息"
}
// 将 fileSizeStr 转换为整数
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
if err != nil {
fmt.Println("获取图片大小失败:", err)
return "获取图片大小失败"
}
if fileSize/1024/1024 > 1.0 {
log.Println("文件大小超过1M")
return "文件大小超过1M"
}
filePath := a.GetImage(file)
// println("filePath:", filePath)
if filePath == "" {
log.Println("获取图片失败")
return "获取图片失败"
}
base64Img := Image2Base64(filePath)
if base64Img == "" {
log.Println("图片转换base64失败")
return "图片转换base64失败"
}
// 找到第一个空格的位置
firstSpaceIndex := strings.Index(a.RawMsg, " ")
// 找到最后一个空格的位置
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
requestBody = map[string]interface{}{
"model": MODEL,
"stream": false,
"messages": []interface{}{
map[string]interface{}{
"role": "system",
"content": "#角色你是一名AI视觉助手任务是分析单个图像。",
},
map[string]interface{}{
"role": "user",
"content": []interface{}{
map[string]interface{}{
"type": "text",
"text": a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
},
map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": base64Img,
},
},
},
},
},
"temperature": 0.7,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_p": 1,
}
}
request := gorequest.New()
resp, body, errs := request.Post(OPENAI_BaseURL).
Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway).
Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY).
Send(requestBody).
End()
if errs != nil {
log.Println(errs)
return "请求失败"
}
println(resp.StatusCode)
if resp.StatusCode == 200 {
var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
log.Println(err)
return "解析失败"
}
choices := responseBody["choices"].([]interface{})
if len(choices) > 0 {
choice := choices[0].(map[string]interface{})
msg := choice["message"].(map[string]interface{})["content"].(string)
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
} else {
log.Println("choices为空")
return "api解析失败"
}
}
return "请求失败!"
// return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID, a.Parms)
}
}
func getConfig() (string, string, string) {
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
if cfg["OPENAI_API_KEY"] != nil {
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
} else {
log.Println("OPENAI_API_KEY 未配置")
}
if cfg["OPENAI_BaseURL"] != nil {
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
} else {
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
OPENAI_BaseURL = "https://api.openai.com/v1"
}
if cfg["MODEL"] != nil {
MODEL = cfg["MODEL"].(string)
} else {
log.Println("模型 未配置,使用默认 gpt-4o 模型")
MODEL = "gpt-4o"
}
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
}
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
request := gorequest.New()
resp, body, errs := request.Get(OPENAI_BaseURL).
Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY).
End()
if errs != nil {
log.Println(errs)
return "请求失败"
}
if resp.StatusCode == 200 {
var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
log.Println(err)
return "解析模型列表失败"
}
choices := responseBody["data"].([]interface{})
// var models []interface{}
if len(choices) > 0 {
msg := "支持的模型列表:\n"
for _, choice := range choices {
model := choice.(map[string]interface{})["id"]
msg += fmt.Sprintf("%s\n", model)
}
return msg
} else {
return "模型列表为空"
}
} else {
log.Println("请求失败")
return "请求模型列表失败"
}
}
func Image2Base64(path string) string {
file, err := os.Open(path)
if err != nil {
return ""
}
defer file.Close()
if data, err := io.ReadAll(file); err == nil {
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
}
return ""
}