feat(workers): 重构AI worker以支持模型切换和错误处理
支持通过配置切换OpenAI模型,优化了模型请求的处理逻辑,增加了对错误情况的处理,提高了代码的鲁棒性和可维护性。
This commit is contained in:
parent
c63b71b2a4
commit
31b7ab9f67
2 changed files with 126 additions and 149 deletions
108
workers/ai.go
108
workers/ai.go
|
@ -2,7 +2,6 @@ package workers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"go-bot/config"
|
|
||||||
"log"
|
"log"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -25,29 +24,41 @@ type AI struct {
|
||||||
func (a *AI) GetMsg() string {
|
func (a *AI) GetMsg() string {
|
||||||
if len(a.Parms) < 2 {
|
if len(a.Parms) < 2 {
|
||||||
return "使用!ai xxx 向我提问吧"
|
return "使用!ai xxx 向我提问吧"
|
||||||
|
|
||||||
}
|
}
|
||||||
ask := a.Parms[1]
|
|
||||||
|
|
||||||
|
ask := a.Parms[1]
|
||||||
if ask == "" {
|
if ask == "" {
|
||||||
return "不问问题你说个屁!"
|
return "不问问题你说个屁!"
|
||||||
}
|
}
|
||||||
var msg string
|
|
||||||
var OPENAI_API_KEY string
|
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
|
||||||
|
if OPENAI_API_KEY == "" {
|
||||||
|
return "OPENAI_API_KEY 未配置"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ToLower(a.Parms[1]) == "models" {
|
||||||
|
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
||||||
|
} else {
|
||||||
|
return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getConfig() (string, string, string) {
|
||||||
|
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
|
||||||
|
|
||||||
if cfg["OPENAI_API_KEY"] != nil {
|
if cfg["OPENAI_API_KEY"] != nil {
|
||||||
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
|
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
|
||||||
} else {
|
} else {
|
||||||
log.Println("OPENAI_API_KEY 未配置")
|
log.Println("OPENAI_API_KEY 未配置")
|
||||||
return "OPENAI_API_KEY 未配置"
|
|
||||||
}
|
}
|
||||||
var OPENAI_BaseURL string
|
|
||||||
if cfg["OPENAI_BaseURL"] != nil {
|
if cfg["OPENAI_BaseURL"] != nil {
|
||||||
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
|
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
|
||||||
} else {
|
} else {
|
||||||
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
|
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
|
||||||
OPENAI_BaseURL = "https://api.openai.com/v1"
|
OPENAI_BaseURL = "https://api.openai.com/v1"
|
||||||
}
|
}
|
||||||
var MODEL string
|
|
||||||
if cfg["MODEL"] != nil {
|
if cfg["MODEL"] != nil {
|
||||||
MODEL = cfg["MODEL"].(string)
|
MODEL = cfg["MODEL"].(string)
|
||||||
} else {
|
} else {
|
||||||
|
@ -55,19 +66,22 @@ func (a *AI) GetMsg() string {
|
||||||
MODEL = "chatglm_pro"
|
MODEL = "chatglm_pro"
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.ToLower(a.Parms[1]) == "models" {
|
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
|
||||||
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
|
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
|
||||||
|
|
||||||
request := gorequest.New()
|
request := gorequest.New()
|
||||||
resp, body, errs := request.Get(OPENAI_BaseURL).
|
resp, body, errs := request.Get(OPENAI_BaseURL).
|
||||||
Set("Content-Type", "application/json").
|
Set("Content-Type", "application/json").
|
||||||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||||||
End()
|
End()
|
||||||
|
|
||||||
if errs != nil {
|
if errs != nil {
|
||||||
log.Println(errs)
|
log.Println(errs)
|
||||||
return "请求失败"
|
return "请求失败"
|
||||||
} else {
|
}
|
||||||
|
|
||||||
if resp.StatusCode == 200 {
|
if resp.StatusCode == 200 {
|
||||||
var responseBody map[string]interface{}
|
var responseBody map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||||||
|
@ -76,57 +90,31 @@ func (a *AI) GetMsg() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
choices := responseBody["data"].([]interface{})
|
choices := responseBody["data"].([]interface{})
|
||||||
var models []interface{}
|
// var models []interface{}
|
||||||
if len(choices) > 0 {
|
if len(choices) > 0 {
|
||||||
msg = "支持的模型列表:\n"
|
msg := "支持的模型列表:\\n"
|
||||||
for _, choice := range choices {
|
for _, choice := range choices {
|
||||||
model := choice.(map[string]interface{})["id"].(string)
|
model := choice.(map[string]interface{})["id"]
|
||||||
if model == MODEL {
|
msg += fmt.Sprintf("%s\\n", model)
|
||||||
msg = msg + model + "\t ✔\n"
|
|
||||||
} else {
|
|
||||||
msg = msg + model + "\n"
|
|
||||||
}
|
|
||||||
models = append(models, model)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
msg = "没查到支持模型列表"
|
|
||||||
}
|
|
||||||
if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" {
|
|
||||||
// 判断允许设置权限,需要AllowUser和发消息用户账号相同
|
|
||||||
if a.Master != nil && contains(a.Master, a.UID) {
|
|
||||||
if contains(models, a.Parms[3]) {
|
|
||||||
cfg["MODEL"] = a.Parms[3]
|
|
||||||
msg = "已设置模型为 " + a.Parms[3]
|
|
||||||
config.ModifyConfig("MODEL", a.Parms[3])
|
|
||||||
config.ReloadConfig()
|
|
||||||
config.PrintConfig(cfg, "")
|
|
||||||
} else {
|
|
||||||
msg = "不支持的模型"
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
msg = "无权限设置模型"
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
return msg
|
return msg
|
||||||
|
} else {
|
||||||
|
return "模型列表为空"
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Println("请求失败")
|
log.Println("请求失败")
|
||||||
return "请求模型列表失败"
|
return "请求模型列表失败"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
|
func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string {
|
||||||
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
||||||
PROMPT, ok := cfg["PROMPT"].(string)
|
PROMPT, ok := cfg["PROMPT"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Println("PROMRT 未配置")
|
log.Println("PROMPT 未配置")
|
||||||
PROMPT = ""
|
PROMPT = ""
|
||||||
}
|
}
|
||||||
// PROMPT = ""
|
|
||||||
// println("PROMPT:", PROMPT)
|
|
||||||
// println("ask:", ask)
|
|
||||||
requestBody := map[string]interface{}{
|
requestBody := map[string]interface{}{
|
||||||
"model": MODEL,
|
"model": MODEL,
|
||||||
"stream": false,
|
"stream": false,
|
||||||
|
@ -135,8 +123,8 @@ func (a *AI) GetMsg() string {
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": PROMPT,
|
"content": PROMPT,
|
||||||
},
|
},
|
||||||
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]}},
|
{"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]},
|
||||||
// "max_tokens": 200,
|
},
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
}
|
}
|
||||||
request := gorequest.New()
|
request := gorequest.New()
|
||||||
|
@ -145,10 +133,12 @@ func (a *AI) GetMsg() string {
|
||||||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||||||
Send(requestBody).
|
Send(requestBody).
|
||||||
End()
|
End()
|
||||||
|
|
||||||
if errs != nil {
|
if errs != nil {
|
||||||
log.Println(errs)
|
log.Println(errs)
|
||||||
return "请求失败"
|
return "请求失败"
|
||||||
} else {
|
}
|
||||||
|
|
||||||
if resp.StatusCode == 200 {
|
if resp.StatusCode == 200 {
|
||||||
var responseBody map[string]interface{}
|
var responseBody map[string]interface{}
|
||||||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||||||
|
@ -158,17 +148,13 @@ func (a *AI) GetMsg() string {
|
||||||
choices := responseBody["choices"].([]interface{})
|
choices := responseBody["choices"].([]interface{})
|
||||||
if len(choices) > 0 {
|
if len(choices) > 0 {
|
||||||
choice := choices[0].(map[string]interface{})
|
choice := choices[0].(map[string]interface{})
|
||||||
msg = choice["message"].(map[string]interface{})["content"].(string)
|
msg := choice["message"].(map[string]interface{})["content"].(string)
|
||||||
// println("msg:", msg)
|
return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
log.Println("choices为空")
|
log.Println("choices为空")
|
||||||
msg = "api解析失败"
|
return "api解析失败"
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "请求失败"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
|
@ -2,12 +2,11 @@ package workers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
|
"github.com/parnurzeal/gorequest"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -40,19 +39,15 @@ func (a *Pkg) GetMsg() string {
|
||||||
}
|
}
|
||||||
// 输出请求地址
|
// 输出请求地址
|
||||||
fmt.Println("pkg url:", url)
|
fmt.Println("pkg url:", url)
|
||||||
req, err := http.Get(url)
|
request := gorequest.New()
|
||||||
if err != nil {
|
_, body, errs := request.Get(url).End()
|
||||||
|
if len(errs) > 0 {
|
||||||
return "服务器网络错误!"
|
return "服务器网络错误!"
|
||||||
}
|
}
|
||||||
defer req.Body.Close()
|
|
||||||
body, err := io.ReadAll(req.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "数据解析错误!"
|
|
||||||
}
|
|
||||||
// fmt.Println("pkg body:", string(body))
|
// fmt.Println("pkg body:", string(body))
|
||||||
// var pkg []Package
|
// var pkg []Package
|
||||||
var pkg map[string]interface{}
|
var pkg map[string]interface{}
|
||||||
err = json.Unmarshal(body, &pkg)
|
err := json.Unmarshal([]byte(body), &pkg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
@ -61,20 +56,16 @@ func (a *Pkg) GetMsg() string {
|
||||||
if len(resultSlipe) == 0 {
|
if len(resultSlipe) == 0 {
|
||||||
|
|
||||||
url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1]
|
url := "https://aur.archlinux.org/rpc/v5/suggest/" + parms[1]
|
||||||
req, err := http.Get(url)
|
_, body, errs := request.Get(url).End()
|
||||||
if err != nil {
|
if len(errs) > 0 {
|
||||||
fmt.Println(err)
|
fmt.Println(errs)
|
||||||
}
|
}
|
||||||
defer req.Body.Close()
|
|
||||||
|
|
||||||
fmt.Println("aur suggest url:", url)
|
fmt.Println("aur suggest url:", url)
|
||||||
re, err := io.ReadAll(req.Body)
|
|
||||||
// fmt.Println(string(re))
|
// fmt.Println(string(re))
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
|
||||||
var suggestions []string
|
var suggestions []string
|
||||||
err = json.Unmarshal(re, &suggestions)
|
err := json.Unmarshal([]byte(body), &suggestions)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
@ -83,20 +74,16 @@ func (a *Pkg) GetMsg() string {
|
||||||
return "没有找到相关软件"
|
return "没有找到相关软件"
|
||||||
}
|
}
|
||||||
searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name"
|
searchUrl := "https://aur.archlinux.org/rpc/v5/search/" + suggestions[0] + "?by=name"
|
||||||
searchReq, err := http.Get(searchUrl)
|
_, body, errs = request.Get(searchUrl).End()
|
||||||
if err != nil {
|
if len(errs) > 0 {
|
||||||
fmt.Println(err)
|
fmt.Println("searchUrl err:", errs)
|
||||||
}
|
|
||||||
defer searchReq.Body.Close()
|
|
||||||
searchRe, err := io.ReadAll(searchReq.Body)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
}
|
}
|
||||||
var searchMap map[string]interface{}
|
var searchMap map[string]interface{}
|
||||||
err = json.Unmarshal(searchRe, &searchMap)
|
err = json.Unmarshal([]byte(body), &searchMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
searchResults := searchMap["results"].([]interface{})
|
searchResults := searchMap["results"].([]interface{})
|
||||||
// println("searchResults:", len(searchResults))
|
// println("searchResults:", len(searchResults))
|
||||||
maxVotes := 0.0
|
maxVotes := 0.0
|
||||||
|
@ -129,7 +116,11 @@ func (a *Pkg) GetMsg() string {
|
||||||
msg += "版本:" + searchResult["Version"].(string) + "\n"
|
msg += "版本:" + searchResult["Version"].(string) + "\n"
|
||||||
msg += "描述:" + searchResult["Description"].(string) + "\n"
|
msg += "描述:" + searchResult["Description"].(string) + "\n"
|
||||||
msg += "维护者:" + maintainer + "\n"
|
msg += "维护者:" + maintainer + "\n"
|
||||||
msg += "链接:" + searchResult["URL"].(string) + "\n"
|
upstream, ok := searchResult["URL"].(string)
|
||||||
|
if !ok || upstream == "" {
|
||||||
|
upstream = "无"
|
||||||
|
}
|
||||||
|
msg += "上游:" + upstream + "\n"
|
||||||
msg += "更新时间:" + last_update
|
msg += "更新时间:" + last_update
|
||||||
|
|
||||||
fmt.Println(msg)
|
fmt.Println(msg)
|
||||||
|
@ -149,7 +140,7 @@ func (a *Pkg) GetMsg() string {
|
||||||
msg += "版本:" + result["pkgver"].(string) + "\n"
|
msg += "版本:" + result["pkgver"].(string) + "\n"
|
||||||
msg += "描述:" + result["pkgdesc"].(string) + "\n"
|
msg += "描述:" + result["pkgdesc"].(string) + "\n"
|
||||||
msg += "打包:" + result["packager"].(string) + "\n"
|
msg += "打包:" + result["packager"].(string) + "\n"
|
||||||
msg += "链接:" + result["url"].(string) + "\n"
|
msg += "上游:" + result["url"].(string) + "\n"
|
||||||
msg += "更新日期:" + last_update
|
msg += "更新日期:" + last_update
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue