feat(workers): 重构AI worker以支持模型切换和错误处理
支持通过配置切换OpenAI模型,优化了模型请求的处理逻辑,增加了对错误情况的处理,提高了代码的鲁棒性和可维护性。
This commit is contained in:
parent
c63b71b2a4
commit
31b7ab9f67
2 changed files with 126 additions and 149 deletions
108
workers/ai.go
108
workers/ai.go
|
@ -2,7 +2,6 @@ package workers
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"go-bot/config"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
|
@ -25,29 +24,41 @@ type AI struct {
|
|||
func (a *AI) GetMsg() string {
|
||||
if len(a.Parms) < 2 {
|
||||
return "使用!ai xxx 向我提问吧"
|
||||
|
||||
}
|
||||
ask := a.Parms[1]
|
||||
|
||||
ask := a.Parms[1]
|
||||
if ask == "" {
|
||||
return "不问问题你说个屁!"
|
||||
}
|
||||
var msg string
|
||||
var OPENAI_API_KEY string
|
||||
|
||||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
|
||||
if OPENAI_API_KEY == "" {
|
||||
return "OPENAI_API_KEY 未配置"
|
||||
}
|
||||
|
||||
if strings.ToLower(a.Parms[1]) == "models" {
|
||||
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
||||
} else {
|
||||
return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID)
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig() (string, string, string) {
|
||||
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
|
||||
|
||||
if cfg["OPENAI_API_KEY"] != nil {
|
||||
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
|
||||
} else {
|
||||
log.Println("OPENAI_API_KEY 未配置")
|
||||
return "OPENAI_API_KEY 未配置"
|
||||
}
|
||||
var OPENAI_BaseURL string
|
||||
|
||||
if cfg["OPENAI_BaseURL"] != nil {
|
||||
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
|
||||
} else {
|
||||
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
|
||||
OPENAI_BaseURL = "https://api.openai.com/v1"
|
||||
}
|
||||
var MODEL string
|
||||
|
||||
if cfg["MODEL"] != nil {
|
||||
MODEL = cfg["MODEL"].(string)
|
||||
} else {
|
||||
|
@ -55,19 +66,22 @@ func (a *AI) GetMsg() string {
|
|||
MODEL = "chatglm_pro"
|
||||
}
|
||||
|
||||
if strings.ToLower(a.Parms[1]) == "models" {
|
||||
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
|
||||
}
|
||||
|
||||
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
|
||||
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
|
||||
|
||||
request := gorequest.New()
|
||||
resp, body, errs := request.Get(OPENAI_BaseURL).
|
||||
Set("Content-Type", "application/json").
|
||||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||||
End()
|
||||
|
||||
if errs != nil {
|
||||
log.Println(errs)
|
||||
return "请求失败"
|
||||
} else {
|
||||
}
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
var responseBody map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||||
|
@ -76,57 +90,31 @@ func (a *AI) GetMsg() string {
|
|||
}
|
||||
|
||||
choices := responseBody["data"].([]interface{})
|
||||
var models []interface{}
|
||||
// var models []interface{}
|
||||
if len(choices) > 0 {
|
||||
msg = "支持的模型列表:\n"
|
||||
msg := "支持的模型列表:\\n"
|
||||
for _, choice := range choices {
|
||||
model := choice.(map[string]interface{})["id"].(string)
|
||||
if model == MODEL {
|
||||
msg = msg + model + "\t ✔\n"
|
||||
} else {
|
||||
msg = msg + model + "\n"
|
||||
}
|
||||
models = append(models, model)
|
||||
|
||||
}
|
||||
|
||||
} else {
|
||||
msg = "没查到支持模型列表"
|
||||
}
|
||||
if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" {
|
||||
// 判断允许设置权限,需要AllowUser和发消息用户账号相同
|
||||
if a.Master != nil && contains(a.Master, a.UID) {
|
||||
if contains(models, a.Parms[3]) {
|
||||
cfg["MODEL"] = a.Parms[3]
|
||||
msg = "已设置模型为 " + a.Parms[3]
|
||||
config.ModifyConfig("MODEL", a.Parms[3])
|
||||
config.ReloadConfig()
|
||||
config.PrintConfig(cfg, "")
|
||||
} else {
|
||||
msg = "不支持的模型"
|
||||
}
|
||||
|
||||
} else {
|
||||
msg = "无权限设置模型"
|
||||
}
|
||||
|
||||
model := choice.(map[string]interface{})["id"]
|
||||
msg += fmt.Sprintf("%s\\n", model)
|
||||
}
|
||||
return msg
|
||||
} else {
|
||||
return "模型列表为空"
|
||||
}
|
||||
} else {
|
||||
log.Println("请求失败")
|
||||
return "请求模型列表失败"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
||||
func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string {
|
||||
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
||||
PROMPT, ok := cfg["PROMPT"].(string)
|
||||
if !ok {
|
||||
log.Println("PROMRT 未配置")
|
||||
log.Println("PROMPT 未配置")
|
||||
PROMPT = ""
|
||||
}
|
||||
// PROMPT = ""
|
||||
// println("PROMPT:", PROMPT)
|
||||
// println("ask:", ask)
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"model": MODEL,
|
||||
"stream": false,
|
||||
|
@ -135,8 +123,8 @@ func (a *AI) GetMsg() string {
|
|||
"role": "system",
|
||||
"content": PROMPT,
|
||||
},
|
||||
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}},
|
||||
// "max_tokens": 200,
|
||||
{"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]},
|
||||
},
|
||||
"temperature": 0.7,
|
||||
}
|
||||
request := gorequest.New()
|
||||
|
@ -145,10 +133,12 @@ func (a *AI) GetMsg() string {
|
|||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||||
Send(requestBody).
|
||||
End()
|
||||
|
||||
if errs != nil {
|
||||
log.Println(errs)
|
||||
return "请求失败"
|
||||
} else {
|
||||
}
|
||||
|
||||
if resp.StatusCode == 200 {
|
||||
var responseBody map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||||
|
@ -158,17 +148,13 @@ func (a *AI) GetMsg() string {
|
|||
choices := responseBody["choices"].([]interface{})
|
||||
if len(choices) > 0 {
|
||||
choice := choices[0].(map[string]interface{})
|
||||
msg = choice["message"].(map[string]interface{})["content"].(string)
|
||||
// println("msg:", msg)
|
||||
|
||||
msg := choice["message"].(map[string]interface{})["content"].(string)
|
||||
return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg)
|
||||
} else {
|
||||
log.Println("choices为空")
|
||||
msg = "api解析失败"
|
||||
return "api解析失败"
|
||||
}
|
||||
} else {
|
||||
return "请求失败"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
|
||||
|
||||
}
|
||||
|
|
|
@ -2,12 +2,11 @@ package workers
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/goccy/go-json"
|
||||
"github.com/parnurzeal/gorequest"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -40,19 +39,15 @@ func (a *Pkg) GetMsg() string {
|
|||
}
|
||||
// 输出请求地址
|
||||
fmt.Println("pkg url:", url)
|
||||
req, err := http.Get(url)
|
||||
if err != nil {
|
||||
request := gorequest.New()
|
||||
_, body, errs := request.Get(url).End()
|
||||
if len(errs) > 0 {
|
||||
return "服务器网络错误!"
|
||||
}
|
||||
defer req.Body.Close()
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return "数据解析错误!"
|
||||
}
|
||||
// fmt.Println("pkg body:", string(body))
|
||||
// var pkg []Package
|
||||
var pkg map[string]interface{}
|
||||
err = json.Unmarshal(body, &pkg)
|
||||
err := json.Unmarshal([]byte(body), &pkg)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
@ -61,20 +56,16 @@ func (a *Pkg) GetMsg() string {
|
|||
if len(resultSlipe) == 0 {
|
||||
|
||||
url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1]
|
||||
req, err := http.Get(url)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
_, body, errs := request.Get(url).End()
|
||||
if len(errs) > 0 {
|
||||
fmt.Println(errs)
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
fmt.Println("aur suggest url:", url)
|
||||
re, err := io.ReadAll(req.Body)
|
||||
// fmt.Println(string(re))
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
var suggestions []string
|
||||
err = json.Unmarshal(re, &suggestions)
|
||||
err := json.Unmarshal([]byte(body), &suggestions)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
@ -83,20 +74,16 @@ func (a *Pkg) GetMsg() string {
|
|||
return "没有找到相关软件"
|
||||
}
|
||||
searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name"
|
||||
searchReq, err := http.Get(searchUrl)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
defer searchReq.Body.Close()
|
||||
searchRe, err := io.ReadAll(searchReq.Body)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
_, body, errs = request.Get(searchUrl).End()
|
||||
if len(errs) > 0 {
|
||||
fmt.Println("searchUrl err:", errs)
|
||||
}
|
||||
var searchMap map[string]interface{}
|
||||
err = json.Unmarshal(searchRe, &searchMap)
|
||||
err = json.Unmarshal([]byte(body), &searchMap)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
|
||||
searchResults := searchMap["results"].([]interface{})
|
||||
// println("searchResults:", len(searchResults))
|
||||
maxVotes := 0.0
|
||||
|
@ -129,7 +116,11 @@ func (a *Pkg) GetMsg() string {
|
|||
msg += "版本:" + searchResult["Version"].(string) + "\n"
|
||||
msg += "描述:" + searchResult["Description"].(string) + "\n"
|
||||
msg += "维护者:" + maintainer + "\n"
|
||||
msg += "链接:" + searchResult["URL"].(string) + "\n"
|
||||
upstream, ok := searchResult["URL"].(string)
|
||||
if !ok || upstream == "" {
|
||||
upstream = "无"
|
||||
}
|
||||
msg += "上游:" + upstream + "\n"
|
||||
msg += "更新时间:" + last_update
|
||||
|
||||
fmt.Println(msg)
|
||||
|
@ -149,7 +140,7 @@ func (a *Pkg) GetMsg() string {
|
|||
msg += "版本:" + result["pkgver"].(string) + "\n"
|
||||
msg += "描述:" + result["pkgdesc"].(string) + "\n"
|
||||
msg += "打包:" + result["packager"].(string) + "\n"
|
||||
msg += "链接:" + result["url"].(string) + "\n"
|
||||
msg += "上游:" + result["url"].(string) + "\n"
|
||||
msg += "更新日期:" + last_update
|
||||
return msg
|
||||
|
||||
|
|
Loading…
Reference in a new issue