diff --git a/config/config.go b/config/config.go index af01f86..f989a22 100644 --- a/config/config.go +++ b/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "os" "reflect" "sync" @@ -10,10 +11,12 @@ import ( var ( config map[string]interface{} - once sync.Once + mu sync.Mutex ) func loadConfig() { + mu.Lock() + defer mu.Unlock() if _, err := toml.DecodeFile("config.toml", &config); err != nil { panic(err) @@ -21,10 +24,36 @@ func loadConfig() { } func GetConfig() map[string]interface{} { - once.Do(loadConfig) + // mu.Lock() + // defer mu.Unlock() // print(config) + if config == nil { + loadConfig() + } return config } +func ReloadConfig() { + loadConfig() +} +func ModifyConfig(key string, value interface{}) { + mu.Lock() + defer mu.Unlock() + + // 修改配置 + config[key] = value + // fmt.Println("修改后的配置:") + // 将修改后的配置写回文件 + file, err := os.Create("config.toml") + if err != nil { + panic(err) + } + defer file.Close() + + encoder := toml.NewEncoder(file) + if err := encoder.Encode(config); err != nil { + panic(err) + } +} func PrintConfig(m map[string]interface{}, indent string) { for key, value := range m { diff --git a/test/test.go b/test/test.go index e8d1c0b..0d1618e 100644 --- a/test/test.go +++ b/test/test.go @@ -22,8 +22,8 @@ func main() { } parms := strings.Fields(raw_msg) - worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg) - fmt.Println("Test:", worker.CheckPermission()) + worker := workers.NewWorker(parms, "794508986", "111", "111", "222", raw_msg) + fmt.Println("TestPermission:", worker.CheckPermission()) message := worker.GetMsg() fmt.Println("message:", message) diff --git a/workers/ai.go b/workers/ai.go index 45dbf9c..a6426c8 100644 --- a/workers/ai.go +++ b/workers/ai.go @@ -1,6 +1,7 @@ package workers import ( + "go-bot/config" "log" "strings" @@ -45,6 +46,7 @@ func (a *AI) GetMsg() string { } if strings.ToLower(a.Parms[1]) == "models" { + OPENAI_BaseURL = OPENAI_BaseURL + "/models" request := gorequest.New() resp, body, errs := request.Get(OPENAI_BaseURL). @@ -62,20 +64,41 @@ func (a *AI) GetMsg() string { return "解析模型列表失败" } choices := responseBody["data"].([]interface{}) + var models []interface{} if len(choices) > 0 { msg = "支持的模型列表:\n" for _, choice := range choices { - models := choice.(map[string]interface{})["id"].(string) - if models == MODEL { - msg = msg + models + "\t ✔\n" + model := choice.(map[string]interface{})["id"].(string) + if model == MODEL { + msg = msg + model + "\t ✔\n" } else { - msg = msg + models + "\n" + msg = msg + model + "\n" } + models = append(models, model) } + } else { msg = "没查到支持模型列表" } + if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" { + // 判断允许设置权限,需要AllowUser和发消息用户账号相同 + if a.AllowUser != nil && contains(a.AllowUser, a.UID) { + if contains(models, a.Parms[3]) { + cfg["MODEL"] = a.Parms[3] + msg = "已设置模型为 " + a.Parms[3] + config.ModifyConfig("MODEL", a.Parms[3]) + config.ReloadConfig() + config.PrintConfig(cfg, "") + } else { + msg = "不支持的模型" + } + + } else { + msg = "无权限设置模型" + } + + } return msg } else { log.Println("请求失败") @@ -89,16 +112,18 @@ func (a *AI) GetMsg() string { log.Println("PROMRT 未配置") PROMPT = "" } + // PROMPT = "" // println("PROMPT:", PROMPT) requestBody := map[string]interface{}{ "model": MODEL, "stream": false, - "messages": []map[string]string{{ - "role": "system", - "content": PROMPT, - }, + "messages": []map[string]string{ + { + "role": "system", + "content": PROMPT, + }, {"role": "user", "content": ask}}, - "max_tokens": 75, + "max_tokens": 200, "temperature": 0.7, } request := gorequest.New() diff --git a/workers/core.go b/workers/core.go index 23842fd..6d341ce 100644 --- a/workers/core.go +++ b/workers/core.go @@ -63,7 +63,7 @@ func (s *StdAns) CheckPermission() string { if len(s.AllowRole) > 0 && !contains(s.AllowRole, s.Role) { return s.RoleNotAllow } - return "0" + return "ok" } func contains(slice []interface{}, value string) bool { for _, item := range slice {