feat(config): 动态修改AI模型配置

通过新增的ModifyConfig函数,现在可以在运行时动态修改AI的配置,包括更换模型。
This commit is contained in:
liyp 2024-07-01 10:05:42 +08:00
parent eedb68a282
commit 8ca0d27976
4 changed files with 68 additions and 14 deletions

View file

@ -2,6 +2,7 @@ package config
import ( import (
"fmt" "fmt"
"os"
"reflect" "reflect"
"sync" "sync"
@ -10,10 +11,12 @@ import (
var ( var (
config map[string]interface{} config map[string]interface{}
once sync.Once mu sync.Mutex
) )
func loadConfig() { func loadConfig() {
mu.Lock()
defer mu.Unlock()
if _, err := toml.DecodeFile("config.toml", &config); err != nil { if _, err := toml.DecodeFile("config.toml", &config); err != nil {
panic(err) panic(err)
@ -21,10 +24,36 @@ func loadConfig() {
} }
func GetConfig() map[string]interface{} { func GetConfig() map[string]interface{} {
once.Do(loadConfig) // mu.Lock()
// defer mu.Unlock()
// print(config) // print(config)
if config == nil {
loadConfig()
}
return config return config
} }
func ReloadConfig() {
loadConfig()
}
func ModifyConfig(key string, value interface{}) {
mu.Lock()
defer mu.Unlock()
// 修改配置
config[key] = value
// fmt.Println("修改后的配置:")
// 将修改后的配置写回文件
file, err := os.Create("config.toml")
if err != nil {
panic(err)
}
defer file.Close()
encoder := toml.NewEncoder(file)
if err := encoder.Encode(config); err != nil {
panic(err)
}
}
func PrintConfig(m map[string]interface{}, indent string) { func PrintConfig(m map[string]interface{}, indent string) {
for key, value := range m { for key, value := range m {

View file

@ -22,8 +22,8 @@ func main() {
} }
parms := strings.Fields(raw_msg) parms := strings.Fields(raw_msg)
worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg) worker := workers.NewWorker(parms, "794508986", "111", "111", "222", raw_msg)
fmt.Println("Test:", worker.CheckPermission()) fmt.Println("TestPermission:", worker.CheckPermission())
message := worker.GetMsg() message := worker.GetMsg()
fmt.Println("message:", message) fmt.Println("message:", message)

View file

@ -1,6 +1,7 @@
package workers package workers
import ( import (
"go-bot/config"
"log" "log"
"strings" "strings"
@ -45,6 +46,7 @@ func (a *AI) GetMsg() string {
} }
if strings.ToLower(a.Parms[1]) == "models" { if strings.ToLower(a.Parms[1]) == "models" {
OPENAI_BaseURL = OPENAI_BaseURL + "/models" OPENAI_BaseURL = OPENAI_BaseURL + "/models"
request := gorequest.New() request := gorequest.New()
resp, body, errs := request.Get(OPENAI_BaseURL). resp, body, errs := request.Get(OPENAI_BaseURL).
@ -62,20 +64,41 @@ func (a *AI) GetMsg() string {
return "解析模型列表失败" return "解析模型列表失败"
} }
choices := responseBody["data"].([]interface{}) choices := responseBody["data"].([]interface{})
var models []interface{}
if len(choices) > 0 { if len(choices) > 0 {
msg = "支持的模型列表:\n" msg = "支持的模型列表:\n"
for _, choice := range choices { for _, choice := range choices {
models := choice.(map[string]interface{})["id"].(string) model := choice.(map[string]interface{})["id"].(string)
if models == MODEL { if model == MODEL {
msg = msg + models + "\t ✔\n" msg = msg + model + "\t ✔\n"
} else { } else {
msg = msg + models + "\n" msg = msg + model + "\n"
} }
models = append(models, model)
} }
} else { } else {
msg = "没查到支持模型列表" msg = "没查到支持模型列表"
} }
if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" {
// 判断允许设置权限,需要AllowUser和发消息用户账号相同
if a.AllowUser != nil && contains(a.AllowUser, a.UID) {
if contains(models, a.Parms[3]) {
cfg["MODEL"] = a.Parms[3]
msg = "已设置模型为 " + a.Parms[3]
config.ModifyConfig("MODEL", a.Parms[3])
config.ReloadConfig()
config.PrintConfig(cfg, "")
} else {
msg = "不支持的模型"
}
} else {
msg = "无权限设置模型"
}
}
return msg return msg
} else { } else {
log.Println("请求失败") log.Println("请求失败")
@ -89,16 +112,18 @@ func (a *AI) GetMsg() string {
log.Println("PROMRT 未配置") log.Println("PROMRT 未配置")
PROMPT = "" PROMPT = ""
} }
// PROMPT = ""
// println("PROMPT:", PROMPT) // println("PROMPT:", PROMPT)
requestBody := map[string]interface{}{ requestBody := map[string]interface{}{
"model": MODEL, "model": MODEL,
"stream": false, "stream": false,
"messages": []map[string]string{{ "messages": []map[string]string{
"role": "system", {
"content": PROMPT, "role": "system",
}, "content": PROMPT,
},
{"role": "user", "content": ask}}, {"role": "user", "content": ask}},
"max_tokens": 75, "max_tokens": 200,
"temperature": 0.7, "temperature": 0.7,
} }
request := gorequest.New() request := gorequest.New()

View file

@ -63,7 +63,7 @@ func (s *StdAns) CheckPermission() string {
if len(s.AllowRole) > 0 && !contains(s.AllowRole, s.Role) { if len(s.AllowRole) > 0 && !contains(s.AllowRole, s.Role) {
return s.RoleNotAllow return s.RoleNotAllow
} }
return "0" return "ok"
} }
func contains(slice []interface{}, value string) bool { func contains(slice []interface{}, value string) bool {
for _, item := range slice { for _, item := range slice {