From 55bdc0615db3e61bf95303445413af674c61ffcf Mon Sep 17 00:00:00 2001 From: liyp Date: Sun, 14 Jul 2024 21:38:39 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E6=B7=BB=E5=8A=A0ai=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.go | 11 ++- utils/image2base64.go | 21 ----- utils/router.go | 16 +++- workers/ai.go | 207 +++++++++++++++++++++++++++++++----------- workers/core.go | 52 +++++++++++ 5 files changed, 228 insertions(+), 79 deletions(-) delete mode 100644 utils/image2base64.go diff --git a/main.go b/main.go index 6b789db..8fd2bfd 100644 --- a/main.go +++ b/main.go @@ -37,7 +37,7 @@ func insertMessage(db *sql.DB, data map[string]interface{}) error { sender_role := sender["role"].(string) message_seq := data["message_seq"].(float64) - fmt.Println(post_type, message_time, int64(group_id), int64(message_id), raw_message, sender_user_id, sender_nickname, sender_card, sender_role, int64(message_seq)) + fmt.Println("消息类型:", post_type, " 发送时间:", message_time, " 群号:", int64(group_id), " 消息id:", int64(message_id), " 原始消息:", raw_message, " 发送者id:", sender_user_id, " 发送者昵称:", sender_nickname, " 发送者名片:", sender_card, " 发送者角色:", sender_role, " 消息序列:", int64(message_seq)) _, err = db.Exec("INSERT INTO messages ( post_type, message_type, time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", post_type, message_type, message_time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq) fmt.Println("Data inserted successfully!") @@ -66,7 +66,8 @@ func handlePost(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{"error": "Error decoding JSON"}) return } - + // 输出解析后的 JSON 数据 + fmt.Printf("data: %s\n\n", string(body)) // 打开数据库 db, err := sql.Open("sqlite3", "./data.db") if err != nil { @@ -103,8 +104,10 @@ func handlePost(c *gin.Context) { } // 调用路由处理函数 - utils.Router(data) - c.JSON(http.StatusOK, gin.H{"message": "JSON data received successfully!"}) + if data["post_type"] == "message" || data["post_type"] == "message_sent" { + utils.Router(data) + c.JSON(http.StatusOK, gin.H{"message": "JSON data received successfully!"}) + } } func main() { diff --git a/utils/image2base64.go b/utils/image2base64.go deleted file mode 100644 index 11ca0c1..0000000 --- a/utils/image2base64.go +++ /dev/null @@ -1,21 +0,0 @@ -package utils - -import ( - "encoding/base64" - "io" - "os" -) - -func Image2Base64(path string) string { - file, err := os.Open(path) - if err != nil { - return "" - } - defer file.Close() - - if data, err := io.ReadAll(file); err == nil { - return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) - } - return "" - -} diff --git a/utils/router.go b/utils/router.go index 1d47040..a00a310 100644 --- a/utils/router.go +++ b/utils/router.go @@ -3,6 +3,7 @@ package utils import ( "fmt" "go-bot/workers" + "regexp" "strings" ) @@ -19,7 +20,20 @@ func Router(data map[string]interface{}) { //包含发送消息的'!' raw_msg := data["raw_message"].(string) - // fmt.Println("raw_msg:", string(raw_msg[0])) + + // 匹配回复消息 + if strings.HasPrefix(raw_msg, "[CQ:reply,id=") { + pattern := `^\[CQ:reply,id=(-?\d+)\]` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(raw_msg) + if len(matches) > 0 { + fullMatch := matches[0] + raw_msg = re.ReplaceAllString(raw_msg, "") + raw_msg = raw_msg + " " + fullMatch + } + } + + fmt.Println("raw_msg:", string(raw_msg)) if len(raw_msg) > 1 && raw_msg[0] == '!' { // 去除'!' raw_msg = raw_msg[1:] diff --git a/workers/ai.go b/workers/ai.go index 3fbe3c5..bdda97b 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -1,9 +1,15 @@ package workers import ( + "encoding/base64" "fmt" + "io" + "os" + "log" "net/http" + "regexp" + "strconv" "strings" "time" @@ -29,7 +35,7 @@ func (a *AI) GetMsg() string { } ask := a.Parms[1] - if ask == "" { + if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") { return "不问问题你说个屁!" } @@ -41,7 +47,145 @@ func (a *AI) GetMsg() string { if strings.ToLower(a.Parms[1]) == "models" { return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) } else { - return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID) + OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" + PROMPT, ok := cfg["PROMPT"].(string) + if !ok { + log.Println("PROMPT 未配置") + PROMPT = "" + } + var requestBody map[string]interface{} + if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") { + + requestBody = map[string]interface{}{ + "model": MODEL, + "stream": false, + "messages": []map[string]string{ + { + "role": "system", + "content": PROMPT, + }, + {"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}, + }, + "temperature": 0.7, + "presence_penalty": 0, + "frequency_penalty": 0, + "top_p": 1, + } + } else { + pattern := `^\[CQ:reply,id=(-?\d+)\]` + re := regexp.MustCompile(pattern) + matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1]) + var msgId string + if len(matches) > 0 { + + msgId = matches[1] + } else { + msgId = "" + log.Println("未找到回复消息") + return "未找到回复消息" + } + message := a.GetHisMsg(msgId) + + // 正则表达式匹配 file 和 file_size 的值 + re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`) + matches = re.FindStringSubmatch(message) + var file string + var fileSizeStr string + if len(matches) > 2 { + file = matches[1] + fileSizeStr = matches[2] + } else { + log.Println("未找到文件信息") + return "未找到文件信息" + } + // 将 fileSizeStr 转换为整数 + fileSize, err := strconv.ParseFloat(fileSizeStr, 64) + if err != nil { + fmt.Println("获取图片大小失败:", err) + return "获取图片大小失败" + } + if fileSize/1024/1024 > 1.0 { + log.Println("文件大小超过1M") + return "文件大小超过1M" + } + filePath := a.GetImage(file) + if filePath == "" { + log.Println("获取图片失败") + return "获取图片失败" + } + base64Img := Image2Base64(filePath) + if base64Img == "" { + log.Println("图片转换base64失败") + return "图片转换base64失败" + } + // 找到第一个空格的位置 + firstSpaceIndex := strings.Index(a.RawMsg, " ") + + // 找到最后一个空格的位置 + lastSpaceIndex := strings.LastIndex(a.RawMsg, " ") + requestBody = map[string]interface{}{ + "model": MODEL, + "stream": false, + "messages": []interface{}{ + map[string]interface{}{ + "role": "system", + "content": "#角色你是一名AI视觉助手,任务是分析单个图像。", + }, + map[string]interface{}{ + "role": "user", + "content": []interface{}{ + map[string]interface{}{ + "type": "text", + "text": a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex], + }, + map[string]interface{}{ + "type": "image_url", + "image_url": map[string]string{ + "url": base64Img, + }, + }, + }, + }, + }, + "temperature": 0.7, + "presence_penalty": 0, + "frequency_penalty": 0, + "top_p": 1, + } + } + request := gorequest.New() + resp, body, errs := request.Post(OPENAI_BaseURL). + Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + Send(requestBody). + End() + + if errs != nil { + log.Println(errs) + return "请求失败" + } + println(resp.StatusCode) + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析失败" + } + choices := responseBody["choices"].([]interface{}) + if len(choices) > 0 { + choice := choices[0].(map[string]interface{}) + msg := choice["message"].(map[string]interface{})["content"].(string) + return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg) + } else { + log.Println("choices为空") + return "api解析失败" + } + + } + return "请求失败!" + // return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID, a.Parms) + } } @@ -109,59 +253,16 @@ func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string { } } -func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string { - OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" - PROMPT, ok := cfg["PROMPT"].(string) - if !ok { - log.Println("PROMPT 未配置") - PROMPT = "" +func Image2Base64(path string) string { + file, err := os.Open(path) + if err != nil { + return "" } + defer file.Close() - requestBody := map[string]interface{}{ - "model": MODEL, - "stream": false, - "messages": []map[string]string{ - { - "role": "system", - "content": PROMPT, - }, - {"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]}, - }, - "temperature": 0.7, - "presence_penalty": 0, - "frequency_penalty": 0, - "top_p": 1, + if data, err := io.ReadAll(file); err == nil { + return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data) } + return "" - request := gorequest.New() - resp, body, errs := request.Post(OPENAI_BaseURL). - Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway). - Set("Content-Type", "application/json"). - Set("Authorization", "Bearer "+OPENAI_API_KEY). - Send(requestBody). - End() - - if errs != nil { - log.Println(errs) - return "请求失败" - } - println(resp.StatusCode) - if resp.StatusCode == 200 { - var responseBody map[string]interface{} - if err := json.Unmarshal([]byte(body), &responseBody); err != nil { - log.Println(err) - return "解析失败" - } - choices := responseBody["choices"].([]interface{}) - if len(choices) > 0 { - choice := choices[0].(map[string]interface{}) - msg := choice["message"].(map[string]interface{})["content"].(string) - return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg) - } else { - log.Println("choices为空") - return "api解析失败" - } - - } - return "请求失败!" } diff --git a/workers/core.go b/workers/core.go index f976169..07d7b40 100644 --- a/workers/core.go +++ b/workers/core.go @@ -12,6 +12,8 @@ type Worker interface { CheckPermission() string GetMsg() string SendMsg(msg string) bool + GetImage(file string) string + GetHisMsg(id string) string } type StdAns struct { AllowGroup []interface{} @@ -115,7 +117,31 @@ func (s *StdAns) GetMsg() string { // } } +func (s *StdAns) GetHisMsg(id string) string { + if id == "" { + return "" + } + url := cfg["POSTURL"].(string) + "/get_msg?message_id=" + id + request := gorequest.New() + resp, body, errs := request.Post(url).End() + if len(errs) > 0 { + + fmt.Println("Error sending request:", errs) + return "" + } + defer resp.Body.Close() + data := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &data) + if err != nil { + fmt.Println("解析JSON失败:", err) + return "" + } + + // fmt.Println("响应返回:", body) + return data["data"].(map[string]interface{})["message"].(string) + +} func (s *StdAns) SendMsg(msg string) bool { if msg == "-1" { return false @@ -156,3 +182,29 @@ func (s *StdAns) SendMsg(msg string) bool { fmt.Println("响应返回:", body) return true } + +func (s *StdAns) GetImage(file string) string { + if file == "" { + return "" + } + url := cfg["POSTURL"].(string) + "/get_image?file=" + file + + request := gorequest.New() + resp, body, errs := request.Post(url).End() + if len(errs) > 0 { + + fmt.Println("Error sending request:", errs) + return "" + } + defer resp.Body.Close() + data := make(map[string]interface{}) + err := json.Unmarshal([]byte(body), &data) + if err != nil { + fmt.Println("解析JSON失败:", err) + return "" + } + path := data["data"].(map[string]interface{})["file"].(string) + + // fmt.Println("响应返回:", body) + return path +}