feat(workers): 重构AI worker以支持模型切换和错误处理

支持通过配置切换OpenAI模型,优化了模型请求的处理逻辑,增加了对错误情况的处理,提高了代码的鲁棒性和可维护性。
This commit is contained in:
liyp 2024-07-06 19:18:59 +08:00
parent c63b71b2a4
commit 31b7ab9f67
2 changed files with 126 additions and 149 deletions

View file

@ -2,7 +2,6 @@ package workers
import ( import (
"fmt" "fmt"
"go-bot/config"
"log" "log"
"strings" "strings"
@ -25,29 +24,41 @@ type AI struct {
func (a *AI) GetMsg() string { func (a *AI) GetMsg() string {
if len(a.Parms) < 2 { if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧" return "使用!ai xxx 向我提问吧"
} }
ask := a.Parms[1]
ask := a.Parms[1]
if ask == "" { if ask == "" {
return "不问问题你说个屁!" return "不问问题你说个屁!"
} }
var msg string
var OPENAI_API_KEY string OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
if OPENAI_API_KEY == "" {
return "OPENAI_API_KEY 未配置"
}
if strings.ToLower(a.Parms[1]) == "models" {
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
} else {
return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID)
}
}
func getConfig() (string, string, string) {
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
if cfg["OPENAI_API_KEY"] != nil { if cfg["OPENAI_API_KEY"] != nil {
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string) OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
} else { } else {
log.Println("OPENAI_API_KEY 未配置") log.Println("OPENAI_API_KEY 未配置")
return "OPENAI_API_KEY 未配置"
} }
var OPENAI_BaseURL string
if cfg["OPENAI_BaseURL"] != nil { if cfg["OPENAI_BaseURL"] != nil {
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string) OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
} else { } else {
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置") log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
OPENAI_BaseURL = "https://api.openai.com/v1" OPENAI_BaseURL = "https://api.openai.com/v1"
} }
var MODEL string
if cfg["MODEL"] != nil { if cfg["MODEL"] != nil {
MODEL = cfg["MODEL"].(string) MODEL = cfg["MODEL"].(string)
} else { } else {
@ -55,19 +66,22 @@ func (a *AI) GetMsg() string {
MODEL = "chatglm_pro" MODEL = "chatglm_pro"
} }
if strings.ToLower(a.Parms[1]) == "models" { return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
}
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
OPENAI_BaseURL = OPENAI_BaseURL + "/models" OPENAI_BaseURL = OPENAI_BaseURL + "/models"
request := gorequest.New() request := gorequest.New()
resp, body, errs := request.Get(OPENAI_BaseURL). resp, body, errs := request.Get(OPENAI_BaseURL).
Set("Content-Type", "application/json"). Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY). Set("Authorization", "Bearer "+OPENAI_API_KEY).
End() End()
if errs != nil { if errs != nil {
log.Println(errs) log.Println(errs)
return "请求失败" return "请求失败"
} else { }
if resp.StatusCode == 200 { if resp.StatusCode == 200 {
var responseBody map[string]interface{} var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil { if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
@ -76,57 +90,31 @@ func (a *AI) GetMsg() string {
} }
choices := responseBody["data"].([]interface{}) choices := responseBody["data"].([]interface{})
var models []interface{} // var models []interface{}
if len(choices) > 0 { if len(choices) > 0 {
msg = "支持的模型列表:\n" msg := "支持的模型列表:\\n"
for _, choice := range choices { for _, choice := range choices {
model := choice.(map[string]interface{})["id"].(string) model := choice.(map[string]interface{})["id"]
if model == MODEL { msg += fmt.Sprintf("%s\\n", model)
msg = msg + model + "\t ✔\n"
} else {
msg = msg + model + "\n"
}
models = append(models, model)
}
} else {
msg = "没查到支持模型列表"
}
if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" {
// 判断允许设置权限,需要AllowUser和发消息用户账号相同
if a.Master != nil && contains(a.Master, a.UID) {
if contains(models, a.Parms[3]) {
cfg["MODEL"] = a.Parms[3]
msg = "已设置模型为 " + a.Parms[3]
config.ModifyConfig("MODEL", a.Parms[3])
config.ReloadConfig()
config.PrintConfig(cfg, "")
} else {
msg = "不支持的模型"
}
} else {
msg = "无权限设置模型"
}
} }
return msg return msg
} else {
return "模型列表为空"
}
} else { } else {
log.Println("请求失败") log.Println("请求失败")
return "请求模型列表失败" return "请求模型列表失败"
} }
} }
} else {
func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string {
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
PROMPT, ok := cfg["PROMPT"].(string) PROMPT, ok := cfg["PROMPT"].(string)
if !ok { if !ok {
log.Println("PROMRT 未配置") log.Println("PROMPT 未配置")
PROMPT = "" PROMPT = ""
} }
// PROMPT = ""
// println("PROMPT:", PROMPT)
// println("ask:", ask)
requestBody := map[string]interface{}{ requestBody := map[string]interface{}{
"model": MODEL, "model": MODEL,
"stream": false, "stream": false,
@ -135,8 +123,8 @@ func (a *AI) GetMsg() string {
"role": "system", "role": "system",
"content": PROMPT, "content": PROMPT,
}, },
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}}, {"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]},
// "max_tokens": 200, },
"temperature": 0.7, "temperature": 0.7,
} }
request := gorequest.New() request := gorequest.New()
@ -145,10 +133,12 @@ func (a *AI) GetMsg() string {
Set("Authorization", "Bearer "+OPENAI_API_KEY). Set("Authorization", "Bearer "+OPENAI_API_KEY).
Send(requestBody). Send(requestBody).
End() End()
if errs != nil { if errs != nil {
log.Println(errs) log.Println(errs)
return "请求失败" return "请求失败"
} else { }
if resp.StatusCode == 200 { if resp.StatusCode == 200 {
var responseBody map[string]interface{} var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil { if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
@ -158,17 +148,13 @@ func (a *AI) GetMsg() string {
choices := responseBody["choices"].([]interface{}) choices := responseBody["choices"].([]interface{})
if len(choices) > 0 { if len(choices) > 0 {
choice := choices[0].(map[string]interface{}) choice := choices[0].(map[string]interface{})
msg = choice["message"].(map[string]interface{})["content"].(string) msg := choice["message"].(map[string]interface{})["content"].(string)
// println("msg:", msg) return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg)
} else { } else {
log.Println("choices为空") log.Println("choices为空")
msg = "api解析失败" return "api解析失败"
} }
} else {
return "请求失败"
} }
}
}
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
} }

