From 13483b964324065fd493175dd5db11329a51750f Mon Sep 17 00:00:00 2001 From: liyp Date: Sun, 30 Jun 2024 21:56:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(config):=20=E6=B7=BB=E5=8A=A0OpenAI=20API?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=B9=B6=E4=BC=98=E5=8C=96=E6=89=93=E5=8D=B0?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 在config.toml中添加了OPENAI_API_KEY、OPENAI_BaseURL和MODEL配置项,以支持OpenAI API的集成。 同时,优化了PrintConfig函数,使其能够递归打印嵌套的配置结构,提高了配置管理的可读性和易用性。 --- config/config.go | 26 +++++++++++++++ go.mod | 7 ++++ go.sum | 8 +++++ main.go | 11 ++----- test/test.go | 7 ++-- workers/ai.go | 78 ++++++++++++++++++++++++++++++++++++++++++++ workers/core.go | 31 ++++++++++-------- workers/lsp.go | 2 +- workers/newworker.go | 5 +++ 9 files changed, 150 insertions(+), 25 deletions(-) create mode 100644 workers/ai.go diff --git a/config/config.go b/config/config.go index 945f063..af01f86 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,8 @@ package config import ( + "fmt" + "reflect" "sync" "github.com/BurntSushi/toml" @@ -20,5 +22,29 @@ func loadConfig() { } func GetConfig() map[string]interface{} { once.Do(loadConfig) + // print(config) return config } + +func PrintConfig(m map[string]interface{}, indent string) { + for key, value := range m { + switch v := value.(type) { + case map[string]interface{}: + fmt.Printf("%s%s (type: %s):\n", indent, key, reflect.TypeOf(v)) + PrintConfig(v, indent+" ") + case []interface{}: + fmt.Printf("%s%s:\n", indent, key) + for i, item := range v { + switch itemValue := item.(type) { + case map[string]interface{}: + fmt.Printf("%s [%d] (type: %s):\n", indent, i, reflect.TypeOf(itemValue)) + PrintConfig(itemValue, indent+" ") + default: + fmt.Printf("%s [%d] (type: %s): %v\n", indent, i, reflect.TypeOf(itemValue), item) + } + } + default: + fmt.Printf("%s%s (type: %s): %v\n", indent, key, reflect.TypeOf(value), value) + } + } +} diff --git a/go.mod b/go.mod index 91eeedc..c133698 100644 --- a/go.mod +++ b/go.mod @@ -8,3 +8,10 @@ require ( github.com/goccy/go-json v0.10.2 github.com/mattn/go-sqlite3 v1.14.22 ) + +require ( + github.com/moul/http2curl v1.0.0 // indirect + github.com/parnurzeal/gorequest v0.3.0 // indirect + github.com/pkg/errors v0.9.1 // indirect + golang.org/x/net v0.26.0 // indirect +) diff --git a/go.sum b/go.sum index fa68fcb..665bca3 100644 --- a/go.sum +++ b/go.sum @@ -6,5 +6,13 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs= +github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ= +github.com/parnurzeal/gorequest v0.3.0 h1:SoFyqCDC9COr1xuS6VA8fC8RU7XyrJZN2ona1kEX7FI= +github.com/parnurzeal/gorequest v0.3.0/go.mod h1:3Kh2QUMJoqw3icWAecsyzkpY7UzRfDhbRdTjtNwNiUE= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= +golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= +golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= diff --git a/main.go b/main.go index 39d8ba3..2820249 100644 --- a/main.go +++ b/main.go @@ -114,17 +114,12 @@ func handlePost(w http.ResponseWriter, r *http.Request) { } func main() { - // var config map[string]interface{} - // if _, err := toml.DecodeFile("config.toml", &config); err != nil { - // println("配置文件不正确,请修改正确的配置文件!") - // log.Fatal(err) - // } + cfg := config.GetConfig() APIURL := cfg["APIURL"].(string) + // config.PrintConfig(cfg, "") + // print(cfg["AllowGroup"].([]interface{})[0].(string)) - // PORT := config.GlobalConfig.Server.Port - // fmt.Println(APIURL) - // fmt.Println(PORT) http.HandleFunc("/", handlePost) // 协程支持 // http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { diff --git a/test/test.go b/test/test.go index 236bd51..e8d1c0b 100644 --- a/test/test.go +++ b/test/test.go @@ -11,12 +11,15 @@ import ( func main() { for { reader := bufio.NewReader(os.Stdin) - + // cfg := config.GetConfig() + // config.PrintConfig(cfg, "") fmt.Print("输入指令(不要带/):") raw_msg, _ := reader.ReadString('\n') // 去除末尾的换行符 // raw_msg = strings.TrimRight(raw_msg, "\r\n") - + if raw_msg == "" { + raw_msg = "ping" + } parms := strings.Fields(raw_msg) worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg) diff --git a/workers/ai.go b/workers/ai.go new file mode 100644 index 0000000..457b2fe --- /dev/null +++ b/workers/ai.go @@ -0,0 +1,78 @@ +package workers + +import ( + "log" + + "github.com/goccy/go-json" + "github.com/parnurzeal/gorequest" +) + +type AI struct { + *StdAns +} + +func (a *AI) GetMsg() string { + if len(a.Parms) < 2 { + return "使用!ai xxx 向我提问吧" + + } + var msg string + var OPENAI_API_KEY string + if cfg["OPENAI_API_KEY"] != nil { + OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string) + } else { + log.Println("OPENAI_API_KEY 未配置") + return "OPENAI_API_KEY 未配置" + } + var OPENAI_BaseURL string + if cfg["OPENAI_BaseURL"] != nil { + OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string) + } else { + log.Println("OPENAI_BaseURL 未配置,使用openai默认配置") + OPENAI_BaseURL = "https://api.openai.com/v1/chat/completions" + } + var MODEL string + if cfg["MODEL"] != nil { + MODEL = cfg["MODEL"].(string) + } else { + log.Println("模型 未配置,使用默认chatglm_pro模型") + MODEL = "chatglm_pro" + } + ask := a.Parms[1] + if ask == "" { + return "不问问题你说个屁!" + } + requestBody := map[string]interface{}{ + "model": MODEL, + "messages": []map[string]string{{"role": "user", "content": ask}}, + "temperature": 0.7, + } + request := gorequest.New() + resp, body, errs := request.Post(OPENAI_BaseURL). + Set("Content-Type", "application/json"). + Set("Authorization", "Bearer "+OPENAI_API_KEY). + Send(requestBody). + End() + if errs != nil { + log.Println(errs) + return "请求失败" + } else { + if resp.StatusCode == 200 { + var responseBody map[string]interface{} + if err := json.Unmarshal([]byte(body), &responseBody); err != nil { + log.Println(err) + return "解析失败" + } + choices := responseBody["choices"].([]interface{}) + if len(choices) > 0 { + choice := choices[0].(map[string]interface{}) + msg = choice["message"].(map[string]interface{})["content"].(string) + } else { + log.Println("choices为空") + } + } + } + + return msg + +} diff --git a/workers/core.go b/workers/core.go index 8c51662..23842fd 100644 --- a/workers/core.go +++ b/workers/core.go @@ -9,11 +9,11 @@ import ( ) type StdAns struct { - AllowGroup []string - AllowUser []string - AllowRole []string - BlockGroup []string - BlockUser []string + AllowGroup []interface{} + AllowUser []interface{} + AllowRole []interface{} + BlockGroup []interface{} + BlockUser []interface{} GroupNotAllow string UserNotAllow string RoleNotAllow string @@ -27,11 +27,14 @@ type StdAns struct { var cfg map[string]interface{} -func init() { - cfg = config.GetConfig() -} +// func init() { +// cfg = config.GetConfig() +// } func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns { + // var cfg map[string]interface{} + cfg = config.GetConfig() + // println("AllowGroup:", cfg["AllowGroup"].([]interface{})) return &StdAns{ Parms: parms, UID: uid, @@ -39,11 +42,11 @@ func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns { Role: role, MID: mid, RawMsg: rawMsg, - AllowGroup: cfg["AllowGroup"].([]string), - AllowUser: cfg["AllowUser"].([]string), - AllowRole: cfg["AllowRole"].([]string), - BlockGroup: cfg["BlockGroup"].([]string), - BlockUser: cfg["BlockUser"].([]string), + AllowGroup: cfg["AllowGroup"].([]interface{}), + AllowUser: cfg["AllowUser"].([]interface{}), + AllowRole: cfg["AllowRole"].([]interface{}), + BlockGroup: cfg["BlockGroup"].([]interface{}), + BlockUser: cfg["BlockUser"].([]interface{}), GroupNotAllow: "汝所在的群组不被允许这样命令咱呢.", UserNotAllow: "汝不被允许呢.", RoleNotAllow: "汝的角色不被允许哦.", @@ -62,7 +65,7 @@ func (s *StdAns) CheckPermission() string { } return "0" } -func contains(slice []string, value string) bool { +func contains(slice []interface{}, value string) bool { for _, item := range slice { if item == value { return true diff --git a/workers/lsp.go b/workers/lsp.go index 0255b5b..0845549 100644 --- a/workers/lsp.go +++ b/workers/lsp.go @@ -14,7 +14,7 @@ type Lsp struct { } func (a *Lsp) GetMsg() string { - a.AllowGroup = []string{"313047773"} + a.AllowGroup = append(a.AllowGroup, []string{"313047773"}) url := "https://api.lolicon.app/setu/v2?r18=0&size=small" resp, err := http.Get(url) if err != nil { diff --git a/workers/newworker.go b/workers/newworker.go index 9992d51..516b267 100644 --- a/workers/newworker.go +++ b/workers/newworker.go @@ -5,6 +5,7 @@ import "fmt" func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker { fmt.Println("NewWorker:", parms) switch parms[0] { + case "ping": return &Ping{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), @@ -26,6 +27,10 @@ func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker { return &Lsp{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), } + case "ai": + return &AI{ + StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), + } default: return &Emm{ StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg)}