go-bot/workers/ai.go
liyp f160de4320 feat(readme): 更新项目介绍和部署指南
更新了README,加入了使用Go语言重新实现sihuan/XZZ机器人项目的介绍。由于原项目使用的go-cqhttp不再维护,本项目转向使用napcat实现。同时,更新了部署服务的步骤和配置文件示例,方便用户进行部署和使用。
2024-07-20 15:49:03 +08:00

269 lines
6.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package workers
import (
"encoding/base64"
"fmt"
"io"
"os"
"log"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/goccy/go-json"
"github.com/parnurzeal/gorequest"
)
func init() {
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
return &AI{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
}
})
}
type AI struct {
*StdAns
}
func (a *AI) GetMsg() string {
if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧"
}
ask := a.Parms[1]
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
return "不问问题你说个屁!"
}
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
if OPENAI_API_KEY == "" {
return "OPENAI_API_KEY 未配置"
}
if strings.ToLower(a.Parms[1]) == "models" {
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
} else {
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
PROMPT, ok := cfg["PROMPT"].(string)
if !ok {
log.Println("PROMPT 未配置")
PROMPT = ""
}
var requestBody map[string]interface{}
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
requestBody = map[string]interface{}{
"model": MODEL,
"stream": false,
"messages": []map[string]string{
{
"role": "system",
"content": PROMPT,
},
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]},
},
"temperature": 0.7,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_p": 1,
}
} else {
pattern := `^\[CQ:reply,id=(-?\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
var msgId string
if len(matches) > 0 {
msgId = matches[1]
} else {
msgId = ""
log.Println("未找到回复消息")
return "未找到回复消息"
}
message := a.GetHisMsg(msgId)
// 正则表达式匹配 file 和 file_size 的值
re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`)
matches = re.FindStringSubmatch(message)
var file string
var fileSizeStr string
if len(matches) > 2 {
file = matches[1]
fileSizeStr = matches[2]
} else {
log.Println("未找到文件信息")
return "未找到文件信息"
}
// 将 fileSizeStr 转换为整数
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
if err != nil {
fmt.Println("获取图片大小失败:", err)
return "获取图片大小失败"
}
if fileSize/1024/1024 > 2.0 {
log.Println("文件大小超过2M")
return "文件大小超过2M"
}
filePath := a.GetImage(file)
// println("filePath:", filePath)
if filePath == "" {
log.Println("获取图片失败")
return "获取图片失败"
}
base64Img := Image2Base64(filePath)
if base64Img == "" {
log.Println("图片转换base64失败")
return "图片转换base64失败"
}
// 找到第一个空格的位置
firstSpaceIndex := strings.Index(a.RawMsg, " ")
// 找到最后一个空格的位置
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
requestBody = map[string]interface{}{
"model": MODEL,
"stream": false,
"messages": []interface{}{
map[string]interface{}{
"role": "system",
"content": "#角色你是一名AI视觉助手任务是分析单个图像。",
},
map[string]interface{}{
"role": "user",
"content": []interface{}{
map[string]interface{}{
"type": "text",
"text": a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
},
map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": base64Img,
},
},
},
},
},
"temperature": 0.7,
"presence_penalty": 0,
"frequency_penalty": 0,
"top_p": 1,
}
}
request := gorequest.New()
resp, body, errs := request.Post(OPENAI_BaseURL).
Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway).
Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY).
Send(requestBody).
End()
if errs != nil {
log.Println(errs)
return "请求失败"
}
println(resp.StatusCode)
if resp.StatusCode == 200 {
var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
log.Println(err)
return "解析失败"
}
choices := responseBody["choices"].([]interface{})
if len(choices) > 0 {
choice := choices[0].(map[string]interface{})
msg := choice["message"].(map[string]interface{})["content"].(string)
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
} else {
log.Println("choices为空")
return "api解析失败"
}
}
return "请求失败!"
// return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID, a.Parms)
}
}
func getConfig() (string, string, string) {
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
if cfg["OPENAI_API_KEY"] != nil {
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
} else {
log.Println("OPENAI_API_KEY 未配置")
}
if cfg["OPENAI_BaseURL"] != nil {
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
} else {
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
OPENAI_BaseURL = "https://api.openai.com/v1"
}
if cfg["MODEL"] != nil {
MODEL = cfg["MODEL"].(string)
} else {
log.Println("模型 未配置,使用默认 gpt-4o 模型")
MODEL = "gpt-4o"
}
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
}
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
request := gorequest.New()
resp, body, errs := request.Get(OPENAI_BaseURL).
Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY).
End()
if errs != nil {
log.Println(errs)
return "请求失败"
}
if resp.StatusCode == 200 {
var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
log.Println(err)
return "解析模型列表失败"
}
choices := responseBody["data"].([]interface{})
// var models []interface{}
if len(choices) > 0 {
msg := "支持的模型列表:\n"
for _, choice := range choices {
model := choice.(map[string]interface{})["id"]
msg += fmt.Sprintf("%s\n", model)
}
return msg
} else {
return "模型列表为空"
}
} else {
log.Println("请求失败")
return "请求模型列表失败"
}
}
func Image2Base64(path string) string {
file, err := os.Open(path)
if err != nil {
return ""
}
defer file.Close()
if data, err := io.ReadAll(file); err == nil {
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
}
return ""
}