JM-WechatMini/service/system/sys_chatgpt.go

168 lines
4.0 KiB
Go
Raw Permalink Normal View History

2023-11-02 04:34:46 +08:00
package system
import (
"context"
"errors"
"fmt"
"github.com/sashabaranov/go-openai"
"gorm.io/gorm"
"miniapp/global"
"miniapp/model/system"
"miniapp/model/system/request"
"strings"
)
type ChatGptService struct{}
func (chat *ChatGptService) CreateSK(option system.SysChatGptOption) error {
_, err := chat.GetSK()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return global.GVA_DB.Create(option).Error
}
return err
}
return errors.New("已经存在sk")
}
func (chat *ChatGptService) GetSK() (option system.SysChatGptOption, err error) {
err = global.GVA_DB.First(&option).Error
return
}
func (chat *ChatGptService) DeleteSK() error {
option, err := chat.GetSK()
if err != nil {
return err
}
return global.GVA_DB.Delete(option, "sk = ?", option.SK).Error
}
func (chat *ChatGptService) GetTable(req request.ChatGptRequest) (sql string, results []map[string]interface{}, err error) {
if req.DBName == "" {
return "", nil, errors.New("未选择db")
}
var tablesInfo []system.ChatField
var tableName string
global.GVA_DB.Table("information_schema.columns").Where("TABLE_SCHEMA = ?", req.DBName).Scan(&tablesInfo)
var tablesMap = make(map[string]bool)
for i := range tablesInfo {
tablesMap[tablesInfo[i].TABLE_NAME] = true
}
for i := range tablesMap {
tableName += i + ","
}
option, err := chat.GetSK()
if err != nil {
return "", nil, err
}
client := openai.NewClient(option.SK)
ctx := context.Background()
tables, err := getTables(ctx, client, tableName, req.Chat)
if err != nil {
return "", nil, err
}
tableArr := strings.Split(tables, ",")
if len(tableArr) != 0 {
firstKeyArr := strings.Split(tableArr[0], ":")
tableArr[0] = strings.Trim(firstKeyArr[len(firstKeyArr)-1], "\n")
}
sql, err = getSql(ctx, client, tableArr, tablesInfo, req.Chat)
if err != nil {
return "", nil, err
}
err = global.GVA_DB.Raw(sql).Scan(&results).Error
return sql, results, err
}
func getTables(ctx context.Context, client *openai.Client, tables string, chat string) (string, error) {
var tablePrompt = `You are a database administrator
Filter out the table names you might need from the tables I provided formatted as:
Table1,Table2,Table3
I will provide you with the following table configuration information:
Table1,Table2,Table3
Do not return information other than the table
Configured as:
%s
The problem is:
%s
`
content := fmt.Sprintf(tablePrompt, tables, chat)
chatReq := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: content,
},
},
}
resp, err := client.CreateChatCompletion(ctx, chatReq)
if err != nil {
fmt.Printf("Completion error: %v\n", err)
return "", err
}
return resp.Choices[0].Message.Content, nil
}
func getSql(ctx context.Context, client *openai.Client, tables []string, ChatField []system.ChatField, chat string) (string, error) {
var sqlPrompt = `You are a database administrator
Give me an SQL statement based on my question
I will provide you with my current database table configuration information in the form below
Table Name | Column Name | Column Description
Do not return information other than SQL
Configured as:
%s
The problem is:
%s`
var configured string
for ii := range ChatField {
for i := range tables {
if strings.Index(tables[i], ChatField[ii].TABLE_NAME) > -1 {
configured += fmt.Sprintf("%s | %s | %s \n", ChatField[ii].TABLE_NAME, ChatField[ii].COLUMN_NAME, ChatField[ii].COLUMN_COMMENT)
}
}
}
if configured == "" {
return "", errors.New("未找到表")
}
chatReq := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(sqlPrompt, configured, chat),
},
},
}
resp, err := client.CreateChatCompletion(ctx, chatReq)
if err != nil {
fmt.Printf("Completion error: %v\n", err)
return "", err
}
sql := resp.Choices[0].Message.Content
sqlArr := strings.Split(sql, ":")
sql = sqlArr[len(sqlArr)-1]
return sql, nil
}