package workers import ( "context" "encoding/base64" "encoding/json" "fmt" "go-bot/config" "go-bot/tools" "os" "slices" "io" "log" "regexp" "strconv" "strings" "github.com/imroc/req/v3" openai "github.com/sashabaranov/go-openai" ) func init() { plugins := config.GetConfig()["PLUGINS"].([]interface{}) if slices.Contains(plugins, "ai") { RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker { return &AI{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), } }) } } type AI struct { *StdAns } func (a *AI) GetMsg() string { if len(a.Parms) < 2 { return "使用!ai xxx 向我提问吧" } ask := a.Parms[1] if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") { return "不问问题你说个屁!" } OPENAI_API_KEY, OPENAI_BaseURL, MODEL, PROMPT, CONTEXT := tools.GetOAIConfig() if OPENAI_API_KEY == "" { return "OPENAI_API_KEY 未配置" } oaiconfig := openai.DefaultConfig(OPENAI_API_KEY) oaiconfig.BaseURL = OPENAI_BaseURL client := openai.NewClientWithConfig(oaiconfig) msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID) if strings.ToLower(a.Parms[1]) != "models" { // 如果非回复消息 if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") { messages := make([]openai.ChatCompletionMessage, 0) messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleSystem, Content: PROMPT, }) // println("messages0:", messages) var key string redisClient := tools.GetRedisClient() // 如果redisClient不为空,则获取上下文 if redisClient != nil { // 限制请求频率 if !tools.CheckRequestFrequency(a.GID, a.UID, 10) { log.Printf("请求过于频繁。\n") return "请求过于频繁。" } key = fmt.Sprintf("context:%s:%s:%s", "ai", a.GID, a.UID) // 获取上下文 message := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], } // 序列化为 JSON 字符串 jsonMessage, err := json.Marshal(message) if err != nil { log.Println("序列化错误:", err) log.Println("存储的 JSON 字符串:", string(jsonMessage)) return "序列化错误" } tools.AddToContext(key, jsonMessage, CONTEXT) // } // println("RawMsg:", a.RawMsg[strings.Index(a.RawMsg, " ")+1:]) length := tools.GetListLength(key) // if length > 0 { for i := 0; i < int(length); i++ { message, err := tools.GetListValue(key, int64(i)) if err != nil { log.Println("获取上下文失败:", err) return "获取上下文失败" } // log.Println("读取的 JSON 字符串:", message) var msg openai.ChatCompletionMessage err = json.Unmarshal([]byte(message), &msg) if err != nil { log.Println("反序列化错误:", err) return "反序列化错误" } messages = append(messages, msg) } } else { messages = append(messages, openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleUser, Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], }) } // println("messages:", messages) // for i, msg := range messages { // fmt.Printf("消息 %d:\n", i+1) // fmt.Printf(" 角色: %s\n", msg.Role) // fmt.Printf(" 内容: %s\n", msg.Content) // fmt.Println() // } resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: MODEL, Stream: false, Messages: messages, }, ) if err != nil { log.Println("ChatCompletion error: ", err) tools.RemoveLastValueFromList(key) return "请求失败" } // println(resp.Choices[0].Message.Content) content := resp.Choices[0].Message.Content if redisClient != nil { message := openai.ChatCompletionMessage{ Role: openai.ChatMessageRoleAssistant, Content: stripMarkdown(content)} // 序列化为 JSON 字符串 jsonMessage, err := json.Marshal(message) if err != nil { log.Println("Assistant消息序列化错误:", err) return "Assistant消息序列化错误" } tools.AddToContext(key, jsonMessage, CONTEXT) } return msg + stripMarkdown(content) } else { // 匹配回复消息 pattern := `^\[CQ:reply,id=(-?\d+)\]` re := regexp.MustCompile(pattern) matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1]) var msgId string if len(matches) > 0 { msgId = matches[1] // println("msgId:", msgId) } else { msgId = "" log.Println("未找到回复消息") return "未找到回复消息" } file, picUrl, fileSizeStr := a.GetHisMsg(msgId) // println("file:", file, "picUrl:", picUrl, "fileSizeStr:", fileSizeStr) if picUrl == "" { log.Println("未找到文件信息") return "未找到文件信息" } fileSize, err := strconv.ParseFloat(fileSizeStr, 64) if err != nil { fmt.Println("获取图片大小失败:", err) return "获取图片大小失败" } if fileSize/1024/1024 > 5.0 { log.Println("文件大小超过5M") return "文件大小超过5M" } filePath := a.GetImage(file) // println("filePath:", filePath) if filePath == "" { log.Println("获取图片失败") // return "获取图片失败" // 下载picUrl文件 } base64Img := Image2Base64(filePath, picUrl) if base64Img == "" { log.Println("图片转换base64失败") return "图片转换base64失败" } // 找到第一个空格的位置 firstSpaceIndex := strings.Index(a.RawMsg, " ") // 找到最后一个空格的位置 lastSpaceIndex := strings.LastIndex(a.RawMsg, " ") // 调用图片分析 resp, err := client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ Model: MODEL, Messages: []openai.ChatCompletionMessage{ { Role: "user", MultiContent: []openai.ChatMessagePart{ { Type: openai.ChatMessagePartTypeText, Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex], }, { Type: openai.ChatMessagePartTypeImageURL, ImageURL: &openai.ChatMessageImageURL{ URL: base64Img, Detail: openai.ImageURLDetailAuto, }, }, }, }, }, }, ) if err != nil { log.Println("ChatCompletion error: ", err) return "请求失败,api可能不支持图片上传" } msg += resp.Choices[0].Message.Content } return stripMarkdown(msg) } models, err := client.ListModels(context.Background()) if err != nil { log.Println("ListModels error: ", err) return "查询支持模型失败!" } var modelList string for _, model := range models.Models { if MODEL == model.ID { model.ID = model.ID + "\t(当前使用)" } modelList += model.ID + "\n" } return modelList // return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) } func stripMarkdown(text string) string { // println("before:", text) // 移除代码块,但保留代码内容 re := regexp.MustCompile("(?s)```(?:\\w+\\s*\n)?(.*?)```") text = re.ReplaceAllStringFunc(text, func(match string) string { submatches := re.FindStringSubmatch(match) if len(submatches) > 1 { return strings.TrimSpace(submatches[1]) } return match }) // 移除行内代码 re = regexp.MustCompile("`([^`]+)`") text = re.ReplaceAllString(text, "$1") // 移除标题 re = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) text = re.ReplaceAllString(text, "$1") // 移除粗体和斜体 re = regexp.MustCompile(`(\*\*|__)(.+?)(\*\*|__)`) text = re.ReplaceAllString(text, "$2") re = regexp.MustCompile(`(\*|_)(.+?)(\*|_)`) text = re.ReplaceAllString(text, "$2") // 移除链接,保留链接文本 re = regexp.MustCompile(`\[([^\]]+)\]\([^\)]+\)`) text = re.ReplaceAllString(text, "$1") // 移除图片 re = regexp.MustCompile(`!\[([^\]]*)\]\([^\)]+\)`) text = re.ReplaceAllString(text, "$1") // 移除水平线 re = regexp.MustCompile(`(?m)^-{3,}|_{3,}|\*{3,}$`) text = re.ReplaceAllString(text, "") // 移除块引用 re = regexp.MustCompile(`(?m)^>\s+(.+)`) text = re.ReplaceAllString(text, "$1") // 移除列表标记 re = regexp.MustCompile(`(?m)^[ \t]*[\*\-+][ \t]+`) text = re.ReplaceAllString(text, "") re = regexp.MustCompile(`(?m)^[ \t]*\d+\.[ \t]+`) text = re.ReplaceAllString(text, "") // 移除多余的空行 re = regexp.MustCompile(`\n{3,}`) text = re.ReplaceAllString(text, "\n\n") // println("after:", text) return strings.TrimSpace(text) } func Image2Base64(path string, picUrl string) string { // 尝试从文件路径读取图片 file, err := os.Open(path) if err == nil { defer file.Close() if data, err := io.ReadAll(file); err == nil { return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) } } // 如果文件路径不可用,则尝试从 URL 下载图片 client := req.C(). SetUserAgent("Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/114.0.0.0 Safari/537.36"). SetTLSFingerprintChrome() // 模拟 Chrome 浏览器的 TLS 握手指纹,让网站相信这是 Chrome 浏览器在访问,予以通行。 resp, err := client.R().Get(picUrl) if err != nil { return "" } defer resp.Body.Close() if data, err := io.ReadAll(resp.Body); err == nil { return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) } return "" }