170 lines
4.1 KiB
Go
170 lines
4.1 KiB
Go
package workers
|
||
|
||
import (
|
||
"fmt"
|
||
"log"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/goccy/go-json"
|
||
"github.com/parnurzeal/gorequest"
|
||
)
|
||
|
||
func init() {
|
||
RegisterWorkerFactory("ai", func(parms []string, uid, gid, role, mid, rawMsg string) Worker {
|
||
return &AI{
|
||
StdAns: NewStdAns(parms, uid, gid, role, mid, rawMsg),
|
||
}
|
||
})
|
||
}
|
||
|
||
type AI struct {
|
||
*StdAns
|
||
}
|
||
|
||
func (a *AI) GetMsg() string {
|
||
if len(a.Parms) < 2 {
|
||
return "使用!ai xxx 向我提问吧"
|
||
}
|
||
|
||
ask := a.Parms[1]
|
||
if ask == "" {
|
||
return "不问问题你说个屁!"
|
||
}
|
||
|
||
OPENAI_API_KEY, OPENAI_BaseURL, MODEL := getConfig()
|
||
if OPENAI_API_KEY == "" {
|
||
return "OPENAI_API_KEY 未配置"
|
||
}
|
||
|
||
if strings.ToLower(a.Parms[1]) == "models" {
|
||
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
||
} else {
|
||
return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID)
|
||
}
|
||
}
|
||
|
||
func getConfig() (string, string, string) {
|
||
var OPENAI_API_KEY, OPENAI_BaseURL, MODEL string
|
||
|
||
if cfg["OPENAI_API_KEY"] != nil {
|
||
OPENAI_API_KEY = cfg["OPENAI_API_KEY"].(string)
|
||
} else {
|
||
log.Println("OPENAI_API_KEY 未配置")
|
||
}
|
||
|
||
if cfg["OPENAI_BaseURL"] != nil {
|
||
OPENAI_BaseURL = cfg["OPENAI_BaseURL"].(string)
|
||
} else {
|
||
log.Println("OPENAI_BaseURL 未配置,使用openai默认配置")
|
||
OPENAI_BaseURL = "https://api.openai.com/v1"
|
||
}
|
||
|
||
if cfg["MODEL"] != nil {
|
||
MODEL = cfg["MODEL"].(string)
|
||
} else {
|
||
log.Println("模型 未配置,使用默认 gpt-4o 模型")
|
||
MODEL = "gpt-4o"
|
||
}
|
||
|
||
return OPENAI_API_KEY, OPENAI_BaseURL, MODEL
|
||
}
|
||
|
||
func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
|
||
OPENAI_BaseURL = OPENAI_BaseURL + "/models"
|
||
request := gorequest.New()
|
||
resp, body, errs := request.Get(OPENAI_BaseURL).
|
||
Set("Content-Type", "application/json").
|
||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||
End()
|
||
|
||
if errs != nil {
|
||
log.Println(errs)
|
||
return "请求失败"
|
||
}
|
||
|
||
if resp.StatusCode == 200 {
|
||
var responseBody map[string]interface{}
|
||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||
log.Println(err)
|
||
return "解析模型列表失败"
|
||
}
|
||
|
||
choices := responseBody["data"].([]interface{})
|
||
// var models []interface{}
|
||
if len(choices) > 0 {
|
||
msg := "支持的模型列表:\n"
|
||
for _, choice := range choices {
|
||
model := choice.(map[string]interface{})["id"]
|
||
msg += fmt.Sprintf("%s\n", model)
|
||
}
|
||
return msg
|
||
} else {
|
||
return "模型列表为空"
|
||
}
|
||
} else {
|
||
log.Println("请求失败")
|
||
return "请求模型列表失败"
|
||
}
|
||
}
|
||
|
||
func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string {
|
||
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
||
PROMPT, ok := cfg["PROMPT"].(string)
|
||
if !ok {
|
||
log.Println("PROMPT 未配置")
|
||
PROMPT = ""
|
||
}
|
||
|
||
requestBody := map[string]interface{}{
|
||
"model": MODEL,
|
||
"stream": false,
|
||
"messages": []map[string]string{
|
||
{
|
||
"role": "system",
|
||
"content": PROMPT,
|
||
},
|
||
{"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]},
|
||
},
|
||
"temperature": 0.7,
|
||
"presence_penalty": 0,
|
||
"frequency_penalty": 0,
|
||
"top_p": 1,
|
||
}
|
||
const maxRetry = 2
|
||
for retries := 0; retries <= maxRetry; retries++ {
|
||
request := gorequest.New()
|
||
resp, body, errs := request.Post(OPENAI_BaseURL).
|
||
Set("Content-Type", "application/json").
|
||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||
Send(requestBody).
|
||
End()
|
||
|
||
if errs != nil {
|
||
log.Println(errs)
|
||
return "请求失败"
|
||
}
|
||
println(resp.StatusCode)
|
||
if resp.StatusCode == 200 {
|
||
var responseBody map[string]interface{}
|
||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||
log.Println(err)
|
||
return "解析失败"
|
||
}
|
||
choices := responseBody["choices"].([]interface{})
|
||
if len(choices) > 0 {
|
||
choice := choices[0].(map[string]interface{})
|
||
msg := choice["message"].(map[string]interface{})["content"].(string)
|
||
return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg)
|
||
} else {
|
||
log.Println("choices为空")
|
||
return "api解析失败"
|
||
}
|
||
} else {
|
||
log.Printf("请求失败,状态码:%d,重试中...(%d/%d)\n", resp.StatusCode, retries+1, maxRetry)
|
||
time.Sleep(time.Second * 1)
|
||
// return "请求失败: " + fmt.Sprintf("%d", resp.StatusCode)
|
||
}
|
||
}
|
||
return "请求失败!"
|
||
}
|