feat(ai): 添加ai回复支持
This commit is contained in:
parent
9ad5b17438
commit
55bdc0615d
5 changed files with 228 additions and 79 deletions
7
main.go
7
main.go
|
@ -37,7 +37,7 @@ func insertMessage(db *sql.DB, data map[string]interface{}) error {
|
||||||
sender_role := sender["role"].(string)
|
sender_role := sender["role"].(string)
|
||||||
message_seq := data["message_seq"].(float64)
|
message_seq := data["message_seq"].(float64)
|
||||||
|
|
||||||
fmt.Println(post_type, message_time, int64(group_id), int64(message_id), raw_message, sender_user_id, sender_nickname, sender_card, sender_role, int64(message_seq))
|
fmt.Println("消息类型:", post_type, " 发送时间:", message_time, " 群号:", int64(group_id), " 消息id:", int64(message_id), " 原始消息:", raw_message, " 发送者id:", sender_user_id, " 发送者昵称:", sender_nickname, " 发送者名片:", sender_card, " 发送者角色:", sender_role, " 消息序列:", int64(message_seq))
|
||||||
_, err = db.Exec("INSERT INTO messages ( post_type, message_type, time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
_, err = db.Exec("INSERT INTO messages ( post_type, message_type, time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
post_type, message_type, message_time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq)
|
post_type, message_type, message_time, group_id, message_id, raw_message, sender_user_id, sender_nickname, sender_card, sender_role, message_seq)
|
||||||
fmt.Println("Data inserted successfully!")
|
fmt.Println("Data inserted successfully!")
|
||||||
|
@ -66,7 +66,8 @@ func handlePost(c *gin.Context) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Error decoding JSON"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Error decoding JSON"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// 输出解析后的 JSON 数据
|
||||||
|
fmt.Printf("data: %s\n\n", string(body))
|
||||||
// 打开数据库
|
// 打开数据库
|
||||||
db, err := sql.Open("sqlite3", "./data.db")
|
db, err := sql.Open("sqlite3", "./data.db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -103,9 +104,11 @@ func handlePost(c *gin.Context) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// 调用路由处理函数
|
// 调用路由处理函数
|
||||||
|
if data["post_type"] == "message" || data["post_type"] == "message_sent" {
|
||||||
utils.Router(data)
|
utils.Router(data)
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "JSON data received successfully!"})
|
c.JSON(http.StatusOK, gin.H{"message": "JSON data received successfully!"})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cfg := config.GetConfig()
|
cfg := config.GetConfig()
|
||||||
|
|
|
@ -1,21 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/base64"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Image2Base64(path string) string {
|
|
||||||
file, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
if data, err := io.ReadAll(file); err == nil {
|
|
||||||
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@ package utils
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"go-bot/workers"
|
"go-bot/workers"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,7 +20,20 @@ func Router(data map[string]interface{}) {
|
||||||
|
|
||||||
//包含发送消息的'!'
|
//包含发送消息的'!'
|
||||||
raw_msg := data["raw_message"].(string)
|
raw_msg := data["raw_message"].(string)
|
||||||
// fmt.Println("raw_msg:", string(raw_msg[0]))
|
|
||||||
|
// 匹配回复消息
|
||||||
|
if strings.HasPrefix(raw_msg, "[CQ:reply,id=") {
|
||||||
|
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
matches := re.FindStringSubmatch(raw_msg)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
fullMatch := matches[0]
|
||||||
|
raw_msg = re.ReplaceAllString(raw_msg, "")
|
||||||
|
raw_msg = raw_msg + " " + fullMatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("raw_msg:", string(raw_msg))
|
||||||
if len(raw_msg) > 1 && raw_msg[0] == '!' {
|
if len(raw_msg) > 1 && raw_msg[0] == '!' {
|
||||||
// 去除'!'
|
// 去除'!'
|
||||||
raw_msg = raw_msg[1:]
|
raw_msg = raw_msg[1:]
|
||||||
|
|
207
workers/ai.go
207
workers/ai.go
|
@ -1,9 +1,15 @@
|
||||||
package workers
|
package workers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -29,7 +35,7 @@ func (a *AI) GetMsg() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
ask := a.Parms[1]
|
ask := a.Parms[1]
|
||||||
if ask == "" {
|
if ask == "" || strings.HasPrefix(ask, "[CQ:reply,id=") {
|
||||||
return "不问问题你说个屁!"
|
return "不问问题你说个屁!"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,7 +47,145 @@ func (a *AI) GetMsg() string {
|
||||||
if strings.ToLower(a.Parms[1]) == "models" {
|
if strings.ToLower(a.Parms[1]) == "models" {
|
||||||
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
return handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL)
|
||||||
} else {
|
} else {
|
||||||
return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID)
|
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
||||||
|
PROMPT, ok := cfg["PROMPT"].(string)
|
||||||
|
if !ok {
|
||||||
|
log.Println("PROMPT 未配置")
|
||||||
|
PROMPT = ""
|
||||||
|
}
|
||||||
|
var requestBody map[string]interface{}
|
||||||
|
if !strings.HasPrefix(a.Parms[len(a.Parms)-1], "[CQ:reply,id=") {
|
||||||
|
|
||||||
|
requestBody = map[string]interface{}{
|
||||||
|
"model": MODEL,
|
||||||
|
"stream": false,
|
||||||
|
"messages": []map[string]string{
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": PROMPT,
|
||||||
|
},
|
||||||
|
{"role": "user", "content": a.RawMsg[strings.Index(a.RawMsg, " ")+1:]},
|
||||||
|
},
|
||||||
|
"temperature": 0.7,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 1,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pattern := `^\[CQ:reply,id=(-?\d+)\]`
|
||||||
|
re := regexp.MustCompile(pattern)
|
||||||
|
matches := re.FindStringSubmatch(a.Parms[len(a.Parms)-1])
|
||||||
|
var msgId string
|
||||||
|
if len(matches) > 0 {
|
||||||
|
|
||||||
|
msgId = matches[1]
|
||||||
|
} else {
|
||||||
|
msgId = ""
|
||||||
|
log.Println("未找到回复消息")
|
||||||
|
return "未找到回复消息"
|
||||||
|
}
|
||||||
|
message := a.GetHisMsg(msgId)
|
||||||
|
|
||||||
|
// 正则表达式匹配 file 和 file_size 的值
|
||||||
|
re = regexp.MustCompile(`file=([^,]+),.*file_size=(\d+)`)
|
||||||
|
matches = re.FindStringSubmatch(message)
|
||||||
|
var file string
|
||||||
|
var fileSizeStr string
|
||||||
|
if len(matches) > 2 {
|
||||||
|
file = matches[1]
|
||||||
|
fileSizeStr = matches[2]
|
||||||
|
} else {
|
||||||
|
log.Println("未找到文件信息")
|
||||||
|
return "未找到文件信息"
|
||||||
|
}
|
||||||
|
// 将 fileSizeStr 转换为整数
|
||||||
|
fileSize, err := strconv.ParseFloat(fileSizeStr, 64)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("获取图片大小失败:", err)
|
||||||
|
return "获取图片大小失败"
|
||||||
|
}
|
||||||
|
if fileSize/1024/1024 > 1.0 {
|
||||||
|
log.Println("文件大小超过1M")
|
||||||
|
return "文件大小超过1M"
|
||||||
|
}
|
||||||
|
filePath := a.GetImage(file)
|
||||||
|
if filePath == "" {
|
||||||
|
log.Println("获取图片失败")
|
||||||
|
return "获取图片失败"
|
||||||
|
}
|
||||||
|
base64Img := Image2Base64(filePath)
|
||||||
|
if base64Img == "" {
|
||||||
|
log.Println("图片转换base64失败")
|
||||||
|
return "图片转换base64失败"
|
||||||
|
}
|
||||||
|
// 找到第一个空格的位置
|
||||||
|
firstSpaceIndex := strings.Index(a.RawMsg, " ")
|
||||||
|
|
||||||
|
// 找到最后一个空格的位置
|
||||||
|
lastSpaceIndex := strings.LastIndex(a.RawMsg, " ")
|
||||||
|
requestBody = map[string]interface{}{
|
||||||
|
"model": MODEL,
|
||||||
|
"stream": false,
|
||||||
|
"messages": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"role": "system",
|
||||||
|
"content": "#角色你是一名AI视觉助手,任务是分析单个图像。",
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"role": "user",
|
||||||
|
"content": []interface{}{
|
||||||
|
map[string]interface{}{
|
||||||
|
"type": "text",
|
||||||
|
"text": a.RawMsg[firstSpaceIndex+1 : lastSpaceIndex],
|
||||||
|
},
|
||||||
|
map[string]interface{}{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": map[string]string{
|
||||||
|
"url": base64Img,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"temperature": 0.7,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"top_p": 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
request := gorequest.New()
|
||||||
|
resp, body, errs := request.Post(OPENAI_BaseURL).
|
||||||
|
Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway).
|
||||||
|
Set("Content-Type", "application/json").
|
||||||
|
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
||||||
|
Send(requestBody).
|
||||||
|
End()
|
||||||
|
|
||||||
|
if errs != nil {
|
||||||
|
log.Println(errs)
|
||||||
|
return "请求失败"
|
||||||
|
}
|
||||||
|
println(resp.StatusCode)
|
||||||
|
if resp.StatusCode == 200 {
|
||||||
|
var responseBody map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
||||||
|
log.Println(err)
|
||||||
|
return "解析失败"
|
||||||
|
}
|
||||||
|
choices := responseBody["choices"].([]interface{})
|
||||||
|
if len(choices) > 0 {
|
||||||
|
choice := choices[0].(map[string]interface{})
|
||||||
|
msg := choice["message"].(map[string]interface{})["content"].(string)
|
||||||
|
return fmt.Sprintf("[CQ:at,qq=%s] %s", a.UID, msg)
|
||||||
|
} else {
|
||||||
|
log.Println("choices为空")
|
||||||
|
return "api解析失败"
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return "请求失败!"
|
||||||
|
// return handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, a.RawMsg, a.UID, a.Parms)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,59 +253,16 @@ func handleModelRequest(OPENAI_API_KEY, OPENAI_BaseURL string) string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleChatRequest(OPENAI_API_KEY, OPENAI_BaseURL, MODEL, rawMsg, UID string) string {
|
func Image2Base64(path string) string {
|
||||||
OPENAI_BaseURL = OPENAI_BaseURL + "/chat/completions"
|
file, err := os.Open(path)
|
||||||
PROMPT, ok := cfg["PROMPT"].(string)
|
if err != nil {
|
||||||
if !ok {
|
return ""
|
||||||
log.Println("PROMPT 未配置")
|
|
||||||
PROMPT = ""
|
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
|
|
||||||
requestBody := map[string]interface{}{
|
if data, err := io.ReadAll(file); err == nil {
|
||||||
"model": MODEL,
|
return "data:image/jpeg;base64," + base64.StdEncoding.EncodeToString(data)
|
||||||
"stream": false,
|
|
||||||
"messages": []map[string]string{
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": PROMPT,
|
|
||||||
},
|
|
||||||
{"role": "user", "content": rawMsg[strings.Index(rawMsg, " ")+1:]},
|
|
||||||
},
|
|
||||||
"temperature": 0.7,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"top_p": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
request := gorequest.New()
|
|
||||||
resp, body, errs := request.Post(OPENAI_BaseURL).
|
|
||||||
Retry(3, 5*time.Second, http.StatusServiceUnavailable, http.StatusBadGateway).
|
|
||||||
Set("Content-Type", "application/json").
|
|
||||||
Set("Authorization", "Bearer "+OPENAI_API_KEY).
|
|
||||||
Send(requestBody).
|
|
||||||
End()
|
|
||||||
|
|
||||||
if errs != nil {
|
|
||||||
log.Println(errs)
|
|
||||||
return "请求失败"
|
|
||||||
}
|
|
||||||
println(resp.StatusCode)
|
|
||||||
if resp.StatusCode == 200 {
|
|
||||||
var responseBody map[string]interface{}
|
|
||||||
if err := json.Unmarshal([]byte(body), &responseBody); err != nil {
|
|
||||||
log.Println(err)
|
|
||||||
return "解析失败"
|
|
||||||
}
|
|
||||||
choices := responseBody["choices"].([]interface{})
|
|
||||||
if len(choices) > 0 {
|
|
||||||
choice := choices[0].(map[string]interface{})
|
|
||||||
msg := choice["message"].(map[string]interface{})["content"].(string)
|
|
||||||
return fmt.Sprintf("[CQ:at,qq=%s] %s", UID, msg)
|
|
||||||
} else {
|
|
||||||
log.Println("choices为空")
|
|
||||||
return "api解析失败"
|
|
||||||
}
|
}
|
||||||
|
return ""
|
||||||
|
|
||||||
}
|
}
|
||||||
return "请求失败!"
|
|
||||||
}
|
|
||||||
|
|
|
@ -12,6 +12,8 @@ type Worker interface {
|
||||||
CheckPermission() string
|
CheckPermission() string
|
||||||
GetMsg() string
|
GetMsg() string
|
||||||
SendMsg(msg string) bool
|
SendMsg(msg string) bool
|
||||||
|
GetImage(file string) string
|
||||||
|
GetHisMsg(id string) string
|
||||||
}
|
}
|
||||||
type StdAns struct {
|
type StdAns struct {
|
||||||
AllowGroup []interface{}
|
AllowGroup []interface{}
|
||||||
|
@ -115,7 +117,31 @@ func (s *StdAns) GetMsg() string {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
}
|
}
|
||||||
|
func (s *StdAns) GetHisMsg(id string) string {
|
||||||
|
if id == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
url := cfg["POSTURL"].(string) + "/get_msg?message_id=" + id
|
||||||
|
|
||||||
|
request := gorequest.New()
|
||||||
|
resp, body, errs := request.Post(url).End()
|
||||||
|
if len(errs) > 0 {
|
||||||
|
|
||||||
|
fmt.Println("Error sending request:", errs)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
err := json.Unmarshal([]byte(body), &data)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("解析JSON失败:", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// fmt.Println("响应返回:", body)
|
||||||
|
return data["data"].(map[string]interface{})["message"].(string)
|
||||||
|
|
||||||
|
}
|
||||||
func (s *StdAns) SendMsg(msg string) bool {
|
func (s *StdAns) SendMsg(msg string) bool {
|
||||||
if msg == "-1" {
|
if msg == "-1" {
|
||||||
return false
|
return false
|
||||||
|
@ -156,3 +182,29 @@ func (s *StdAns) SendMsg(msg string) bool {
|
||||||
fmt.Println("响应返回:", body)
|
fmt.Println("响应返回:", body)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *StdAns) GetImage(file string) string {
|
||||||
|
if file == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
url := cfg["POSTURL"].(string) + "/get_image?file=" + file
|
||||||
|
|
||||||
|
request := gorequest.New()
|
||||||
|
resp, body, errs := request.Post(url).End()
|
||||||
|
if len(errs) > 0 {
|
||||||
|
|
||||||
|
fmt.Println("Error sending request:", errs)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
data := make(map[string]interface{})
|
||||||
|
err := json.Unmarshal([]byte(body), &data)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println("解析JSON失败:", err)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
path := data["data"].(map[string]interface{})["file"].(string)
|
||||||
|
|
||||||
|
// fmt.Println("响应返回:", body)
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue