feat(config): 添加OpenAI API配置并优化打印配置函数

在config.toml中添加了OPENAI_API_KEY、OPENAI_BaseURL和MODEL配置项,以支持OpenAI API的集成。
同时,优化了PrintConfig函数,使其能够递归打印嵌套的配置结构,提高了配置管理的可读性和易用性。
This commit is contained in:
liyp 2024-06-30 21:56:34 +08:00
parent cdb74588b2
commit 13483b9643
9 changed files with 150 additions and 25 deletions

View file

@ -1,6 +1,8 @@
package config package config
import ( import (
"fmt"
"reflect"
"sync" "sync"
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
@ -20,5 +22,29 @@ func loadConfig() {
} }
func GetConfig() map[string]interface{} { func GetConfig() map[string]interface{} {
once.Do(loadConfig) once.Do(loadConfig)
// print(config)
return config return config
} }
func PrintConfig(m map[string]interface{}, indent string) {
for key, value := range m {
switch v := value.(type) {
case map[string]interface{}:
fmt.Printf("%s%s (type: %s):\n", indent, key, reflect.TypeOf(v))
PrintConfig(v, indent+" ")
case []interface{}:
fmt.Printf("%s%s:\n", indent, key)
for i, item := range v {
switch itemValue := item.(type) {
case map[string]interface{}:
fmt.Printf("%s [%d] (type: %s):\n", indent, i, reflect.TypeOf(itemValue))
PrintConfig(itemValue, indent+" ")
default:
fmt.Printf("%s [%d] (type: %s): %v\n", indent, i, reflect.TypeOf(itemValue), item)
}
}
default:
fmt.Printf("%s%s (type: %s): %v\n", indent, key, reflect.TypeOf(value), value)
}
}
}

7
go.mod
View file

@ -8,3 +8,10 @@ require (
github.com/goccy/go-json v0.10.2 github.com/goccy/go-json v0.10.2
github.com/mattn/go-sqlite3 v1.14.22 github.com/mattn/go-sqlite3 v1.14.22
) )
require (
github.com/moul/http2curl v1.0.0 // indirect
github.com/parnurzeal/gorequest v0.3.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
golang.org/x/net v0.26.0 // indirect
)

8
go.sum
View file

@ -6,5 +6,13 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/moul/http2curl v1.0.0 h1:dRMWoAtb+ePxMlLkrCbAqh4TlPHXvoGUSQ323/9Zahs=
github.com/moul/http2curl v1.0.0/go.mod h1:8UbvGypXm98wA/IqH45anm5Y2Z6ep6O31QGOAZ3H0fQ=
github.com/parnurzeal/gorequest v0.3.0 h1:SoFyqCDC9COr1xuS6VA8fC8RU7XyrJZN2ona1kEX7FI=
github.com/parnurzeal/gorequest v0.3.0/go.mod h1:3Kh2QUMJoqw3icWAecsyzkpY7UzRfDhbRdTjtNwNiUE=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ=
golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE=

11
main.go
View file

