From 9b2c187566d5960c0310ad2cf439a6f55789b94f Mon Sep 17 00:00:00 2001 From: liyp Date: Sat, 31 Aug 2024 16:31:06 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E4=BD=BF=E7=94=A8=E7=AC=AC?= =?UTF-8?q?=E4=B8=89=E6=96=B9OpenAI=E5=BA=93=E6=94=AF=E6=8C=81=E5=B9=B6?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- go.mod | 1 + go.sum | 2 + utils/openai.go | 1 + workers/ai.go | 182 ++++++++++++++++++------------------------------ 5 files changed, 74 insertions(+), 114 deletions(-) create mode 100644 utils/openai.go diff --git a/README.md b/README.md index e5ba469..57fb976 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ # 使用Go语言重新实现 [sihuan/XZZ](https://github.com/sihuan/XZZ) 机器人项目 本项目是一个使用Go语言重新实现 [sihuan/XZZ](https://github.com/sihuan/XZZ) 机器人项目。原使用go-cqhttp的机器人项目,由于go-cqhttp不再维护,有众多bug,无法使用,所以更换使用 [napcat](https://github.com/NapNeko/NapCatQQ) 实现。 -当前项目的功能都在`workers`目录下。同时所有接收到的消息都保存在一个sqllite数据库中,文件名为`data.db`保存在项目根目录。 +当前项目的功能都在`workers`目录下。同时所有接收到的消息都保存在一个sqlite数据库中,文件名为`data.db`保存在项目根目录。 ## 部署服务: 1. 先使用docker部署[napcat](https://github.com/NapNeko/NapCatQQ),然后修改配置文件,将机器人的token替换为napcat的token,然后运行项目即可。 部署[napcat](https://github.com/NapNeko/NapCatQQ)可参考下面的docker-compose.yml文件: diff --git a/go.mod b/go.mod index 94e01ab..afb3548 100644 --- a/go.mod +++ b/go.mod @@ -45,6 +45,7 @@ require ( github.com/gin-gonic/gin v1.10.0 github.com/moul/http2curl v1.0.0 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/sashabaranov/go-openai v1.29.0 github.com/sirupsen/logrus v1.9.3 github.com/smartystreets/goconvey v1.8.1 // indirect golang.org/x/net v0.26.0 // indirect diff --git a/go.sum b/go.sum index 7d51c1a..c1ba61f 100644 --- a/go.sum +++ b/go.sum @@ -71,6 +71,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-charset v0.0.0-20180617210344-2471d30d28b4/go.mod h1:qgYeAmZ5ZIpBWTGllZSQnw97Dj+woV0toclVaRGI8pc= +github.com/sashabaranov/go-openai v1.29.0 h1:eBH6LSjtX4md5ImDCX8hNhHQvaRf22zujiERoQpsvLo= +github.com/sashabaranov/go-openai v1.29.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= diff --git a/utils/openai.go b/utils/openai.go new file mode 100644 index 0000000..d4b585b --- /dev/null +++ b/utils/openai.go @@ -0,0 +1 @@ +package utils diff --git a/workers/ai.go b/workers/ai.go index 237011b..5ea7575 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -1,20 +1,18 @@ package workers import ( + "context" "encoding/base64" "fmt" "io" "os" "log" - "net/http" "regexp" "strconv" "strings" - "time" - "github.com/goccy/go-json" - "github.com/parnurzeal/gorequest" + openai "github.com/sashabaranov/go-openai" ) func init() { @@ -43,35 +41,46 @@ func (a *AI) GetMsg() string { if OPENAI_API_KEY == "" { return "OPENAI_API_KEY 未配置" } + oaiconfig := openai.DefaultConfig(OPENAI_API_KEY) + oaiconfig.BaseURL = OPENAI_BaseURL - if strings.ToLower(a.Parms[1]) == "models" { - return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) - } else { - OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" + client := openai.NewClientWithConfig(oaiconfig) + msg := fmt.Sprintf("[CQ:at,qq=%s] ", a.UID) + if strings.ToLower(a.Parms[1]) != "models" { + // OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" PROMPT, ok := cfg["PROMPT"].(string) if !ok { log.Println("PROMPT 未配置") PROMPT = "" } - var requestBody map[string]interface{} + // var requestBody map[string]interface{} + // 如果非回复消息 if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") { - - requestBody = map[string]interface{}{ - "model": MODEL, - "stream": false, - "messages": []map[string]string{ - { - "role": "system", - "content": PROMPT, + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: MODEL, + Stream: false, + Messages: []openai.ChatCompletionMessage{ + { + Role: "system", + Content: PROMPT, + }, + { + Role: "user", + Content: a.RawMsg[strings.Index(a.RawMsg, " ")+1:], + }, }, - {"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}, }, - "temperature": 0.7, - "presence_penalty": 0, - "frequency_penalty": 0, - "top_p": 1, + ) + if err != nil { + log.Println("ChatCompletion error: ", err) + return "请求失败" } + // println(resp.Choices[0].Message.Content) + return msg + resp.Choices[0].Message.Content } else { + // 匹配回复消息 pattern := `^\[CQ:reply,id=(-?\d+)\]` re := regexp.MustCompile(pattern) matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1]) @@ -124,70 +133,55 @@ func (a *AI) GetMsg() string { // 找到最后一个空格的位置 lastSpaceIndex := strings.LastIndex(a.RawMsg, " ") - requestBody = map[string]interface{}{ - "model": MODEL, - "stream": false, - "messages": []interface{}{ - map[string]interface{}{ - "role": "system", - "content": "#角色你是一名AI视觉助手,任务是分析单个图像。", - }, - map[string]interface{}{ - "role": "user", - "content": []interface{}{ - map[string]interface{}{ - "type": "text", - "text": a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex], - }, - map[string]interface{}{ - "type": "image_url", - "image_url": map[string]string{ - "url": base64Img, + // 调用图片分析 + resp, err := client.CreateChatCompletion( + context.Background(), + openai.ChatCompletionRequest{ + Model: MODEL, + Messages: []openai.ChatCompletionMessage{ + { + Role: "user", + MultiContent: []openai.ChatMessagePart{ + { + Type: openai.ChatMessagePartTypeText, + Text: a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex], + }, { + Type: openai.ChatMessagePartTypeImageURL, + ImageURL: &openai.ChatMessageImageURL{ + URL: base64Img, + Detail: openai.ImageURLDetailAuto, + }, }, }, }, }, }, - "temperature": 0.7, - "presence_penalty": 0, - "frequency_penalty": 0, - "top_p": 1, - } - } - request := gorequest.New() - resp, body, errs := request.Post(OPENAI_BaseURL). - Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway). - Set("Content-Type", "application/json"). - Set("Authorization", "Bearer "+OPENAI_API_KEY). - Send(requestBody). - End() + ) + if err != nil { + log.Println("ChatCompletion error: ", err) + return "请求失败" - if errs != nil { - log.Println(errs) - return "请求失败" - } - println(resp.StatusCode) - if resp.StatusCode == 200 { - var responseBody map[string]interface{} - if err := json.Unmarshal([]byte(body), &responseBody); err != nil { - log.Println(err) - return "解析失败" - } - choices := responseBody["choices"].([]interface{}) - if len(choices) > 0 { - choice := choices[0].(map[string]interface{}) - msg := choice["message"].(map[string]interface{})["content"].(string) - return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg) - } else { - log.Println("choices为空") - return "api解析失败" } + msg += resp.Choices[0].Message.Content } - return "请求失败!" - // return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID, a.Parms) + return msg + } + models, err := client.ListModels(context.Background()) + if err != nil { + log.Println("ListModels error: ", err) + return "查询支持模型失败!" } + var modelList string + for _, model := range models.Models { + if MODEL == model.ID { + model.ID = model.ID + "(当前使用)" + } + modelList += model.ID + "\n" + } + return modelList + // return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) } func getConfig() (string, string, string) { @@ -216,44 +210,6 @@ func getConfig() (string, string, string) { return OPENAI_API_KEY, OPENAI_BaseURL, MODEL } -func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string { - OPENAI_BaseURL = OPENAI_BaseURL + "/models" - request := gorequest.New() - resp, body, errs := request.Get(OPENAI_BaseURL). - Set("Content-Type", "application/json"). - Set("Authorization", "Bearer "+OPENAI_API_KEY). - End() - - if errs != nil { - log.Println(errs) - return "请求失败" - } - - if resp.StatusCode == 200 { - var responseBody map[string]interface{} - if err := json.Unmarshal([]byte(body), &responseBody); err != nil { - log.Println(err) - return "解析模型列表失败" - } - - choices := responseBody["data"].([]interface{}) - // var models []interface{} - if len(choices) > 0 { - msg := "支持的模型列表:\n" - for _, choice := range choices { - model := choice.(map[string]interface{})["id"] - msg += fmt.Sprintf("%s\n", model) - } - return msg - } else { - return "模型列表为空" - } - } else { - log.Println("请求失败") - return "请求模型列表失败" - } -} - func Image2Base64(path string) string { file, err := os.Open(path) if err != nil {