diff --git a/.gitignore b/.gitignore index 2666aab..96107ec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ data.db test.json +*.exe + config.toml -*.exe \ No newline at end of file diff --git a/config example.toml b/config example.toml index fc3a200..13f6038 100644 --- a/config example.toml +++ b/config example.toml @@ -1,8 +1,7 @@ -[Server] APIURL = "0.0.0.0:5580" POSTURL = "http://0.0.0.0:5700" -[Group] + AllowGroup = [] AllowUser = [] AllowRole = [] diff --git a/config/config.go b/config/config.go index 9a61d11..f989a22 100644 --- a/config/config.go +++ b/config/config.go @@ -1,35 +1,79 @@ package config import ( - "log" + "fmt" + "os" + "reflect" + "sync" "github.com/BurntSushi/toml" ) -type Config struct { - Server struct { - APIURL string - POSTURL string +var ( + config map[string]interface{} + mu sync.Mutex +) + +func loadConfig() { + mu.Lock() + defer mu.Unlock() + + if _, err := toml.DecodeFile("config.toml", &config); err != nil { + panic(err) } - Group struct { - AllowGroup []string - AllowUser []string - AllowRole []string - BlockGroup []string - BlockUser []string - GroupNotAllow []string - UserNotAllow []string - RoleNotAllow []string + +} +func GetConfig() map[string]interface{} { + // mu.Lock() + // defer mu.Unlock() + // print(config) + if config == nil { + loadConfig() + } + return config +} +func ReloadConfig() { + loadConfig() +} +func ModifyConfig(key string, value interface{}) { + mu.Lock() + defer mu.Unlock() + + // 修改配置 + config[key] = value + // fmt.Println("修改后的配置:") + // 将修改后的配置写回文件 + file, err := os.Create("config.toml") + if err != nil { + panic(err) + } + defer file.Close() + + encoder := toml.NewEncoder(file) + if err := encoder.Encode(config); err != nil { + panic(err) } } -var GlobalConfig Config - -func init() { - // var config Config - if _, err := toml.DecodeFile("config.toml", &GlobalConfig); err != nil { - println("配置文件不正确,请修改正确的配置文件!") - log.Fatal(err) +func PrintConfig(m map[string]interface{}, indent string) { + for key, value := range m { + switch v := value.(type) { + case map[string]interface{}: + fmt.Printf("%s%s (type: %s):\n", indent, key, reflect.TypeOf(v)) + PrintConfig(v, indent+" ") + case []interface{}: + fmt.Printf("%s%s:\n", indent, key) + for i, item := range v { + switch itemValue := item.(type) { + case map[string]interface{}: + fmt.Printf("%s [%d] (type: %s):\n", indent, i, reflect.TypeOf(itemValue)) + PrintConfig(itemValue, indent+" ") + default: + fmt.Printf("%s [%d] (type: %s): %v\n", indent, i, reflect.TypeOf(itemValue), item) + } + } + default: + fmt.Printf("%s%s (type: %s): %v\n", indent, key, reflect.TypeOf(value), value) + } } - // fmt.Println(config.Group) } diff --git a/go.mod b/go.mod index 91eeedc..c133698 100644 --- a/go.mod +++ b/go.mod @@ -8,3 +8,10 @@ require ( github.com/goccy/go-json v0.10.2 github.com/mattn/go-sqlite3 v1.14.22 ) + +require ( + github.com/moul/http2curl v1.0.0 // indirect + github.com/parnurzeal/gorequest v0.3.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + golang.org/x/net v0.26.0 // indirect +) diff --git a/go.sum b/go.sum index fa68fcb..665bca3 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,13 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= +github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= +github.com/parnurzeal/gorequest v0.3.0 h1:SoFyqCDC9COr1xuS6VA8fC8RU7XyrJZN2ona1kEX7FI= +github.com/parnurzeal/gorequest v0.3.0/go.mod h1:3Kh2QUMJoqw3icWAecsyzkpY7UzRfDhbRdTjtNwNiUE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= diff --git a/main.go b/main.go index 94a1c6c..99d992b 100644 --- a/main.go +++ b/main.go @@ -114,11 +114,15 @@ func handlePost(w http.ResponseWriter, r *http.Request) { } func main() { - APIURL := config.GlobalConfig.Server.APIURL - // PORT := config.GlobalConfig.Server.Port - // fmt.Println(APIURL) - // fmt.Println(PORT) + cfg := config.GetConfig() + APIURL, ok := cfg["APIURL"].(string) + if !ok { + log.Fatal("加载配置失败!") + } + // config.PrintConfig(cfg, "") + // print(cfg["AllowGroup"].([]interface{})[0].(string)) + http.HandleFunc("/", handlePost) // 协程支持 // http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { diff --git a/test/test.go b/test/test.go index 236bd51..0d1618e 100644 --- a/test/test.go +++ b/test/test.go @@ -11,16 +11,19 @@ import ( func main() { for { reader := bufio.NewReader(os.Stdin) - + // cfg := config.GetConfig() + // config.PrintConfig(cfg, "") fmt.Print("输入指令(不要带/):") raw_msg, _ := reader.ReadString('\n') // 去除末尾的换行符 // raw_msg = strings.TrimRight(raw_msg, "\r\n") - + if raw_msg == "" { + raw_msg = "ping" + } parms := strings.Fields(raw_msg) - worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg) - fmt.Println("Test:", worker.CheckPermission()) + worker := workers.NewWorker(parms, "794508986", "111", "111", "222", raw_msg) + fmt.Println("TestPermission:", worker.CheckPermission()) message := worker.GetMsg() fmt.Println("message:", message) diff --git a/workers/ai.go b/workers/ai.go new file mode 100644 index 0000000..a6426c8 --- /dev/null +++ b/workers/ai.go @@ -0,0 +1,158 @@ +package workers + +import ( + "go-bot/config" + "log" + "strings" + + "github.com/goccy/go-json" + "github.com/parnurzeal/gorequest" +) + +type AI struct { + *StdAns +} + +func (a *AI) GetMsg() string { + if len(a.Parms) < 2 { + return "使用!ai xxx 向我提问吧" + + } + ask := a.Parms[1] + if ask == "" { + return "不问问题你说个屁!" + } + var msg string + var OPENAI_API_KEY string + if cfg["OPENAI_API_KEY"] != nil { + OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string) + } else { + log.Println("OPENAI_API_KEY 未配置") + return "OPENAI_API_KEY 未配置" + } + var OPENAI_BaseURL string + if cfg["OPENAI_BaseURL"] != nil { + OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string) + } else { + log.Println("OPENAI_BaseURL 未配置,使用openai默认配置") + OPENAI_BaseURL = "https://api.openai.com/v1" + } + var MODEL string + if cfg["MODEL"] != nil { + MODEL = cfg["MODEL"].(string) + } else { + log.Println("模型 未配置,使用默认chatglm_pro模型") + MODEL = "chatglm_pro" + } + + if strings.ToLower(a.Parms[1]) == "models" { + + OPENAI_BaseURL = OPENAI_BaseURL + "/models" + request := gorequest.New() + resp, body, errs := request.Get(OPENAI_BaseURL). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + End() + if errs != nil { + log.Println(errs) + return "请求失败" + } else { + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析模型列表失败" + } + choices := responseBody["data"].([]interface{}) + var models []interface{} + if len(choices) > 0 { + msg = "支持的模型列表:\n" + for _, choice := range choices { + model := choice.(map[string]interface{})["id"].(string) + if model == MODEL { + msg = msg + model + "\t ✔\n" + } else { + msg = msg + model + "\n" + } + models = append(models, model) + + } + + } else { + msg = "没查到支持模型列表" + } + if len(a.Parms) > 3 && strings.ToLower(a.Parms[2]) == "set" { + // 判断允许设置权限,需要AllowUser和发消息用户账号相同 + if a.AllowUser != nil && contains(a.AllowUser, a.UID) { + if contains(models, a.Parms[3]) { + cfg["MODEL"] = a.Parms[3] + msg = "已设置模型为 " + a.Parms[3] + config.ModifyConfig("MODEL", a.Parms[3]) + config.ReloadConfig() + config.PrintConfig(cfg, "") + } else { + msg = "不支持的模型" + } + + } else { + msg = "无权限设置模型" + } + + } + return msg + } else { + log.Println("请求失败") + return "请求模型列表失败" + } + } + } else { + OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions" + PROMPT, ok := cfg["PROMPT"].(string) + if !ok { + log.Println("PROMRT 未配置") + PROMPT = "" + } + // PROMPT = "" + // println("PROMPT:", PROMPT) + requestBody := map[string]interface{}{ + "model": MODEL, + "stream": false, + "messages": []map[string]string{ + { + "role": "system", + "content": PROMPT, + }, + {"role": "user", "content": ask}}, + "max_tokens": 200, + "temperature": 0.7, + } + request := gorequest.New() + resp, body, errs := request.Post(OPENAI_BaseURL). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + Send(requestBody). + End() + if errs != nil { + log.Println(errs) + return "请求失败" + } else { + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析失败" + } + choices := responseBody["choices"].([]interface{}) + if len(choices) > 0 { + choice := choices[0].(map[string]interface{}) + msg = choice["message"].(map[string]interface{})["content"].(string) + } else { + log.Println("choices为空") + } + } + } + } + + return msg + +} diff --git a/workers/core.go b/workers/core.go index 110af79..6d341ce 100644 --- a/workers/core.go +++ b/workers/core.go @@ -9,11 +9,11 @@ import ( ) type StdAns struct { - AllowGroup []string - AllowUser []string - AllowRole []string - BlockGroup []string - BlockUser []string + AllowGroup []interface{} + AllowUser []interface{} + AllowRole []interface{} + BlockGroup []interface{} + BlockUser []interface{} GroupNotAllow string UserNotAllow string RoleNotAllow string @@ -25,7 +25,16 @@ type StdAns struct { RawMsg string } +var cfg map[string]interface{} + +// func init() { +// cfg = config.GetConfig() +// } func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns { + // var cfg map[string]interface{} + + cfg = config.GetConfig() + // println("AllowGroup:", cfg["AllowGroup"].([]interface{})) return &StdAns{ Parms: parms, UID: uid, @@ -33,11 +42,11 @@ func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns { Role: role, MID: mid, RawMsg: rawMsg, - AllowGroup: config.GlobalConfig.Group.AllowGroup, - AllowUser: config.GlobalConfig.Group.AllowUser, - AllowRole: config.GlobalConfig.Group.AllowRole, - BlockGroup: config.GlobalConfig.Group.BlockGroup, - BlockUser: config.GlobalConfig.Group.BlockUser, + AllowGroup: cfg["AllowGroup"].([]interface{}), + AllowUser: cfg["AllowUser"].([]interface{}), + AllowRole: cfg["AllowRole"].([]interface{}), + BlockGroup: cfg["BlockGroup"].([]interface{}), + BlockUser: cfg["BlockUser"].([]interface{}), GroupNotAllow: "汝所在的群组不被允许这样命令咱呢.", UserNotAllow: "汝不被允许呢.", RoleNotAllow: "汝的角色不被允许哦.", @@ -54,9 +63,9 @@ func (s *StdAns) CheckPermission() string { if len(s.AllowRole) > 0 && !contains(s.AllowRole, s.Role) { return s.RoleNotAllow } - return "0" + return "ok" } -func contains(slice []string, value string) bool { +func contains(slice []interface{}, value string) bool { for _, item := range slice { if item == value { return true @@ -99,7 +108,7 @@ func (s *StdAns) SendMsg(msg string) bool { } // fmt.Println(string(re)) - url := config.GlobalConfig.Server.POSTURL + url := cfg["POSTURL"].(string) // println("core:", url) // fmt.Println("请求地址:", url) fmt.Println("响应信息:\n", msg) diff --git a/workers/lsp.go b/workers/lsp.go index 0255b5b..ebd1ea5 100644 --- a/workers/lsp.go +++ b/workers/lsp.go @@ -14,14 +14,17 @@ type Lsp struct { } func (a *Lsp) GetMsg() string { - a.AllowGroup = []string{"313047773"} + a.AllowGroup = append(a.AllowGroup, []string{"313047773"}) url := "https://api.lolicon.app/setu/v2?r18=0&size=small" resp, err := http.Get(url) if err != nil { - return "获取失败" + return "请求图片失败" } defer resp.Body.Close() budy, err := io.ReadAll(resp.Body) + if err != nil { + return "读取失败" + } var res map[string]interface{} err = json.Unmarshal(budy, &res) if err != nil { diff --git a/workers/newworker.go b/workers/newworker.go index 9992d51..516b267 100644 --- a/workers/newworker.go +++ b/workers/newworker.go @@ -5,6 +5,7 @@ import "fmt" func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker { fmt.Println("NewWorker:", parms) switch parms[0] { + case "ping": return &Ping{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), @@ -26,6 +27,10 @@ func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker { return &Lsp{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), } + case "ai": + return &AI{ + StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), + } default: return &Emm{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg)}