View file

@ -2,12 +2,11 @@ package workers
import ( import (
"fmt" "fmt"
"io"
"net/http"
"strings" "strings"
"time" "time"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"github.com/parnurzeal/gorequest"
) )
func init() { func init() {
@ -40,19 +39,15 @@ func (a *Pkg) GetMsg() string {
} }
// 输出请求地址 // 输出请求地址
fmt.Println("pkg url:", url) fmt.Println("pkg url:", url)
req, err := http.Get(url) request := gorequest.New()
if err != nil { _, body, errs := request.Get(url).End()
if len(errs) > 0 {
return "服务器网络错误!" return "服务器网络错误!"
} }
defer req.Body.Close()
body, err := io.ReadAll(req.Body)
if err != nil {
return "数据解析错误!"
}
// fmt.Println("pkg body:", string(body)) // fmt.Println("pkg body:", string(body))
// var pkg []Package // var pkg []Package
var pkg map[string]interface{} var pkg map[string]interface{}
err = json.Unmarshal(body, &pkg) err := json.Unmarshal([]byte(body), &pkg)
if err != nil { if err != nil {
return err.Error() return err.Error()
} }
@ -61,20 +56,16 @@ func (a *Pkg) GetMsg() string {
if len(resultSlipe) == 0 { if len(resultSlipe) == 0 {
url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1] url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1]
req, err := http.Get(url) _, body, errs := request.Get(url).End()
if err != nil { if len(errs) > 0 {
fmt.Println(err) fmt.Println(errs)
} }
defer req.Body.Close()
fmt.Println("aur suggest url:", url) fmt.Println("aur suggest url:", url)
re, err := io.ReadAll(req.Body)
// fmt.Println(string(re)) // fmt.Println(string(re))
if err != nil {
fmt.Println(err)
}
var suggestions []string var suggestions []string
err = json.Unmarshal(re, &suggestions) err := json.Unmarshal([]byte(body), &suggestions)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
@ -83,20 +74,16 @@ func (a *Pkg) GetMsg() string {
return "没有找到相关软件" return "没有找到相关软件"
} }
searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name" searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name"
searchReq, err := http.Get(searchUrl) _, body, errs = request.Get(searchUrl).End()
if err != nil { if len(errs) > 0 {
fmt.Println(err) fmt.Println("searchUrl err:", errs)
}
defer searchReq.Body.Close()
searchRe, err := io.ReadAll(searchReq.Body)
if err != nil {
fmt.Println(err)
} }
var searchMap map[string]interface{} var searchMap map[string]interface{}
err = json.Unmarshal(searchRe, &searchMap) err = json.Unmarshal([]byte(body), &searchMap)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
searchResults := searchMap["results"].([]interface{}) searchResults := searchMap["results"].([]interface{})
// println("searchResults:", len(searchResults)) // println("searchResults:", len(searchResults))
maxVotes := 0.0 maxVotes := 0.0
@ -129,7 +116,11 @@ func (a *Pkg) GetMsg() string {
msg += "版本:" + searchResult["Version"].(string) + "\n" msg += "版本:" + searchResult["Version"].(string) + "\n"
msg += "描述:" + searchResult["Description"].(string) + "\n" msg += "描述:" + searchResult["Description"].(string) + "\n"
msg += "维护者:" + maintainer + "\n" msg += "维护者:" + maintainer + "\n"
msg += "链接:" + searchResult["URL"].(string) + "\n" upstream, ok := searchResult["URL"].(string)
if !ok || upstream == "" {
upstream = "无"
}
msg += "上游:" + upstream + "\n"
msg += "更新时间:" + last_update msg += "更新时间:" + last_update
fmt.Println(msg) fmt.Println(msg)
@ -149,7 +140,7 @@ func (a *Pkg) GetMsg() string {
msg += "版本:" + result["pkgver"].(string) + "\n" msg += "版本:" + result["pkgver"].(string) + "\n"
msg += "描述:" + result["pkgdesc"].(string) + "\n" msg += "描述:" + result["pkgdesc"].(string) + "\n"
msg += "打包:" + result["packager"].(string) + "\n" msg += "打包:" + result["packager"].(string) + "\n"
msg += "链接" + result["url"].(string) + "\n" msg += "上游" + result["url"].(string) + "\n"
msg += "更新日期:" + last_update msg += "更新日期:" + last_update
return msg return msg