@ -114,17 +114,12 @@ func handlePost(w http.ResponseWriter, r *http.Request) {
} }
func main() { func main() {
// var config map[string]interface{}
// if _, err := toml.DecodeFile("config.toml", &config); err != nil {
// println("配置文件不正确,请修改正确的配置文件!")
// log.Fatal(err)
// }
cfg := config.GetConfig() cfg := config.GetConfig()
APIURL := cfg["APIURL"].(string) APIURL := cfg["APIURL"].(string)
// config.PrintConfig(cfg, "")
// print(cfg["AllowGroup"].([]interface{})[0].(string))
// PORT := config.GlobalConfig.Server.Port
// fmt.Println(APIURL)
// fmt.Println(PORT)
http.HandleFunc("/", handlePost) http.HandleFunc("/", handlePost)
// 协程支持 // 协程支持
// http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {

View file

@ -11,12 +11,15 @@ import (
func main() { func main() {
for { for {
reader := bufio.NewReader(os.Stdin) reader := bufio.NewReader(os.Stdin)
// cfg := config.GetConfig()
// config.PrintConfig(cfg, "")
fmt.Print("输入指令(不要带/)") fmt.Print("输入指令(不要带/)")
raw_msg, _ := reader.ReadString('\n') raw_msg, _ := reader.ReadString('\n')
// 去除末尾的换行符 // 去除末尾的换行符
// raw_msg = strings.TrimRight(raw_msg, "\r\n") // raw_msg = strings.TrimRight(raw_msg, "\r\n")
if raw_msg == "" {
raw_msg = "ping"
}
parms := strings.Fields(raw_msg) parms := strings.Fields(raw_msg)
worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg) worker := workers.NewWorker(parms, "11", "111", "111", "222", raw_msg)

78
workers/ai.go Normal file
View file

@ -0,0 +1,78 @@
package workers
import (
"log"
"github.com/goccy/go-json"
"github.com/parnurzeal/gorequest"
)
type AI struct {
*StdAns
}
func (a *AI) GetMsg() string {
if len(a.Parms) < 2 {
return "使用!ai xxx 向我提问吧"
}
var msg string
var OPENAI_API_KEY string
if cfg["OPENAI_API_KEY"] != nil {
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
} else {
log.Println("OPENAI_API_KEY 未配置")
return "OPENAI_API_KEY 未配置"
}
var OPENAI_BaseURL string
if cfg["OPENAI_BaseURL"] != nil {
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
} else {
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
OPENAI_BaseURL = "https://api.openai.com/v1/chat/completions"
}
var MODEL string
if cfg["MODEL"] != nil {
MODEL = cfg["MODEL"].(string)
} else {
log.Println("模型 未配置,使用默认chatglm_pro模型")
MODEL = "chatglm_pro"
}
ask := a.Parms[1]
if ask == "" {
return "不问问题你说个屁!"
}
requestBody := map[string]interface{}{
"model": MODEL,
"messages": []map[string]string{{"role": "user", "content": ask}},
"temperature": 0.7,
}
request := gorequest.New()
resp, body, errs := request.Post(OPENAI_BaseURL).
Set("Content-Type", "application/json").
Set("Authorization", "Bearer "+OPENAI_API_KEY).
Send(requestBody).
End()
if errs != nil {
log.Println(errs)
return "请求失败"
} else {
if resp.StatusCode == 200 {
var responseBody map[string]interface{}
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
log.Println(err)
return "解析失败"
}
choices := responseBody["choices"].([]interface{})
if len(choices) > 0 {
choice := choices[0].(map[string]interface{})
msg = choice["message"].(map[string]interface{})["content"].(string)
} else {
log.Println("choices为空")
}
}
}
return msg
}

View file

@ -9,11 +9,11 @@ import (
) )
type StdAns struct { type StdAns struct {
AllowGroup []string AllowGroup []interface{}
AllowUser []string AllowUser []interface{}
AllowRole []string AllowRole []interface{}
BlockGroup []string BlockGroup []interface{}
BlockUser []string BlockUser []interface{}
GroupNotAllow string GroupNotAllow string
UserNotAllow string UserNotAllow string
RoleNotAllow string RoleNotAllow string
@ -27,11 +27,14 @@ type StdAns struct {
var cfg map[string]interface{} var cfg map[string]interface{}
func init() { // func init() {
cfg = config.GetConfig() // cfg = config.GetConfig()
} // }
func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns { func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns {
// var cfg map[string]interface{}
cfg = config.GetConfig()
// println("AllowGroup:", cfg["AllowGroup"].([]interface{}))
return &StdAns{ return &StdAns{
Parms: parms, Parms: parms,
UID: uid, UID: uid,
@ -39,11 +42,11 @@ func NewStdAns(parms []string, uid, gid, role, mid, rawMsg string) *StdAns {
Role: role, Role: role,
MID: mid, MID: mid,
RawMsg: rawMsg, RawMsg: rawMsg,
AllowGroup: cfg["AllowGroup"].([]string), AllowGroup: cfg["AllowGroup"].([]interface{}),
AllowUser: cfg["AllowUser"].([]string), AllowUser: cfg["AllowUser"].([]interface{}),
AllowRole: cfg["AllowRole"].([]string), AllowRole: cfg["AllowRole"].([]interface{}),
BlockGroup: cfg["BlockGroup"].([]string), BlockGroup: cfg["BlockGroup"].([]interface{}),
BlockUser: cfg["BlockUser"].([]string), BlockUser: cfg["BlockUser"].([]interface{}),
GroupNotAllow: "汝所在的群组不被允许这样命令咱呢.", GroupNotAllow: "汝所在的群组不被允许这样命令咱呢.",
UserNotAllow: "汝不被允许呢.", UserNotAllow: "汝不被允许呢.",
RoleNotAllow: "汝的角色不被允许哦.", RoleNotAllow: "汝的角色不被允许哦.",
@ -62,7 +65,7 @@ func (s *StdAns) CheckPermission() string {
} }
return "0" return "0"
} }
func contains(slice []string, value string) bool { func contains(slice []interface{}, value string) bool {
for _, item := range slice { for _, item := range slice {
if item == value { if item == value {
return true return true

View file

@ -14,7 +14,7 @@ type Lsp struct {
} }
func (a *Lsp) GetMsg() string { func (a *Lsp) GetMsg() string {
a.AllowGroup = []string{"313047773"} a.AllowGroup = append(a.AllowGroup, []string{"313047773"})
url := "https://api.lolicon.app/setu/v2?r18=0&size=small" url := "https://api.lolicon.app/setu/v2?r18=0&size=small"
resp, err := http.Get(url) resp, err := http.Get(url)
if err != nil { if err != nil {

View file

@ -5,6 +5,7 @@ import "fmt"
func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker { func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker {
fmt.Println("NewWorker:", parms) fmt.Println("NewWorker:", parms)
switch parms[0] { switch parms[0] {
case "ping": case "ping":
return &Ping{ return &Ping{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
@ -26,6 +27,10 @@ func NewWorker(parms []string, uid, gid, role, mid, rawMsg string) Worker {
return &Lsp{ return &Lsp{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg), StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
} }
case "ai":
return &AI{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
}
default: default:
return &Emm{ return &Emm{
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg)} StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg)}