package workers import ( "context" "encoding/base64" "fmt" "io" "net/http" "os" "log" "regexp" "strconv" "strings" openai "github.com/sashabaranov/go-openai" ) func init() { RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker { return &AI{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), } }) } type AI struct { *StdAns } func (a *AI) GetMsg() string { if len(a.Parms) < 2 { return "使用!ai xxx 向我提问吧" } ask := a.Parms[1] if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") { return "不问问题你说个屁!" } OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig() if OPENAI_API_KEY == "" { return "OPENAI_API_KEY 未配置" } oaiconfig := openai.DefaultConfig(OPENAI_API_KEY) oaiconfig.BaseURL = OPENAI_BaseURL client := openai.NewClientWithConfig(oaiconfig) msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID) if strings.ToLower(a.Parms[1]) != "models" { // OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" PROMPT, ok := cfg["PROMPT"].(string) if !ok { log.Println("PROMPT 未配置") PROMPT = "" } // var requestBody map[string]interface{} // 如果非回复消息 if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") { resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: MODEL, Stream: false, Messages: []openai.ChatCompletionMessage{ { Role: "system", Content: PROMPT, }, { Role: "user", Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], }, }, }, ) if err != nil { log.Println("ChatCompletion error: ", err) return "请求失败" } // println(resp.Choices[0].Message.Content) return msg + stripMarkdown(resp.Choices[0].Message.Content) } else { // 匹配回复消息 pattern := `^\[CQ:reply,id=(-?\d+)\]` re := regexp.MustCompile(pattern) matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1]) var msgId string if len(matches) > 0 { msgId = matches[1] println("msgId:", msgId) } else { msgId = "" log.Println("未找到回复消息") return "未找到回复消息" } file, picUrl, fileSizeStr := a.GetHisMsg(msgId) // println("file:", file, "picUrl:", picUrl, "fileSizeStr:", fileSizeStr) if picUrl == "" { log.Println("未找到文件信息") return "未找到文件信息" } fileSize, err := strconv.ParseFloat(fileSizeStr, 64) if err != nil { fmt.Println("获取图片大小失败:", err) return "获取图片大小失败" } if fileSize/1024/1024 > 5.0 { log.Println("文件大小超过5M") return "文件大小超过5M" } filePath := a.GetImage(file) // println("filePath:", filePath) if filePath == "" { log.Println("获取图片失败") return "获取图片失败" } base64Img := Image2Base64(filePath, picUrl) if base64Img == "" { log.Println("图片转换base64失败") return "图片转换base64失败" } // 找到第一个空格的位置 firstSpaceIndex := strings.Index(a.RawMsg, " ") // 找到最后一个空格的位置 lastSpaceIndex := strings.LastIndex(a.RawMsg, " ") // 调用图片分析 resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: MODEL, Messages: []openai.ChatCompletionMessage{ { Role: "user", MultiContent: []openai.ChatMessagePart{ { Type: openai.ChatMessagePartTypeText, Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex], }, { Type: openai.ChatMessagePartTypeImageURL, ImageURL: &openai.ChatMessageImageURL{ URL: base64Img, Detail: openai.ImageURLDetailAuto, }, }, }, }, }, }, ) if err != nil { log.Println("ChatCompletion error: ", err) return "请求失败" } msg += resp.Choices[0].Message.Content } return stripMarkdown(msg) } models, err := client.ListModels(context.Background()) if err != nil { log.Println("ListModels error: ", err) return "查询支持模型失败!" } var modelList string for _, model := range models.Models { if MODEL == model.ID { model.ID = model.ID + "(当前使用)" } modelList += model.ID + "\n" } return modelList // return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) } func stripMarkdown(text string) string { // 移除代码块,但保留代码内容 re := regexp.MustCompile("```(.*?)```") text = re.ReplaceAllString(text, "$1") // 移除粗体 re = regexp.MustCompile(`\*\*(.*?)\*\*`) text = re.ReplaceAllString(text, "$1") // 移除下划线 re = regexp.MustCompile("__(.*?)__") text = re.ReplaceAllString(text, "$1") return text } func getConfig() (string, string, string) { var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string if cfg["OPENAI_API_KEY"] != nil { OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string) } else { log.Println("OPENAI_API_KEY 未配置") } if cfg["OPENAI_BaseURL"] != nil { OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string) } else { log.Println("OPENAI_BaseURL 未配置,使用openai默认配置") OPENAI_BaseURL = "https://api.openai.com/v1" } if cfg["MODEL"] != nil { MODEL = cfg["MODEL"].(string) } else { log.Println("模型 未配置,使用默认 gpt-4o 模型") MODEL = "gpt-4o" } return OPENAI_API_KEY, OPENAI_BaseURL, MODEL } func Image2Base64(path string, picUrl string) string { // 尝试从文件路径读取图片 file, err := os.Open(path) if err == nil { defer file.Close() if data, err := io.ReadAll(file); err == nil { return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) } } // 如果文件路径不可用,则尝试从 URL 下载图片 resp, err := http.Get(picUrl) if err != nil { return "" } defer resp.Body.Close() if data, err := io.ReadAll(resp.Body); err == nil { return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) } return "" }