diff --git a/workers/ai.go b/workers/ai.go index c8945eb..4f5c505 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -2,7 +2,6 @@ package workers import ( "fmt" - "go-bot/config" "log" "strings" @@ -25,29 +24,41 @@ type AI struct { func (a *AI) GetMsg() string { if len(a.Parms) < 2 { return "使用!ai xxx 向我提问吧" - } - ask := a.Parms[1] + ask := a.Parms[1] if ask == "" { return "不问问题你说个屁!" } - var msg string - var OPENAI_API_KEY string + + OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig() + if OPENAI_API_KEY == "" { + return "OPENAI_API_KEY 未配置" + } + + if strings.ToLower(a.Parms[1]) == "models" { + return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL) + } else { + return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID) + } +} + +func getConfig() (string, string, string) { + var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string + if cfg["OPENAI_API_KEY"] != nil { OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string) } else { log.Println("OPENAI_API_KEY 未配置") - return "OPENAI_API_KEY 未配置" } - var OPENAI_BaseURL string + if cfg["OPENAI_BaseURL"] != nil { OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string) } else { log.Println("OPENAI_BaseURL 未配置,使用openai默认配置") OPENAI_BaseURL = "https://api.openai.com/v1" } - var MODEL string + if cfg["MODEL"] != nil { MODEL = cfg["MODEL"].(string) } else { @@ -55,120 +66,95 @@ func (a *AI) GetMsg() string { MODEL = "chatglm_pro" } - if strings.ToLower(a.Parms[1]) == "models" { + return OPENAI_API_KEY, OPENAI_BaseURL, MODEL +} - OPENAI_BaseURL = OPENAI_BaseURL + "/models" +func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string { + OPENAI_BaseURL = OPENAI_BaseURL + "/models" + request := gorequest.New() + resp, body, errs := request.Get(OPENAI_BaseURL). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + End() - request := gorequest.New() - resp, body, errs := request.Get(OPENAI_BaseURL). - Set("Content-Type", "application/json"). - Set("Authorization", "Bearer "+OPENAI_API_KEY). - End() - if errs != nil { - log.Println(errs) - return "请求失败" - } else { - if resp.StatusCode == 200 { - var responseBody map[string]interface{} - if err := json.Unmarshal([]byte(body), &responseBody); err != nil { - log.Println(err) - return "解析模型列表失败" - } - - choices := responseBody["data"].([]interface{}) - var models []interface{} - if len(choices) > 0 { - msg = "支持的模型列表:\n" - for _, choice := range choices { - model := choice.(map[string]interface{})["id"].(string) - if model == MODEL { - msg = msg + model + "\t ✔\n" - } else { - msg = msg + model + "\n" - } - models = append(models, model) - - } - - } else { - msg = "没查到支持模型列表" - } - if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" { - // 判断允许设置权限,需要AllowUser和发消息用户账号相同 - if a.Master != nil && contains(a.Master, a.UID) { - if contains(models, a.Parms[3]) { - cfg["MODEL"] = a.Parms[3] - msg = "已设置模型为 " + a.Parms[3] - config.ModifyConfig("MODEL", a.Parms[3]) - config.ReloadConfig() - config.PrintConfig(cfg, "") - } else { - msg = "不支持的模型" - } - - } else { - msg = "无权限设置模型" - } - - } - return msg - } else { - log.Println("请求失败") - return "请求模型列表失败" - } - } - } else { - OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" - PROMPT, ok := cfg["PROMPT"].(string) - if !ok { - log.Println("PROMRT 未配置") - PROMPT = "" - } - // PROMPT = "" - // println("PROMPT:", PROMPT) - // println("ask:", ask) - requestBody := map[string]interface{}{ - "model": MODEL, - "stream": false, - "messages": []map[string]string{ - { - "role": "system", - "content": PROMPT, - }, - {"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}}, - // "max_tokens": 200, - "temperature": 0.7, - } - request := gorequest.New() - resp, body, errs := request.Post(OPENAI_BaseURL). - Set("Content-Type", "application/json"). - Set("Authorization", "Bearer "+OPENAI_API_KEY). - Send(requestBody). - End() - if errs != nil { - log.Println(errs) - return "请求失败" - } else { - if resp.StatusCode == 200 { - var responseBody map[string]interface{} - if err := json.Unmarshal([]byte(body), &responseBody); err != nil { - log.Println(err) - return "解析失败" - } - choices := responseBody["choices"].([]interface{}) - if len(choices) > 0 { - choice := choices[0].(map[string]interface{}) - msg = choice["message"].(map[string]interface{})["content"].(string) - // println("msg:", msg) - - } else { - log.Println("choices为空") - msg = "api解析失败" - } - } - } + if errs != nil { + log.Println(errs) + return "请求失败" } - return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg) + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析模型列表失败" + } + choices := responseBody["data"].([]interface{}) + // var models []interface{} + if len(choices) > 0 { + msg := "支持的模型列表:\\n" + for _, choice := range choices { + model := choice.(map[string]interface{})["id"] + msg += fmt.Sprintf("%s\\n", model) + } + return msg + } else { + return "模型列表为空" + } + } else { + log.Println("请求失败") + return "请求模型列表失败" + } +} + +func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string { + OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" + PROMPT, ok := cfg["PROMPT"].(string) + if !ok { + log.Println("PROMPT 未配置") + PROMPT = "" + } + + requestBody := map[string]interface{}{ + "model": MODEL, + "stream": false, + "messages": []map[string]string{ + { + "role": "system", + "content": PROMPT, + }, + {"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]}, + }, + "temperature": 0.7, + } + request := gorequest.New() + resp, body, errs := request.Post(OPENAI_BaseURL). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + Send(requestBody). + End() + + if errs != nil { + log.Println(errs) + return "请求失败" + } + + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析失败" + } + choices := responseBody["choices"].([]interface{}) + if len(choices) > 0 { + choice := choices[0].(map[string]interface{}) + msg := choice["message"].(map[string]interface{})["content"].(string) + return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg) + } else { + log.Println("choices为空") + return "api解析失败" + } + } else { + return "请求失败" + } } diff --git a/workers/pkg.go b/workers/pkg.go index bbe0efd..01dd8fe 100644 --- a/workers/pkg.go +++ b/workers/pkg.go @@ -2,12 +2,11 @@ package workers import ( "fmt" - "io" - "net/http" "strings" "time" "github.com/goccy/go-json" + "github.com/parnurzeal/gorequest" ) func init() { @@ -40,19 +39,15 @@ func (a *Pkg) GetMsg() string { } // 输出请求地址 fmt.Println("pkg url:", url) - req, err := http.Get(url) - if err != nil { + request := gorequest.New() + _, body, errs := request.Get(url).End() + if len(errs) > 0 { return "服务器网络错误!" } - defer req.Body.Close() - body, err := io.ReadAll(req.Body) - if err != nil { - return "数据解析错误!" - } // fmt.Println("pkg body:", string(body)) // var pkg []Package var pkg map[string]interface{} - err = json.Unmarshal(body, &pkg) + err := json.Unmarshal([]byte(body), &pkg) if err != nil { return err.Error() } @@ -61,20 +56,16 @@ func (a *Pkg) GetMsg() string { if len(resultSlipe) == 0 { url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1] - req, err := http.Get(url) - if err != nil { - fmt.Println(err) + _, body, errs := request.Get(url).End() + if len(errs) > 0 { + fmt.Println(errs) } - defer req.Body.Close() fmt.Println("aur suggest url:", url) - re, err := io.ReadAll(req.Body) // fmt.Println(string(re)) - if err != nil { - fmt.Println(err) - } + var suggestions []string - err = json.Unmarshal(re, &suggestions) + err := json.Unmarshal([]byte(body), &suggestions) if err != nil { fmt.Println(err) } @@ -83,20 +74,16 @@ func (a *Pkg) GetMsg() string { return "没有找到相关软件" } searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name" - searchReq, err := http.Get(searchUrl) - if err != nil { - fmt.Println(err) - } - defer searchReq.Body.Close() - searchRe, err := io.ReadAll(searchReq.Body) - if err != nil { - fmt.Println(err) + _, body, errs = request.Get(searchUrl).End() + if len(errs) > 0 { + fmt.Println("searchUrl err:", errs) } var searchMap map[string]interface{} - err = json.Unmarshal(searchRe, &searchMap) + err = json.Unmarshal([]byte(body), &searchMap) if err != nil { fmt.Println(err) } + searchResults := searchMap["results"].([]interface{}) // println("searchResults:", len(searchResults)) maxVotes := 0.0 @@ -129,7 +116,11 @@ func (a *Pkg) GetMsg() string { msg += "版本:" + searchResult["Version"].(string) + "\n" msg += "描述:" + searchResult["Description"].(string) + "\n" msg += "维护者:" + maintainer + "\n" - msg += "链接:" + searchResult["URL"].(string) + "\n" + upstream, ok := searchResult["URL"].(string) + if !ok || upstream == "" { + upstream = "无" + } + msg += "上游:" + upstream + "\n" msg += "更新时间:" + last_update fmt.Println(msg) @@ -149,7 +140,7 @@ func (a *Pkg) GetMsg() string { msg += "版本:" + result["pkgver"].(string) + "\n" msg += "描述:" + result["pkgdesc"].(string) + "\n" msg += "打包:" + result["packager"].(string) + "\n" - msg += "链接:" + result["url"].(string) + "\n" + msg += "上游:" + result["url"].(string) + "\n" msg += "更新日期:" + last_update return msg