🎨 优化项目结构 && 完善ai配置

This commit is contained in:
2026-03-03 15:39:23 +08:00
parent 557c865948
commit 2714e63d2a
585 changed files with 62223 additions and 100018 deletions

View File

@@ -1,273 +1,43 @@
package app
import (
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
req "git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiPresetService struct{}
// CreateAiPreset 创建AI预设
func (s *AiPresetService) CreateAiPreset(userId uint, req *request.CreateAiPresetRequest) (preset app.AiPreset, err error) {
preset = app.AiPreset{
UserID: userId,
Name: req.Name,
Description: req.Description,
Prompts: req.Prompts,
RegexScripts: req.RegexScripts,
Temperature: req.Temperature,
TopP: req.TopP,
MaxTokens: req.MaxTokens,
FrequencyPenalty: req.FrequencyPenalty,
PresencePenalty: req.PresencePenalty,
StreamEnabled: req.StreamEnabled,
IsDefault: req.IsDefault,
IsPublic: req.IsPublic,
}
err = global.GVA_DB.Create(&preset).Error
return preset, err
// CreateAiPreset 创建预设
func (s *AiPresetService) CreateAiPreset(preset *app.AiPreset) error {
return global.GVA_DB.Create(preset).Error
}
// DeleteAiPreset 删除AI预设
func (s *AiPresetService) DeleteAiPreset(id uint, userId uint) (err error) {
// 如果 userId 为 0未登录不允许删除
if userId == 0 {
return global.GVA_DB.Where("id = ?", id).Delete(&app.AiPreset{}).Error
}
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userId).Delete(&app.AiPreset{}).Error
// DeleteAiPreset 删除预设
func (s *AiPresetService) DeleteAiPreset(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPreset{}).Error
}
// UpdateAiPreset 更新AI预设
func (s *AiPresetService) UpdateAiPreset(userId uint, req *request.UpdateAiPresetRequest) (preset app.AiPreset, err error) {
// 如果 userId 为 0未登录允许更新任何预设
if userId == 0 {
err = global.GVA_DB.Where("id = ?", req.ID).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND user_id = ?", req.ID, userId).First(&preset).Error
}
if err != nil {
return preset, err
}
if req.Name != "" {
preset.Name = req.Name
}
preset.Description = req.Description
if req.Prompts != nil {
preset.Prompts = req.Prompts
}
if req.RegexScripts != nil {
preset.RegexScripts = req.RegexScripts
}
preset.Temperature = req.Temperature
preset.TopP = req.TopP
preset.MaxTokens = req.MaxTokens
preset.FrequencyPenalty = req.FrequencyPenalty
preset.PresencePenalty = req.PresencePenalty
preset.StreamEnabled = req.StreamEnabled
preset.IsDefault = req.IsDefault
preset.IsPublic = req.IsPublic
err = global.GVA_DB.Save(&preset).Error
return preset, err
// UpdateAiPreset 更新预设
func (s *AiPresetService) UpdateAiPreset(preset *app.AiPreset, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(preset).Error
}
// GetAiPreset 获取AI预设详情
func (s *AiPresetService) GetAiPreset(id uint, userId uint) (preset app.AiPreset, err error) {
// 如果 userId 为 0未登录只能获取公开的预设
if userId == 0 {
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
}
return preset, err
// GetAiPreset 查询预设
func (s *AiPresetService) GetAiPreset(id uint, userID uint) (preset app.AiPreset, err error) {
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&preset).Error
return
}
// GetAiPresetList 获取AI预设列表
func (s *AiPresetService) GetAiPresetList(userId uint, info req.PageInfo) (list []app.AiPreset, total int64, err error) {
// GetAiPresetList 获取预设列表
func (s *AiPresetService) GetAiPresetList(info request.PageInfo, userID uint) (list []app.AiPreset, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiPreset{})
// 如果 userId 为 0未登录只返回公开的预设
if userId == 0 {
db = db.Where("is_public = ?", true)
} else {
db = db.Where("user_id = ? OR is_public = ?", userId, true)
}
db := global.GVA_DB.Model(&app.AiPreset{}).Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("created_at DESC").Find(&list).Error
return list, total, err
}
// ImportAiPreset 导入AI预设支持SillyTavern格式
func (s *AiPresetService) ImportAiPreset(userId uint, req *request.ImportAiPresetRequest) (preset app.AiPreset, err error) {
// 解析 SillyTavern JSON 格式
var stData map[string]interface{}
var jsonData []byte
// 类型断言处理 req.Data
switch v := req.Data.(type) {
case string:
jsonData = []byte(v)
case []byte:
jsonData = v
default:
return preset, fmt.Errorf("不支持的数据类型")
}
if err := json.Unmarshal(jsonData, &stData); err != nil {
return preset, fmt.Errorf("JSON 解析失败: %w", err)
}
// 提取基本信息
preset = app.AiPreset{
UserID: userId,
Name: req.Name,
Description: getStringValue(stData, "description"),
IsPublic: false,
}
// 提取参数
if temp, ok := stData["temperature"].(float64); ok {
preset.Temperature = temp
}
if topP, ok := stData["top_p"].(float64); ok {
preset.TopP = topP
}
if maxTokens, ok := stData["openai_max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
} else if maxTokens, ok := stData["max_tokens"].(float64); ok {
preset.MaxTokens = int(maxTokens)
}
if freqPenalty, ok := stData["frequency_penalty"].(float64); ok {
preset.FrequencyPenalty = freqPenalty
}
if presPenalty, ok := stData["presence_penalty"].(float64); ok {
preset.PresencePenalty = presPenalty
}
if stream, ok := stData["stream_openai"].(bool); ok {
preset.StreamEnabled = stream
}
// 提取提示词
prompts := make([]app.Prompt, 0)
if promptsData, ok := stData["prompts"].([]interface{}); ok {
for i, p := range promptsData {
if promptMap, ok := p.(map[string]interface{}); ok {
prompt := app.Prompt{
Name: getStringValue(promptMap, "name"),
Role: getStringValue(promptMap, "role"),
Content: getStringValue(promptMap, "content"),
SystemPrompt: getBoolValue(promptMap, "system_prompt"),
Marker: getBoolValue(promptMap, "marker"),
InjectionOrder: i,
InjectionDepth: int(getFloatValue(promptMap, "injection_depth")),
InjectionPosition: int(getFloatValue(promptMap, "injection_position")),
}
prompts = append(prompts, prompt)
}
}
}
preset.Prompts = prompts
// 提取正则脚本
regexScripts := make([]app.RegexScript, 0)
if extensions, ok := stData["extensions"].(map[string]interface{}); ok {
if scripts, ok := extensions["regex_scripts"].([]interface{}); ok {
for _, s := range scripts {
if scriptMap, ok := s.(map[string]interface{}); ok {
script := app.RegexScript{
ScriptName: getStringValue(scriptMap, "script_name"),
FindRegex: getStringValue(scriptMap, "find_regex"),
ReplaceString: getStringValue(scriptMap, "replace_string"),
Disabled: getBoolValue(scriptMap, "disabled"),
Placement: getIntArray(scriptMap, "placement"),
}
regexScripts = append(regexScripts, script)
}
}
}
}
preset.RegexScripts = regexScripts
// 保存到数据库
err = global.GVA_DB.Create(&preset).Error
return preset, err
}
// 辅助函数
func getStringValue(m map[string]interface{}, key string) string {
if v, ok := m[key].(string); ok {
return v
}
return ""
}
func getBoolValue(m map[string]interface{}, key string) bool {
if v, ok := m[key].(bool); ok {
return v
}
return false
}
func getFloatValue(m map[string]interface{}, key string) float64 {
if v, ok := m[key].(float64); ok {
return v
}
return 0
}
func getIntArray(m map[string]interface{}, key string) []int {
result := make([]int, 0)
if arr, ok := m[key].([]interface{}); ok {
for _, v := range arr {
if num, ok := v.(float64); ok {
result = append(result, int(num))
}
}
}
return result
}
// ExportAiPreset 导出AI预设
func (s *AiPresetService) ExportAiPreset(id uint, userId uint) (data map[string]interface{}, err error) {
var preset app.AiPreset
// 如果 userId 为 0未登录只能导出公开的预设
if userId == 0 {
err = global.GVA_DB.Where("id = ? AND is_public = ?", id, true).First(&preset).Error
} else {
err = global.GVA_DB.Where("id = ? AND (user_id = ? OR is_public = ?)", id, userId, true).First(&preset).Error
}
if err != nil {
return nil, err
}
data = map[string]interface{}{
"prompts": preset.Prompts,
"extensions": map[string]interface{}{
"regex_scripts": preset.RegexScripts,
},
"temperature": preset.Temperature,
"top_p": preset.TopP,
"openai_max_tokens": preset.MaxTokens,
"frequency_penalty": preset.FrequencyPenalty,
"presence_penalty": preset.PresencePenalty,
"stream_openai": preset.StreamEnabled,
}
return data, nil
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}

View File

@@ -1,154 +1,43 @@
package app
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"gorm.io/gorm"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type PresetBindingService struct{}
type AiPresetBindingService struct{}
// CreateBinding 创建预设绑定
func (s *PresetBindingService) CreateBinding(req *request.CreateBindingRequest) error {
// 检查预设是否存在
var preset app.AiPreset
if err := global.GVA_DB.First(&preset, req.PresetID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("预设不存在")
}
return err
}
// 检查提供商是否存在
var provider app.AiProvider
if err := global.GVA_DB.First(&provider, req.ProviderID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("提供商不存在")
}
return err
}
// 检查是否已存在相同的绑定
var count int64
err := global.GVA_DB.Model(&app.AiPresetBinding{}).
Where("preset_id = ? AND provider_id = ?", req.PresetID, req.ProviderID).
Count(&count).Error
if err != nil {
return err
}
if count > 0 {
return errors.New("该绑定已存在")
}
binding := app.AiPresetBinding{
PresetID: req.PresetID,
ProviderID: req.ProviderID,
Priority: req.Priority,
IsActive: true,
}
return global.GVA_DB.Create(&binding).Error
// CreateAiPresetBinding 创建绑定
func (s *AiPresetBindingService) CreateAiPresetBinding(binding *app.AiPresetBinding) error {
return global.GVA_DB.Create(binding).Error
}
// UpdateBinding 更新预设绑定
func (s *PresetBindingService) UpdateBinding(req *request.UpdateBindingRequest) error {
var binding app.AiPresetBinding
if err := global.GVA_DB.First(&binding, req.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("绑定不存在")
}
return err
}
updates := map[string]interface{}{
"priority": req.Priority,
"is_active": req.IsActive,
}
return global.GVA_DB.Model(&binding).Updates(updates).Error
// DeleteAiPresetBinding 删除绑定
func (s *AiPresetBindingService) DeleteAiPresetBinding(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiPresetBinding{}).Error
}
// DeleteBinding 删除预设绑定
func (s *PresetBindingService) DeleteBinding(id uint) error {
var binding app.AiPresetBinding
if err := global.GVA_DB.First(&binding, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("绑定不存在")
}
return err
}
return global.GVA_DB.Delete(&binding).Error
// UpdateAiPresetBinding 更新绑定
func (s *AiPresetBindingService) UpdateAiPresetBinding(binding *app.AiPresetBinding, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(binding).Error
}
// GetBindingList 获取绑定列表
func (s *PresetBindingService) GetBindingList(req *request.GetBindingListRequest) (list []response.BindingInfo, total int64, err error) {
db := global.GVA_DB.Model(&app.AiPresetBinding{})
// GetAiPresetBinding 查询绑定
func (s *AiPresetBindingService) GetAiPresetBinding(id uint, userID uint) (binding app.AiPresetBinding, err error) {
err = global.GVA_DB.Preload("Preset").Preload("Provider").Where("id = ? AND user_id = ?", id, userID).First(&binding).Error
return
}
// 条件查询
if req.ProviderID > 0 {
db = db.Where("provider_id = ?", req.ProviderID)
}
if req.PresetID > 0 {
db = db.Where("preset_id = ?", req.PresetID)
}
// 获取总数
// GetAiPresetBindingList 获取绑定列表
func (s *AiPresetBindingService) GetAiPresetBindingList(info request.PageInfo, userID uint) (list []app.AiPresetBinding, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiPresetBinding{}).Preload("Preset").Preload("Provider").Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return nil, 0, err
return
}
// 分页查询
if req.Page > 0 && req.PageSize > 0 {
offset := (req.Page - 1) * req.PageSize
db = db.Offset(offset).Limit(req.PageSize)
}
var bindings []app.AiPresetBinding
err = db.Preload("Preset").Preload("Provider").Order("priority ASC, created_at DESC").Find(&bindings).Error
if err != nil {
return nil, 0, err
}
// 转换为响应格式
list = make([]response.BindingInfo, len(bindings))
for i, binding := range bindings {
list[i] = response.BindingInfo{
ID: binding.ID,
PresetID: binding.PresetID,
PresetName: binding.Preset.Name,
ProviderID: binding.ProviderID,
ProviderName: binding.Provider.Name,
Priority: binding.Priority,
IsActive: binding.IsActive,
CreatedAt: binding.CreatedAt,
UpdatedAt: binding.UpdatedAt,
}
}
return list, total, nil
}
// GetBindingsByProvider 根据提供商获取绑定的预设列表
func (s *PresetBindingService) GetBindingsByProvider(providerID uint) ([]app.AiPreset, error) {
var bindings []app.AiPresetBinding
err := global.GVA_DB.Where("provider_id = ? AND is_active = ?", providerID, true).
Preload("Preset").
Order("priority ASC").
Find(&bindings).Error
if err != nil {
return nil, err
}
presets := make([]app.AiPreset, len(bindings))
for i, binding := range bindings {
presets[i] = binding.Preset
}
return presets, nil
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return
}

View File

@@ -0,0 +1,305 @@
package app
import (
"fmt"
"regexp"
"sort"
"strings"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
)
// PresetInjector 预设注入器
type PresetInjector struct {
preset *app.AiPreset
}
// NewPresetInjector 创建预设注入器
func NewPresetInjector(preset *app.AiPreset) *PresetInjector {
return &PresetInjector{preset: preset}
}
// InjectMessages 注入预设到消息列表
func (p *PresetInjector) InjectMessages(messages []request.ChatMessage) []request.ChatMessage {
if p.preset == nil || len(p.preset.Prompts) == 0 {
return messages
}
// 1. 应用用户输入前的正则替换
messages = p.applyRegexScripts(messages, 1)
// 2. 获取启用的提示词并排序
enabledPrompts := p.getEnabledPrompts()
// 3. 构建注入后的消息列表
injectedMessages := p.buildInjectedMessages(messages, enabledPrompts)
return injectedMessages
}
// getEnabledPrompts 获取启用的提示词并按注入顺序排序
func (p *PresetInjector) getEnabledPrompts() []app.PresetPrompt {
var prompts []app.PresetPrompt
for _, prompt := range p.preset.Prompts {
if prompt.Enabled && !prompt.Marker {
prompts = append(prompts, prompt)
}
}
// 按 injection_order 和 injection_depth 排序
sort.Slice(prompts, func(i, j int) bool {
if prompts[i].InjectionOrder != prompts[j].InjectionOrder {
return prompts[i].InjectionOrder < prompts[j].InjectionOrder
}
return prompts[i].InjectionDepth < prompts[j].InjectionDepth
})
return prompts
}
// buildInjectedMessages 构建注入后的消息列表
func (p *PresetInjector) buildInjectedMessages(messages []request.ChatMessage, prompts []app.PresetPrompt) []request.ChatMessage {
result := make([]request.ChatMessage, 0)
// 分离系统提示词和对话消息
var systemPrompts []app.PresetPrompt
var otherPrompts []app.PresetPrompt
for _, prompt := range prompts {
if prompt.Role == "system" {
systemPrompts = append(systemPrompts, prompt)
} else {
otherPrompts = append(otherPrompts, prompt)
}
}
// 1. 先添加系统提示词
for _, prompt := range systemPrompts {
result = append(result, request.ChatMessage{
Role: "system",
Content: p.processPromptContent(prompt.Content),
})
}
// 2. 处理对话历史注入
chatHistoryIndex := p.findMarkerIndex("chatHistory")
if chatHistoryIndex >= 0 {
// 在 chatHistory 标记位置注入原始消息
result = append(result, messages...)
} else {
// 如果没有 chatHistory 标记,直接添加到末尾
result = append(result, messages...)
}
// 3. 添加其他角色的提示词(assistant等)
for _, prompt := range otherPrompts {
result = append(result, request.ChatMessage{
Role: prompt.Role,
Content: p.processPromptContent(prompt.Content),
})
}
return result
}
// findMarkerIndex 查找标记位置
func (p *PresetInjector) findMarkerIndex(identifier string) int {
for i, prompt := range p.preset.Prompts {
if prompt.Identifier == identifier && prompt.Marker {
return i
}
}
return -1
}
// processPromptContent 处理提示词内容(变量替换等)
func (p *PresetInjector) processPromptContent(content string) string {
// 处理 {{user}} 和 {{char}} 等变量
content = strings.ReplaceAll(content, "{{user}}", "User")
content = strings.ReplaceAll(content, "{{char}}", "Assistant")
// 处理 {{getvar::key}} 语法
getvarRegex := regexp.MustCompile(`\{\{getvar::(\w+)\}\}`)
content = getvarRegex.ReplaceAllString(content, "")
// 处理 {{setvar::key::value}} 语法
setvarRegex := regexp.MustCompile(`\{\{setvar::(\w+)::(.*?)\}\}`)
content = setvarRegex.ReplaceAllString(content, "")
// 处理注释 {{//...}}
commentRegex := regexp.MustCompile(`\{\{//.*?\}\}`)
content = commentRegex.ReplaceAllString(content, "")
return strings.TrimSpace(content)
}
// applyRegexScripts 应用正则替换脚本
func (p *PresetInjector) applyRegexScripts(messages []request.ChatMessage, placement int) []request.ChatMessage {
if p.preset.Extensions.RegexBinding == nil {
return messages
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查 placement
hasPlacement := false
for _, p := range script.Placement {
if p == placement {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 应用正则替换
messages = p.applyRegexScript(messages, script)
}
return messages
}
// applyRegexScript 应用单个正则脚本
func (p *PresetInjector) applyRegexScript(messages []request.ChatMessage, script app.RegexScript) []request.ChatMessage {
// 解析正则表达式
pattern := script.FindRegex
// 移除正则标志(如 /pattern/g)
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
return messages
}
// 对每条消息应用替换
for i := range messages {
if script.PromptOnly && messages[i].Role != "user" {
continue
}
messages[i].Content = re.ReplaceAllString(messages[i].Content, script.ReplaceString)
}
return messages
}
// ProcessResponse 处理AI响应(应用输出后的正则)
func (p *PresetInjector) ProcessResponse(content string) string {
if p.preset == nil || p.preset.Extensions.RegexBinding == nil {
return content
}
for _, script := range p.preset.Extensions.RegexBinding.Regexes {
if script.Disabled {
continue
}
// 检查是否应用于输出(placement=2)
hasPlacement := false
for _, placement := range script.Placement {
if placement == 2 {
hasPlacement = true
break
}
}
if !hasPlacement {
continue
}
// 解析正则表达式
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") && strings.HasSuffix(pattern, "/g") {
pattern = pattern[1 : len(pattern)-2]
} else if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
re, err := regexp.Compile(pattern)
if err != nil {
continue
}
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
// ApplyPresetParameters 应用预设参数到请求
func (p *PresetInjector) ApplyPresetParameters(req *request.ChatCompletionRequest) {
if p.preset == nil {
return
}
// 如果请求中没有指定参数,使用预设的参数
if req.Temperature == nil && p.preset.Temperature > 0 {
temp := p.preset.Temperature
req.Temperature = &temp
}
if req.TopP == nil && p.preset.TopP > 0 {
topP := p.preset.TopP
req.TopP = &topP
}
if req.MaxTokens == nil && p.preset.MaxTokens > 0 {
maxTokens := p.preset.MaxTokens
req.MaxTokens = &maxTokens
}
if req.PresencePenalty == nil && p.preset.PresencePenalty != 0 {
pp := p.preset.PresencePenalty
req.PresencePenalty = &pp
}
if req.FrequencyPenalty == nil && p.preset.FrequencyPenalty != 0 {
fp := p.preset.FrequencyPenalty
req.FrequencyPenalty = &fp
}
}
// ValidatePreset 验证预设配置
func ValidatePreset(preset *app.AiPreset) error {
if preset == nil {
return fmt.Errorf("预设不能为空")
}
if preset.Name == "" {
return fmt.Errorf("预设名称不能为空")
}
// 验证正则表达式
if preset.Extensions.RegexBinding != nil {
for _, script := range preset.Extensions.RegexBinding.Regexes {
pattern := script.FindRegex
if strings.HasPrefix(pattern, "/") {
lastSlash := strings.LastIndex(pattern, "/")
if lastSlash > 0 {
pattern = pattern[1:lastSlash]
}
}
_, err := regexp.Compile(pattern)
if err != nil {
return fmt.Errorf("正则表达式 '%s' 无效: %v", script.ScriptName, err)
}
}
}
return nil
}

View File

@@ -1,230 +1,43 @@
package app
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
)
type AiProviderService struct{}
// CreateAiProvider 创建AI提供商
func (s *AiProviderService) CreateAiProvider(req *request.CreateAiProviderRequest) (provider app.AiProvider, err error) {
provider = app.AiProvider{
Name: req.Name,
Type: req.Type,
BaseURL: req.BaseURL,
Endpoint: req.Endpoint,
UpstreamKey: req.UpstreamKey,
Model: req.Model,
ProxyKey: req.ProxyKey,
Config: req.Config,
IsActive: req.IsActive,
}
err = global.GVA_DB.Create(&provider).Error
return provider, err
// CreateAiProvider 创建提供商
func (s *AiProviderService) CreateAiProvider(provider *app.AiProvider) error {
return global.GVA_DB.Create(provider).Error
}
// DeleteAiProvider 删除AI提供商
func (s *AiProviderService) DeleteAiProvider(id uint) (err error) {
return global.GVA_DB.Delete(&app.AiProvider{}, id).Error
// DeleteAiProvider 删除提供商
func (s *AiProviderService) DeleteAiProvider(id uint, userID uint) error {
return global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).Delete(&app.AiProvider{}).Error
}
// UpdateAiProvider 更新AI提供商
func (s *AiProviderService) UpdateAiProvider(req *request.UpdateAiProviderRequest) (provider app.AiProvider, err error) {
err = global.GVA_DB.First(&provider, req.ID).Error
// UpdateAiProvider 更新提供商
func (s *AiProviderService) UpdateAiProvider(provider *app.AiProvider, userID uint) error {
return global.GVA_DB.Where("user_id = ?", userID).Updates(provider).Error
}
// GetAiProvider 查询提供商
func (s *AiProviderService) GetAiProvider(id uint, userID uint) (provider app.AiProvider, err error) {
err = global.GVA_DB.Where("id = ? AND user_id = ?", id, userID).First(&provider).Error
return
}
// GetAiProviderList 获取提供商列表
func (s *AiProviderService) GetAiProviderList(info request.PageInfo, userID uint) (list []app.AiProvider, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&app.AiProvider{}).Where("user_id = ?", userID)
err = db.Count(&total).Error
if err != nil {
return provider, err
return
}
if req.Name != "" {
provider.Name = req.Name
}
if req.Type != "" {
provider.Type = req.Type
}
if req.BaseURL != "" {
provider.BaseURL = req.BaseURL
}
if req.Endpoint != "" {
provider.Endpoint = req.Endpoint
}
if req.UpstreamKey != "" {
provider.UpstreamKey = req.UpstreamKey
}
if req.Model != "" {
provider.Model = req.Model
}
if req.ProxyKey != "" {
provider.ProxyKey = req.ProxyKey
}
if req.Config != nil {
provider.Config = req.Config
}
provider.IsActive = req.IsActive
err = global.GVA_DB.Save(&provider).Error
return provider, err
}
// GetAiProvider 获取AI提供商详情
func (s *AiProviderService) GetAiProvider(id uint) (provider app.AiProvider, err error) {
err = global.GVA_DB.First(&provider, id).Error
return provider, err
}
// GetAiProviderList 获取AI提供商列表
func (s *AiProviderService) GetAiProviderList() (list []app.AiProvider, err error) {
err = global.GVA_DB.Where("is_active = ?", true).Find(&list).Error
return list, err
}
// TestConnection 测试连接
func (s *AiProviderService) TestConnection(req *request.TestConnectionRequest) (resp response.TestConnectionResponse, err error) {
startTime := time.Now()
// 根据类型构建测试 URL
var testURL string
switch strings.ToLower(req.Type) {
case "openai":
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
case "claude":
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/messages"
default:
testURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("GET", testURL, nil)
if err != nil {
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("创建请求失败: %v", err),
Latency: 0,
}, nil
}
// 设置请求头
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
httpReq.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{
Timeout: 10 * time.Second,
}
httpResp, err := client.Do(httpReq)
if err != nil {
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("连接失败: %v", err),
Latency: time.Since(startTime).Milliseconds(),
}, nil
}
defer httpResp.Body.Close()
latency := time.Since(startTime).Milliseconds()
// 检查响应状态
if httpResp.StatusCode == http.StatusOK || httpResp.StatusCode == http.StatusCreated {
return response.TestConnectionResponse{
Success: true,
Message: "连接成功",
Latency: latency,
}, nil
}
// 读取错误响应
body, _ := io.ReadAll(httpResp.Body)
return response.TestConnectionResponse{
Success: false,
Message: fmt.Sprintf("连接失败 (状态码: %d): %s", httpResp.StatusCode, string(body)),
Latency: latency,
}, nil
}
// GetModels 获取模型列表
func (s *AiProviderService) GetModels(req *request.GetModelsRequest) (models []response.ModelInfo, err error) {
// 根据类型构建 URL
var modelsURL string
switch strings.ToLower(req.Type) {
case "openai":
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
case "claude":
// Claude API 不提供模型列表接口,返回预定义的模型
return []response.ModelInfo{
{ID: "claude-opus-4-6", Name: "Claude Opus 4.6", OwnedBy: "anthropic"},
{ID: "claude-sonnet-4-6", Name: "Claude Sonnet 4.6", OwnedBy: "anthropic"},
{ID: "claude-haiku-4-5-20251001", Name: "Claude Haiku 4.5", OwnedBy: "anthropic"},
{ID: "claude-3-5-sonnet-20241022", Name: "Claude 3.5 Sonnet", OwnedBy: "anthropic"},
{ID: "claude-3-opus-20240229", Name: "Claude 3 Opus", OwnedBy: "anthropic"},
}, nil
default:
modelsURL = strings.TrimSuffix(req.BaseURL, "/") + "/v1/models"
}
// 创建 HTTP 请求
httpReq, err := http.NewRequest("GET", modelsURL, nil)
if err != nil {
return nil, errors.New("创建请求失败: " + err.Error())
}
// 设置请求头
httpReq.Header.Set("Authorization", "Bearer "+req.UpstreamKey)
httpReq.Header.Set("Content-Type", "application/json")
// 发送请求
client := &http.Client{
Timeout: 10 * time.Second,
}
httpResp, err := client.Do(httpReq)
if err != nil {
return nil, errors.New("请求失败: " + err.Error())
}
defer httpResp.Body.Close()
// 检查响应状态
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("获取模型列表失败 (状态码: %d): %s", httpResp.StatusCode, string(body))
}
// 解析响应
body, err := io.ReadAll(httpResp.Body)
if err != nil {
return nil, errors.New("读取响应失败: " + err.Error())
}
// OpenAI 格式的响应
var modelsResp struct {
Data []struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
} `json:"data"`
}
if err := json.Unmarshal(body, &modelsResp); err != nil {
return nil, errors.New("解析响应失败: " + err.Error())
}
// 转换为响应格式
models = make([]response.ModelInfo, len(modelsResp.Data))
for i, model := range modelsResp.Data {
models[i] = response.ModelInfo{
ID: model.ID,
Name: model.ID,
OwnedBy: model.OwnedBy,
}
}
return models, nil
err = db.Limit(limit).Offset(offset).Order("priority desc, id desc").Find(&list).Error
return
}

View File

@@ -1,14 +1,13 @@
package app
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"sort"
"strings"
"time"
@@ -16,264 +15,252 @@ import (
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"git.echol.cn/loser/ai_proxy/server/model/app/response"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
type AiProxyService struct{}
// ProcessChatCompletion 处理聊天补全请求
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userId uint, req *request.ChatCompletionRequest) (resp response.ChatCompletionResponse, err error) {
func (s *AiProxyService) ProcessChatCompletion(ctx context.Context, userID uint, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
startTime := time.Now()
// 1. 获取预设配置
var preset app.AiPreset
if req.PresetID > 0 {
err = global.GVA_DB.First(&preset, req.PresetID).Error
if err != nil {
return resp, fmt.Errorf("预设不存在: %w", err)
}
}
// 2. 获取提供商配置
var provider app.AiProvider
// 根据 binding_key 或预设绑定获取 provider
if req.BindingKey != "" {
// 通过 binding_key 查找绑定关系
var binding app.AiPresetBinding
err = global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
Order("priority ASC").
First(&binding).Error
if err == nil {
err = global.GVA_DB.First(&provider, binding.ProviderID).Error
}
}
// 如果没有找到,使用默认的活跃提供商
if provider.ID == 0 {
err = global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
if err != nil {
return resp, fmt.Errorf("未找到可用的AI提供商: %w", err)
}
}
// 3. 构建注入后的消息
messages, err := s.buildInjectedMessages(req, &preset)
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
if err != nil {
return resp, fmt.Errorf("构建消息失败: %w", err)
return nil, fmt.Errorf("获取绑定配置失败: %w", err)
}
// 4. 转发到上游AI
resp, err = s.forwardToAI(ctx, &provider, &preset, messages)
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
// 3. 转发请求到上游
resp, err := s.forwardRequest(ctx, &binding.Provider, req)
if err != nil {
// 记录失败日志
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, "", err, time.Since(startTime))
return resp, err
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
return nil, err
}
// 5. 应用输出正则脚本
resp.Choices[0].Message.Content = s.applyOutputRegex(resp.Choices[0].Message.Content, preset.RegexScripts)
// 4. 处理响
if len(resp.Choices) > 0 {
resp.Choices[0].Message.Content = injector.ProcessResponse(resp.Choices[0].Message.Content)
}
// 6. 记录成功日志
s.logRequest(userId, &preset, &provider, req.Messages[0].Content, resp.Choices[0].Message.Content, nil, time.Since(startTime))
// 5. 记录日志
s.logRequest(userID, binding, req, resp, nil, time.Since(startTime))
return resp, nil
}
// buildInjectedMessages 构建注入预设后的消息数组
func (s *AiProxyService) buildInjectedMessages(req *request.ChatCompletionRequest, preset *app.AiPreset) ([]request.Message, error) {
if preset == nil || preset.ID == 0 {
return req.Messages, nil
}
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userID uint, req *request.ChatCompletionRequest) {
startTime := time.Now()
// 1. 按 injection_order 排序 prompts
sortedPrompts := make([]app.Prompt, len(preset.Prompts))
copy(sortedPrompts, preset.Prompts)
sort.Slice(sortedPrompts, func(i, j int) bool {
return sortedPrompts[i].InjectionOrder < sortedPrompts[j].InjectionOrder
})
messages := make([]request.Message, 0)
// 2. 根据 injection_depth 插入到对话历史中
for _, prompt := range sortedPrompts {
if prompt.Marker {
continue // 跳过标记提示词
}
// 替换变量
content := s.replaceVariables(prompt.Content, req.Variables, req.CharacterCard)
// 根据 injection_depth 决定插入位置
// depth=0: 插入到最前面(系统提示词)
// depth>0: 从对话历史末尾往前数 depth 条消息的位置插入
if prompt.InjectionDepth == 0 || prompt.SystemPrompt {
messages = append(messages, request.Message{
Role: prompt.Role,
Content: content,
})
} else {
// 先添加用户消息,稍后根据 depth 插入
// 这里简化处理,将非系统提示词也添加到前面
messages = append(messages, request.Message{
Role: prompt.Role,
Content: content,
})
}
}
// 添加用户消息
messages = append(messages, req.Messages...)
// 4. 应用输入正则脚本 (placement=1)
for i := range messages {
messages[i].Content = s.applyInputRegex(messages[i].Content, preset.RegexScripts)
}
return messages, nil
}
// replaceVariables 替换变量
func (s *AiProxyService) replaceVariables(content string, vars map[string]string, card *request.CharacterCard) string {
result := content
// 替换自定义变量
for key, value := range vars {
placeholder := fmt.Sprintf("{{%s}}", key)
result = replaceAll(result, placeholder, value)
}
// 替换角色卡片变量
if card != nil {
result = replaceAll(result, "{{char}}", card.Name)
result = replaceAll(result, "{{char_name}}", card.Name)
}
return result
}
// applyInputRegex 应用输入正则脚本
func (s *AiProxyService) applyInputRegex(content string, scripts []app.RegexScript) string {
for _, script := range scripts {
if script.Disabled {
continue
}
if !containsPlacement(script.Placement, 1) {
continue
}
// 编译正则表达式
re, err := regexp.Compile(script.FindRegex)
if err != nil {
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
continue
}
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
// applyOutputRegex 应用输出正则脚本
func (s *AiProxyService) applyOutputRegex(content string, scripts []app.RegexScript) string {
for _, script := range scripts {
if script.Disabled {
continue
}
if !containsPlacement(script.Placement, 2) {
continue
}
// 编译正则表达式
re, err := regexp.Compile(script.FindRegex)
if err != nil {
global.GVA_LOG.Error(fmt.Sprintf("正则表达式编译失败: %s", script.ScriptName))
continue
}
// 执行替换
content = re.ReplaceAllString(content, script.ReplaceString)
}
return content
}
// forwardToAI 转发请求到上游AI
func (s *AiProxyService) forwardToAI(ctx context.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message) (response.ChatCompletionResponse, error) {
// 构建请求体
reqBody := map[string]interface{}{
"model": provider.Model,
"messages": messages,
}
if preset != nil {
reqBody["temperature"] = preset.Temperature
reqBody["top_p"] = preset.TopP
reqBody["max_tokens"] = preset.MaxTokens
reqBody["frequency_penalty"] = preset.FrequencyPenalty
reqBody["presence_penalty"] = preset.PresencePenalty
}
jsonData, err := json.Marshal(reqBody)
// 1. 获取绑定配置
binding, err := s.getBinding(userID, req)
if err != nil {
return response.ChatCompletionResponse{}, err
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 创建HTTP请求
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
// 2. 注入预设
injector := NewPresetInjector(&binding.Preset)
req.Messages = injector.InjectMessages(req.Messages)
injector.ApplyPresetParameters(req)
// 3. 设置 SSE 响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 4. 转发流式请求
err = s.forwardStreamRequest(c, &binding.Provider, req, injector)
if err != nil {
return response.ChatCompletionResponse{}, err
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
s.logRequest(userID, binding, req, nil, err, time.Since(startTime))
}
}
// getBinding 获取绑定配置
func (s *AiProxyService) getBinding(userID uint, req *request.ChatCompletionRequest) (*app.AiPresetBinding, error) {
var binding app.AiPresetBinding
query := global.GVA_DB.Preload("Preset").Preload("Provider").Where("user_id = ? AND enabled = ?", userID, true)
// 优先使用 binding_name
if req.BindingName != "" {
query = query.Where("name = ?", req.BindingName)
} else if req.PresetName != "" && req.ProviderName != "" {
// 使用 preset_name 和 provider_name
query = query.Joins("JOIN ai_presets ON ai_presets.id = ai_preset_bindings.preset_id").
Joins("JOIN ai_providers ON ai_providers.id = ai_preset_bindings.provider_id").
Where("ai_presets.name = ? AND ai_providers.name = ?", req.PresetName, req.ProviderName)
} else {
// 使用默认绑定(第一个启用的)
query = query.Order("id ASC")
}
req.Header.Set("Content-Type", "application/json")
if provider.UpstreamKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
if err := query.First(&binding).Error; err != nil {
return nil, fmt.Errorf("未找到可用的绑定配置")
}
if !binding.Provider.Enabled {
return nil, fmt.Errorf("提供商已禁用")
}
if !binding.Preset.Enabled {
return nil, fmt.Errorf("预设已禁用")
}
return &binding, nil
}
// forwardRequest 转发请求到上游 AI 服务
func (s *AiProxyService) forwardRequest(ctx context.Context, provider *app.AiProvider, req *request.ChatCompletionRequest) (*response.ChatCompletionResponse, error) {
// 使用提供商的默认模型(如果请求中没有指定)
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
// 构建请求
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(reqBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
// 发送请求
client := &http.Client{Timeout: 120 * time.Second}
resp, err := client.Do(req)
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return response.ChatCompletionResponse{}, err
return nil, fmt.Errorf("请求失败: %w", err)
}
defer resp.Body.Close()
defer httpResp.Body.Close()
// 读取响应
body, err := io.ReadAll(resp.Body)
if err != nil {
return response.ChatCompletionResponse{}, err
}
if resp.StatusCode != http.StatusOK {
return response.ChatCompletionResponse{}, fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return nil, fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 解析响应
var aiResp response.ChatCompletionResponse
if err := json.Unmarshal(body, &aiResp); err != nil {
return response.ChatCompletionResponse{}, err
var resp response.ChatCompletionResponse
if err := json.NewDecoder(httpResp.Body).Decode(&resp); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
return aiResp, nil
return &resp, nil
}
// forwardStreamRequest 转发流式请求
func (s *AiProxyService) forwardStreamRequest(c *gin.Context, provider *app.AiProvider, req *request.ChatCompletionRequest, injector *PresetInjector) error {
// 使用提供商的默认模型
if req.Model == "" && provider.Model != "" {
req.Model = provider.Model
}
reqBody, err := json.Marshal(req)
if err != nil {
return err
}
url := strings.TrimRight(provider.BaseURL, "/") + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewReader(reqBody))
if err != nil {
return err
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+provider.APIKey)
client := &http.Client{Timeout: time.Duration(provider.Timeout) * time.Second}
httpResp, err := client.Do(httpReq)
if err != nil {
return err
}
defer httpResp.Body.Close()
if httpResp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(httpResp.Body)
return fmt.Errorf("上游返回错误: %d - %s", httpResp.StatusCode, string(body))
}
// 读取并转发流式响应
reader := bufio.NewReader(httpResp.Body)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return fmt.Errorf("不支持流式响应")
}
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
// 跳过空行
if len(bytes.TrimSpace(line)) == 0 {
continue
}
// 处理 SSE 数据
if bytes.HasPrefix(line, []byte("data: ")) {
data := bytes.TrimPrefix(line, []byte("data: "))
data = bytes.TrimSpace(data)
// 检查是否是结束标记
if string(data) == "[DONE]" {
c.Writer.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
break
}
// 解析并处理响应
var chunk response.ChatCompletionStreamResponse
if err := json.Unmarshal(data, &chunk); err != nil {
continue
}
// 应用输出正则处理
if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" {
chunk.Choices[0].Delta.Content = injector.ProcessResponse(chunk.Choices[0].Delta.Content)
}
// 重新序列化并发送
processedData, _ := json.Marshal(chunk)
c.Writer.Write([]byte("data: "))
c.Writer.Write(processedData)
c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
}
return nil
}
// logRequest 记录请求日志
func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider *app.AiProvider, originalMsg, responseText string, err error, latency time.Duration) {
func (s *AiProxyService) logRequest(userID uint, binding *app.AiPresetBinding, req *request.ChatCompletionRequest, resp *response.ChatCompletionResponse, err error, duration time.Duration) {
log := app.AiRequestLog{
UserID: &userId,
OriginalMessage: originalMsg,
ResponseText: responseText,
LatencyMs: int(latency.Milliseconds()),
}
if preset != nil {
presetID := preset.ID
log.PresetID = &presetID
}
if provider != nil {
providerID := provider.ID
log.ProviderID = &providerID
UserID: userID,
BindingID: binding.ID,
ProviderID: binding.ProviderID,
PresetID: binding.PresetID,
Model: req.Model,
Duration: duration.Milliseconds(),
RequestTime: time.Now(),
}
if err != nil {
@@ -281,21 +268,12 @@ func (s *AiProxyService) logRequest(userId uint, preset *app.AiPreset, provider
log.ErrorMessage = err.Error()
} else {
log.Status = "success"
if resp != nil {
log.PromptTokens = resp.Usage.PromptTokens
log.CompletionTokens = resp.Usage.CompletionTokens
log.TotalTokens = resp.Usage.TotalTokens
}
}
global.GVA_DB.Create(&log)
}
// 辅助函数
func replaceAll(s, old, new string) string {
return strings.ReplaceAll(s, old, new)
}
func containsPlacement(placements []int, target int) bool {
for _, p := range placements {
if p == target {
return true
}
}
return false
}

View File

@@ -1,193 +0,0 @@
package app
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/app"
"git.echol.cn/loser/ai_proxy/server/model/app/request"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ProcessChatCompletionStream 处理流式聊天补全请求
func (s *AiProxyService) ProcessChatCompletionStream(c *gin.Context, userId uint, req *request.ChatCompletionRequest) {
startTime := time.Now()
// 1. 获取预设配置
var preset app.AiPreset
if req.PresetID > 0 {
err := global.GVA_DB.First(&preset, req.PresetID).Error
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "预设不存在"})
return
}
}
// 2. 获取提供商配置
var provider app.AiProvider
if req.BindingKey != "" {
var binding app.AiPresetBinding
err := global.GVA_DB.Where("preset_id = ? AND is_active = ?", req.PresetID, true).
Order("priority ASC").
First(&binding).Error
if err == nil {
global.GVA_DB.First(&provider, binding.ProviderID)
}
}
if provider.ID == 0 {
err := global.GVA_DB.Where("is_active = ?", true).First(&provider).Error
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "未找到可用的AI提供商"})
return
}
}
// 3. 构建注入后的消息
messages, err := s.buildInjectedMessages(req, &preset)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "构建消息失败"})
return
}
// 4. 转发流式请求到上游AI
err = s.forwardStreamToAI(c, &provider, &preset, messages, userId, startTime)
if err != nil {
global.GVA_LOG.Error("流式请求失败", zap.Error(err))
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
// forwardStreamToAI 转发流式请求到上游AI
func (s *AiProxyService) forwardStreamToAI(c *gin.Context, provider *app.AiProvider, preset *app.AiPreset, messages []request.Message, userId uint, startTime time.Time) error {
// 构建请求体
reqBody := map[string]interface{}{
"model": provider.Model,
"messages": messages,
"stream": true,
}
if preset != nil {
reqBody["temperature"] = preset.Temperature
reqBody["top_p"] = preset.TopP
reqBody["max_tokens"] = preset.MaxTokens
reqBody["frequency_penalty"] = preset.FrequencyPenalty
reqBody["presence_penalty"] = preset.PresencePenalty
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return err
}
// 创建HTTP请求
url := fmt.Sprintf("%s/chat/completions", provider.BaseURL)
req, err := http.NewRequestWithContext(c.Request.Context(), "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
if provider.UpstreamKey != "" {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", provider.UpstreamKey))
}
// 发送请求
client := &http.Client{Timeout: 300 * time.Second}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf("API错误: %s - %s", resp.Status, string(body))
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Transfer-Encoding", "chunked")
// 读取并转发流式响应
reader := bufio.NewReader(resp.Body)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return fmt.Errorf("streaming not supported")
}
var fullResponse strings.Builder
for {
line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
break
}
return err
}
// 跳过空行
if len(bytes.TrimSpace(line)) == 0 {
continue
}
// 解析SSE数据
lineStr := string(line)
if strings.HasPrefix(lineStr, "data: ") {
data := strings.TrimPrefix(lineStr, "data: ")
data = strings.TrimSpace(data)
// 检查是否是结束标记
if data == "[DONE]" {
c.Writer.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
break
}
// 解析JSON并提取内容
var chunk map[string]interface{}
if err := json.Unmarshal([]byte(data), &chunk); err == nil {
if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
if choice, ok := choices[0].(map[string]interface{}); ok {
if delta, ok := choice["delta"].(map[string]interface{}); ok {
if content, ok := delta["content"].(string); ok {
fullResponse.WriteString(content)
}
}
}
}
}
// 转发原始数据
c.Writer.Write(line)
flusher.Flush()
}
}
// 应用输出正则脚本
finalContent := fullResponse.String()
if preset != nil {
finalContent = s.applyOutputRegex(finalContent, preset.RegexScripts)
}
// 记录日志
var originalMsg string
if len(messages) > 0 {
originalMsg = messages[len(messages)-1].Content
}
s.logRequest(userId, preset, provider, originalMsg, finalContent, nil, time.Since(startTime))
return nil
}

View File

@@ -1,8 +1,8 @@
package app
type AppServiceGroup struct {
AiPresetService AiPresetService
AiProviderService AiProviderService
AiProxyService AiProxyService
PresetBindingService PresetBindingService
AiProxyService
AiPresetService
AiProviderService
AiPresetBindingService
}

View File

@@ -2,12 +2,14 @@ package service
import (
"git.echol.cn/loser/ai_proxy/server/service/app"
"git.echol.cn/loser/ai_proxy/server/service/example"
"git.echol.cn/loser/ai_proxy/server/service/system"
)
var ServiceGroupApp = new(ServiceGroup)
type ServiceGroup struct {
SystemServiceGroup system.ServiceGroup
AppServiceGroup app.AppServiceGroup
SystemServiceGroup system.ServiceGroup
ExampleServiceGroup example.ServiceGroup
AppServiceGroup app.AppServiceGroup
}

View File

@@ -0,0 +1,7 @@
package example
type ServiceGroup struct {
CustomerService
FileUploadAndDownloadService
AttachmentCategoryService
}

View File

@@ -0,0 +1,66 @@
package example
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/example"
"gorm.io/gorm"
)
type AttachmentCategoryService struct{}
// AddCategory 创建/更新的分类
func (a *AttachmentCategoryService) AddCategory(req *example.ExaAttachmentCategory) (err error) {
// 检查是否已存在相同名称的分类
if (!errors.Is(global.GVA_DB.Take(&example.ExaAttachmentCategory{}, "name = ? and pid = ?", req.Name, req.Pid).Error, gorm.ErrRecordNotFound)) {
return errors.New("分类名称已存在")
}
if req.ID > 0 {
if err = global.GVA_DB.Model(&example.ExaAttachmentCategory{}).Where("id = ?", req.ID).Updates(&example.ExaAttachmentCategory{
Name: req.Name,
Pid: req.Pid,
}).Error; err != nil {
return err
}
} else {
if err = global.GVA_DB.Create(&example.ExaAttachmentCategory{
Name: req.Name,
Pid: req.Pid,
}).Error; err != nil {
return err
}
}
return nil
}
// DeleteCategory 删除分类
func (a *AttachmentCategoryService) DeleteCategory(id *int) error {
var childCount int64
global.GVA_DB.Model(&example.ExaAttachmentCategory{}).Where("pid = ?", id).Count(&childCount)
if childCount > 0 {
return errors.New("请先删除子级")
}
return global.GVA_DB.Where("id = ?", id).Unscoped().Delete(&example.ExaAttachmentCategory{}).Error
}
// GetCategoryList 分类列表
func (a *AttachmentCategoryService) GetCategoryList() (res []*example.ExaAttachmentCategory, err error) {
var fileLists []example.ExaAttachmentCategory
err = global.GVA_DB.Model(&example.ExaAttachmentCategory{}).Find(&fileLists).Error
if err != nil {
return res, err
}
return a.getChildrenList(fileLists, 0), nil
}
// getChildrenList 子类
func (a *AttachmentCategoryService) getChildrenList(categories []example.ExaAttachmentCategory, parentID uint) []*example.ExaAttachmentCategory {
var tree []*example.ExaAttachmentCategory
for _, category := range categories {
if category.Pid == parentID {
category.Children = a.getChildrenList(categories, category.ID)
tree = append(tree, &category)
}
}
return tree
}

View File

@@ -0,0 +1,71 @@
package example
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/example"
"gorm.io/gorm"
)
type FileUploadAndDownloadService struct{}
var FileUploadAndDownloadServiceApp = new(FileUploadAndDownloadService)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: FindOrCreateFile
//@description: 上传文件时检测当前文件属性,如果没有文件则创建,有则返回文件的当前切片
//@param: fileMd5 string, fileName string, chunkTotal int
//@return: file model.ExaFile, err error
func (e *FileUploadAndDownloadService) FindOrCreateFile(fileMd5 string, fileName string, chunkTotal int) (file example.ExaFile, err error) {
var cfile example.ExaFile
cfile.FileMd5 = fileMd5
cfile.FileName = fileName
cfile.ChunkTotal = chunkTotal
if errors.Is(global.GVA_DB.Where("file_md5 = ? AND file_name = ? AND is_finish = ?", fileMd5, fileName, true).First(&file).Error, gorm.ErrRecordNotFound) {
err = global.GVA_DB.Where("file_md5 = ? AND file_name = ?", fileMd5, fileName).Preload("ExaFileChunk").FirstOrCreate(&file, cfile).Error
return file, err
}
cfile.IsFinish = true
cfile.FilePath = file.FilePath
err = global.GVA_DB.Create(&cfile).Error
return cfile, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateFileChunk
//@description: 创建文件切片记录
//@param: id uint, fileChunkPath string, fileChunkNumber int
//@return: error
func (e *FileUploadAndDownloadService) CreateFileChunk(id uint, fileChunkPath string, fileChunkNumber int) error {
var chunk example.ExaFileChunk
chunk.FileChunkPath = fileChunkPath
chunk.ExaFileID = id
chunk.FileChunkNumber = fileChunkNumber
err := global.GVA_DB.Create(&chunk).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteFileChunk
//@description: 删除文件切片记录
//@param: fileMd5 string, fileName string, filePath string
//@return: error
func (e *FileUploadAndDownloadService) DeleteFileChunk(fileMd5 string, filePath string) error {
var chunks []example.ExaFileChunk
var file example.ExaFile
err := global.GVA_DB.Where("file_md5 = ?", fileMd5).First(&file).
Updates(map[string]interface{}{
"IsFinish": true,
"file_path": filePath,
}).Error
if err != nil {
return err
}
err = global.GVA_DB.Where("exa_file_id = ?", file.ID).Delete(&chunks).Unscoped().Error
return err
}

View File

@@ -0,0 +1,87 @@
package example
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/example"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemService "git.echol.cn/loser/ai_proxy/server/service/system"
)
type CustomerService struct{}
var CustomerServiceApp = new(CustomerService)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateExaCustomer
//@description: 创建客户
//@param: e model.ExaCustomer
//@return: err error
func (exa *CustomerService) CreateExaCustomer(e example.ExaCustomer) (err error) {
err = global.GVA_DB.Create(&e).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteFileChunk
//@description: 删除客户
//@param: e model.ExaCustomer
//@return: err error
func (exa *CustomerService) DeleteExaCustomer(e example.ExaCustomer) (err error) {
err = global.GVA_DB.Delete(&e).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateExaCustomer
//@description: 更新客户
//@param: e *model.ExaCustomer
//@return: err error
func (exa *CustomerService) UpdateExaCustomer(e *example.ExaCustomer) (err error) {
err = global.GVA_DB.Save(e).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetExaCustomer
//@description: 获取客户信息
//@param: id uint
//@return: customer model.ExaCustomer, err error
func (exa *CustomerService) GetExaCustomer(id uint) (customer example.ExaCustomer, err error) {
err = global.GVA_DB.Where("id = ?", id).First(&customer).Error
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetCustomerInfoList
//@description: 分页获取客户列表
//@param: sysUserAuthorityID string, info request.PageInfo
//@return: list interface{}, total int64, err error
func (exa *CustomerService) GetCustomerInfoList(sysUserAuthorityID uint, info request.PageInfo) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&example.ExaCustomer{})
var a system.SysAuthority
a.AuthorityId = sysUserAuthorityID
auth, err := systemService.AuthorityServiceApp.GetAuthorityInfo(a)
if err != nil {
return
}
var dataId []uint
for _, v := range auth.DataAuthorityId {
dataId = append(dataId, v.AuthorityId)
}
var CustomerList []example.ExaCustomer
err = db.Where("sys_user_authority_id in ?", dataId).Count(&total).Error
if err != nil {
return CustomerList, total, err
} else {
err = db.Limit(limit).Offset(offset).Preload("SysUser").Where("sys_user_authority_id in ?", dataId).Find(&CustomerList).Error
}
return CustomerList, total, err
}

View File

@@ -0,0 +1,130 @@
package example
import (
"errors"
"mime/multipart"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/example"
"git.echol.cn/loser/ai_proxy/server/model/example/request"
"git.echol.cn/loser/ai_proxy/server/utils/upload"
"gorm.io/gorm"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Upload
//@description: 创建文件上传记录
//@param: file model.ExaFileUploadAndDownload
//@return: error
func (e *FileUploadAndDownloadService) Upload(file example.ExaFileUploadAndDownload) error {
return global.GVA_DB.Create(&file).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: FindFile
//@description: 查询文件记录
//@param: id uint
//@return: model.ExaFileUploadAndDownload, error
func (e *FileUploadAndDownloadService) FindFile(id uint) (example.ExaFileUploadAndDownload, error) {
var file example.ExaFileUploadAndDownload
err := global.GVA_DB.Where("id = ?", id).First(&file).Error
return file, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteFile
//@description: 删除文件记录
//@param: file model.ExaFileUploadAndDownload
//@return: err error
func (e *FileUploadAndDownloadService) DeleteFile(file example.ExaFileUploadAndDownload) (err error) {
var fileFromDb example.ExaFileUploadAndDownload
fileFromDb, err = e.FindFile(file.ID)
if err != nil {
return
}
oss := upload.NewOss()
if err = oss.DeleteFile(fileFromDb.Key); err != nil {
return errors.New("文件删除失败")
}
err = global.GVA_DB.Where("id = ?", file.ID).Unscoped().Delete(&file).Error
return err
}
// EditFileName 编辑文件名或者备注
func (e *FileUploadAndDownloadService) EditFileName(file example.ExaFileUploadAndDownload) (err error) {
var fileFromDb example.ExaFileUploadAndDownload
return global.GVA_DB.Where("id = ?", file.ID).First(&fileFromDb).Update("name", file.Name).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetFileRecordInfoList
//@description: 分页获取数据
//@param: info request.ExaAttachmentCategorySearch
//@return: list interface{}, total int64, err error
func (e *FileUploadAndDownloadService) GetFileRecordInfoList(info request.ExaAttachmentCategorySearch) (list []example.ExaFileUploadAndDownload, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&example.ExaFileUploadAndDownload{})
if len(info.Keyword) > 0 {
db = db.Where("name LIKE ?", "%"+info.Keyword+"%")
}
if info.ClassId > 0 {
db = db.Where("class_id = ?", info.ClassId)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Find(&list).Error
return list, total, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UploadFile
//@description: 根据配置文件判断是文件上传到本地或者七牛云
//@param: header *multipart.FileHeader, noSave string
//@return: file model.ExaFileUploadAndDownload, err error
func (e *FileUploadAndDownloadService) UploadFile(header *multipart.FileHeader, noSave string, classId int) (file example.ExaFileUploadAndDownload, err error) {
oss := upload.NewOss()
filePath, key, uploadErr := oss.UploadFile(header)
if uploadErr != nil {
return file, uploadErr
}
s := strings.Split(header.Filename, ".")
f := example.ExaFileUploadAndDownload{
Url: filePath,
Name: header.Filename,
ClassId: classId,
Tag: s[len(s)-1],
Key: key,
}
if noSave == "0" {
// 检查是否已存在相同key的记录
var existingFile example.ExaFileUploadAndDownload
err = global.GVA_DB.Where(&example.ExaFileUploadAndDownload{Key: key}).First(&existingFile).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return f, e.Upload(f)
}
return f, err
}
return f, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: ImportURL
//@description: 导入URL
//@param: file model.ExaFileUploadAndDownload
//@return: error
func (e *FileUploadAndDownloadService) ImportURL(file *[]example.ExaFileUploadAndDownload) error {
return global.GVA_DB.Create(&file).Error
}

View File

@@ -0,0 +1,217 @@
package system
import (
"context"
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/utils/ast"
"github.com/pkg/errors"
"path"
"path/filepath"
"strconv"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
common "git.echol.cn/loser/ai_proxy/server/model/common/request"
model "git.echol.cn/loser/ai_proxy/server/model/system"
request "git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"go.uber.org/zap"
)
var AutocodeHistory = new(autoCodeHistory)
type autoCodeHistory struct{}
// Create 创建代码生成器历史记录
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) Create(ctx context.Context, info request.SysAutoHistoryCreate) error {
create := info.Create()
err := global.GVA_DB.WithContext(ctx).Create(&create).Error
if err != nil {
return errors.Wrap(err, "创建失败!")
}
return nil
}
// First 根据id获取代码生成器历史的数据
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) First(ctx context.Context, info common.GetById) (string, error) {
var meta string
err := global.GVA_DB.WithContext(ctx).Model(model.SysAutoCodeHistory{}).Where("id = ?", info.ID).Pluck("request", &meta).Error
if err != nil {
return "", errors.Wrap(err, "获取失败!")
}
return meta, nil
}
// Repeat 检测重复
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) Repeat(businessDB, structName, abbreviation, Package string) bool {
var count int64
global.GVA_DB.Model(&model.SysAutoCodeHistory{}).Where("business_db = ? and (struct_name = ? OR abbreviation = ?) and package = ? and flag = ?", businessDB, structName, abbreviation, Package, 0).Count(&count).Debug()
return count > 0
}
// RollBack 回滚
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) RollBack(ctx context.Context, info request.SysAutoHistoryRollBack) error {
var history model.SysAutoCodeHistory
err := global.GVA_DB.Where("id = ?", info.ID).First(&history).Error
if err != nil {
return err
}
if history.ExportTemplateID != 0 {
err = global.GVA_DB.Delete(&model.SysExportTemplate{}, "id = ?", history.ExportTemplateID).Error
if err != nil {
return err
}
}
if info.DeleteApi {
ids := info.ApiIds(history)
err = ApiServiceApp.DeleteApisByIds(ids)
if err != nil {
global.GVA_LOG.Error("ClearTag DeleteApiByIds:", zap.Error(err))
}
} // 清除API表
if info.DeleteMenu {
err = BaseMenuServiceApp.DeleteBaseMenu(int(history.MenuID))
if err != nil {
return errors.Wrap(err, "删除菜单失败!")
}
} // 清除菜单表
if info.DeleteTable {
err = s.DropTable(history.BusinessDB, history.Table)
if err != nil {
return errors.Wrap(err, "删除表失败!")
}
} // 删除表
templates := make(map[string]string, len(history.Templates))
for key, template := range history.Templates {
{
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
keys := strings.Split(key, "/")
key = filepath.Join(keys...)
key = strings.TrimPrefix(key, server)
} // key
{
web := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.WebRoot())
server := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server)
slices := strings.Split(template, "/")
template = filepath.Join(slices...)
ext := path.Ext(template)
switch ext {
case ".js", ".vue":
template = filepath.Join(web, template)
case ".go":
template = filepath.Join(server, template)
}
} // value
templates[key] = template
}
history.Templates = templates
for key, value := range history.Injections {
var injection ast.Ast
switch key {
case ast.TypePackageApiEnter, ast.TypePackageRouterEnter, ast.TypePackageServiceEnter:
case ast.TypePackageApiModuleEnter, ast.TypePackageRouterModuleEnter, ast.TypePackageServiceModuleEnter:
var entity ast.PackageModuleEnter
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePackageInitializeGorm:
var entity ast.PackageInitializeGorm
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePackageInitializeRouter:
var entity ast.PackageInitializeRouter
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePluginGen:
var entity ast.PluginGen
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePluginApiEnter, ast.TypePluginRouterEnter, ast.TypePluginServiceEnter:
var entity ast.PluginEnter
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePluginInitializeGorm:
var entity ast.PluginInitializeGorm
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
case ast.TypePluginInitializeRouter:
var entity ast.PluginInitializeRouter
_ = json.Unmarshal([]byte(value), &entity)
injection = &entity
}
if injection == nil {
continue
}
file, _ := injection.Parse("", nil)
if file != nil {
_ = injection.Rollback(file)
err = injection.Format("", nil, file)
if err != nil {
return err
}
fmt.Printf("[filepath:%s]回滚注入代码成功!\n", key)
}
} // 清除注入代码
removeBasePath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, "rm_file", strconv.FormatInt(int64(time.Now().Nanosecond()), 10))
for _, value := range history.Templates {
if !filepath.IsAbs(value) {
continue
}
removePath := filepath.Join(removeBasePath, strings.TrimPrefix(value, global.GVA_CONFIG.AutoCode.Root))
err = utils.FileMove(value, removePath)
if err != nil {
return errors.Wrapf(err, "[src:%s][dst:%s]文件移动失败!", value, removePath)
}
} // 移动文件
err = global.GVA_DB.WithContext(ctx).Model(&model.SysAutoCodeHistory{}).Where("id = ?", info.ID).Update("flag", 1).Error
if err != nil {
return errors.Wrap(err, "更新失败!")
}
return nil
}
// Delete 删除历史数据
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) Delete(ctx context.Context, info common.GetById) error {
err := global.GVA_DB.WithContext(ctx).Where("id = ?", info.Uint()).Delete(&model.SysAutoCodeHistory{}).Error
if err != nil {
return errors.Wrap(err, "删除失败!")
}
return nil
}
// GetList 获取系统历史数据
// Author [SliverHorn](https://github.com/SliverHorn)
// Author [songzhibin97](https://github.com/songzhibin97)
func (s *autoCodeHistory) GetList(ctx context.Context, info common.PageInfo) (list []model.SysAutoCodeHistory, total int64, err error) {
var entities []model.SysAutoCodeHistory
db := global.GVA_DB.WithContext(ctx).Model(&model.SysAutoCodeHistory{})
err = db.Count(&total).Error
if err != nil {
return nil, total, err
}
err = db.Scopes(info.Paginate()).Order("updated_at desc").Find(&entities).Error
return entities, total, err
}
// DropTable 获取指定数据库和指定数据表的所有字段名,类型值等
// @author: [piexlmax](https://github.com/piexlmax)
func (s *autoCodeHistory) DropTable(BusinessDb, tableName string) error {
if BusinessDb != "" {
return global.MustGetGlobalDBByDBName(BusinessDb).Exec("DROP TABLE " + tableName).Error
} else {
return global.GVA_DB.Exec("DROP TABLE " + tableName).Error
}
}

View File

@@ -0,0 +1,51 @@
package system
import (
"context"
"errors"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common"
commonResp "git.echol.cn/loser/ai_proxy/server/model/common/response"
"git.echol.cn/loser/ai_proxy/server/utils/request"
"github.com/goccy/go-json"
"io"
"strings"
)
// LLMAuto 调用大模型服务,返回生成结果数据
// 入参为通用 JSONMap需包含 mode例如 ai/butler/eye/painter 等)以及业务 prompt/payload
func (s *AutoCodeService) LLMAuto(ctx context.Context, llm common.JSONMap) (interface{}, error) {
if global.GVA_CONFIG.AutoCode.AiPath == "" {
return nil, errors.New("请先前往插件市场个人中心获取AiPath并填入config.yaml中")
}
// 构建调用路径:{AiPath} 中的 {FUNC} 由 mode 替换
mode := fmt.Sprintf("%v", llm["mode"]) // 统一转字符串,避免 nil 造成路径异常
path := strings.ReplaceAll(global.GVA_CONFIG.AutoCode.AiPath, "{FUNC}", mode)
res, err := request.HttpRequest(
path,
"POST",
nil,
nil,
llm,
)
if err != nil {
return nil, fmt.Errorf("大模型生成失败: %w", err)
}
defer res.Body.Close()
var resStruct commonResp.Response
b, err := io.ReadAll(res.Body)
if err != nil {
return nil, fmt.Errorf("读取大模型响应失败: %w", err)
}
if err = json.Unmarshal(b, &resStruct); err != nil {
return nil, fmt.Errorf("解析大模型响应失败: %w", err)
}
if resStruct.Code == 7 { // 业务约定7 表示模型生成失败
return nil, fmt.Errorf("大模型生成失败: %s", resStruct.Msg)
}
return resStruct.Data, nil
}

View File

@@ -0,0 +1,45 @@
package system
import (
"context"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/utils/autocode"
"os"
"path/filepath"
"text/template"
)
func (s *autoCodeTemplate) CreateMcp(ctx context.Context, info request.AutoMcpTool) (toolFilePath string, err error) {
mcpTemplatePath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "resource", "mcp", "tools.tpl")
mcpToolPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "mcp")
var files *template.Template
templateName := filepath.Base(mcpTemplatePath)
files, err = template.New(templateName).Funcs(autocode.GetTemplateFuncMap()).ParseFiles(mcpTemplatePath)
if err != nil {
return
}
fileName := utils.HumpToUnderscore(info.Name)
toolFilePath = filepath.Join(mcpToolPath, fileName+".go")
f, err := os.Create(toolFilePath)
if err != nil {
return
}
defer f.Close()
// 执行模板,将内容写入文件
err = files.Execute(f, info)
if err != nil {
return
}
return
}

View File

@@ -0,0 +1,743 @@
package system
import (
"context"
"fmt"
"go/token"
"os"
"path/filepath"
"strings"
"text/template"
"git.echol.cn/loser/ai_proxy/server/global"
common "git.echol.cn/loser/ai_proxy/server/model/common/request"
model "git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/utils/ast"
"git.echol.cn/loser/ai_proxy/server/utils/autocode"
"github.com/pkg/errors"
"gorm.io/gorm"
)
var AutoCodePackage = new(autoCodePackage)
type autoCodePackage struct{}
// Create 创建包信息
// @author: [piexlmax](https://github.com/piexlmax)
// @author: [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodePackage) Create(ctx context.Context, info *request.SysAutoCodePackageCreate) error {
switch {
case info.Template == "":
return errors.New("模板不能为空!")
case info.Template == "page":
return errors.New("page为表单生成器!")
case info.PackageName == "":
return errors.New("PackageName不能为空!")
case token.IsKeyword(info.PackageName):
return errors.Errorf("%s为go的关键字!", info.PackageName)
case info.Template == "package":
if info.PackageName == "system" || info.PackageName == "example" {
return errors.New("不能使用已保留的package name")
}
default:
break
}
if !errors.Is(global.GVA_DB.Where("package_name = ? and template = ?", info.PackageName, info.Template).First(&model.SysAutoCodePackage{}).Error, gorm.ErrRecordNotFound) {
return errors.New("存在相同PackageName")
}
create := info.Create()
return global.GVA_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Create(&create).Error
if err != nil {
return errors.Wrap(err, "创建失败!")
}
code := info.AutoCode()
_, asts, creates, err := s.templates(ctx, create, code, true)
if err != nil {
return err
}
for key, value := range creates { // key 为 模版绝对路径
var files *template.Template
files, err = template.New(filepath.Base(key)).Funcs(autocode.GetTemplateFuncMap()).ParseFiles(key)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]读取模版文件失败!", key)
}
err = os.MkdirAll(filepath.Dir(value), os.ModePerm)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]创建文件夹失败!", value)
}
var file *os.File
file, err = os.Create(value)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]创建文件夹失败!", value)
}
err = files.Execute(file, code)
_ = file.Close()
if err != nil {
return errors.Wrapf(err, "[filepath:%s]生成失败!", value)
}
fmt.Printf("[template:%s][filepath:%s]生成成功!\n", key, value)
}
for key, value := range asts {
keys := strings.Split(key, "=>")
if len(keys) == 2 {
switch keys[1] {
case ast.TypePluginInitializeV2, ast.TypePackageApiEnter, ast.TypePackageRouterEnter, ast.TypePackageServiceEnter:
file, _ := value.Parse("", nil)
if file != nil {
err = value.Injection(file)
if err != nil {
return err
}
err = value.Format("", nil, file)
if err != nil {
return err
}
}
fmt.Printf("[type:%s]注入成功!\n", key)
}
}
}
return nil
})
}
// Delete 删除包记录
// @author: [piexlmax](https://github.com/piexlmax)
// @author: [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodePackage) Delete(ctx context.Context, info common.GetById) error {
err := global.GVA_DB.WithContext(ctx).Delete(&model.SysAutoCodePackage{}, info.Uint()).Error
if err != nil {
return errors.Wrap(err, "删除失败!")
}
return nil
}
// DeleteByNames
// @author: [piexlmax](https://github.com/piexlmax)
// @author: [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodePackage) DeleteByNames(ctx context.Context, names []string) error {
if len(names) == 0 {
return nil
}
err := global.GVA_DB.WithContext(ctx).Where("package_name IN ?", names).Delete(&model.SysAutoCodePackage{}).Error
if err != nil {
return errors.Wrap(err, "删除失败!")
}
return nil
}
// All 获取所有包
// @author: [piexlmax](https://github.com/piexlmax)
// @author: [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodePackage) All(ctx context.Context) (entities []model.SysAutoCodePackage, err error) {
server := make([]model.SysAutoCodePackage, 0)
plugin := make([]model.SysAutoCodePackage, 0)
serverPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service")
pluginPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin")
serverDir, err := os.ReadDir(serverPath)
if err != nil {
return nil, errors.Wrap(err, "读取service文件夹失败!")
}
pluginDir, err := os.ReadDir(pluginPath)
if err != nil {
return nil, errors.Wrap(err, "读取plugin文件夹失败!")
}
for i := 0; i < len(serverDir); i++ {
if serverDir[i].IsDir() {
serverPackage := model.SysAutoCodePackage{
PackageName: serverDir[i].Name(),
Template: "package",
Label: serverDir[i].Name() + "包",
Desc: "系统自动读取" + serverDir[i].Name() + "包",
Module: global.GVA_CONFIG.AutoCode.Module,
}
server = append(server, serverPackage)
}
}
for i := 0; i < len(pluginDir); i++ {
if pluginDir[i].IsDir() {
dirNameMap := map[string]bool{
"api": true,
"config": true,
"initialize": true,
"plugin": true,
"router": true,
"service": true,
}
dir, e := os.ReadDir(filepath.Join(pluginPath, pluginDir[i].Name()))
if e != nil {
return nil, errors.Wrap(err, "读取plugin文件夹失败!")
}
//dir目录需要包含所有的dirNameMap
for k := 0; k < len(dir); k++ {
if dir[k].IsDir() {
if ok := dirNameMap[dir[k].Name()]; ok {
delete(dirNameMap, dir[k].Name())
}
}
}
var desc string
if len(dirNameMap) == 0 {
// 完全符合标准结构
desc = "系统自动读取" + pluginDir[i].Name() + "插件使用前请确认是否为v2版本插件"
} else {
// 缺少某些结构,生成警告描述
var missingDirs []string
for dirName := range dirNameMap {
missingDirs = append(missingDirs, dirName)
}
desc = fmt.Sprintf("系统自动读取,但是缺少 %s 结构不建议自动化和mcp使用", strings.Join(missingDirs, "、"))
}
pluginPackage := model.SysAutoCodePackage{
PackageName: pluginDir[i].Name(),
Template: "plugin",
Label: pluginDir[i].Name() + "插件",
Desc: desc,
Module: global.GVA_CONFIG.AutoCode.Module,
}
plugin = append(plugin, pluginPackage)
}
}
err = global.GVA_DB.WithContext(ctx).Find(&entities).Error
if err != nil {
return nil, errors.Wrap(err, "获取所有包失败!")
}
entitiesMap := make(map[string]model.SysAutoCodePackage)
for i := 0; i < len(entities); i++ {
entitiesMap[entities[i].PackageName] = entities[i]
}
createEntity := []model.SysAutoCodePackage{}
for i := 0; i < len(server); i++ {
if _, ok := entitiesMap[server[i].PackageName]; !ok {
if server[i].Template == "package" {
createEntity = append(createEntity, server[i])
}
}
}
for i := 0; i < len(plugin); i++ {
if _, ok := entitiesMap[plugin[i].PackageName]; !ok {
if plugin[i].Template == "plugin" {
createEntity = append(createEntity, plugin[i])
}
}
}
if len(createEntity) > 0 {
err = global.GVA_DB.WithContext(ctx).Create(&createEntity).Error
if err != nil {
return nil, errors.Wrap(err, "同步失败!")
}
entities = append(entities, createEntity...)
}
// 处理数据库存在但实体文件不存在的情况 - 删除数据库中对应的数据
existingPackageNames := make(map[string]bool)
// 收集所有存在的包名
for i := 0; i < len(server); i++ {
existingPackageNames[server[i].PackageName] = true
}
for i := 0; i < len(plugin); i++ {
existingPackageNames[plugin[i].PackageName] = true
}
// 找出需要删除的数据库记录
deleteEntityIDs := []uint{}
for i := 0; i < len(entities); i++ {
if !existingPackageNames[entities[i].PackageName] {
deleteEntityIDs = append(deleteEntityIDs, entities[i].ID)
}
}
// 删除数据库中不存在文件的记录
if len(deleteEntityIDs) > 0 {
err = global.GVA_DB.WithContext(ctx).Delete(&model.SysAutoCodePackage{}, deleteEntityIDs).Error
if err != nil {
return nil, errors.Wrap(err, "删除不存在的包记录失败!")
}
// 从返回结果中移除已删除的记录
filteredEntities := []model.SysAutoCodePackage{}
for i := 0; i < len(entities); i++ {
if existingPackageNames[entities[i].PackageName] {
filteredEntities = append(filteredEntities, entities[i])
}
}
entities = filteredEntities
}
return entities, nil
}
// Templates 获取所有模版文件夹
// @author: [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodePackage) Templates(ctx context.Context) ([]string, error) {
templates := make([]string, 0)
entries, err := os.ReadDir("resource")
if err != nil {
return nil, errors.Wrap(err, "读取模版文件夹失败!")
}
for i := 0; i < len(entries); i++ {
if entries[i].IsDir() {
if entries[i].Name() == "page" {
continue
} // page 为表单生成器
if entries[i].Name() == "function" {
continue
} // function 为函数生成器
if entries[i].Name() == "preview" {
continue
} // preview 为预览代码生成器的代码
if entries[i].Name() == "mcp" {
continue
} // preview 为mcp生成器的代码
templates = append(templates, entries[i].Name())
}
}
return templates, nil
}
func (s *autoCodePackage) templates(ctx context.Context, entity model.SysAutoCodePackage, info request.AutoCode, isPackage bool) (code map[string]string, asts map[string]ast.Ast, creates map[string]string, err error) {
code = make(map[string]string)
asts = make(map[string]ast.Ast)
creates = make(map[string]string)
templateDir := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "resource", entity.Template)
templateDirs, err := os.ReadDir(templateDir)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", templateDir)
}
for i := 0; i < len(templateDirs); i++ {
second := filepath.Join(templateDir, templateDirs[i].Name())
switch templateDirs[i].Name() {
case "server":
if !info.GenerateServer && !isPackage {
break
}
var secondDirs []os.DirEntry
secondDirs, err = os.ReadDir(second)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", second)
}
for j := 0; j < len(secondDirs); j++ {
if secondDirs[j].Name() == ".DS_Store" {
continue
}
three := filepath.Join(second, secondDirs[j].Name())
if !secondDirs[j].IsDir() {
ext := filepath.Ext(secondDirs[j].Name())
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", three)
}
name := strings.TrimSuffix(secondDirs[j].Name(), ext)
if name == "main.go" || name == "plugin.go" {
pluginInitialize := &ast.PluginInitializeV2{
Type: ast.TypePluginInitializeV2,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, name),
PluginPath: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "register.go"),
ImportPath: fmt.Sprintf(`"%s/plugin/%s"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
PackageName: entity.PackageName,
}
asts[pluginInitialize.PluginPath+"=>"+pluginInitialize.Type.String()] = pluginInitialize
creates[three] = pluginInitialize.Path
continue
}
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", three)
}
switch secondDirs[j].Name() {
case "api", "router", "service":
var threeDirs []os.DirEntry
threeDirs, err = os.ReadDir(three)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", three)
}
for k := 0; k < len(threeDirs); k++ {
if threeDirs[k].Name() == ".DS_Store" {
continue
}
four := filepath.Join(three, threeDirs[k].Name())
if threeDirs[k].IsDir() {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", four)
}
ext := filepath.Ext(four)
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", four)
}
api := strings.Index(threeDirs[k].Name(), "api")
hasEnter := strings.Index(threeDirs[k].Name(), "enter")
router := strings.Index(threeDirs[k].Name(), "router")
service := strings.Index(threeDirs[k].Name(), "service")
if router == -1 && api == -1 && service == -1 && hasEnter == -1 {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", four)
}
if entity.Template == "package" {
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), entity.PackageName, info.HumpPackageName+".go")
if api != -1 {
create = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), "v1", entity.PackageName, info.HumpPackageName+".go")
}
if hasEnter != -1 {
isApi := strings.Index(secondDirs[j].Name(), "api")
isRouter := strings.Index(secondDirs[j].Name(), "router")
isService := strings.Index(secondDirs[j].Name(), "service")
if isApi != -1 {
packageApiEnter := &ast.PackageEnter{
Type: ast.TypePackageApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), "v1", "enter.go"),
ImportPath: fmt.Sprintf(`"%s/%s/%s/%s"`, global.GVA_CONFIG.AutoCode.Module, "api", "v1", entity.PackageName),
StructName: utils.FirstUpper(entity.PackageName) + "ApiGroup",
PackageName: entity.PackageName,
PackageStructName: "ApiGroup",
}
asts[packageApiEnter.Path+"=>"+packageApiEnter.Type.String()] = packageApiEnter
packageApiModuleEnter := &ast.PackageModuleEnter{
Type: ast.TypePackageApiModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), "v1", entity.PackageName, "enter.go"),
ImportPath: fmt.Sprintf(`"%s/service"`, global.GVA_CONFIG.AutoCode.Module),
StructName: info.StructName + "Api",
AppName: "ServiceGroupApp",
GroupName: utils.FirstUpper(entity.PackageName) + "ServiceGroup",
ModuleName: info.Abbreviation + "Service",
PackageName: "service",
ServiceName: info.StructName + "Service",
}
asts[packageApiModuleEnter.Path+"=>"+packageApiModuleEnter.Type.String()] = packageApiModuleEnter
creates[four] = packageApiModuleEnter.Path
}
if isRouter != -1 {
packageRouterEnter := &ast.PackageEnter{
Type: ast.TypePackageRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), "enter.go"),
ImportPath: fmt.Sprintf(`"%s/%s/%s"`, global.GVA_CONFIG.AutoCode.Module, secondDirs[j].Name(), entity.PackageName),
StructName: utils.FirstUpper(entity.PackageName),
PackageName: entity.PackageName,
PackageStructName: "RouterGroup",
}
asts[packageRouterEnter.Path+"=>"+packageRouterEnter.Type.String()] = packageRouterEnter
packageRouterModuleEnter := &ast.PackageModuleEnter{
Type: ast.TypePackageRouterModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), entity.PackageName, "enter.go"),
ImportPath: fmt.Sprintf(`api "%s/api/v1"`, global.GVA_CONFIG.AutoCode.Module),
StructName: info.StructName + "Router",
AppName: "ApiGroupApp",
GroupName: utils.FirstUpper(entity.PackageName) + "ApiGroup",
ModuleName: info.Abbreviation + "Api",
PackageName: "api",
ServiceName: info.StructName + "Api",
}
creates[four] = packageRouterModuleEnter.Path
asts[packageRouterModuleEnter.Path+"=>"+packageRouterModuleEnter.Type.String()] = packageRouterModuleEnter
packageInitializeRouter := &ast.PackageInitializeRouter{
Type: ast.TypePackageInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router_biz.go"),
ImportPath: fmt.Sprintf(`"%s/router"`, global.GVA_CONFIG.AutoCode.Module),
AppName: "RouterGroupApp",
GroupName: utils.FirstUpper(entity.PackageName),
ModuleName: entity.PackageName + "Router",
PackageName: "router",
FunctionName: "Init" + info.StructName + "Router",
LeftRouterGroupName: "privateGroup",
RightRouterGroupName: "publicGroup",
}
asts[packageInitializeRouter.Path+"=>"+packageInitializeRouter.Type.String()] = packageInitializeRouter
}
if isService != -1 {
path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext))
importPath := fmt.Sprintf(`"%s/service/%s"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName)
packageServiceEnter := &ast.PackageEnter{
Type: ast.TypePackageServiceEnter,
Path: path,
ImportPath: importPath,
StructName: utils.FirstUpper(entity.PackageName) + "ServiceGroup",
PackageName: entity.PackageName,
PackageStructName: "ServiceGroup",
}
asts[packageServiceEnter.Path+"=>"+packageServiceEnter.Type.String()] = packageServiceEnter
packageServiceModuleEnter := &ast.PackageModuleEnter{
Type: ast.TypePackageServiceModuleEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), entity.PackageName, "enter.go"),
StructName: info.StructName + "Service",
}
asts[packageServiceModuleEnter.Path+"=>"+packageServiceModuleEnter.Type.String()] = packageServiceModuleEnter
creates[four] = packageServiceModuleEnter.Path
}
continue
}
code[four] = create
continue
}
if hasEnter != -1 {
isApi := strings.Index(secondDirs[j].Name(), "api")
isRouter := strings.Index(secondDirs[j].Name(), "router")
isService := strings.Index(secondDirs[j].Name(), "service")
if isRouter != -1 {
pluginRouterEnter := &ast.PluginEnter{
Type: ast.TypePluginRouterEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
ImportPath: fmt.Sprintf(`"%s/plugin/%s/api"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
StructName: info.StructName,
StructCamelName: info.Abbreviation,
ModuleName: "api" + info.StructName,
GroupName: "Api",
PackageName: "api",
ServiceName: info.StructName,
}
asts[pluginRouterEnter.Path+"=>"+pluginRouterEnter.Type.String()] = pluginRouterEnter
creates[four] = pluginRouterEnter.Path
}
if isApi != -1 {
pluginApiEnter := &ast.PluginEnter{
Type: ast.TypePluginApiEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
ImportPath: fmt.Sprintf(`"%s/plugin/%s/service"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
StructName: info.StructName,
StructCamelName: info.Abbreviation,
ModuleName: "service" + info.StructName,
GroupName: "Service",
PackageName: "service",
ServiceName: info.StructName,
}
asts[pluginApiEnter.Path+"=>"+pluginApiEnter.Type.String()] = pluginApiEnter
creates[four] = pluginApiEnter.Path
}
if isService != -1 {
pluginServiceEnter := &ast.PluginEnter{
Type: ast.TypePluginServiceEnter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
StructName: info.StructName,
StructCamelName: info.Abbreviation,
}
asts[pluginServiceEnter.Path+"=>"+pluginServiceEnter.Type.String()] = pluginServiceEnter
creates[four] = pluginServiceEnter.Path
}
continue
} // enter.go
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), info.HumpPackageName+".go")
code[four] = create
}
case "gen", "config", "initialize", "plugin", "response":
if entity.Template == "package" {
continue
} // package模板不需要生成gen, config, initialize
var threeDirs []os.DirEntry
threeDirs, err = os.ReadDir(three)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", three)
}
for k := 0; k < len(threeDirs); k++ {
if threeDirs[k].Name() == ".DS_Store" {
continue
}
four := filepath.Join(three, threeDirs[k].Name())
if threeDirs[k].IsDir() {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", four)
}
ext := filepath.Ext(four)
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", four)
}
gen := strings.Index(threeDirs[k].Name(), "gen")
api := strings.Index(threeDirs[k].Name(), "api")
menu := strings.Index(threeDirs[k].Name(), "menu")
viper := strings.Index(threeDirs[k].Name(), "viper")
plugin := strings.Index(threeDirs[k].Name(), "plugin")
config := strings.Index(threeDirs[k].Name(), "config")
router := strings.Index(threeDirs[k].Name(), "router")
hasGorm := strings.Index(threeDirs[k].Name(), "gorm")
response := strings.Index(threeDirs[k].Name(), "response")
dictionary := strings.Index(threeDirs[k].Name(), "dictionary")
if gen != -1 && api != -1 && menu != -1 && viper != -1 && plugin != -1 && config != -1 && router != -1 && hasGorm != -1 && response != -1 && dictionary != -1 {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", four)
}
if api != -1 || menu != -1 || viper != -1 || response != -1 || plugin != -1 || config != -1 || dictionary != -1 {
creates[four] = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext))
}
if gen != -1 {
pluginGen := &ast.PluginGen{
Type: ast.TypePluginGen,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
ImportPath: fmt.Sprintf(`"%s/plugin/%s/model"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
StructName: info.StructName,
PackageName: "model",
IsNew: true,
}
asts[pluginGen.Path+"=>"+pluginGen.Type.String()] = pluginGen
creates[four] = pluginGen.Path
}
if hasGorm != -1 {
pluginInitializeGorm := &ast.PluginInitializeGorm{
Type: ast.TypePluginInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
ImportPath: fmt.Sprintf(`"%s/plugin/%s/model"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
StructName: info.StructName,
PackageName: "model",
IsNew: true,
}
asts[pluginInitializeGorm.Path+"=>"+pluginInitializeGorm.Type.String()] = pluginInitializeGorm
creates[four] = pluginInitializeGorm.Path
}
if router != -1 {
pluginInitializeRouter := &ast.PluginInitializeRouter{
Type: ast.TypePluginInitializeRouter,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), strings.TrimSuffix(threeDirs[k].Name(), ext)),
ImportPath: fmt.Sprintf(`"%s/plugin/%s/router"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
AppName: "Router",
GroupName: info.StructName,
PackageName: "router",
FunctionName: "Init",
LeftRouterGroupName: "public",
RightRouterGroupName: "private",
}
asts[pluginInitializeRouter.Path+"=>"+pluginInitializeRouter.Type.String()] = pluginInitializeRouter
creates[four] = pluginInitializeRouter.Path
}
}
case "model":
var threeDirs []os.DirEntry
threeDirs, err = os.ReadDir(three)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", three)
}
for k := 0; k < len(threeDirs); k++ {
if threeDirs[k].Name() == ".DS_Store" {
continue
}
four := filepath.Join(three, threeDirs[k].Name())
if threeDirs[k].IsDir() {
var fourDirs []os.DirEntry
fourDirs, err = os.ReadDir(four)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", four)
}
for l := 0; l < len(fourDirs); l++ {
if fourDirs[l].Name() == ".DS_Store" {
continue
}
five := filepath.Join(four, fourDirs[l].Name())
if fourDirs[l].IsDir() {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", five)
}
ext := filepath.Ext(five)
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", five)
}
hasRequest := strings.Index(fourDirs[l].Name(), "request")
if hasRequest == -1 {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", five)
}
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), threeDirs[k].Name(), info.HumpPackageName+".go")
if entity.Template == "package" {
create = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), entity.PackageName, threeDirs[k].Name(), info.HumpPackageName+".go")
}
code[five] = create
}
continue
}
ext := filepath.Ext(threeDirs[k].Name())
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", four)
}
hasModel := strings.Index(threeDirs[k].Name(), "model")
if hasModel == -1 {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", four)
}
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", entity.PackageName, secondDirs[j].Name(), info.HumpPackageName+".go")
if entity.Template == "package" {
packageInitializeGorm := &ast.PackageInitializeGorm{
Type: ast.TypePackageInitializeGorm,
Path: filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm_biz.go"),
ImportPath: fmt.Sprintf(`"%s/model/%s"`, global.GVA_CONFIG.AutoCode.Module, entity.PackageName),
Business: info.BusinessDB,
StructName: info.StructName,
PackageName: entity.PackageName,
IsNew: true,
}
code[four] = packageInitializeGorm.Path
asts[packageInitializeGorm.Path+"=>"+packageInitializeGorm.Type.String()] = packageInitializeGorm
create = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, secondDirs[j].Name(), entity.PackageName, info.HumpPackageName+".go")
}
code[four] = create
}
default:
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", three)
}
}
case "web":
if !info.GenerateWeb && !isPackage {
break
}
var secondDirs []os.DirEntry
secondDirs, err = os.ReadDir(second)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", second)
}
for j := 0; j < len(secondDirs); j++ {
if secondDirs[j].Name() == ".DS_Store" {
continue
}
three := filepath.Join(second, secondDirs[j].Name())
if !secondDirs[j].IsDir() {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", three)
}
switch secondDirs[j].Name() {
case "api", "form", "view", "table":
var threeDirs []os.DirEntry
threeDirs, err = os.ReadDir(three)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "读取模版文件夹[%s]失败!", three)
}
for k := 0; k < len(threeDirs); k++ {
if threeDirs[k].Name() == ".DS_Store" {
continue
}
four := filepath.Join(three, threeDirs[k].Name())
if threeDirs[k].IsDir() {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", four)
}
ext := filepath.Ext(four)
if ext != ".tpl" {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版后缀!", four)
}
api := strings.Index(threeDirs[k].Name(), "api")
form := strings.Index(threeDirs[k].Name(), "form")
view := strings.Index(threeDirs[k].Name(), "view")
table := strings.Index(threeDirs[k].Name(), "table")
if api == -1 && form == -1 && view == -1 && table == -1 {
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", four)
}
if entity.Template == "package" {
if view != -1 || table != -1 {
formPath := filepath.Join(three, "form.vue"+ext)
value, ok := code[formPath]
if ok {
value = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.WebRoot(), secondDirs[j].Name(), entity.PackageName, info.PackageName, info.PackageName+"Form"+filepath.Ext(strings.TrimSuffix(threeDirs[k].Name(), ext)))
code[formPath] = value
}
}
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.WebRoot(), secondDirs[j].Name(), entity.PackageName, info.PackageName, info.PackageName+filepath.Ext(strings.TrimSuffix(threeDirs[k].Name(), ext)))
if api != -1 {
create = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.WebRoot(), secondDirs[j].Name(), entity.PackageName, info.PackageName+filepath.Ext(strings.TrimSuffix(threeDirs[k].Name(), ext)))
}
code[four] = create
continue
}
create := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.WebRoot(), "plugin", entity.PackageName, secondDirs[j].Name(), info.PackageName+filepath.Ext(strings.TrimSuffix(threeDirs[k].Name(), ext)))
code[four] = create
}
default:
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件夹!", three)
}
}
case "readme.txt.tpl", "readme.txt.template":
continue
default:
if templateDirs[i].Name() == ".DS_Store" {
continue
}
return nil, nil, nil, errors.Errorf("[filpath:%s]非法模版文件!", second)
}
}
return code, asts, creates, nil
}

View File

@@ -0,0 +1,108 @@
package system
import (
"context"
"reflect"
"testing"
model "git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
)
func Test_autoCodePackage_Create(t *testing.T) {
type args struct {
ctx context.Context
info *request.SysAutoCodePackageCreate
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "测试 package",
args: args{
ctx: context.Background(),
info: &request.SysAutoCodePackageCreate{
Template: "package",
PackageName: "gva",
},
},
wantErr: false,
},
{
name: "测试 plugin",
args: args{
ctx: context.Background(),
info: &request.SysAutoCodePackageCreate{
Template: "plugin",
PackageName: "gva",
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &autoCodePackage{}
if err := a.Create(tt.args.ctx, tt.args.info); (err != nil) != tt.wantErr {
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_autoCodePackage_templates(t *testing.T) {
type args struct {
ctx context.Context
entity model.SysAutoCodePackage
info request.AutoCode
isPackage bool
}
tests := []struct {
name string
args args
wantCode map[string]string
wantEnter map[string]map[string]string
wantErr bool
}{
{
name: "测试1",
args: args{
ctx: context.Background(),
entity: model.SysAutoCodePackage{
Desc: "描述",
Label: "展示名",
Template: "plugin",
PackageName: "preview",
},
info: request.AutoCode{
Abbreviation: "user",
HumpPackageName: "user",
},
isPackage: false,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &autoCodePackage{}
gotCode, gotEnter, gotCreates, err := s.templates(tt.args.ctx, tt.args.entity, tt.args.info, tt.args.isPackage)
if (err != nil) != tt.wantErr {
t.Errorf("templates() error = %v, wantErr %v", err, tt.wantErr)
return
}
for key, value := range gotCode {
t.Logf("\n")
t.Logf(key)
t.Logf(value)
t.Logf("\n")
}
t.Log(gotCreates)
if !reflect.DeepEqual(gotEnter, tt.wantEnter) {
t.Errorf("templates() gotEnter = %v, want %v", gotEnter, tt.wantEnter)
}
})
}
}

View File

@@ -0,0 +1,512 @@
package system
import (
"bytes"
"context"
"fmt"
goast "go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"mime/multipart"
"os"
"path/filepath"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
pluginUtils "git.echol.cn/loser/ai_proxy/server/plugin/plugin-tool/utils"
"git.echol.cn/loser/ai_proxy/server/utils"
ast "git.echol.cn/loser/ai_proxy/server/utils/ast"
"github.com/mholt/archives"
cp "github.com/otiai10/copy"
"github.com/pkg/errors"
"go.uber.org/zap"
)
var AutoCodePlugin = new(autoCodePlugin)
type autoCodePlugin struct{}
// Install 插件安装
func (s *autoCodePlugin) Install(file *multipart.FileHeader) (web, server int, err error) {
const GVAPLUGPINATH = "./gva-plug-temp/"
defer os.RemoveAll(GVAPLUGPINATH)
_, err = os.Stat(GVAPLUGPINATH)
if os.IsNotExist(err) {
os.Mkdir(GVAPLUGPINATH, os.ModePerm)
}
src, err := file.Open()
if err != nil {
return -1, -1, err
}
defer src.Close()
// 在临时目录创建目标文件
// 使用完整路径拼接的好处:明确文件位置,避免路径混乱
out, err := os.Create(GVAPLUGPINATH + file.Filename)
if err != nil {
return -1, -1, err
}
// 将上传的文件内容复制到临时文件
// 使用io.Copy的好处高效处理大文件自动管理缓冲区避免内存溢出
_, err = io.Copy(out, src)
if err != nil {
out.Close()
return -1, -1, err
}
// 立即关闭文件,确保数据写入磁盘并释放文件句柄
// 必须在解压前关闭否则在Windows系统上会导致文件被占用无法解压
err = out.Close()
if err != nil {
return -1, -1, err
}
paths, err := utils.Unzip(GVAPLUGPINATH+file.Filename, GVAPLUGPINATH)
paths = filterFile(paths)
var webIndex = -1
var serverIndex = -1
webPlugin := ""
serverPlugin := ""
serverPackage := ""
serverRootName := ""
for i := range paths {
paths[i] = filepath.ToSlash(paths[i])
pathArr := strings.Split(paths[i], "/")
ln := len(pathArr)
if ln < 4 {
continue
}
if pathArr[2]+"/"+pathArr[3] == `server/plugin` {
if len(serverPlugin) == 0 {
serverPlugin = filepath.Join(pathArr[0], pathArr[1], pathArr[2], pathArr[3])
}
if serverRootName == "" && ln > 1 && pathArr[1] != "" {
serverRootName = pathArr[1]
}
if ln > 4 && serverPackage == "" && pathArr[4] != "" {
serverPackage = pathArr[4]
}
}
if pathArr[2]+"/"+pathArr[3] == `web/plugin` && len(webPlugin) == 0 {
webPlugin = filepath.Join(pathArr[0], pathArr[1], pathArr[2], pathArr[3])
}
}
if len(serverPlugin) == 0 && len(webPlugin) == 0 {
zap.L().Error("非标准插件,请按照文档自动迁移使用")
return webIndex, serverIndex, errors.New("非标准插件,请按照文档自动迁移使用")
}
if len(serverPlugin) != 0 {
if serverPackage == "" {
serverPackage = serverRootName
}
err = installation(serverPlugin, global.GVA_CONFIG.AutoCode.Server, global.GVA_CONFIG.AutoCode.Server)
if err != nil {
return webIndex, serverIndex, err
}
err = ensurePluginRegisterImport(serverPackage)
if err != nil {
return webIndex, serverIndex, err
}
}
if len(webPlugin) != 0 {
err = installation(webPlugin, global.GVA_CONFIG.AutoCode.Server, global.GVA_CONFIG.AutoCode.Web)
if err != nil {
return webIndex, serverIndex, err
}
}
return 1, 1, err
}
func installation(path string, formPath string, toPath string) error {
arr := strings.Split(filepath.ToSlash(path), "/")
ln := len(arr)
if ln < 3 {
return errors.New("arr")
}
name := arr[ln-3]
var form = filepath.Join(global.GVA_CONFIG.AutoCode.Root, formPath, path)
var to = filepath.Join(global.GVA_CONFIG.AutoCode.Root, toPath, "plugin")
_, err := os.Stat(to + name)
if err == nil {
zap.L().Error("autoPath 已存在同名插件,请自行手动安装", zap.String("to", to))
return errors.New(toPath + "已存在同名插件,请自行手动安装")
}
return cp.Copy(form, to, cp.Options{Skip: skipMacSpecialDocument})
}
func ensurePluginRegisterImport(packageName string) error {
module := strings.TrimSpace(global.GVA_CONFIG.AutoCode.Module)
if module == "" {
return errors.New("autocode module is empty")
}
if packageName == "" {
return errors.New("plugin package is empty")
}
registerPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "register.go")
src, err := os.ReadFile(registerPath)
if err != nil {
return err
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, registerPath, src, parser.ParseComments)
if err != nil {
return err
}
importPath := fmt.Sprintf("%s/plugin/%s", module, packageName)
if ast.CheckImport(astFile, importPath) {
return nil
}
importSpec := &goast.ImportSpec{
Name: goast.NewIdent("_"),
Path: &goast.BasicLit{Kind: token.STRING, Value: fmt.Sprintf("%q", importPath)},
}
var importDecl *goast.GenDecl
for _, decl := range astFile.Decls {
genDecl, ok := decl.(*goast.GenDecl)
if !ok {
continue
}
if genDecl.Tok == token.IMPORT {
importDecl = genDecl
break
}
}
if importDecl == nil {
astFile.Decls = append([]goast.Decl{
&goast.GenDecl{
Tok: token.IMPORT,
Specs: []goast.Spec{importSpec},
},
}, astFile.Decls...)
} else {
importDecl.Specs = append(importDecl.Specs, importSpec)
}
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
return os.WriteFile(registerPath, bf.Bytes(), 0666)
}
func filterFile(paths []string) []string {
np := make([]string, 0, len(paths))
for _, path := range paths {
if ok, _ := skipMacSpecialDocument(nil, path, ""); ok {
continue
}
np = append(np, path)
}
return np
}
func skipMacSpecialDocument(_ os.FileInfo, src, _ string) (bool, error) {
if strings.Contains(src, ".DS_Store") || strings.Contains(src, "__MACOSX") {
return true, nil
}
return false, nil
}
func (s *autoCodePlugin) PubPlug(plugName string) (zipPath string, err error) {
if plugName == "" {
return "", errors.New("插件名称不能为空")
}
// 防止路径穿越
plugName = filepath.Clean(plugName)
webPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Web, "plugin", plugName)
serverPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", plugName)
// 创建一个新的zip文件
// 判断目录是否存在
_, err = os.Stat(webPath)
if err != nil {
return "", errors.New("web路径不存在")
}
_, err = os.Stat(serverPath)
if err != nil {
return "", errors.New("server路径不存在")
}
fileName := plugName + ".zip"
// 创建一个新的zip文件
files, err := archives.FilesFromDisk(context.Background(), nil, map[string]string{
webPath: plugName + "/web/plugin/" + plugName,
serverPath: plugName + "/server/plugin/" + plugName,
})
// create the output file we'll write to
out, err := os.Create(fileName)
if err != nil {
return
}
defer out.Close()
// we can use the CompressedArchive type to gzip a tarball
// (compression is not required; you could use Tar directly)
format := archives.CompressedArchive{
//Compression: archives.Gz{},
Archival: archives.Zip{},
}
// create the archive
err = format.Archive(context.Background(), out, files)
if err != nil {
return
}
return filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, fileName), nil
}
func (s *autoCodePlugin) InitMenu(menuInfo request.InitMenu) (err error) {
menuPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", menuInfo.PlugName, "initialize", "menu.go")
src, err := os.ReadFile(menuPath)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
arrayAst := ast.FindArray(astFile, "model", "SysBaseMenu")
var menus []system.SysBaseMenu
parentMenu := []system.SysBaseMenu{
{
ParentId: 0,
Path: menuInfo.PlugName + "Menu",
Name: menuInfo.PlugName + "Menu",
Hidden: false,
Component: "view/routerHolder.vue",
Sort: 0,
Meta: system.Meta{
Title: menuInfo.ParentMenu,
Icon: "school",
},
},
}
// 查询菜单及其关联的参数和按钮
err = global.GVA_DB.Preload("Parameters").Preload("MenuBtn").Find(&menus, "id in (?)", menuInfo.Menus).Error
if err != nil {
return err
}
menus = append(parentMenu, menus...)
menuExpr := ast.CreateMenuStructAst(menus)
arrayAst.Elts = *menuExpr
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.WriteFile(menuPath, bf.Bytes(), 0666)
return nil
}
func (s *autoCodePlugin) InitAPI(apiInfo request.InitApi) (err error) {
apiPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", apiInfo.PlugName, "initialize", "api.go")
src, err := os.ReadFile(apiPath)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
arrayAst := ast.FindArray(astFile, "model", "SysApi")
var apis []system.SysApi
err = global.GVA_DB.Find(&apis, "id in (?)", apiInfo.APIs).Error
if err != nil {
return err
}
apisExpr := ast.CreateApiStructAst(apis)
arrayAst.Elts = *apisExpr
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.WriteFile(apiPath, bf.Bytes(), 0666)
return nil
}
func (s *autoCodePlugin) InitDictionary(dictInfo request.InitDictionary) (err error) {
dictPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", dictInfo.PlugName, "initialize", "dictionary.go")
src, err := os.ReadFile(dictPath)
if err != nil {
fmt.Println(err)
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
arrayAst := ast.FindArray(astFile, "model", "SysDictionary")
var dictionaries []system.SysDictionary
err = global.GVA_DB.Preload("SysDictionaryDetails").Find(&dictionaries, "id in (?)", dictInfo.Dictionaries).Error
if err != nil {
return err
}
dictExpr := ast.CreateDictionaryStructAst(dictionaries)
arrayAst.Elts = *dictExpr
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
os.WriteFile(dictPath, bf.Bytes(), 0666)
return nil
}
func (s *autoCodePlugin) Remove(pluginName string, pluginType string) (err error) {
// 1. 删除前端代码
if pluginType == "web" || pluginType == "full" {
webDir := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Web, "plugin", pluginName)
err = os.RemoveAll(webDir)
if err != nil {
return errors.Wrap(err, "删除前端插件目录失败")
}
}
// 2. 删除后端代码
if pluginType == "server" || pluginType == "full" {
serverDir := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", pluginName)
err = os.RemoveAll(serverDir)
if err != nil {
return errors.Wrap(err, "删除后端插件目录失败")
}
// 移除注册
removePluginRegisterImport(pluginName)
}
// 通过utils 获取 api 菜单 字典
apis, menus, dicts := pluginUtils.GetPluginData(pluginName)
// 3. 删除菜单 (递归删除)
if len(menus) > 0 {
for _, menu := range menus {
var dbMenu system.SysBaseMenu
if err := global.GVA_DB.Where("name = ?", menu.Name).First(&dbMenu).Error; err == nil {
// 获取该菜单及其所有子菜单的ID
var menuIds []int
GetMenuIds(dbMenu, &menuIds)
// 逆序删除,先删除子菜单
for i := len(menuIds) - 1; i >= 0; i-- {
err := BaseMenuServiceApp.DeleteBaseMenu(menuIds[i])
if err != nil {
zap.L().Error("删除菜单失败", zap.Int("id", menuIds[i]), zap.Error(err))
}
}
}
}
}
// 4. 删除API
if len(apis) > 0 {
for _, api := range apis {
var dbApi system.SysApi
if err := global.GVA_DB.Where("path = ? AND method = ?", api.Path, api.Method).First(&dbApi).Error; err == nil {
err := ApiServiceApp.DeleteApi(dbApi)
if err != nil {
zap.L().Error("删除API失败", zap.String("path", api.Path), zap.Error(err))
}
}
}
}
// 5. 删除字典
if len(dicts) > 0 {
for _, dict := range dicts {
var dbDict system.SysDictionary
if err := global.GVA_DB.Where("type = ?", dict.Type).First(&dbDict).Error; err == nil {
err := DictionaryServiceApp.DeleteSysDictionary(dbDict)
if err != nil {
zap.L().Error("删除字典失败", zap.String("type", dict.Type), zap.Error(err))
}
}
}
}
return nil
}
func GetMenuIds(menu system.SysBaseMenu, ids *[]int) {
*ids = append(*ids, int(menu.ID))
var children []system.SysBaseMenu
global.GVA_DB.Where("parent_id = ?", menu.ID).Find(&children)
for _, child := range children {
// 先递归收集子菜单
GetMenuIds(child, ids)
}
}
func removePluginRegisterImport(packageName string) error {
module := strings.TrimSpace(global.GVA_CONFIG.AutoCode.Module)
if module == "" {
return errors.New("autocode module is empty")
}
if packageName == "" {
return errors.New("plugin package is empty")
}
registerPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", "register.go")
src, err := os.ReadFile(registerPath)
if err != nil {
return err
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, registerPath, src, parser.ParseComments)
if err != nil {
return err
}
importPath := fmt.Sprintf("%s/plugin/%s", module, packageName)
importLit := fmt.Sprintf("%q", importPath)
// 移除 import
var newDecls []goast.Decl
for _, decl := range astFile.Decls {
genDecl, ok := decl.(*goast.GenDecl)
if !ok {
newDecls = append(newDecls, decl)
continue
}
if genDecl.Tok == token.IMPORT {
var newSpecs []goast.Spec
for _, spec := range genDecl.Specs {
importSpec, ok := spec.(*goast.ImportSpec)
if !ok {
newSpecs = append(newSpecs, spec)
continue
}
if importSpec.Path.Value != importLit {
newSpecs = append(newSpecs, spec)
}
}
// 如果还有其他import保留该 decl
if len(newSpecs) > 0 {
genDecl.Specs = newSpecs
newDecls = append(newDecls, genDecl)
}
} else {
newDecls = append(newDecls, decl)
}
}
astFile.Decls = newDecls
var out []byte
bf := bytes.NewBuffer(out)
printer.Fprint(bf, fileSet, astFile)
return os.WriteFile(registerPath, bf.Bytes(), 0666)
}

View File

@@ -0,0 +1,453 @@
package system
import (
"context"
"encoding/json"
"fmt"
"git.echol.cn/loser/ai_proxy/server/utils/autocode"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"
"text/template"
"git.echol.cn/loser/ai_proxy/server/global"
model "git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
utilsAst "git.echol.cn/loser/ai_proxy/server/utils/ast"
"github.com/pkg/errors"
"gorm.io/gorm"
)
var AutoCodeTemplate = new(autoCodeTemplate)
type autoCodeTemplate struct{}
func (s *autoCodeTemplate) checkPackage(Pkg string, template string) (err error) {
switch template {
case "package":
apiEnter := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", Pkg, "enter.go")
_, err = os.Stat(apiEnter)
if err != nil {
return fmt.Errorf("package结构异常,缺少api/v1/%s/enter.go", Pkg)
}
serviceEnter := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", Pkg, "enter.go")
_, err = os.Stat(serviceEnter)
if err != nil {
return fmt.Errorf("package结构异常,缺少service/%s/enter.go", Pkg)
}
routerEnter := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", Pkg, "enter.go")
_, err = os.Stat(routerEnter)
if err != nil {
return fmt.Errorf("package结构异常,缺少router/%s/enter.go", Pkg)
}
case "plugin":
pluginEnter := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", Pkg, "plugin.go")
_, err = os.Stat(pluginEnter)
if err != nil {
return fmt.Errorf("plugin结构异常,缺少plugin/%s/plugin.go", Pkg)
}
}
return nil
}
// Create 创建生成自动化代码
func (s *autoCodeTemplate) Create(ctx context.Context, info request.AutoCode) error {
history := info.History()
var autoPkg model.SysAutoCodePackage
err := global.GVA_DB.WithContext(ctx).Where("package_name = ?", info.Package).First(&autoPkg).Error
if err != nil {
return errors.Wrap(err, "查询包失败!")
}
err = s.checkPackage(info.Package, autoPkg.Template)
if err != nil {
return err
}
// 增加判断: 重复创建struct 或者重复的简称
if AutocodeHistory.Repeat(info.BusinessDB, info.StructName, info.Abbreviation, info.Package) {
return errors.New("已经创建过此数据结构,请勿重复创建!")
}
generate, templates, injections, err := s.generate(ctx, info, autoPkg)
if err != nil {
return err
}
for key, builder := range generate {
err = os.MkdirAll(filepath.Dir(key), os.ModePerm)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]创建文件夹失败!", key)
}
err = os.WriteFile(key, []byte(builder.String()), 0666)
if err != nil {
return errors.Wrapf(err, "[filepath:%s]写入文件失败!", key)
}
}
// 自动创建api
if info.AutoCreateApiToSql && !info.OnlyTemplate {
apis := info.Apis()
err := global.GVA_DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, v := range apis {
var api model.SysApi
var id uint
err := tx.Where("path = ? AND method = ?", v.Path, v.Method).First(&api).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
if err = tx.Create(&v).Error; err != nil { // 遇到错误时回滚事务
return err
}
id = v.ID
} else {
id = api.ID
}
history.ApiIDs = append(history.ApiIDs, id)
}
return nil
})
if err != nil {
return err
}
}
// 自动创建menu
if info.AutoCreateMenuToSql {
var entity model.SysBaseMenu
var id uint
err := global.GVA_DB.WithContext(ctx).First(&entity, "name = ?", info.Abbreviation).Error
if err == nil {
id = entity.ID
} else {
entity = info.Menu(autoPkg.Template)
if info.AutoCreateBtnAuth && !info.OnlyTemplate {
entity.MenuBtn = []model.SysBaseMenuBtn{
{SysBaseMenuID: entity.ID, Name: "add", Desc: "新增"},
{SysBaseMenuID: entity.ID, Name: "batchDelete", Desc: "批量删除"},
{SysBaseMenuID: entity.ID, Name: "delete", Desc: "删除"},
{SysBaseMenuID: entity.ID, Name: "edit", Desc: "编辑"},
{SysBaseMenuID: entity.ID, Name: "info", Desc: "详情"},
}
if info.HasExcel {
excelBtn := []model.SysBaseMenuBtn{
{SysBaseMenuID: entity.ID, Name: "exportTemplate", Desc: "导出模板"},
{SysBaseMenuID: entity.ID, Name: "exportExcel", Desc: "导出Excel"},
{SysBaseMenuID: entity.ID, Name: "importExcel", Desc: "导入Excel"},
}
entity.MenuBtn = append(entity.MenuBtn, excelBtn...)
}
}
err = global.GVA_DB.WithContext(ctx).Create(&entity).Error
id = entity.ID
if err != nil {
return errors.Wrap(err, "创建菜单失败!")
}
}
history.MenuID = id
}
if info.HasExcel {
dbName := info.BusinessDB
name := info.Package + "_" + info.StructName
tableName := info.TableName
fieldsMap := make(map[string]string, len(info.Fields))
for _, field := range info.Fields {
if field.Excel {
fieldsMap[field.ColumnName] = field.FieldDesc
}
}
templateInfo, _ := json.Marshal(fieldsMap)
sysExportTemplate := model.SysExportTemplate{
DBName: dbName,
Name: name,
TableName: tableName,
TemplateID: name,
TemplateInfo: string(templateInfo),
}
err = SysExportTemplateServiceApp.CreateSysExportTemplate(&sysExportTemplate)
if err != nil {
return err
}
history.ExportTemplateID = sysExportTemplate.ID
}
// 创建历史记录
history.Templates = templates
history.Injections = make(map[string]string, len(injections))
for key, value := range injections {
bytes, _ := json.Marshal(value)
history.Injections[key] = string(bytes)
}
err = AutocodeHistory.Create(ctx, history)
if err != nil {
return err
}
return nil
}
// Preview 预览自动化代码
func (s *autoCodeTemplate) Preview(ctx context.Context, info request.AutoCode) (map[string]string, error) {
var entity model.SysAutoCodePackage
err := global.GVA_DB.WithContext(ctx).Where("package_name = ?", info.Package).First(&entity).Error
if err != nil {
return nil, errors.Wrap(err, "查询包失败!")
}
// 增加判断: 重复创建struct 或者重复的简称
if AutocodeHistory.Repeat(info.BusinessDB, info.StructName, info.Abbreviation, info.Package) && !info.IsAdd {
return nil, errors.New("已经创建过此数据结构或重复简称,请勿重复创建!")
}
preview := make(map[string]string)
codes, _, _, err := s.generate(ctx, info, entity)
if err != nil {
return nil, err
}
for key, writer := range codes {
if len(key) > len(global.GVA_CONFIG.AutoCode.Root) {
key, _ = filepath.Rel(global.GVA_CONFIG.AutoCode.Root, key)
}
// 获取key的后缀 取消.
suffix := filepath.Ext(key)[1:]
var builder strings.Builder
builder.WriteString("```" + suffix + "\n\n")
builder.WriteString(writer.String())
builder.WriteString("\n\n```")
preview[key] = builder.String()
}
return preview, nil
}
func (s *autoCodeTemplate) generate(ctx context.Context, info request.AutoCode, entity model.SysAutoCodePackage) (map[string]strings.Builder, map[string]string, map[string]utilsAst.Ast, error) {
templates, asts, _, err := AutoCodePackage.templates(ctx, entity, info, false)
if err != nil {
return nil, nil, nil, err
}
code := make(map[string]strings.Builder)
for key, create := range templates {
var files *template.Template
files, err = template.New(filepath.Base(key)).Funcs(autocode.GetTemplateFuncMap()).ParseFiles(key)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "[filpath:%s]读取模版文件失败!", key)
}
var builder strings.Builder
err = files.Execute(&builder, info)
if err != nil {
return nil, nil, nil, errors.Wrapf(err, "[filpath:%s]生成文件失败!", create)
}
code[create] = builder
} // 生成文件
injections := make(map[string]utilsAst.Ast, len(asts))
for key, value := range asts {
keys := strings.Split(key, "=>")
if len(keys) == 2 {
if keys[1] == utilsAst.TypePluginInitializeV2 {
continue
}
if info.OnlyTemplate {
if keys[1] == utilsAst.TypePackageInitializeGorm || keys[1] == utilsAst.TypePluginInitializeGorm {
continue
}
}
if !info.AutoMigrate {
if keys[1] == utilsAst.TypePackageInitializeGorm || keys[1] == utilsAst.TypePluginInitializeGorm {
continue
}
}
var builder strings.Builder
parse, _ := value.Parse("", &builder)
if parse != nil {
_ = value.Injection(parse)
err = value.Format("", &builder, parse)
if err != nil {
return nil, nil, nil, err
}
code[keys[0]] = builder
injections[keys[1]] = value
fmt.Println(keys[0], "注入成功!")
}
}
}
// 注入代码
return code, templates, injections, nil
}
func (s *autoCodeTemplate) AddFunc(info request.AutoFunc) error {
autoPkg := model.SysAutoCodePackage{}
err := global.GVA_DB.First(&autoPkg, "package_name = ?", info.Package).Error
if err != nil {
return err
}
if autoPkg.Template != "package" {
info.IsPlugin = true
}
err = s.addTemplateToFile("api.go", info)
if err != nil {
return err
}
err = s.addTemplateToFile("server.go", info)
if err != nil {
return err
}
err = s.addTemplateToFile("api.js", info)
if err != nil {
return err
}
return s.addTemplateToAst("router", info)
}
func (s *autoCodeTemplate) GetApiAndServer(info request.AutoFunc) (map[string]string, error) {
autoPkg := model.SysAutoCodePackage{}
err := global.GVA_DB.First(&autoPkg, "package_name = ?", info.Package).Error
if err != nil {
return nil, err
}
if autoPkg.Template != "package" {
info.IsPlugin = true
}
apiStr, err := s.getTemplateStr("api.go", info)
if err != nil {
return nil, err
}
serverStr, err := s.getTemplateStr("server.go", info)
if err != nil {
return nil, err
}
jsStr, err := s.getTemplateStr("api.js", info)
if err != nil {
return nil, err
}
return map[string]string{"api": apiStr, "server": serverStr, "js": jsStr}, nil
}
func (s *autoCodeTemplate) getTemplateStr(t string, info request.AutoFunc) (string, error) {
tempPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "resource", "function", t+".tpl")
files, err := template.New(filepath.Base(tempPath)).Funcs(autocode.GetTemplateFuncMap()).ParseFiles(tempPath)
if err != nil {
return "", errors.Wrapf(err, "[filepath:%s]读取模版文件失败!", tempPath)
}
var builder strings.Builder
err = files.Execute(&builder, info)
if err != nil {
fmt.Println(err.Error())
return "", errors.Wrapf(err, "[filpath:%s]生成文件失败!", tempPath)
}
return builder.String(), nil
}
func (s *autoCodeTemplate) addTemplateToAst(t string, info request.AutoFunc) error {
tPath := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "router", info.Package, info.HumpPackageName+".go")
funcName := fmt.Sprintf("Init%sRouter", info.StructName)
routerStr := "RouterWithoutAuth"
if info.IsAuth {
routerStr = "Router"
}
stmtStr := fmt.Sprintf("%s%s.%s(\"%s\", %sApi.%s)", info.Abbreviation, routerStr, info.Method, info.Router, info.Abbreviation, info.FuncName)
if info.IsPlugin {
tPath = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", info.Package, "router", info.HumpPackageName+".go")
stmtStr = fmt.Sprintf("group.%s(\"%s\", api%s.%s)", info.Method, info.Router, info.StructName, info.FuncName)
funcName = "Init"
}
src, err := os.ReadFile(tPath)
if err != nil {
return err
}
fileSet := token.NewFileSet()
astFile, err := parser.ParseFile(fileSet, "", src, 0)
if err != nil {
return err
}
funcDecl := utilsAst.FindFunction(astFile, funcName)
stmtNode := utilsAst.CreateStmt(stmtStr)
if info.IsAuth {
for i := 0; i < len(funcDecl.Body.List); i++ {
st := funcDecl.Body.List[i]
// 使用类型断言来检查stmt是否是一个块语句
if blockStmt, ok := st.(*ast.BlockStmt); ok {
// 如果是,插入代码 跳出
blockStmt.List = append(blockStmt.List, stmtNode)
break
}
}
} else {
for i := len(funcDecl.Body.List) - 1; i >= 0; i-- {
st := funcDecl.Body.List[i]
// 使用类型断言来检查stmt是否是一个块语句
if blockStmt, ok := st.(*ast.BlockStmt); ok {
// 如果是,插入代码 跳出
blockStmt.List = append(blockStmt.List, stmtNode)
break
}
}
}
// 创建一个新的文件
f, err := os.Create(tPath)
if err != nil {
return err
}
defer f.Close()
if err := format.Node(f, fileSet, astFile); err != nil {
return err
}
return err
}
func (s *autoCodeTemplate) addTemplateToFile(t string, info request.AutoFunc) error {
getTemplateStr, err := s.getTemplateStr(t, info)
if err != nil {
return err
}
var target string
switch t {
case "api.go":
if info.IsAi && info.ApiFunc != "" {
getTemplateStr = info.ApiFunc
}
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "api", "v1", info.Package, info.HumpPackageName+".go")
case "server.go":
if info.IsAi && info.ServerFunc != "" {
getTemplateStr = info.ServerFunc
}
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "service", info.Package, info.HumpPackageName+".go")
case "api.js":
if info.IsAi && info.JsFunc != "" {
getTemplateStr = info.JsFunc
}
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Web, "api", info.Package, info.PackageName+".js")
}
if info.IsPlugin {
switch t {
case "api.go":
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", info.Package, "api", info.HumpPackageName+".go")
case "server.go":
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "plugin", info.Package, "service", info.HumpPackageName+".go")
case "api.js":
target = filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Web, "plugin", info.Package, "api", info.PackageName+".js")
}
}
// 打开文件,如果不存在则返回错误
file, err := os.OpenFile(target, os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
return err
}
defer file.Close()
// 写入内容
_, err = fmt.Fprintln(file, getTemplateStr)
if err != nil {
fmt.Printf("写入文件失败: %s\n", err.Error())
return err
}
return nil
}

View File

@@ -0,0 +1,84 @@
package system
import (
"context"
"encoding/json"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"reflect"
"testing"
)
func Test_autoCodeTemplate_Create(t *testing.T) {
type args struct {
ctx context.Context
info request.AutoCode
}
tests := []struct {
name string
args args
wantErr bool
}{
// TODO: Add test cases.
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &autoCodeTemplate{}
if err := s.Create(tt.args.ctx, tt.args.info); (err != nil) != tt.wantErr {
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func Test_autoCodeTemplate_Preview(t *testing.T) {
type args struct {
ctx context.Context
info request.AutoCode
}
tests := []struct {
name string
args args
want map[string]string
wantErr bool
}{
{
name: "测试 package",
args: args{
ctx: context.Background(),
info: request.AutoCode{},
},
wantErr: false,
},
{
name: "测试 plugin",
args: args{
ctx: context.Background(),
info: request.AutoCode{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testJson := `{"structName":"SysUser","tableName":"sys_users","packageName":"sysUsers","package":"gva","abbreviation":"sysUsers","description":"sysUsers表","businessDB":"","autoCreateApiToSql":true,"autoCreateMenuToSql":true,"autoMigrate":true,"gvaModel":true,"autoCreateResource":false,"fields":[{"fieldName":"Uuid","fieldDesc":"用户UUID","fieldType":"string","dataType":"varchar","fieldJson":"uuid","primaryKey":false,"dataTypeLong":"191","columnName":"uuid","comment":"用户UUID","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"Username","fieldDesc":"用户登录名","fieldType":"string","dataType":"varchar","fieldJson":"username","primaryKey":false,"dataTypeLong":"191","columnName":"username","comment":"用户登录名","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"Password","fieldDesc":"用户登录密码","fieldType":"string","dataType":"varchar","fieldJson":"password","primaryKey":false,"dataTypeLong":"191","columnName":"password","comment":"用户登录密码","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"NickName","fieldDesc":"用户昵称","fieldType":"string","dataType":"varchar","fieldJson":"nickName","primaryKey":false,"dataTypeLong":"191","columnName":"nick_name","comment":"用户昵称","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"SideMode","fieldDesc":"用户侧边主题","fieldType":"string","dataType":"varchar","fieldJson":"sideMode","primaryKey":false,"dataTypeLong":"191","columnName":"side_mode","comment":"用户侧边主题","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"HeaderImg","fieldDesc":"用户头像","fieldType":"string","dataType":"varchar","fieldJson":"headerImg","primaryKey":false,"dataTypeLong":"191","columnName":"header_img","comment":"用户头像","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"BaseColor","fieldDesc":"基础颜色","fieldType":"string","dataType":"varchar","fieldJson":"baseColor","primaryKey":false,"dataTypeLong":"191","columnName":"base_color","comment":"基础颜色","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"AuthorityId","fieldDesc":"用户角色ID","fieldType":"int","dataType":"bigint","fieldJson":"authorityId","primaryKey":false,"dataTypeLong":"20","columnName":"authority_id","comment":"用户角色ID","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"Phone","fieldDesc":"用户手机号","fieldType":"string","dataType":"varchar","fieldJson":"phone","primaryKey":false,"dataTypeLong":"191","columnName":"phone","comment":"用户手机号","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"Email","fieldDesc":"用户邮箱","fieldType":"string","dataType":"varchar","fieldJson":"email","primaryKey":false,"dataTypeLong":"191","columnName":"email","comment":"用户邮箱","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}},{"fieldName":"Enable","fieldDesc":"用户是否被冻结 1正常 2冻结","fieldType":"int","dataType":"bigint","fieldJson":"enable","primaryKey":false,"dataTypeLong":"19","columnName":"enable","comment":"用户是否被冻结 1正常 2冻结","require":false,"errorText":"","clearable":true,"fieldSearchType":"","fieldIndexType":"","dictType":"","front":true,"dataSource":{"association":1,"table":"","label":"","value":""}}],"humpPackageName":"sys_users"}`
err := json.Unmarshal([]byte(testJson), &tt.args.info)
if err != nil {
t.Error(err)
return
}
err = tt.args.info.Pretreatment()
if err != nil {
t.Error(err)
return
}
got, err := AutoCodeTemplate.Preview(tt.args.ctx, tt.args.info)
if (err != nil) != tt.wantErr {
t.Errorf("Preview() error = %+v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Preview() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -1,6 +1,29 @@
package system
type ServiceGroup struct {
UserService UserService
ApiService ApiService
JwtService
ApiService
MenuService
UserService
CasbinService
InitDBService
AutoCodeService
BaseMenuService
AuthorityService
DictionaryService
SystemConfigService
OperationRecordService
DictionaryDetailService
AuthorityBtnService
SysExportTemplateService
SysParamsService
SysVersionService
SkillsService
AutoCodePlugin autoCodePlugin
AutoCodePackage autoCodePackage
AutoCodeHistory autoCodeHistory
AutoCodeTemplate autoCodeTemplate
SysErrorService
LoginLogService
ApiTokenService
}

View File

@@ -0,0 +1,52 @@
package system
import (
"context"
"go.uber.org/zap"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
)
type JwtService struct{}
var JwtServiceApp = new(JwtService)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: JsonInBlacklist
//@description: 拉黑jwt
//@param: jwtList model.JwtBlacklist
//@return: err error
func (jwtService *JwtService) JsonInBlacklist(jwtList system.JwtBlacklist) (err error) {
err = global.GVA_DB.Create(&jwtList).Error
if err != nil {
return
}
global.BlackCache.SetDefault(jwtList.Jwt, struct{}{})
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetRedisJWT
//@description: 从redis取jwt
//@param: userName string
//@return: redisJWT string, err error
func (jwtService *JwtService) GetRedisJWT(userName string) (redisJWT string, err error) {
redisJWT, err = global.GVA_REDIS.Get(context.Background(), userName).Result()
return redisJWT, err
}
func LoadAll() {
var data []string
err := global.GVA_DB.Model(&system.JwtBlacklist{}).Select("jwt").Find(&data).Error
if err != nil {
global.GVA_LOG.Error("加载数据库jwt黑名单失败!", zap.Error(err))
return
}
for i := 0; i < len(data); i++ {
global.BlackCache.SetDefault(data[i], struct{}{})
} // jwt黑名单 加入 BlackCache 中
}

View File

@@ -2,154 +2,325 @@ package system
import (
"errors"
"fmt"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
systemRes "git.echol.cn/loser/ai_proxy/server/model/system/response"
"gorm.io/gorm"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateApi
//@description: 新增基础api
//@param: api model.SysApi
//@return: err error
type ApiService struct{}
// CreateApi 创建API
func (s *ApiService) CreateApi(req *request.CreateApiRequest) error {
// 检查是否已存在相同的 API
var count int64
err := global.GVA_DB.Model(&system.SysApi{}).
Where("path = ? AND method = ?", req.Path, req.Method).
Count(&count).Error
if err != nil {
return err
}
if count > 0 {
return errors.New("API已存在")
}
var ApiServiceApp = new(ApiService)
api := system.SysApi{
Path: req.Path,
Description: req.Description,
ApiGroup: req.ApiGroup,
Method: req.Method,
func (apiService *ApiService) CreateApi(api system.SysApi) (err error) {
if !errors.Is(global.GVA_DB.Where("path = ? AND method = ?", api.Path, api.Method).First(&system.SysApi{}).Error, gorm.ErrRecordNotFound) {
return errors.New("存在相同api")
}
return global.GVA_DB.Create(&api).Error
}
// UpdateApi 更新API
func (s *ApiService) UpdateApi(req *request.UpdateApiRequest) error {
var api system.SysApi
if err := global.GVA_DB.First(&api, req.ID).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("API不存在")
}
return err
}
// 检查是否有其他相同的 API
var count int64
err := global.GVA_DB.Model(&system.SysApi{}).
Where("path = ? AND method = ? AND id != ?", req.Path, req.Method, req.ID).
Count(&count).Error
if err != nil {
return err
}
if count > 0 {
return errors.New("API已存在")
}
updates := map[string]interface{}{
"path": req.Path,
"description": req.Description,
"api_group": req.ApiGroup,
"method": req.Method,
}
return global.GVA_DB.Model(&api).Updates(updates).Error
}
// DeleteApi 删除API
func (s *ApiService) DeleteApi(id uint) error {
var api system.SysApi
if err := global.GVA_DB.First(&api, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("API不存在")
}
return err
}
return global.GVA_DB.Delete(&api).Error
}
// GetApiList 获取API列表
func (s *ApiService) GetApiList(req *request.GetApiListRequest) (list []response.ApiInfo, total int64, err error) {
db := global.GVA_DB.Model(&system.SysApi{})
// 条件查询
if req.Path != "" {
db = db.Where("path LIKE ?", "%"+req.Path+"%")
}
if req.ApiGroup != "" {
db = db.Where("api_group = ?", req.ApiGroup)
}
if req.Method != "" {
db = db.Where("method = ?", req.Method)
}
// 获取总数
err = db.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 分页查询
if req.Page > 0 && req.PageSize > 0 {
offset := (req.Page - 1) * req.PageSize
db = db.Offset(offset).Limit(req.PageSize)
}
func (apiService *ApiService) GetApiGroups() (groups []string, groupApiMap map[string]string, err error) {
var apis []system.SysApi
err = db.Order("created_at DESC").Find(&apis).Error
err = global.GVA_DB.Find(&apis).Error
if err != nil {
return nil, 0, err
return
}
groupApiMap = make(map[string]string, 0)
for i := range apis {
pathArr := strings.Split(apis[i].Path, "/")
newGroup := true
for i2 := range groups {
if groups[i2] == apis[i].ApiGroup {
newGroup = false
}
}
if newGroup {
groups = append(groups, apis[i].ApiGroup)
}
groupApiMap[pathArr[1]] = apis[i].ApiGroup
}
return
}
func (apiService *ApiService) SyncApi() (newApis, deleteApis, ignoreApis []system.SysApi, err error) {
newApis = make([]system.SysApi, 0)
deleteApis = make([]system.SysApi, 0)
ignoreApis = make([]system.SysApi, 0)
var apis []system.SysApi
err = global.GVA_DB.Find(&apis).Error
if err != nil {
return
}
var ignores []system.SysIgnoreApi
err = global.GVA_DB.Find(&ignores).Error
if err != nil {
return
}
// 转换为响应格式
list = make([]response.ApiInfo, len(apis))
for i, api := range apis {
list[i] = response.ApiInfo{
ID: api.ID,
Path: api.Path,
Description: api.Description,
ApiGroup: api.ApiGroup,
Method: api.Method,
CreatedAt: api.CreatedAt,
UpdatedAt: api.UpdatedAt,
for i := range ignores {
ignoreApis = append(ignoreApis, system.SysApi{
Path: ignores[i].Path,
Description: "",
ApiGroup: "",
Method: ignores[i].Method,
})
}
var cacheApis []system.SysApi
for i := range global.GVA_ROUTERS {
ignoresFlag := false
for j := range ignores {
if ignores[j].Path == global.GVA_ROUTERS[i].Path && ignores[j].Method == global.GVA_ROUTERS[i].Method {
ignoresFlag = true
}
}
if !ignoresFlag {
cacheApis = append(cacheApis, system.SysApi{
Path: global.GVA_ROUTERS[i].Path,
Method: global.GVA_ROUTERS[i].Method,
})
}
}
return list, total, nil
}
// GetApiById 根据ID获取API
func (s *ApiService) GetApiById(id uint) (info response.ApiInfo, err error) {
var api system.SysApi
if err = global.GVA_DB.First(&api, id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return info, errors.New("API不存在")
//对比数据库中的api和内存中的api如果数据库中的api不存在于内存中则把api放入删除数组如果内存中的api不存在于数据库中则把api放入新增数组
for i := range cacheApis {
var flag bool
// 如果存在于内存不存在于api数组中
for j := range apis {
if cacheApis[i].Path == apis[j].Path && cacheApis[i].Method == apis[j].Method {
flag = true
}
}
if !flag {
newApis = append(newApis, system.SysApi{
Path: cacheApis[i].Path,
Description: "",
ApiGroup: "",
Method: cacheApis[i].Method,
})
}
return info, err
}
info = response.ApiInfo{
ID: api.ID,
Path: api.Path,
Description: api.Description,
ApiGroup: api.ApiGroup,
Method: api.Method,
CreatedAt: api.CreatedAt,
UpdatedAt: api.UpdatedAt,
for i := range apis {
var flag bool
// 如果存在于api数组不存在于内存
for j := range cacheApis {
if cacheApis[j].Path == apis[i].Path && cacheApis[j].Method == apis[i].Method {
flag = true
}
}
if !flag {
deleteApis = append(deleteApis, apis[i])
}
}
return info, nil
return
}
func (apiService *ApiService) IgnoreApi(ignoreApi system.SysIgnoreApi) (err error) {
if ignoreApi.Flag {
return global.GVA_DB.Create(&ignoreApi).Error
}
return global.GVA_DB.Unscoped().Delete(&ignoreApi, "path = ? AND method = ?", ignoreApi.Path, ignoreApi.Method).Error
}
func (apiService *ApiService) EnterSyncApi(syncApis systemRes.SysSyncApis) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
var txErr error
if len(syncApis.NewApis) > 0 {
txErr = tx.Create(&syncApis.NewApis).Error
if txErr != nil {
return txErr
}
}
for i := range syncApis.DeleteApis {
CasbinServiceApp.ClearCasbin(1, syncApis.DeleteApis[i].Path, syncApis.DeleteApis[i].Method)
txErr = tx.Delete(&system.SysApi{}, "path = ? AND method = ?", syncApis.DeleteApis[i].Path, syncApis.DeleteApis[i].Method).Error
if txErr != nil {
return txErr
}
}
return nil
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteApi
//@description: 删除基础api
//@param: api model.SysApi
//@return: err error
func (apiService *ApiService) DeleteApi(api system.SysApi) (err error) {
var entity system.SysApi
err = global.GVA_DB.First(&entity, "id = ?", api.ID).Error // 根据id查询api记录
if errors.Is(err, gorm.ErrRecordNotFound) { // api记录不存在
return err
}
err = global.GVA_DB.Delete(&entity).Error
if err != nil {
return err
}
CasbinServiceApp.ClearCasbin(1, entity.Path, entity.Method)
return nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetAPIInfoList
//@description: 分页获取数据,
//@param: api model.SysApi, info request.PageInfo, order string, desc bool
//@return: list interface{}, total int64, err error
func (apiService *ApiService) GetAPIInfoList(api system.SysApi, info request.PageInfo, order string, desc bool) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&system.SysApi{})
var apiList []system.SysApi
if api.Path != "" {
db = db.Where("path LIKE ?", "%"+api.Path+"%")
}
if api.Description != "" {
db = db.Where("description LIKE ?", "%"+api.Description+"%")
}
if api.Method != "" {
db = db.Where("method = ?", api.Method)
}
if api.ApiGroup != "" {
db = db.Where("api_group = ?", api.ApiGroup)
}
err = db.Count(&total).Error
if err != nil {
return apiList, total, err
}
db = db.Limit(limit).Offset(offset)
OrderStr := "id desc"
if order != "" {
orderMap := make(map[string]bool, 5)
orderMap["id"] = true
orderMap["path"] = true
orderMap["api_group"] = true
orderMap["description"] = true
orderMap["method"] = true
if !orderMap[order] {
err = fmt.Errorf("非法的排序字段: %v", order)
return apiList, total, err
}
OrderStr = order
if desc {
OrderStr = order + " desc"
}
}
err = db.Order(OrderStr).Find(&apiList).Error
return apiList, total, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetAllApis
//@description: 获取所有的api
//@return: apis []model.SysApi, err error
func (apiService *ApiService) GetAllApis(authorityID uint) (apis []system.SysApi, err error) {
parentAuthorityID, err := AuthorityServiceApp.GetParentAuthorityID(authorityID)
if err != nil {
return nil, err
}
err = global.GVA_DB.Order("id desc").Find(&apis).Error
if parentAuthorityID == 0 || !global.GVA_CONFIG.System.UseStrictAuth {
return
}
paths := CasbinServiceApp.GetPolicyPathByAuthorityId(authorityID)
// 挑选 apis里面的path和method也在paths里面的api
var authApis []system.SysApi
for i := range apis {
for j := range paths {
if paths[j].Path == apis[i].Path && paths[j].Method == apis[i].Method {
authApis = append(authApis, apis[i])
}
}
}
return authApis, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetApiById
//@description: 根据id获取api
//@param: id float64
//@return: api model.SysApi, err error
func (apiService *ApiService) GetApiById(id int) (api system.SysApi, err error) {
err = global.GVA_DB.First(&api, "id = ?", id).Error
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateApi
//@description: 根据id更新api
//@param: api model.SysApi
//@return: err error
func (apiService *ApiService) UpdateApi(api system.SysApi) (err error) {
var oldA system.SysApi
err = global.GVA_DB.First(&oldA, "id = ?", api.ID).Error
if oldA.Path != api.Path || oldA.Method != api.Method {
var duplicateApi system.SysApi
if ferr := global.GVA_DB.First(&duplicateApi, "path = ? AND method = ?", api.Path, api.Method).Error; ferr != nil {
if !errors.Is(ferr, gorm.ErrRecordNotFound) {
return ferr
}
} else {
if duplicateApi.ID != api.ID {
return errors.New("存在相同api路径")
}
}
}
if err != nil {
return err
}
err = CasbinServiceApp.UpdateCasbinApi(oldA.Path, api.Path, oldA.Method, api.Method)
if err != nil {
return err
}
return global.GVA_DB.Save(&api).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteApisByIds
//@description: 删除选中API
//@param: apis []model.SysApi
//@return: err error
func (apiService *ApiService) DeleteApisByIds(ids request.IdsReq) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
var apis []system.SysApi
err = tx.Find(&apis, "id in ?", ids.Ids).Error
if err != nil {
return err
}
err = tx.Delete(&[]system.SysApi{}, "id in ?", ids.Ids).Error
if err != nil {
return err
}
for _, sysApi := range apis {
CasbinServiceApp.ClearCasbin(1, sysApi.Path, sysApi.Method)
}
return err
})
}

View File

@@ -0,0 +1,106 @@
package system
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
sysReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/golang-jwt/jwt/v5"
"time"
)
type ApiTokenService struct{}
func (apiVersion *ApiTokenService) CreateApiToken(apiToken system.SysApiToken, days int) (string, error) {
var user system.SysUser
if err := global.GVA_DB.Where("id = ?", apiToken.UserID).First(&user).Error; err != nil {
return "", errors.New("用户不存在")
}
hasAuth := false
for _, auth := range user.Authorities {
if auth.AuthorityId == apiToken.AuthorityID {
hasAuth = true
break
}
}
if !hasAuth && user.AuthorityId != apiToken.AuthorityID {
return "", errors.New("用户不具备该角色权限")
}
j := &utils.JWT{SigningKey: []byte(global.GVA_CONFIG.JWT.SigningKey)} // 唯一不同的部分是过期时间
expireTime := time.Duration(days) * 24 * time.Hour
if days == -1 {
expireTime = 100 * 365 * 24 * time.Hour
}
bf, _ := utils.ParseDuration(global.GVA_CONFIG.JWT.BufferTime)
claims := sysReq.CustomClaims{
BaseClaims: sysReq.BaseClaims{
UUID: user.UUID,
ID: user.ID,
Username: user.Username,
NickName: user.NickName,
AuthorityId: apiToken.AuthorityID,
},
BufferTime: int64(bf / time.Second), // 缓冲时间
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{"GVA"},
NotBefore: jwt.NewNumericDate(time.Now().Add(-1000)),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expireTime)),
Issuer: global.GVA_CONFIG.JWT.Issuer,
},
}
token, err := j.CreateToken(claims)
if err != nil {
return "", err
}
apiToken.Token = token
apiToken.Status = true
apiToken.ExpiresAt = time.Now().Add(expireTime)
err = global.GVA_DB.Create(&apiToken).Error
return token, err
}
func (apiVersion *ApiTokenService) GetApiTokenList(info sysReq.SysApiTokenSearch) (list []system.SysApiToken, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&system.SysApiToken{})
db = db.Preload("User")
if info.UserID != 0 {
db = db.Where("user_id = ?", info.UserID)
}
if info.Status != nil {
db = db.Where("status = ?", *info.Status)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("created_at desc").Find(&list).Error
return list, total, err
}
func (apiVersion *ApiTokenService) DeleteApiToken(id uint) error {
var apiToken system.SysApiToken
err := global.GVA_DB.First(&apiToken, id).Error
if err != nil {
return err
}
jwtService := JwtService{}
err = jwtService.JsonInBlacklist(system.JwtBlacklist{Jwt: apiToken.Token})
if err != nil {
return err
}
return global.GVA_DB.Model(&apiToken).Update("status", false).Error
}

View File

@@ -0,0 +1,333 @@
package system
import (
"errors"
"strconv"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
"gorm.io/gorm"
)
var ErrRoleExistence = errors.New("存在相同角色id")
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateAuthority
//@description: 创建一个角色
//@param: auth model.SysAuthority
//@return: authority system.SysAuthority, err error
type AuthorityService struct{}
var AuthorityServiceApp = new(AuthorityService)
func (authorityService *AuthorityService) CreateAuthority(auth system.SysAuthority) (authority system.SysAuthority, err error) {
if err = global.GVA_DB.Where("authority_id = ?", auth.AuthorityId).First(&system.SysAuthority{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) {
return auth, ErrRoleExistence
}
e := global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err = tx.Create(&auth).Error; err != nil {
return err
}
auth.SysBaseMenus = systemReq.DefaultMenu()
if err = tx.Model(&auth).Association("SysBaseMenus").Replace(&auth.SysBaseMenus); err != nil {
return err
}
casbinInfos := systemReq.DefaultCasbin()
authorityId := strconv.Itoa(int(auth.AuthorityId))
rules := [][]string{}
for _, v := range casbinInfos {
rules = append(rules, []string{authorityId, v.Path, v.Method})
}
return CasbinServiceApp.AddPolicies(tx, rules)
})
return auth, e
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CopyAuthority
//@description: 复制一个角色
//@param: copyInfo response.SysAuthorityCopyResponse
//@return: authority system.SysAuthority, err error
func (authorityService *AuthorityService) CopyAuthority(adminAuthorityID uint, copyInfo response.SysAuthorityCopyResponse) (authority system.SysAuthority, err error) {
var authorityBox system.SysAuthority
if !errors.Is(global.GVA_DB.Where("authority_id = ?", copyInfo.Authority.AuthorityId).First(&authorityBox).Error, gorm.ErrRecordNotFound) {
return authority, ErrRoleExistence
}
copyInfo.Authority.Children = []system.SysAuthority{}
menus, err := MenuServiceApp.GetMenuAuthority(&request.GetAuthorityId{AuthorityId: copyInfo.OldAuthorityId})
if err != nil {
return
}
var baseMenu []system.SysBaseMenu
for _, v := range menus {
intNum := v.MenuId
v.SysBaseMenu.ID = uint(intNum)
baseMenu = append(baseMenu, v.SysBaseMenu)
}
copyInfo.Authority.SysBaseMenus = baseMenu
err = global.GVA_DB.Create(&copyInfo.Authority).Error
if err != nil {
return
}
var btns []system.SysAuthorityBtn
err = global.GVA_DB.Find(&btns, "authority_id = ?", copyInfo.OldAuthorityId).Error
if err != nil {
return
}
if len(btns) > 0 {
for i := range btns {
btns[i].AuthorityId = copyInfo.Authority.AuthorityId
}
err = global.GVA_DB.Create(&btns).Error
if err != nil {
return
}
}
paths := CasbinServiceApp.GetPolicyPathByAuthorityId(copyInfo.OldAuthorityId)
err = CasbinServiceApp.UpdateCasbin(adminAuthorityID, copyInfo.Authority.AuthorityId, paths)
if err != nil {
_ = authorityService.DeleteAuthority(&copyInfo.Authority)
}
return copyInfo.Authority, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateAuthority
//@description: 更改一个角色
//@param: auth model.SysAuthority
//@return: authority system.SysAuthority, err error
func (authorityService *AuthorityService) UpdateAuthority(auth system.SysAuthority) (authority system.SysAuthority, err error) {
var oldAuthority system.SysAuthority
err = global.GVA_DB.Where("authority_id = ?", auth.AuthorityId).First(&oldAuthority).Error
if err != nil {
global.GVA_LOG.Debug(err.Error())
return system.SysAuthority{}, errors.New("查询角色数据失败")
}
err = global.GVA_DB.Model(&oldAuthority).Updates(&auth).Error
return auth, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteAuthority
//@description: 删除角色
//@param: auth *model.SysAuthority
//@return: err error
func (authorityService *AuthorityService) DeleteAuthority(auth *system.SysAuthority) error {
if errors.Is(global.GVA_DB.Debug().Preload("Users").First(&auth).Error, gorm.ErrRecordNotFound) {
return errors.New("该角色不存在")
}
if len(auth.Users) != 0 {
return errors.New("此角色有用户正在使用禁止删除")
}
if !errors.Is(global.GVA_DB.Where("authority_id = ?", auth.AuthorityId).First(&system.SysUser{}).Error, gorm.ErrRecordNotFound) {
return errors.New("此角色有用户正在使用禁止删除")
}
if !errors.Is(global.GVA_DB.Where("parent_id = ?", auth.AuthorityId).First(&system.SysAuthority{}).Error, gorm.ErrRecordNotFound) {
return errors.New("此角色存在子角色不允许删除")
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
var err error
if err = tx.Preload("SysBaseMenus").Preload("DataAuthorityId").Where("authority_id = ?", auth.AuthorityId).First(auth).Unscoped().Delete(auth).Error; err != nil {
return err
}
if len(auth.SysBaseMenus) > 0 {
if err = tx.Model(auth).Association("SysBaseMenus").Delete(auth.SysBaseMenus); err != nil {
return err
}
// err = db.Association("SysBaseMenus").Delete(&auth)
}
if len(auth.DataAuthorityId) > 0 {
if err = tx.Model(auth).Association("DataAuthorityId").Delete(auth.DataAuthorityId); err != nil {
return err
}
}
if err = tx.Delete(&system.SysUserAuthority{}, "sys_authority_authority_id = ?", auth.AuthorityId).Error; err != nil {
return err
}
if err = tx.Where("authority_id = ?", auth.AuthorityId).Delete(&[]system.SysAuthorityBtn{}).Error; err != nil {
return err
}
authorityId := strconv.Itoa(int(auth.AuthorityId))
if err = CasbinServiceApp.RemoveFilteredPolicy(tx, authorityId); err != nil {
return err
}
return nil
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetAuthorityInfoList
//@description: 分页获取数据
//@param: info request.PageInfo
//@return: list interface{}, total int64, err error
func (authorityService *AuthorityService) GetAuthorityInfoList(authorityID uint) (list []system.SysAuthority, err error) {
var authority system.SysAuthority
err = global.GVA_DB.Where("authority_id = ?", authorityID).First(&authority).Error
if err != nil {
return nil, err
}
var authorities []system.SysAuthority
db := global.GVA_DB.Model(&system.SysAuthority{})
if global.GVA_CONFIG.System.UseStrictAuth {
// 当开启了严格树形结构后
if *authority.ParentId == 0 {
// 只有顶级角色可以修改自己的权限和以下权限
err = db.Preload("DataAuthorityId").Where("authority_id = ?", authorityID).Find(&authorities).Error
} else {
// 非顶级角色只能修改以下权限
err = db.Debug().Preload("DataAuthorityId").Where("parent_id = ?", authorityID).Find(&authorities).Error
}
} else {
err = db.Preload("DataAuthorityId").Where("parent_id = ?", "0").Find(&authorities).Error
}
for k := range authorities {
err = authorityService.findChildrenAuthority(&authorities[k])
}
return authorities, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetAuthorityInfoList
//@description: 分页获取数据
//@param: info request.PageInfo
//@return: list interface{}, total int64, err error
func (authorityService *AuthorityService) GetStructAuthorityList(authorityID uint) (list []uint, err error) {
var auth system.SysAuthority
_ = global.GVA_DB.First(&auth, "authority_id = ?", authorityID).Error
var authorities []system.SysAuthority
err = global.GVA_DB.Preload("DataAuthorityId").Where("parent_id = ?", authorityID).Find(&authorities).Error
if len(authorities) > 0 {
for k := range authorities {
list = append(list, authorities[k].AuthorityId)
childrenList, err := authorityService.GetStructAuthorityList(authorities[k].AuthorityId)
if err == nil {
list = append(list, childrenList...)
}
}
}
if *auth.ParentId == 0 {
list = append(list, authorityID)
}
return list, err
}
func (authorityService *AuthorityService) CheckAuthorityIDAuth(authorityID, targetID uint) (err error) {
if !global.GVA_CONFIG.System.UseStrictAuth {
return nil
}
authIDS, err := authorityService.GetStructAuthorityList(authorityID)
if err != nil {
return err
}
hasAuth := false
for _, v := range authIDS {
if v == targetID {
hasAuth = true
break
}
}
if !hasAuth {
return errors.New("您提交的角色ID不合法")
}
return nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetAuthorityInfo
//@description: 获取所有角色信息
//@param: auth model.SysAuthority
//@return: sa system.SysAuthority, err error
func (authorityService *AuthorityService) GetAuthorityInfo(auth system.SysAuthority) (sa system.SysAuthority, err error) {
err = global.GVA_DB.Preload("DataAuthorityId").Where("authority_id = ?", auth.AuthorityId).First(&sa).Error
return sa, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetDataAuthority
//@description: 设置角色资源权限
//@param: auth model.SysAuthority
//@return: error
func (authorityService *AuthorityService) SetDataAuthority(adminAuthorityID uint, auth system.SysAuthority) error {
var checkIDs []uint
checkIDs = append(checkIDs, auth.AuthorityId)
for i := range auth.DataAuthorityId {
checkIDs = append(checkIDs, auth.DataAuthorityId[i].AuthorityId)
}
for i := range checkIDs {
err := authorityService.CheckAuthorityIDAuth(adminAuthorityID, checkIDs[i])
if err != nil {
return err
}
}
var s system.SysAuthority
global.GVA_DB.Preload("DataAuthorityId").First(&s, "authority_id = ?", auth.AuthorityId)
err := global.GVA_DB.Model(&s).Association("DataAuthorityId").Replace(&auth.DataAuthorityId)
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetMenuAuthority
//@description: 菜单与角色绑定
//@param: auth *model.SysAuthority
//@return: error
func (authorityService *AuthorityService) SetMenuAuthority(auth *system.SysAuthority) error {
var s system.SysAuthority
global.GVA_DB.Preload("SysBaseMenus").First(&s, "authority_id = ?", auth.AuthorityId)
err := global.GVA_DB.Model(&s).Association("SysBaseMenus").Replace(&auth.SysBaseMenus)
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: findChildrenAuthority
//@description: 查询子角色
//@param: authority *model.SysAuthority
//@return: err error
func (authorityService *AuthorityService) findChildrenAuthority(authority *system.SysAuthority) (err error) {
err = global.GVA_DB.Preload("DataAuthorityId").Where("parent_id = ?", authority.AuthorityId).Find(&authority.Children).Error
if len(authority.Children) > 0 {
for k := range authority.Children {
err = authorityService.findChildrenAuthority(&authority.Children[k])
}
}
return err
}
func (authorityService *AuthorityService) GetParentAuthorityID(authorityID uint) (parentID uint, err error) {
var authority system.SysAuthority
err = global.GVA_DB.Where("authority_id = ?", authorityID).First(&authority).Error
if err != nil {
return
}
return *authority.ParentId, nil
}

View File

@@ -0,0 +1,60 @@
package system
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
"gorm.io/gorm"
)
type AuthorityBtnService struct{}
var AuthorityBtnServiceApp = new(AuthorityBtnService)
func (a *AuthorityBtnService) GetAuthorityBtn(req request.SysAuthorityBtnReq) (res response.SysAuthorityBtnRes, err error) {
var authorityBtn []system.SysAuthorityBtn
err = global.GVA_DB.Find(&authorityBtn, "authority_id = ? and sys_menu_id = ?", req.AuthorityId, req.MenuID).Error
if err != nil {
return
}
var selected []uint
for _, v := range authorityBtn {
selected = append(selected, v.SysBaseMenuBtnID)
}
res.Selected = selected
return res, err
}
func (a *AuthorityBtnService) SetAuthorityBtn(req request.SysAuthorityBtnReq) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
var authorityBtn []system.SysAuthorityBtn
err = tx.Delete(&[]system.SysAuthorityBtn{}, "authority_id = ? and sys_menu_id = ?", req.AuthorityId, req.MenuID).Error
if err != nil {
return err
}
for _, v := range req.Selected {
authorityBtn = append(authorityBtn, system.SysAuthorityBtn{
AuthorityId: req.AuthorityId,
SysMenuID: req.MenuID,
SysBaseMenuBtnID: v,
})
}
if len(authorityBtn) > 0 {
err = tx.Create(&authorityBtn).Error
}
if err != nil {
return err
}
return err
})
}
func (a *AuthorityBtnService) CanRemoveAuthorityBtn(ID string) (err error) {
fErr := global.GVA_DB.First(&system.SysAuthorityBtn{}, "sys_base_menu_btn_id = ?", ID).Error
if errors.Is(fErr, gorm.ErrRecordNotFound) {
return nil
}
return errors.New("此按钮正在被使用无法删除")
}

View File

@@ -0,0 +1,55 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
)
type AutoCodeService struct{}
type Database interface {
GetDB(businessDB string) (data []response.Db, err error)
GetTables(businessDB string, dbName string) (data []response.Table, err error)
GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error)
}
func (autoCodeService *AutoCodeService) Database(businessDB string) Database {
if businessDB == "" {
switch global.GVA_CONFIG.System.DbType {
case "mysql":
return AutoCodeMysql
case "pgsql":
return AutoCodePgsql
case "mssql":
return AutoCodeMssql
case "oracle":
return AutoCodeOracle
case "sqlite":
return AutoCodeSqlite
default:
return AutoCodeMysql
}
} else {
for _, info := range global.GVA_CONFIG.DBList {
if info.AliasName == businessDB {
switch info.Type {
case "mysql":
return AutoCodeMysql
case "mssql":
return AutoCodeMssql
case "pgsql":
return AutoCodePgsql
case "oracle":
return AutoCodeOracle
case "sqlite":
return AutoCodeSqlite
default:
return AutoCodeMysql
}
}
}
return AutoCodeMysql
}
}

View File

@@ -0,0 +1,83 @@
package system
import (
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
)
var AutoCodeMssql = new(autoCodeMssql)
type autoCodeMssql struct{}
// GetDB 获取数据库的所有数据库名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMssql) GetDB(businessDB string) (data []response.Db, err error) {
var entities []response.Db
sql := "select name AS 'database' from sys.databases;"
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
}
return entities, err
}
// GetTables 获取数据库的所有表名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMssql) GetTables(businessDB string, dbName string) (data []response.Table, err error) {
var entities []response.Table
sql := fmt.Sprintf(`select name as 'table_name' from %s.DBO.sysobjects where xtype='U'`, dbName)
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
}
return entities, err
}
// GetColumn 获取指定数据库和指定数据表的所有字段名,类型值等
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMssql) GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error) {
var entities []response.Column
sql := fmt.Sprintf(`
SELECT
sc.name AS column_name,
st.name AS data_type,
sc.max_length AS data_type_long,
CASE
WHEN pk.object_id IS NOT NULL THEN 1
ELSE 0
END AS primary_key,
sc.column_id
FROM
%s.sys.columns sc
JOIN
sys.types st ON sc.user_type_id=st.user_type_id
LEFT JOIN
%s.sys.objects so ON so.name='%s' AND so.type='U'
LEFT JOIN
%s.sys.indexes si ON si.object_id = so.object_id AND si.is_primary_key = 1
LEFT JOIN
%s.sys.index_columns sic ON sic.object_id = si.object_id AND sic.index_id = si.index_id AND sic.column_id = sc.column_id
LEFT JOIN
%s.sys.key_constraints pk ON pk.object_id = si.object_id
WHERE
st.is_user_defined=0 AND sc.object_id = so.object_id
ORDER BY
sc.column_id
`, dbName, dbName, tableName, dbName, dbName, dbName)
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
}
return entities, err
}

View File

@@ -0,0 +1,83 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
)
var AutoCodeMysql = new(autoCodeMysql)
type autoCodeMysql struct{}
// GetDB 获取数据库的所有数据库名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMysql) GetDB(businessDB string) (data []response.Db, err error) {
var entities []response.Db
sql := "SELECT SCHEMA_NAME AS `database` FROM INFORMATION_SCHEMA.SCHEMATA;"
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
}
return entities, err
}
// GetTables 获取数据库的所有表名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMysql) GetTables(businessDB string, dbName string) (data []response.Table, err error) {
var entities []response.Table
sql := `select table_name as table_name from information_schema.tables where table_schema = ?`
if businessDB == "" {
err = global.GVA_DB.Raw(sql, dbName).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql, dbName).Scan(&entities).Error
}
return entities, err
}
// GetColumn 获取指定数据库和指定数据表的所有字段名,类型值等
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeMysql) GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error) {
var entities []response.Column
sql := `
SELECT
c.COLUMN_NAME column_name,
c.DATA_TYPE data_type,
CASE c.DATA_TYPE
WHEN 'longtext' THEN c.CHARACTER_MAXIMUM_LENGTH
WHEN 'varchar' THEN c.CHARACTER_MAXIMUM_LENGTH
WHEN 'double' THEN CONCAT_WS(',', c.NUMERIC_PRECISION, c.NUMERIC_SCALE)
WHEN 'decimal' THEN CONCAT_WS(',', c.NUMERIC_PRECISION, c.NUMERIC_SCALE)
WHEN 'int' THEN c.NUMERIC_PRECISION
WHEN 'bigint' THEN c.NUMERIC_PRECISION
ELSE ''
END AS data_type_long,
c.COLUMN_COMMENT column_comment,
CASE WHEN kcu.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END AS primary_key,
c.ORDINAL_POSITION
FROM
INFORMATION_SCHEMA.COLUMNS c
LEFT JOIN
INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
ON
c.TABLE_SCHEMA = kcu.TABLE_SCHEMA
AND c.TABLE_NAME = kcu.TABLE_NAME
AND c.COLUMN_NAME = kcu.COLUMN_NAME
AND kcu.CONSTRAINT_NAME = 'PRIMARY'
WHERE
c.TABLE_NAME = ?
AND c.TABLE_SCHEMA = ?
ORDER BY
c.ORDINAL_POSITION;`
if businessDB == "" {
err = global.GVA_DB.Raw(sql, tableName, dbName).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql, tableName, dbName).Scan(&entities).Error
}
return entities, err
}

View File

@@ -0,0 +1,72 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
)
var AutoCodeOracle = new(autoCodeOracle)
type autoCodeOracle struct{}
// GetDB 获取数据库的所有数据库名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeOracle) GetDB(businessDB string) (data []response.Db, err error) {
var entities []response.Db
sql := `SELECT lower(username) AS "database" FROM all_users`
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
return entities, err
}
// GetTables 获取数据库的所有表名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeOracle) GetTables(businessDB string, dbName string) (data []response.Table, err error) {
var entities []response.Table
sql := `select lower(table_name) as "table_name" from all_tables where lower(owner) = ?`
err = global.GVA_DBList[businessDB].Raw(sql, dbName).Scan(&entities).Error
return entities, err
}
// GetColumn 获取指定数据库和指定数据表的所有字段名,类型值等
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (s *autoCodeOracle) GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error) {
var entities []response.Column
sql := `
SELECT
lower(a.COLUMN_NAME) as "column_name",
(CASE WHEN a.DATA_TYPE = 'NUMBER' AND a.DATA_SCALE=0 THEN 'int' else lower(a.DATA_TYPE) end) as "data_type",
(CASE WHEN a.DATA_TYPE = 'NUMBER' THEN a.DATA_PRECISION else a.DATA_LENGTH end) as "data_type_long",
b.COMMENTS as "column_comment",
(CASE WHEN pk.COLUMN_NAME IS NOT NULL THEN 1 ELSE 0 END) as "primary_key",
a.COLUMN_ID
FROM
all_tab_columns a
JOIN
all_col_comments b ON a.OWNER = b.OWNER AND a.TABLE_NAME = b.TABLE_NAME AND a.COLUMN_NAME = b.COLUMN_NAME
LEFT JOIN
(
SELECT
acc.OWNER,
acc.TABLE_NAME,
acc.COLUMN_NAME
FROM
all_cons_columns acc
JOIN
all_constraints ac ON acc.OWNER = ac.OWNER AND acc.CONSTRAINT_NAME = ac.CONSTRAINT_NAME
WHERE
ac.CONSTRAINT_TYPE = 'P'
) pk ON a.OWNER = pk.OWNER AND a.TABLE_NAME = pk.TABLE_NAME AND a.COLUMN_NAME = pk.COLUMN_NAME
WHERE
lower(a.table_name) = ?
AND lower(a.OWNER) = ?
ORDER BY
a.COLUMN_ID
`
err = global.GVA_DBList[businessDB].Raw(sql, tableName, dbName).Scan(&entities).Error
return entities, err
}

View File

@@ -0,0 +1,135 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
)
var AutoCodePgsql = new(autoCodePgsql)
type autoCodePgsql struct{}
// GetDB 获取数据库的所有数据库名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodePgsql) GetDB(businessDB string) (data []response.Db, err error) {
var entities []response.Db
sql := `SELECT datname as database FROM pg_database WHERE datistemplate = false`
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&entities).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&entities).Error
}
return entities, err
}
// GetTables 获取数据库的所有表名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodePgsql) GetTables(businessDB string, dbName string) (data []response.Table, err error) {
var entities []response.Table
sql := `select table_name as table_name from information_schema.tables where table_catalog = ? and table_schema = ?`
db := global.GVA_DB
if businessDB != "" {
db = global.GVA_DBList[businessDB]
}
err = db.Raw(sql, dbName, "public").Scan(&entities).Error
return entities, err
}
// GetColumn 获取指定数据库和指定数据表的所有字段名,类型值等
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodePgsql) GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error) {
// todo 数据获取不全, 待完善sql
sql := `
SELECT
psc.COLUMN_NAME AS COLUMN_NAME,
psc.udt_name AS data_type,
CASE
psc.udt_name
WHEN 'text' THEN
concat_ws ( '', '', psc.CHARACTER_MAXIMUM_LENGTH )
WHEN 'varchar' THEN
concat_ws ( '', '', psc.CHARACTER_MAXIMUM_LENGTH )
WHEN 'smallint' THEN
concat_ws ( ',', psc.NUMERIC_PRECISION, psc.NUMERIC_SCALE )
WHEN 'decimal' THEN
concat_ws ( ',', psc.NUMERIC_PRECISION, psc.NUMERIC_SCALE )
WHEN 'integer' THEN
concat_ws ( '', '', psc.NUMERIC_PRECISION )
WHEN 'int4' THEN
concat_ws ( '', '', psc.NUMERIC_PRECISION )
WHEN 'int8' THEN
concat_ws ( '', '', psc.NUMERIC_PRECISION )
WHEN 'bigint' THEN
concat_ws ( '', '', psc.NUMERIC_PRECISION )
WHEN 'timestamp' THEN
concat_ws ( '', '', psc.datetime_precision )
ELSE ''
END AS data_type_long,
(
SELECT
pd.description
FROM
pg_description pd
WHERE
(pd.objoid,pd.objsubid) in (
SELECT pa.attrelid,pa.attnum
FROM
pg_attribute pa
WHERE pa.attrelid = ( SELECT oid FROM pg_class pc WHERE
pc.relname = psc.table_name
)
and attname = psc.column_name
)
) AS column_comment,
(
SELECT
COUNT(*)
FROM
pg_constraint
WHERE
contype = 'p'
AND conrelid = (
SELECT
oid
FROM
pg_class
WHERE
relname = psc.table_name
)
AND conkey::int[] @> ARRAY[(
SELECT
attnum::integer
FROM
pg_attribute
WHERE
attrelid = conrelid
AND attname = psc.column_name
)]
) > 0 AS primary_key,
psc.ordinal_position
FROM
INFORMATION_SCHEMA.COLUMNS psc
WHERE
table_catalog = ?
AND table_schema = 'public'
AND TABLE_NAME = ?
ORDER BY
psc.ordinal_position;
`
var entities []response.Column
//sql = strings.ReplaceAll(sql, "@table_catalog", dbName)
//sql = strings.ReplaceAll(sql, "@table_name", tableName)
db := global.GVA_DB
if businessDB != "" {
db = global.GVA_DBList[businessDB]
}
err = db.Raw(sql, dbName, tableName).Scan(&entities).Error
return entities, err
}

View File

@@ -0,0 +1,84 @@
package system
import (
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
"path/filepath"
"strings"
)
var AutoCodeSqlite = new(autoCodeSqlite)
type autoCodeSqlite struct{}
// GetDB 获取数据库的所有数据库名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodeSqlite) GetDB(businessDB string) (data []response.Db, err error) {
var entities []response.Db
sql := "PRAGMA database_list;"
var databaseList []struct {
File string `gorm:"column:file"`
}
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Find(&databaseList).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Find(&databaseList).Error
}
for _, database := range databaseList {
if database.File != "" {
fileName := filepath.Base(database.File)
fileExt := filepath.Ext(fileName)
fileNameWithoutExt := strings.TrimSuffix(fileName, fileExt)
entities = append(entities, response.Db{fileNameWithoutExt})
}
}
// entities = append(entities, response.Db{global.GVA_CONFIG.Sqlite.Dbname})
return entities, err
}
// GetTables 获取数据库的所有表名
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodeSqlite) GetTables(businessDB string, dbName string) (data []response.Table, err error) {
var entities []response.Table
sql := `SELECT name FROM sqlite_master WHERE type='table'`
tabelNames := []string{}
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Find(&tabelNames).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Find(&tabelNames).Error
}
for _, tabelName := range tabelNames {
entities = append(entities, response.Table{tabelName})
}
return entities, err
}
// GetColumn 获取指定数据表的所有字段名,类型值等
// Author [piexlmax](https://github.com/piexlmax)
// Author [SliverHorn](https://github.com/SliverHorn)
func (a *autoCodeSqlite) GetColumn(businessDB string, tableName string, dbName string) (data []response.Column, err error) {
var entities []response.Column
sql := fmt.Sprintf("PRAGMA table_info(%s);", tableName)
var columnInfos []struct {
Name string `gorm:"column:name"`
Type string `gorm:"column:type"`
Pk int `gorm:"column:pk"`
}
if businessDB == "" {
err = global.GVA_DB.Raw(sql).Scan(&columnInfos).Error
} else {
err = global.GVA_DBList[businessDB].Raw(sql).Scan(&columnInfos).Error
}
for _, columnInfo := range columnInfos {
entities = append(entities, response.Column{
ColumnName: columnInfo.Name,
DataType: columnInfo.Type,
PrimaryKey: columnInfo.Pk == 1,
})
}
return entities, err
}

View File

@@ -0,0 +1,147 @@
package system
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"gorm.io/gorm"
)
type BaseMenuService struct{}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteBaseMenu
//@description: 删除基础路由
//@param: id float64
//@return: err error
var BaseMenuServiceApp = new(BaseMenuService)
func (baseMenuService *BaseMenuService) DeleteBaseMenu(id int) (err error) {
err = global.GVA_DB.First(&system.SysBaseMenu{}, "parent_id = ?", id).Error
if err == nil {
return errors.New("此菜单存在子菜单不可删除")
}
var menu system.SysBaseMenu
err = global.GVA_DB.First(&menu, id).Error
if err != nil {
return errors.New("记录不存在")
}
err = global.GVA_DB.First(&system.SysAuthority{}, "default_router = ?", menu.Name).Error
if err == nil {
return errors.New("此菜单有角色正在作为首页,不可删除")
}
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
err = tx.Delete(&system.SysBaseMenu{}, "id = ?", id).Error
if err != nil {
return err
}
err = tx.Delete(&system.SysBaseMenuParameter{}, "sys_base_menu_id = ?", id).Error
if err != nil {
return err
}
err = tx.Delete(&system.SysBaseMenuBtn{}, "sys_base_menu_id = ?", id).Error
if err != nil {
return err
}
err = tx.Delete(&system.SysAuthorityBtn{}, "sys_menu_id = ?", id).Error
if err != nil {
return err
}
err = tx.Delete(&system.SysAuthorityMenu{}, "sys_base_menu_id = ?", id).Error
if err != nil {
return err
}
return nil
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateBaseMenu
//@description: 更新路由
//@param: menu model.SysBaseMenu
//@return: err error
func (baseMenuService *BaseMenuService) UpdateBaseMenu(menu system.SysBaseMenu) (err error) {
var oldMenu system.SysBaseMenu
upDateMap := make(map[string]interface{})
upDateMap["keep_alive"] = menu.KeepAlive
upDateMap["transition_type"] = menu.TransitionType
upDateMap["close_tab"] = menu.CloseTab
upDateMap["default_menu"] = menu.DefaultMenu
upDateMap["parent_id"] = menu.ParentId
upDateMap["path"] = menu.Path
upDateMap["name"] = menu.Name
upDateMap["hidden"] = menu.Hidden
upDateMap["component"] = menu.Component
upDateMap["title"] = menu.Title
upDateMap["active_name"] = menu.ActiveName
upDateMap["icon"] = menu.Icon
upDateMap["sort"] = menu.Sort
err = global.GVA_DB.Transaction(func(tx *gorm.DB) error {
tx.Where("id = ?", menu.ID).Find(&oldMenu)
if oldMenu.Name != menu.Name {
if !errors.Is(tx.Where("id <> ? AND name = ?", menu.ID, menu.Name).First(&system.SysBaseMenu{}).Error, gorm.ErrRecordNotFound) {
global.GVA_LOG.Debug("存在相同name修改失败")
return errors.New("存在相同name修改失败")
}
}
txErr := tx.Unscoped().Delete(&system.SysBaseMenuParameter{}, "sys_base_menu_id = ?", menu.ID).Error
if txErr != nil {
global.GVA_LOG.Debug(txErr.Error())
return txErr
}
txErr = tx.Unscoped().Delete(&system.SysBaseMenuBtn{}, "sys_base_menu_id = ?", menu.ID).Error
if txErr != nil {
global.GVA_LOG.Debug(txErr.Error())
return txErr
}
if len(menu.Parameters) > 0 {
for k := range menu.Parameters {
menu.Parameters[k].SysBaseMenuID = menu.ID
}
txErr = tx.Create(&menu.Parameters).Error
if txErr != nil {
global.GVA_LOG.Debug(txErr.Error())
return txErr
}
}
if len(menu.MenuBtn) > 0 {
for k := range menu.MenuBtn {
menu.MenuBtn[k].SysBaseMenuID = menu.ID
}
txErr = tx.Create(&menu.MenuBtn).Error
if txErr != nil {
global.GVA_LOG.Debug(txErr.Error())
return txErr
}
}
txErr = tx.Model(&oldMenu).Updates(upDateMap).Error
if txErr != nil {
global.GVA_LOG.Debug(txErr.Error())
return txErr
}
return nil
})
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetBaseMenuById
//@description: 返回当前选中menu
//@param: id float64
//@return: menu system.SysBaseMenu, err error
func (baseMenuService *BaseMenuService) GetBaseMenuById(id int) (menu system.SysBaseMenu, err error) {
err = global.GVA_DB.Preload("MenuBtn").Preload("Parameters").Where("id = ?", id).First(&menu).Error
return
}

View File

@@ -0,0 +1,173 @@
package system
import (
"errors"
"strconv"
"gorm.io/gorm"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
gormadapter "github.com/casbin/gorm-adapter/v3"
_ "github.com/go-sql-driver/mysql"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateCasbin
//@description: 更新casbin权限
//@param: authorityId string, casbinInfos []request.CasbinInfo
//@return: error
type CasbinService struct{}
var CasbinServiceApp = new(CasbinService)
func (casbinService *CasbinService) UpdateCasbin(adminAuthorityID, AuthorityID uint, casbinInfos []request.CasbinInfo) error {
err := AuthorityServiceApp.CheckAuthorityIDAuth(adminAuthorityID, AuthorityID)
if err != nil {
return err
}
if global.GVA_CONFIG.System.UseStrictAuth {
apis, e := ApiServiceApp.GetAllApis(adminAuthorityID)
if e != nil {
return e
}
for i := range casbinInfos {
hasApi := false
for j := range apis {
if apis[j].Path == casbinInfos[i].Path && apis[j].Method == casbinInfos[i].Method {
hasApi = true
break
}
}
if !hasApi {
return errors.New("存在api不在权限列表中")
}
}
}
authorityId := strconv.Itoa(int(AuthorityID))
casbinService.ClearCasbin(0, authorityId)
rules := [][]string{}
//做权限去重处理
deduplicateMap := make(map[string]bool)
for _, v := range casbinInfos {
key := authorityId + v.Path + v.Method
if _, ok := deduplicateMap[key]; !ok {
deduplicateMap[key] = true
rules = append(rules, []string{authorityId, v.Path, v.Method})
}
}
if len(rules) == 0 {
return nil
} // 设置空权限无需调用 AddPolicies 方法
e := utils.GetCasbin()
success, _ := e.AddPolicies(rules)
if !success {
return errors.New("存在相同api,添加失败,请联系管理员")
}
return nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateCasbinApi
//@description: API更新随动
//@param: oldPath string, newPath string, oldMethod string, newMethod string
//@return: error
func (casbinService *CasbinService) UpdateCasbinApi(oldPath string, newPath string, oldMethod string, newMethod string) error {
err := global.GVA_DB.Model(&gormadapter.CasbinRule{}).Where("v1 = ? AND v2 = ?", oldPath, oldMethod).Updates(map[string]interface{}{
"v1": newPath,
"v2": newMethod,
}).Error
if err != nil {
return err
}
e := utils.GetCasbin()
return e.LoadPolicy()
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetPolicyPathByAuthorityId
//@description: 获取权限列表
//@param: authorityId string
//@return: pathMaps []request.CasbinInfo
func (casbinService *CasbinService) GetPolicyPathByAuthorityId(AuthorityID uint) (pathMaps []request.CasbinInfo) {
e := utils.GetCasbin()
authorityId := strconv.Itoa(int(AuthorityID))
list, _ := e.GetFilteredPolicy(0, authorityId)
for _, v := range list {
pathMaps = append(pathMaps, request.CasbinInfo{
Path: v[1],
Method: v[2],
})
}
return pathMaps
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: ClearCasbin
//@description: 清除匹配的权限
//@param: v int, p ...string
//@return: bool
func (casbinService *CasbinService) ClearCasbin(v int, p ...string) bool {
e := utils.GetCasbin()
success, _ := e.RemoveFilteredPolicy(v, p...)
return success
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: RemoveFilteredPolicy
//@description: 使用数据库方法清理筛选的politicy 此方法需要调用FreshCasbin方法才可以在系统中即刻生效
//@param: db *gorm.DB, authorityId string
//@return: error
func (casbinService *CasbinService) RemoveFilteredPolicy(db *gorm.DB, authorityId string) error {
return db.Delete(&gormadapter.CasbinRule{}, "v0 = ?", authorityId).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SyncPolicy
//@description: 同步目前数据库的policy 此方法需要调用FreshCasbin方法才可以在系统中即刻生效
//@param: db *gorm.DB, authorityId string, rules [][]string
//@return: error
func (casbinService *CasbinService) SyncPolicy(db *gorm.DB, authorityId string, rules [][]string) error {
err := casbinService.RemoveFilteredPolicy(db, authorityId)
if err != nil {
return err
}
return casbinService.AddPolicies(db, rules)
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: AddPolicies
//@description: 添加匹配的权限
//@param: v int, p ...string
//@return: bool
func (casbinService *CasbinService) AddPolicies(db *gorm.DB, rules [][]string) error {
var casbinRules []gormadapter.CasbinRule
for i := range rules {
casbinRules = append(casbinRules, gormadapter.CasbinRule{
Ptype: "p",
V0: rules[i][0],
V1: rules[i][1],
V2: rules[i][2],
})
}
return db.Create(&casbinRules).Error
}
func (casbinService *CasbinService) FreshCasbin() (err error) {
e := utils.GetCasbin()
err = e.LoadPolicy()
return err
}

View File

@@ -0,0 +1,297 @@
package system
import (
"encoding/json"
"errors"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"github.com/gin-gonic/gin"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"gorm.io/gorm"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateSysDictionary
//@description: 创建字典数据
//@param: sysDictionary model.SysDictionary
//@return: err error
type DictionaryService struct{}
var DictionaryServiceApp = new(DictionaryService)
func (dictionaryService *DictionaryService) CreateSysDictionary(sysDictionary system.SysDictionary) (err error) {
if (!errors.Is(global.GVA_DB.First(&system.SysDictionary{}, "type = ?", sysDictionary.Type).Error, gorm.ErrRecordNotFound)) {
return errors.New("存在相同的type不允许创建")
}
err = global.GVA_DB.Create(&sysDictionary).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteSysDictionary
//@description: 删除字典数据
//@param: sysDictionary model.SysDictionary
//@return: err error
func (dictionaryService *DictionaryService) DeleteSysDictionary(sysDictionary system.SysDictionary) (err error) {
err = global.GVA_DB.Where("id = ?", sysDictionary.ID).Preload("SysDictionaryDetails").First(&sysDictionary).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("请不要搞事")
}
if err != nil {
return err
}
err = global.GVA_DB.Delete(&sysDictionary).Error
if err != nil {
return err
}
if sysDictionary.SysDictionaryDetails != nil {
return global.GVA_DB.Where("sys_dictionary_id=?", sysDictionary.ID).Delete(sysDictionary.SysDictionaryDetails).Error
}
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateSysDictionary
//@description: 更新字典数据
//@param: sysDictionary *model.SysDictionary
//@return: err error
func (dictionaryService *DictionaryService) UpdateSysDictionary(sysDictionary *system.SysDictionary) (err error) {
var dict system.SysDictionary
sysDictionaryMap := map[string]interface{}{
"Name": sysDictionary.Name,
"Type": sysDictionary.Type,
"Status": sysDictionary.Status,
"Desc": sysDictionary.Desc,
"ParentID": sysDictionary.ParentID,
}
err = global.GVA_DB.Where("id = ?", sysDictionary.ID).First(&dict).Error
if err != nil {
global.GVA_LOG.Debug(err.Error())
return errors.New("查询字典数据失败")
}
if dict.Type != sysDictionary.Type {
if !errors.Is(global.GVA_DB.First(&system.SysDictionary{}, "type = ?", sysDictionary.Type).Error, gorm.ErrRecordNotFound) {
return errors.New("存在相同的type不允许创建")
}
}
// 检查是否会形成循环引用
if sysDictionary.ParentID != nil && *sysDictionary.ParentID != 0 {
if err := dictionaryService.checkCircularReference(sysDictionary.ID, *sysDictionary.ParentID); err != nil {
return err
}
}
err = global.GVA_DB.Model(&dict).Updates(sysDictionaryMap).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetSysDictionary
//@description: 根据id或者type获取字典单条数据
//@param: Type string, Id uint
//@return: err error, sysDictionary model.SysDictionary
func (dictionaryService *DictionaryService) GetSysDictionary(Type string, Id uint, status *bool) (sysDictionary system.SysDictionary, err error) {
var flag = false
if status == nil {
flag = true
} else {
flag = *status
}
err = global.GVA_DB.Where("(type = ? OR id = ?) and status = ?", Type, Id, flag).Preload("SysDictionaryDetails", func(db *gorm.DB) *gorm.DB {
return db.Where("status = ? and deleted_at is null", true).Order("sort")
}).First(&sysDictionary).Error
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: GetSysDictionaryInfoList
//@description: 分页获取字典列表
//@param: info request.SysDictionarySearch
//@return: err error, list interface{}, total int64
func (dictionaryService *DictionaryService) GetSysDictionaryInfoList(c *gin.Context, req request.SysDictionarySearch) (list interface{}, err error) {
var sysDictionarys []system.SysDictionary
query := global.GVA_DB.WithContext(c)
if req.Name != "" {
query = query.Where("name LIKE ? OR type LIKE ?", "%"+req.Name+"%", "%"+req.Name+"%")
}
// 预加载子字典
query = query.Preload("Children")
err = query.Find(&sysDictionarys).Error
return sysDictionarys, err
}
// checkCircularReference 检查是否会形成循环引用
func (dictionaryService *DictionaryService) checkCircularReference(currentID uint, parentID uint) error {
if currentID == parentID {
return errors.New("不能将字典设置为自己的父级")
}
// 递归检查父级链条
var parent system.SysDictionary
err := global.GVA_DB.Where("id = ?", parentID).First(&parent).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil // 父级不存在,允许设置
}
return err
}
// 如果父级还有父级,继续检查
if parent.ParentID != nil && *parent.ParentID != 0 {
return dictionaryService.checkCircularReference(currentID, *parent.ParentID)
}
return nil
}
//@author: [pixelMax]
//@function: ExportSysDictionary
//@description: 导出字典JSON包含字典详情
//@param: id uint
//@return: exportData map[string]interface{}, err error
func (dictionaryService *DictionaryService) ExportSysDictionary(id uint) (exportData map[string]interface{}, err error) {
var dictionary system.SysDictionary
// 查询字典及其所有详情
err = global.GVA_DB.Where("id = ?", id).Preload("SysDictionaryDetails", func(db *gorm.DB) *gorm.DB {
return db.Order("sort")
}).First(&dictionary).Error
if err != nil {
return nil, err
}
// 清空字典详情中的ID、创建时间、更新时间等字段
var cleanDetails []map[string]interface{}
for _, detail := range dictionary.SysDictionaryDetails {
cleanDetail := map[string]interface{}{
"label": detail.Label,
"value": detail.Value,
"extend": detail.Extend,
"status": detail.Status,
"sort": detail.Sort,
"level": detail.Level,
"path": detail.Path,
}
cleanDetails = append(cleanDetails, cleanDetail)
}
// 构造导出数据
exportData = map[string]interface{}{
"name": dictionary.Name,
"type": dictionary.Type,
"status": dictionary.Status,
"desc": dictionary.Desc,
"sysDictionaryDetails": cleanDetails,
}
return exportData, nil
}
//@author: [pixelMax]
//@function: ImportSysDictionary
//@description: 导入字典JSON包含字典详情
//@param: jsonStr string
//@return: err error
func (dictionaryService *DictionaryService) ImportSysDictionary(jsonStr string) error {
// 直接解析到 SysDictionary 结构体
var importData system.SysDictionary
if err := json.Unmarshal([]byte(jsonStr), &importData); err != nil {
return errors.New("JSON 格式错误: " + err.Error())
}
// 验证必填字段
if importData.Name == "" {
return errors.New("字典名称不能为空")
}
if importData.Type == "" {
return errors.New("字典类型不能为空")
}
// 检查字典类型是否已存在
if !errors.Is(global.GVA_DB.First(&system.SysDictionary{}, "type = ?", importData.Type).Error, gorm.ErrRecordNotFound) {
return errors.New("存在相同的type不允许导入")
}
// 创建字典清空导入数据的ID和时间戳
dictionary := system.SysDictionary{
Name: importData.Name,
Type: importData.Type,
Status: importData.Status,
Desc: importData.Desc,
}
// 开启事务
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 创建字典
if err := tx.Create(&dictionary).Error; err != nil {
return err
}
// 处理字典详情
if len(importData.SysDictionaryDetails) > 0 {
// 创建一个映射来跟踪旧ID到新ID的对应关系
idMap := make(map[uint]uint)
// 第一遍:创建所有详情记录
for _, detail := range importData.SysDictionaryDetails {
// 验证必填字段
if detail.Label == "" || detail.Value == "" {
continue
}
// 记录旧ID
oldID := detail.ID
// 创建新的详情记录ID会被GORM自动设置
detailRecord := system.SysDictionaryDetail{
Label: detail.Label,
Value: detail.Value,
Extend: detail.Extend,
Status: detail.Status,
Sort: detail.Sort,
Level: detail.Level,
Path: detail.Path,
SysDictionaryID: int(dictionary.ID),
}
// 创建详情记录
if err := tx.Create(&detailRecord).Error; err != nil {
return err
}
// 记录旧ID到新ID的映射
if oldID > 0 {
idMap[oldID] = detailRecord.ID
}
}
// 第二遍更新parent_id关系
for _, detail := range importData.SysDictionaryDetails {
if detail.ParentID != nil && *detail.ParentID > 0 && detail.ID > 0 {
if newID, exists := idMap[detail.ID]; exists {
if newParentID, parentExists := idMap[*detail.ParentID]; parentExists {
if err := tx.Model(&system.SysDictionaryDetail{}).
Where("id = ?", newID).
Update("parent_id", newParentID).Error; err != nil {
return err
}
}
}
}
}
}
return nil
})
}

View File

@@ -0,0 +1,392 @@
package system
import (
"fmt"
"strconv"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: CreateSysDictionaryDetail
//@description: 创建字典详情数据
//@param: sysDictionaryDetail model.SysDictionaryDetail
//@return: err error
type DictionaryDetailService struct{}
var DictionaryDetailServiceApp = new(DictionaryDetailService)
func (dictionaryDetailService *DictionaryDetailService) CreateSysDictionaryDetail(sysDictionaryDetail system.SysDictionaryDetail) (err error) {
// 计算层级和路径
if sysDictionaryDetail.ParentID != nil {
var parent system.SysDictionaryDetail
err = global.GVA_DB.First(&parent, *sysDictionaryDetail.ParentID).Error
if err != nil {
return err
}
sysDictionaryDetail.Level = parent.Level + 1
if parent.Path == "" {
sysDictionaryDetail.Path = strconv.Itoa(int(parent.ID))
} else {
sysDictionaryDetail.Path = parent.Path + "," + strconv.Itoa(int(parent.ID))
}
} else {
sysDictionaryDetail.Level = 0
sysDictionaryDetail.Path = ""
}
err = global.GVA_DB.Create(&sysDictionaryDetail).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteSysDictionaryDetail
//@description: 删除字典详情数据
//@param: sysDictionaryDetail model.SysDictionaryDetail
//@return: err error
func (dictionaryDetailService *DictionaryDetailService) DeleteSysDictionaryDetail(sysDictionaryDetail system.SysDictionaryDetail) (err error) {
// 检查是否有子项
var count int64
err = global.GVA_DB.Model(&system.SysDictionaryDetail{}).Where("parent_id = ?", sysDictionaryDetail.ID).Count(&count).Error
if err != nil {
return err
}
if count > 0 {
return fmt.Errorf("该字典详情下还有子项,无法删除")
}
err = global.GVA_DB.Delete(&sysDictionaryDetail).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: UpdateSysDictionaryDetail
//@description: 更新字典详情数据
//@param: sysDictionaryDetail *model.SysDictionaryDetail
//@return: err error
func (dictionaryDetailService *DictionaryDetailService) UpdateSysDictionaryDetail(sysDictionaryDetail *system.SysDictionaryDetail) (err error) {
// 如果更新了父级ID需要重新计算层级和路径
if sysDictionaryDetail.ParentID != nil {
var parent system.SysDictionaryDetail
err = global.GVA_DB.First(&parent, *sysDictionaryDetail.ParentID).Error
if err != nil {
return err
}
// 检查循环引用
if dictionaryDetailService.checkCircularReference(sysDictionaryDetail.ID, *sysDictionaryDetail.ParentID) {
return fmt.Errorf("不能将字典详情设置为自己或其子项的父级")
}
sysDictionaryDetail.Level = parent.Level + 1
if parent.Path == "" {
sysDictionaryDetail.Path = strconv.Itoa(int(parent.ID))
} else {
sysDictionaryDetail.Path = parent.Path + "," + strconv.Itoa(int(parent.ID))
}
} else {
sysDictionaryDetail.Level = 0
sysDictionaryDetail.Path = ""
}
err = global.GVA_DB.Save(sysDictionaryDetail).Error
if err != nil {
return err
}
// 更新所有子项的层级和路径
return dictionaryDetailService.updateChildrenLevelAndPath(sysDictionaryDetail.ID)
}
// checkCircularReference 检查循环引用
func (dictionaryDetailService *DictionaryDetailService) checkCircularReference(id, parentID uint) bool {
if id == parentID {
return true
}
var parent system.SysDictionaryDetail
err := global.GVA_DB.First(&parent, parentID).Error
if err != nil {
return false
}
if parent.ParentID == nil {
return false
}
return dictionaryDetailService.checkCircularReference(id, *parent.ParentID)
}
// updateChildrenLevelAndPath 更新子项的层级和路径
func (dictionaryDetailService *DictionaryDetailService) updateChildrenLevelAndPath(parentID uint) error {
var children []system.SysDictionaryDetail
err := global.GVA_DB.Where("parent_id = ?", parentID).Find(&children).Error
if err != nil {
return err
}
var parent system.SysDictionaryDetail
err = global.GVA_DB.First(&parent, parentID).Error
if err != nil {
return err
}
for _, child := range children {
child.Level = parent.Level + 1
if parent.Path == "" {
child.Path = strconv.Itoa(int(parent.ID))
} else {
child.Path = parent.Path + "," + strconv.Itoa(int(parent.ID))
}
err = global.GVA_DB.Save(&child).Error
if err != nil {
return err
}
// 递归更新子项的子项
err = dictionaryDetailService.updateChildrenLevelAndPath(child.ID)
if err != nil {
return err
}
}
return nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetSysDictionaryDetail
//@description: 根据id获取字典详情单条数据
//@param: id uint
//@return: sysDictionaryDetail system.SysDictionaryDetail, err error
func (dictionaryDetailService *DictionaryDetailService) GetSysDictionaryDetail(id uint) (sysDictionaryDetail system.SysDictionaryDetail, err error) {
err = global.GVA_DB.Where("id = ?", id).First(&sysDictionaryDetail).Error
return
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetSysDictionaryDetailInfoList
//@description: 分页获取字典详情列表
//@param: info request.SysDictionaryDetailSearch
//@return: list interface{}, total int64, err error
func (dictionaryDetailService *DictionaryDetailService) GetSysDictionaryDetailInfoList(info request.SysDictionaryDetailSearch) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysDictionaryDetail{})
var sysDictionaryDetails []system.SysDictionaryDetail
// 如果有条件搜索 下方会自动创建搜索语句
if info.Label != "" {
db = db.Where("label LIKE ?", "%"+info.Label+"%")
}
if info.Value != "" {
db = db.Where("value = ?", info.Value)
}
if info.Status != nil {
db = db.Where("status = ?", info.Status)
}
if info.SysDictionaryID != 0 {
db = db.Where("sys_dictionary_id = ?", info.SysDictionaryID)
}
if info.ParentID != nil {
db = db.Where("parent_id = ?", *info.ParentID)
}
if info.Level != nil {
db = db.Where("level = ?", *info.Level)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("sort").Order("id").Find(&sysDictionaryDetails).Error
return sysDictionaryDetails, total, err
}
// 按照字典id获取字典全部内容的方法
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryList(dictionaryID uint) (list []system.SysDictionaryDetail, err error) {
var sysDictionaryDetails []system.SysDictionaryDetail
err = global.GVA_DB.Find(&sysDictionaryDetails, "sys_dictionary_id = ?", dictionaryID).Error
return sysDictionaryDetails, err
}
// GetDictionaryTreeList 获取字典树形结构列表
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryTreeList(dictionaryID uint) (list []system.SysDictionaryDetail, err error) {
var sysDictionaryDetails []system.SysDictionaryDetail
// 只获取顶级项目parent_id为空
err = global.GVA_DB.Where("sys_dictionary_id = ? AND parent_id IS NULL", dictionaryID).Order("sort").Find(&sysDictionaryDetails).Error
if err != nil {
return nil, err
}
// 递归加载子项并设置disabled属性
for i := range sysDictionaryDetails {
// 设置disabled属性当status为false时disabled为true
if sysDictionaryDetails[i].Status != nil {
sysDictionaryDetails[i].Disabled = !*sysDictionaryDetails[i].Status
} else {
sysDictionaryDetails[i].Disabled = false // 默认不禁用
}
err = dictionaryDetailService.loadChildren(&sysDictionaryDetails[i])
if err != nil {
return nil, err
}
}
return sysDictionaryDetails, nil
}
// loadChildren 递归加载子项
func (dictionaryDetailService *DictionaryDetailService) loadChildren(detail *system.SysDictionaryDetail) error {
var children []system.SysDictionaryDetail
err := global.GVA_DB.Where("parent_id = ?", detail.ID).Order("sort").Find(&children).Error
if err != nil {
return err
}
for i := range children {
// 设置disabled属性当status为false时disabled为true
if children[i].Status != nil {
children[i].Disabled = !*children[i].Status
} else {
children[i].Disabled = false // 默认不禁用
}
err = dictionaryDetailService.loadChildren(&children[i])
if err != nil {
return err
}
}
detail.Children = children
return nil
}
// GetDictionaryDetailsByParent 根据父级ID获取字典详情
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryDetailsByParent(req request.GetDictionaryDetailsByParentRequest) (list []system.SysDictionaryDetail, err error) {
db := global.GVA_DB.Model(&system.SysDictionaryDetail{}).Where("sys_dictionary_id = ?", req.SysDictionaryID)
if req.ParentID != nil {
db = db.Where("parent_id = ?", *req.ParentID)
} else {
db = db.Where("parent_id IS NULL")
}
err = db.Order("sort").Find(&list).Error
if err != nil {
return list, err
}
// 设置disabled属性
for i := range list {
if list[i].Status != nil {
list[i].Disabled = !*list[i].Status
} else {
list[i].Disabled = false // 默认不禁用
}
}
// 如果需要包含子级数据,使用递归方式加载所有层级的子项
if req.IncludeChildren {
for i := range list {
err = dictionaryDetailService.loadChildren(&list[i])
if err != nil {
return list, err
}
}
}
return list, err
}
// 按照字典type获取字典全部内容的方法
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryListByType(t string) (list []system.SysDictionaryDetail, err error) {
var sysDictionaryDetails []system.SysDictionaryDetail
db := global.GVA_DB.Model(&system.SysDictionaryDetail{}).Joins("JOIN sys_dictionaries ON sys_dictionaries.id = sys_dictionary_details.sys_dictionary_id")
err = db.Find(&sysDictionaryDetails, "type = ?", t).Error
return sysDictionaryDetails, err
}
// GetDictionaryTreeListByType 根据字典类型获取树形结构
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryTreeListByType(t string) (list []system.SysDictionaryDetail, err error) {
var sysDictionaryDetails []system.SysDictionaryDetail
db := global.GVA_DB.Model(&system.SysDictionaryDetail{}).
Joins("JOIN sys_dictionaries ON sys_dictionaries.id = sys_dictionary_details.sys_dictionary_id").
Where("sys_dictionaries.type = ? AND sys_dictionary_details.parent_id IS NULL", t).
Order("sys_dictionary_details.sort")
err = db.Find(&sysDictionaryDetails).Error
if err != nil {
return nil, err
}
// 递归加载子项并设置disabled属性
for i := range sysDictionaryDetails {
// 设置disabled属性当status为false时disabled为true
if sysDictionaryDetails[i].Status != nil {
sysDictionaryDetails[i].Disabled = !*sysDictionaryDetails[i].Status
} else {
sysDictionaryDetails[i].Disabled = false // 默认不禁用
}
err = dictionaryDetailService.loadChildren(&sysDictionaryDetails[i])
if err != nil {
return nil, err
}
}
return sysDictionaryDetails, nil
}
// 按照字典id+字典内容value获取单条字典内容
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryInfoByValue(dictionaryID uint, value string) (detail system.SysDictionaryDetail, err error) {
var sysDictionaryDetail system.SysDictionaryDetail
err = global.GVA_DB.First(&sysDictionaryDetail, "sys_dictionary_id = ? and value = ?", dictionaryID, value).Error
return sysDictionaryDetail, err
}
// 按照字典type+字典内容value获取单条字典内容
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryInfoByTypeValue(t string, value string) (detail system.SysDictionaryDetail, err error) {
var sysDictionaryDetails system.SysDictionaryDetail
db := global.GVA_DB.Model(&system.SysDictionaryDetail{}).Joins("JOIN sys_dictionaries ON sys_dictionaries.id = sys_dictionary_details.sys_dictionary_id")
err = db.First(&sysDictionaryDetails, "sys_dictionaries.type = ? and sys_dictionary_details.value = ?", t, value).Error
return sysDictionaryDetails, err
}
// GetDictionaryPath 获取字典详情的完整路径
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryPath(id uint) (path []system.SysDictionaryDetail, err error) {
var detail system.SysDictionaryDetail
err = global.GVA_DB.First(&detail, id).Error
if err != nil {
return nil, err
}
path = append(path, detail)
if detail.ParentID != nil {
parentPath, err := dictionaryDetailService.GetDictionaryPath(*detail.ParentID)
if err != nil {
return nil, err
}
path = append(parentPath, path...)
}
return path, nil
}
// GetDictionaryPathByValue 根据值获取字典详情的完整路径
func (dictionaryDetailService *DictionaryDetailService) GetDictionaryPathByValue(dictionaryID uint, value string) (path []system.SysDictionaryDetail, err error) {
detail, err := dictionaryDetailService.GetDictionaryInfoByValue(dictionaryID, value)
if err != nil {
return nil, err
}
return dictionaryDetailService.GetDictionaryPath(detail.ID)
}

View File

@@ -0,0 +1,126 @@
package system
import (
"context"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
)
type SysErrorService struct{}
// CreateSysError 创建错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) CreateSysError(ctx context.Context, sysError *system.SysError) (err error) {
if global.GVA_DB == nil {
return nil
}
err = global.GVA_DB.Create(sysError).Error
return err
}
// DeleteSysError 删除错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) DeleteSysError(ctx context.Context, ID string) (err error) {
err = global.GVA_DB.Delete(&system.SysError{}, "id = ?", ID).Error
return err
}
// DeleteSysErrorByIds 批量删除错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) DeleteSysErrorByIds(ctx context.Context, IDs []string) (err error) {
err = global.GVA_DB.Delete(&[]system.SysError{}, "id in ?", IDs).Error
return err
}
// UpdateSysError 更新错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) UpdateSysError(ctx context.Context, sysError system.SysError) (err error) {
err = global.GVA_DB.Model(&system.SysError{}).Where("id = ?", sysError.ID).Updates(&sysError).Error
return err
}
// GetSysError 根据ID获取错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) GetSysError(ctx context.Context, ID string) (sysError system.SysError, err error) {
err = global.GVA_DB.Where("id = ?", ID).First(&sysError).Error
return
}
// GetSysErrorInfoList 分页获取错误日志记录
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) GetSysErrorInfoList(ctx context.Context, info systemReq.SysErrorSearch) (list []system.SysError, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysError{}).Order("created_at desc")
var sysErrors []system.SysError
// 如果有条件搜索 下方会自动创建搜索语句
if len(info.CreatedAtRange) == 2 {
db = db.Where("created_at BETWEEN ? AND ?", info.CreatedAtRange[0], info.CreatedAtRange[1])
}
if info.Form != nil && *info.Form != "" {
db = db.Where("form = ?", *info.Form)
}
if info.Info != nil && *info.Info != "" {
db = db.Where("info LIKE ?", "%"+*info.Info+"%")
}
err = db.Count(&total).Error
if err != nil {
return
}
if limit != 0 {
db = db.Limit(limit).Offset(offset)
}
err = db.Find(&sysErrors).Error
return sysErrors, total, err
}
// GetSysErrorSolution 异步处理错误
// Author [yourname](https://github.com/yourname)
func (sysErrorService *SysErrorService) GetSysErrorSolution(ctx context.Context, ID string) (err error) {
// 立即更新为处理中
err = global.GVA_DB.WithContext(ctx).Model(&system.SysError{}).Where("id = ?", ID).Update("status", "处理中").Error
if err != nil {
return err
}
// 异步协程在一分钟后更新为处理完成
go func(id string) {
// 查询当前错误信息用于生成方案
var se system.SysError
_ = global.GVA_DB.Model(&system.SysError{}).Where("id = ?", id).First(&se).Error
// 构造 LLM 请求参数,使用管家模式(butler)根据错误信息生成解决方案
var form, info string
if se.Form != nil {
form = *se.Form
}
if se.Info != nil {
info = *se.Info
}
llmReq := common.JSONMap{
"mode": "solution",
"info": info,
"form": form,
}
// 调用服务层 LLMAuto忽略错误但尽量写入方案
var solution string
if data, err := (&AutoCodeService{}).LLMAuto(context.Background(), llmReq); err == nil {
solution = fmt.Sprintf("%v", data.(map[string]interface{})["text"])
_ = global.GVA_DB.Model(&system.SysError{}).Where("id = ?", id).Updates(map[string]interface{}{"status": "处理完成", "solution": solution}).Error
} else {
// 即使生成失败也标记为完成,避免任务卡住
_ = global.GVA_DB.Model(&system.SysError{}).Where("id = ?", id).Update("status", "处理失败").Error
}
}(ID)
return nil
}

View File

@@ -0,0 +1,724 @@
package system
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"mime/multipart"
"net/url"
"strconv"
"strings"
"time"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/xuri/excelize/v2"
"gorm.io/gorm"
)
type SysExportTemplateService struct {
}
var SysExportTemplateServiceApp = new(SysExportTemplateService)
// CreateSysExportTemplate 创建导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) CreateSysExportTemplate(sysExportTemplate *system.SysExportTemplate) (err error) {
err = global.GVA_DB.Create(sysExportTemplate).Error
return err
}
// DeleteSysExportTemplate 删除导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) DeleteSysExportTemplate(sysExportTemplate system.SysExportTemplate) (err error) {
err = global.GVA_DB.Delete(&sysExportTemplate).Error
return err
}
// DeleteSysExportTemplateByIds 批量删除导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) DeleteSysExportTemplateByIds(ids request.IdsReq) (err error) {
err = global.GVA_DB.Delete(&[]system.SysExportTemplate{}, "id in ?", ids.Ids).Error
return err
}
// UpdateSysExportTemplate 更新导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) UpdateSysExportTemplate(sysExportTemplate system.SysExportTemplate) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
conditions := sysExportTemplate.Conditions
e := tx.Delete(&[]system.Condition{}, "template_id = ?", sysExportTemplate.TemplateID).Error
if e != nil {
return e
}
sysExportTemplate.Conditions = nil
joins := sysExportTemplate.JoinTemplate
e = tx.Delete(&[]system.JoinTemplate{}, "template_id = ?", sysExportTemplate.TemplateID).Error
if e != nil {
return e
}
sysExportTemplate.JoinTemplate = nil
e = tx.Updates(&sysExportTemplate).Error
if e != nil {
return e
}
if len(conditions) > 0 {
for i := range conditions {
conditions[i].ID = 0
}
e = tx.Create(&conditions).Error
}
if len(joins) > 0 {
for i := range joins {
joins[i].ID = 0
}
e = tx.Create(&joins).Error
}
return e
})
}
// GetSysExportTemplate 根据id获取导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) GetSysExportTemplate(id uint) (sysExportTemplate system.SysExportTemplate, err error) {
err = global.GVA_DB.Where("id = ?", id).Preload("JoinTemplate").Preload("Conditions").First(&sysExportTemplate).Error
return
}
// GetSysExportTemplateInfoList 分页获取导出模板记录
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) GetSysExportTemplateInfoList(info systemReq.SysExportTemplateSearch) (list []system.SysExportTemplate, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysExportTemplate{})
var sysExportTemplates []system.SysExportTemplate
// 如果有条件搜索 下方会自动创建搜索语句
if info.StartCreatedAt != nil && info.EndCreatedAt != nil {
db = db.Where("created_at BETWEEN ? AND ?", info.StartCreatedAt, info.EndCreatedAt)
}
if info.Name != "" {
db = db.Where("name LIKE ?", "%"+info.Name+"%")
}
if info.TableName != "" {
db = db.Where("table_name = ?", info.TableName)
}
if info.TemplateID != "" {
db = db.Where("template_id = ?", info.TemplateID)
}
err = db.Count(&total).Error
if err != nil {
return
}
if limit != 0 {
db = db.Limit(limit).Offset(offset)
}
err = db.Find(&sysExportTemplates).Error
return sysExportTemplates, total, err
}
// ExportExcel 导出Excel
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) ExportExcel(templateID string, values url.Values) (file *bytes.Buffer, name string, err error) {
var params = values.Get("params")
paramsValues, err := url.ParseQuery(params)
if err != nil {
return nil, "", fmt.Errorf("解析 params 参数失败: %v", err)
}
var template system.SysExportTemplate
err = global.GVA_DB.Preload("Conditions").Preload("JoinTemplate").First(&template, "template_id = ?", templateID).Error
if err != nil {
return nil, "", err
}
f := excelize.NewFile()
defer func() {
if err := f.Close(); err != nil {
fmt.Println(err)
}
}()
// Create a new sheet.
index, err := f.NewSheet("Sheet1")
if err != nil {
fmt.Println(err)
return
}
var templateInfoMap = make(map[string]string)
columns, err := utils.GetJSONKeys(template.TemplateInfo)
if err != nil {
return nil, "", err
}
err = json.Unmarshal([]byte(template.TemplateInfo), &templateInfoMap)
if err != nil {
return nil, "", err
}
var tableTitle []string
var selectKeyFmt []string
for _, key := range columns {
selectKeyFmt = append(selectKeyFmt, key)
tableTitle = append(tableTitle, templateInfoMap[key])
}
selects := strings.Join(selectKeyFmt, ", ")
var tableMap []map[string]interface{}
db := global.GVA_DB
if template.DBName != "" {
db = global.MustGetGlobalDBByDBName(template.DBName)
}
// 如果有自定义SQL则优先使用自定义SQL
if template.SQL != "" {
// 将 url.Values 转换为 map[string]interface{} 以支持 GORM 的命名参数
sqlParams := make(map[string]interface{})
for k, v := range paramsValues {
if len(v) > 0 {
sqlParams[k] = v[0]
}
}
// 执行原生 SQL支持 @key 命名参数
err = db.Raw(template.SQL, sqlParams).Scan(&tableMap).Error
if err != nil {
return nil, "", err
}
} else {
if len(template.JoinTemplate) > 0 {
for _, join := range template.JoinTemplate {
db = db.Joins(join.JOINS + " " + join.Table + " ON " + join.ON)
}
}
db = db.Select(selects).Table(template.TableName)
filterDeleted := false
filterParam := paramsValues.Get("filterDeleted")
if filterParam == "true" {
filterDeleted = true
}
if filterDeleted {
// 自动过滤主表的软删除
db = db.Where(fmt.Sprintf("%s.deleted_at IS NULL", template.TableName))
// 过滤关联表的软删除(如果有)
if len(template.JoinTemplate) > 0 {
for _, join := range template.JoinTemplate {
// 检查关联表是否有deleted_at字段
hasDeletedAt := sysExportTemplateService.hasDeletedAtColumn(join.Table)
if hasDeletedAt {
db = db.Where(fmt.Sprintf("%s.deleted_at IS NULL", join.Table))
}
}
}
}
if len(template.Conditions) > 0 {
for _, condition := range template.Conditions {
sql := fmt.Sprintf("%s %s ?", condition.Column, condition.Operator)
value := paramsValues.Get(condition.From)
if condition.Operator == "IN" || condition.Operator == "NOT IN" {
sql = fmt.Sprintf("%s %s (?)", condition.Column, condition.Operator)
}
if condition.Operator == "BETWEEN" {
sql = fmt.Sprintf("%s BETWEEN ? AND ?", condition.Column)
startValue := paramsValues.Get("start" + condition.From)
endValue := paramsValues.Get("end" + condition.From)
if startValue != "" && endValue != "" {
db = db.Where(sql, startValue, endValue)
}
continue
}
if value != "" {
if condition.Operator == "LIKE" {
value = "%" + value + "%"
}
db = db.Where(sql, value)
}
}
}
// 通过参数传入limit
limit := paramsValues.Get("limit")
if limit != "" {
l, e := strconv.Atoi(limit)
if e == nil {
db = db.Limit(l)
}
}
// 模板的默认limit
if limit == "" && template.Limit != nil && *template.Limit != 0 {
db = db.Limit(*template.Limit)
}
// 通过参数传入offset
offset := paramsValues.Get("offset")
if offset != "" {
o, e := strconv.Atoi(offset)
if e == nil {
db = db.Offset(o)
}
}
// 获取当前表的所有字段
table := template.TableName
orderColumns, err := db.Migrator().ColumnTypes(table)
if err != nil {
return nil, "", err
}
// 创建一个 map 来存储字段名
fields := make(map[string]bool)
for _, column := range orderColumns {
fields[column.Name()] = true
}
// 通过参数传入order
order := paramsValues.Get("order")
if order == "" && template.Order != "" {
// 如果没有order入参这里会使用模板的默认排序
order = template.Order
}
if order != "" {
checkOrderArr := strings.Split(order, " ")
orderStr := ""
// 检查请求的排序字段是否在字段列表中
if _, ok := fields[checkOrderArr[0]]; !ok {
return nil, "", fmt.Errorf("order by %s is not in the fields", order)
}
orderStr = checkOrderArr[0]
if len(checkOrderArr) > 1 {
if checkOrderArr[1] != "asc" && checkOrderArr[1] != "desc" {
return nil, "", fmt.Errorf("order by %s is not secure", order)
}
orderStr = orderStr + " " + checkOrderArr[1]
}
db = db.Order(orderStr)
}
err = db.Debug().Find(&tableMap).Error
if err != nil {
return nil, "", err
}
}
var rows [][]string
rows = append(rows, tableTitle)
for _, exTable := range tableMap {
var row []string
for _, column := range columns {
column = strings.ReplaceAll(column, "\"", "")
column = strings.ReplaceAll(column, "`", "")
if len(template.JoinTemplate) > 0 {
columnAs := strings.Split(column, " as ")
if len(columnAs) > 1 {
column = strings.TrimSpace(strings.Split(column, " as ")[1])
} else {
columnArr := strings.Split(column, ".")
if len(columnArr) > 1 {
column = strings.Split(column, ".")[1]
}
}
}
// 需要对时间类型特殊处理
if t, ok := exTable[column].(time.Time); ok {
row = append(row, t.Format("2006-01-02 15:04:05"))
} else {
row = append(row, fmt.Sprintf("%v", exTable[column]))
}
}
rows = append(rows, row)
}
for i, row := range rows {
for j, colCell := range row {
cell := fmt.Sprintf("%s%d", getColumnName(j+1), i+1)
var sErr error
if v, err := strconv.ParseFloat(colCell, 64); err == nil {
sErr = f.SetCellValue("Sheet1", cell, v)
} else if v, err := strconv.ParseInt(colCell, 10, 64); err == nil {
sErr = f.SetCellValue("Sheet1", cell, v)
} else {
sErr = f.SetCellValue("Sheet1", cell, colCell)
}
if sErr != nil {
return nil, "", sErr
}
}
}
f.SetActiveSheet(index)
file, err = f.WriteToBuffer()
if err != nil {
return nil, "", err
}
return file, template.Name, nil
}
// PreviewSQL 预览最终生成的 SQL不执行查询仅返回 SQL 字符串)
// Author [piexlmax](https://github.com/piexlmax) & [trae-ai]
func (sysExportTemplateService *SysExportTemplateService) PreviewSQL(templateID string, values url.Values) (sqlPreview string, err error) {
// 解析 params与导出逻辑保持一致
var params = values.Get("params")
paramsValues, _ := url.ParseQuery(params)
// 加载模板
var template system.SysExportTemplate
err = global.GVA_DB.Preload("Conditions").Preload("JoinTemplate").First(&template, "template_id = ?", templateID).Error
if err != nil {
return "", err
}
// 解析模板列
var templateInfoMap = make(map[string]string)
columns, err := utils.GetJSONKeys(template.TemplateInfo)
if err != nil {
return "", err
}
err = json.Unmarshal([]byte(template.TemplateInfo), &templateInfoMap)
if err != nil {
return "", err
}
var selectKeyFmt []string
for _, key := range columns {
selectKeyFmt = append(selectKeyFmt, key)
}
selects := strings.Join(selectKeyFmt, ", ")
// 生成 FROM 与 JOIN 片段
var sb strings.Builder
sb.WriteString("SELECT ")
sb.WriteString(selects)
sb.WriteString(" FROM ")
sb.WriteString(template.TableName)
if len(template.JoinTemplate) > 0 {
for _, join := range template.JoinTemplate {
sb.WriteString(" ")
sb.WriteString(join.JOINS)
sb.WriteString(" ")
sb.WriteString(join.Table)
sb.WriteString(" ON ")
sb.WriteString(join.ON)
}
}
// WHERE 条件
var wheres []string
// 软删除过滤
filterDeleted := false
if paramsValues != nil {
filterParam := paramsValues.Get("filterDeleted")
if filterParam == "true" {
filterDeleted = true
}
}
if filterDeleted {
wheres = append(wheres, fmt.Sprintf("%s.deleted_at IS NULL", template.TableName))
if len(template.JoinTemplate) > 0 {
for _, join := range template.JoinTemplate {
if sysExportTemplateService.hasDeletedAtColumn(join.Table) {
wheres = append(wheres, fmt.Sprintf("%s.deleted_at IS NULL", join.Table))
}
}
}
}
// 模板条件(保留与 ExportExcel 同步的解析规则)
if len(template.Conditions) > 0 {
for _, condition := range template.Conditions {
op := strings.ToUpper(strings.TrimSpace(condition.Operator))
col := strings.TrimSpace(condition.Column)
// 预览优先展示传入值,没有则展示占位符
val := ""
if paramsValues != nil {
val = paramsValues.Get(condition.From)
}
switch op {
case "BETWEEN":
startValue := ""
endValue := ""
if paramsValues != nil {
startValue = paramsValues.Get("start" + condition.From)
endValue = paramsValues.Get("end" + condition.From)
}
if startValue != "" && endValue != "" {
wheres = append(wheres, fmt.Sprintf("%s BETWEEN '%s' AND '%s'", col, startValue, endValue))
} else {
wheres = append(wheres, fmt.Sprintf("%s BETWEEN {start%s} AND {end%s}", col, condition.From, condition.From))
}
case "IN", "NOT IN":
if val != "" {
// 逗号分隔值做简单展示
parts := strings.Split(val, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
}
wheres = append(wheres, fmt.Sprintf("%s %s ('%s')", col, op, strings.Join(parts, "','")))
} else {
wheres = append(wheres, fmt.Sprintf("%s %s ({%s})", col, op, condition.From))
}
case "LIKE":
if val != "" {
wheres = append(wheres, fmt.Sprintf("%s LIKE '%%%s%%'", col, val))
} else {
wheres = append(wheres, fmt.Sprintf("%s LIKE {%%%s%%}", col, condition.From))
}
default:
if val != "" {
wheres = append(wheres, fmt.Sprintf("%s %s '%s'", col, op, val))
} else {
wheres = append(wheres, fmt.Sprintf("%s %s {%s}", col, op, condition.From))
}
}
}
}
if len(wheres) > 0 {
sb.WriteString(" WHERE ")
sb.WriteString(strings.Join(wheres, " AND "))
}
// 排序
order := ""
if paramsValues != nil {
order = paramsValues.Get("order")
}
if order == "" && template.Order != "" {
order = template.Order
}
if order != "" {
sb.WriteString(" ORDER BY ")
sb.WriteString(order)
}
// limit/offset如果传入或默认值为0则不生成
limitStr := ""
offsetStr := ""
if paramsValues != nil {
limitStr = paramsValues.Get("limit")
offsetStr = paramsValues.Get("offset")
}
// 处理模板默认limit仅当非0时
if limitStr == "" && template.Limit != nil && *template.Limit != 0 {
limitStr = strconv.Itoa(*template.Limit)
}
// 解析为数值,用于判断是否生成
limitInt := 0
offsetInt := 0
if limitStr != "" {
if v, e := strconv.Atoi(limitStr); e == nil {
limitInt = v
}
}
if offsetStr != "" {
if v, e := strconv.Atoi(offsetStr); e == nil {
offsetInt = v
}
}
if limitInt > 0 {
sb.WriteString(" LIMIT ")
sb.WriteString(strconv.Itoa(limitInt))
if offsetInt > 0 {
sb.WriteString(" OFFSET ")
sb.WriteString(strconv.Itoa(offsetInt))
}
} else {
// 当limit未设置或为0时仅当offset>0才生成OFFSET
if offsetInt > 0 {
sb.WriteString(" OFFSET ")
sb.WriteString(strconv.Itoa(offsetInt))
}
}
return sb.String(), nil
}
// ExportTemplate 导出Excel模板
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) ExportTemplate(templateID string) (file *bytes.Buffer, name string, err error) {
var template system.SysExportTemplate
err = global.GVA_DB.First(&template, "template_id = ?", templateID).Error
if err != nil {
return nil, "", err
}
f := excelize.NewFile()
defer func() {
if err := f.Close(); err != nil {
fmt.Println(err)
}
}()
// Create a new sheet.
index, err := f.NewSheet("Sheet1")
if err != nil {
fmt.Println(err)
return
}
var templateInfoMap = make(map[string]string)
columns, err := utils.GetJSONKeys(template.TemplateInfo)
err = json.Unmarshal([]byte(template.TemplateInfo), &templateInfoMap)
if err != nil {
return nil, "", err
}
var tableTitle []string
for _, key := range columns {
tableTitle = append(tableTitle, templateInfoMap[key])
}
for i := range tableTitle {
fErr := f.SetCellValue("Sheet1", fmt.Sprintf("%s%d", getColumnName(i+1), 1), tableTitle[i])
if fErr != nil {
return nil, "", fErr
}
}
f.SetActiveSheet(index)
file, err = f.WriteToBuffer()
if err != nil {
return nil, "", err
}
return file, template.Name, nil
}
// 辅助函数检查表是否有deleted_at列
func (s *SysExportTemplateService) hasDeletedAtColumn(tableName string) bool {
var count int64
global.GVA_DB.Raw("SELECT COUNT(*) FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = ? AND COLUMN_NAME = 'deleted_at'", tableName).Count(&count)
return count > 0
}
// ImportExcel 导入Excel
// Author [piexlmax](https://github.com/piexlmax)
func (sysExportTemplateService *SysExportTemplateService) ImportExcel(templateID string, file *multipart.FileHeader) (err error) {
var template system.SysExportTemplate
err = global.GVA_DB.First(&template, "template_id = ?", templateID).Error
if err != nil {
return err
}
src, err := file.Open()
if err != nil {
return err
}
defer src.Close()
f, err := excelize.OpenReader(src)
if err != nil {
return err
}
rows, err := f.GetRows("Sheet1")
if err != nil {
return err
}
if len(rows) < 2 {
return errors.New("Excel data is not enough.\nIt should contain title row and data")
}
var templateInfoMap = make(map[string]string)
err = json.Unmarshal([]byte(template.TemplateInfo), &templateInfoMap)
if err != nil {
return err
}
db := global.GVA_DB
if template.DBName != "" {
db = global.MustGetGlobalDBByDBName(template.DBName)
}
items, err := sysExportTemplateService.parseExcelToMap(rows, templateInfoMap)
if err != nil {
return err
}
return db.Transaction(func(tx *gorm.DB) error {
if template.ImportSQL != "" {
return sysExportTemplateService.importBySQL(tx, template.ImportSQL, items)
}
return sysExportTemplateService.importByGORM(tx, template.TableName, items)
})
}
func (sysExportTemplateService *SysExportTemplateService) parseExcelToMap(rows [][]string, templateInfoMap map[string]string) ([]map[string]interface{}, error) {
var titleKeyMap = make(map[string]string)
for key, title := range templateInfoMap {
titleKeyMap[title] = key
}
excelTitle := rows[0]
for i, str := range excelTitle {
excelTitle[i] = strings.TrimSpace(str)
}
values := rows[1:]
items := make([]map[string]interface{}, 0, len(values))
for _, row := range values {
var item = make(map[string]interface{})
for ii, value := range row {
if ii >= len(excelTitle) {
continue
}
if _, ok := titleKeyMap[excelTitle[ii]]; !ok {
continue // excel中多余的标题在模板信息中没有对应的字段因此key为空必须跳过
}
key := titleKeyMap[excelTitle[ii]]
item[key] = value
}
items = append(items, item)
}
return items, nil
}
func (sysExportTemplateService *SysExportTemplateService) importBySQL(tx *gorm.DB, sql string, items []map[string]interface{}) error {
for _, item := range items {
if err := tx.Exec(sql, item).Error; err != nil {
return err
}
}
return nil
}
func (sysExportTemplateService *SysExportTemplateService) importByGORM(tx *gorm.DB, tableName string, items []map[string]interface{}) error {
needCreated := tx.Migrator().HasColumn(tableName, "created_at")
needUpdated := tx.Migrator().HasColumn(tableName, "updated_at")
for _, item := range items {
if item["created_at"] == nil && needCreated {
item["created_at"] = time.Now()
}
if item["updated_at"] == nil && needUpdated {
item["updated_at"] = time.Now()
}
}
return tx.Table(tableName).CreateInBatches(&items, 1000).Error
}
func getColumnName(n int) string {
columnName := ""
for n > 0 {
n--
columnName = string(rune('A'+n%26)) + columnName
n /= 26
}
return columnName
}

View File

@@ -0,0 +1,189 @@
package system
import (
"context"
"database/sql"
"errors"
"fmt"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"gorm.io/gorm"
"sort"
)
const (
Mysql = "mysql"
Pgsql = "pgsql"
Sqlite = "sqlite"
Mssql = "mssql"
InitSuccess = "\n[%v] --> 初始数据成功!\n"
InitDataExist = "\n[%v] --> %v 的初始数据已存在!\n"
InitDataFailed = "\n[%v] --> %v 初始数据失败! \nerr: %+v\n"
InitDataSuccess = "\n[%v] --> %v 初始数据成功!\n"
)
const (
InitOrderSystem = 10
InitOrderInternal = 1000
InitOrderExternal = 100000
)
var (
ErrMissingDBContext = errors.New("missing db in context")
ErrMissingDependentContext = errors.New("missing dependent value in context")
ErrDBTypeMismatch = errors.New("db type mismatch")
)
// SubInitializer 提供 source/*/init() 使用的接口,每个 initializer 完成一个初始化过程
type SubInitializer interface {
InitializerName() string // 不一定代表单独一个表,所以改成了更宽泛的语义
MigrateTable(ctx context.Context) (next context.Context, err error)
InitializeData(ctx context.Context) (next context.Context, err error)
TableCreated(ctx context.Context) bool
DataInserted(ctx context.Context) bool
}
// TypedDBInitHandler 执行传入的 initializer
type TypedDBInitHandler interface {
EnsureDB(ctx context.Context, conf *request.InitDB) (context.Context, error) // 建库,失败属于 fatal error因此让它 panic
WriteConfig(ctx context.Context) error // 回写配置
InitTables(ctx context.Context, inits initSlice) error // 建表 handler
InitData(ctx context.Context, inits initSlice) error // 建数据 handler
}
// orderedInitializer 组合一个顺序字段,以供排序
type orderedInitializer struct {
order int
SubInitializer
}
// initSlice 供 initializer 排序依赖时使用
type initSlice []*orderedInitializer
var (
initializers initSlice
cache map[string]*orderedInitializer
)
// RegisterInit 注册要执行的初始化过程,会在 InitDB() 时调用
func RegisterInit(order int, i SubInitializer) {
if initializers == nil {
initializers = initSlice{}
}
if cache == nil {
cache = map[string]*orderedInitializer{}
}
name := i.InitializerName()
if _, existed := cache[name]; existed {
panic(fmt.Sprintf("Name conflict on %s", name))
}
ni := orderedInitializer{order, i}
initializers = append(initializers, &ni)
cache[name] = &ni
}
/* ---- * service * ---- */
type InitDBService struct{}
// InitDB 创建数据库并初始化 总入口
func (initDBService *InitDBService) InitDB(conf request.InitDB) (err error) {
ctx := context.TODO()
ctx = context.WithValue(ctx, "adminPassword", conf.AdminPassword)
if len(initializers) == 0 {
return errors.New("无可用初始化过程,请检查初始化是否已执行完成")
}
sort.Sort(&initializers) // 保证有依赖的 initializer 排在后面执行
// Note: 若 initializer 只有单一依赖,可以写为 B=A+1, C=A+1; 由于 BC 之间没有依赖关系,所以谁先谁后并不影响初始化
// 若存在多个依赖,可以写为 C=A+B, D=A+B+C, E=A+1;
// C必然>A|B因此在AB之后执行D必然>A|B|C因此在ABC后执行而E只依赖A顺序与CD无关因此E与CD哪个先执行并不影响
var initHandler TypedDBInitHandler
switch conf.DBType {
case "mysql":
initHandler = NewMysqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "mysql")
case "pgsql":
initHandler = NewPgsqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "pgsql")
case "sqlite":
initHandler = NewSqliteInitHandler()
ctx = context.WithValue(ctx, "dbtype", "sqlite")
case "mssql":
initHandler = NewMssqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "mssql")
default:
initHandler = NewMysqlInitHandler()
ctx = context.WithValue(ctx, "dbtype", "mysql")
}
ctx, err = initHandler.EnsureDB(ctx, &conf)
if err != nil {
return err
}
db := ctx.Value("db").(*gorm.DB)
global.GVA_DB = db
if err = initHandler.InitTables(ctx, initializers); err != nil {
return err
}
if err = initHandler.InitData(ctx, initializers); err != nil {
return err
}
if err = initHandler.WriteConfig(ctx); err != nil {
return err
}
initializers = initSlice{}
cache = map[string]*orderedInitializer{}
return nil
}
// createDatabase 创建数据库( EnsureDB() 中调用
func createDatabase(dsn string, driver string, createSql string) error {
db, err := sql.Open(driver, dsn)
if err != nil {
return err
}
defer func(db *sql.DB) {
err = db.Close()
if err != nil {
fmt.Println(err)
}
}(db)
if err = db.Ping(); err != nil {
return err
}
_, err = db.Exec(createSql)
return err
}
// createTables 创建表(默认 dbInitHandler.initTables 行为)
func createTables(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for _, init := range inits {
if init.TableCreated(next) {
continue
}
if n, err := init.MigrateTable(next); err != nil {
return err
} else {
next = n
}
}
return nil
}
/* -- sortable interface -- */
func (a initSlice) Len() int {
return len(a)
}
func (a initSlice) Less(i, j int) bool {
return a[i].order < a[j].order
}
func (a initSlice) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}

View File

@@ -0,0 +1,92 @@
package system
import (
"context"
"errors"
"git.echol.cn/loser/ai_proxy/server/config"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/google/uuid"
"github.com/gookit/color"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"path/filepath"
)
type MssqlInitHandler struct{}
func NewMssqlInitHandler() *MssqlInitHandler {
return &MssqlInitHandler{}
}
// WriteConfig mssql回写配置
func (h MssqlInitHandler) WriteConfig(ctx context.Context) error {
c, ok := ctx.Value("config").(config.Mssql)
if !ok {
return errors.New("mssql config invalid")
}
global.GVA_CONFIG.System.DbType = "mssql"
global.GVA_CONFIG.Mssql = c
global.GVA_CONFIG.JWT.SigningKey = uuid.New().String()
cs := utils.StructToMap(global.GVA_CONFIG)
for k, v := range cs {
global.GVA_VP.Set(k, v)
}
global.GVA_ACTIVE_DBNAME = &c.Dbname
return global.GVA_VP.WriteConfig()
}
// EnsureDB 创建数据库并初始化 mssql
func (h MssqlInitHandler) EnsureDB(ctx context.Context, conf *request.InitDB) (next context.Context, err error) {
if s, ok := ctx.Value("dbtype").(string); !ok || s != "mssql" {
return ctx, ErrDBTypeMismatch
}
c := conf.ToMssqlConfig()
next = context.WithValue(ctx, "config", c)
if c.Dbname == "" {
return ctx, nil
} // 如果没有数据库名, 则跳出初始化数据
dsn := conf.MssqlEmptyDsn()
mssqlConfig := sqlserver.Config{
DSN: dsn, // DSN data source name
DefaultStringSize: 191, // string 类型字段的默认长度
}
var db *gorm.DB
if db, err = gorm.Open(sqlserver.New(mssqlConfig), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}); err != nil {
return nil, err
}
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
next = context.WithValue(next, "db", db)
return next, err
}
func (h MssqlInitHandler) InitTables(ctx context.Context, inits initSlice) error {
return createTables(ctx, inits)
}
func (h MssqlInitHandler) InitData(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for _, init := range inits {
if init.DataInserted(next) {
color.Info.Printf(InitDataExist, Mssql, init.InitializerName())
continue
}
if n, err := init.InitializeData(next); err != nil {
color.Info.Printf(InitDataFailed, Mssql, init.InitializerName(), err)
return err
} else {
next = n
color.Info.Printf(InitDataSuccess, Mssql, init.InitializerName())
}
}
color.Info.Printf(InitSuccess, Mssql)
return nil
}

View File

@@ -0,0 +1,97 @@
package system
import (
"context"
"errors"
"fmt"
"path/filepath"
"git.echol.cn/loser/ai_proxy/server/config"
"github.com/gookit/color"
"git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"github.com/google/uuid"
"gorm.io/driver/mysql"
"gorm.io/gorm"
)
type MysqlInitHandler struct{}
func NewMysqlInitHandler() *MysqlInitHandler {
return &MysqlInitHandler{}
}
// WriteConfig mysql回写配置
func (h MysqlInitHandler) WriteConfig(ctx context.Context) error {
c, ok := ctx.Value("config").(config.Mysql)
if !ok {
return errors.New("mysql config invalid")
}
global.GVA_CONFIG.System.DbType = "mysql"
global.GVA_CONFIG.Mysql = c
global.GVA_CONFIG.JWT.SigningKey = uuid.New().String()
cs := utils.StructToMap(global.GVA_CONFIG)
for k, v := range cs {
global.GVA_VP.Set(k, v)
}
global.GVA_ACTIVE_DBNAME = &c.Dbname
return global.GVA_VP.WriteConfig()
}
// EnsureDB 创建数据库并初始化 mysql
func (h MysqlInitHandler) EnsureDB(ctx context.Context, conf *request.InitDB) (next context.Context, err error) {
if s, ok := ctx.Value("dbtype").(string); !ok || s != "mysql" {
return ctx, ErrDBTypeMismatch
}
c := conf.ToMysqlConfig()
next = context.WithValue(ctx, "config", c)
if c.Dbname == "" {
return ctx, nil
} // 如果没有数据库名, 则跳出初始化数据
dsn := conf.MysqlEmptyDsn()
createSql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS `%s` DEFAULT CHARACTER SET utf8mb4 DEFAULT COLLATE utf8mb4_general_ci;", c.Dbname)
if err = createDatabase(dsn, "mysql", createSql); err != nil {
return nil, err
} // 创建数据库
var db *gorm.DB
if db, err = gorm.Open(mysql.New(mysql.Config{
DSN: c.Dsn(), // DSN data source name
DefaultStringSize: 191, // string 类型字段的默认长度
SkipInitializeWithVersion: true, // 根据版本自动配置
}), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}); err != nil {
return ctx, err
}
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
next = context.WithValue(next, "db", db)
return next, err
}
func (h MysqlInitHandler) InitTables(ctx context.Context, inits initSlice) error {
return createTables(ctx, inits)
}
func (h MysqlInitHandler) InitData(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for _, init := range inits {
if init.DataInserted(next) {
color.Info.Printf(InitDataExist, Mysql, init.InitializerName())
continue
}
if n, err := init.InitializeData(next); err != nil {
color.Info.Printf(InitDataFailed, Mysql, init.InitializerName(), err)
return err
} else {
next = n
color.Info.Printf(InitDataSuccess, Mysql, init.InitializerName())
}
}
color.Info.Printf(InitSuccess, Mysql)
return nil
}

View File

@@ -0,0 +1,101 @@
package system
import (
"context"
"errors"
"fmt"
"path/filepath"
"git.echol.cn/loser/ai_proxy/server/config"
"github.com/gookit/color"
"git.echol.cn/loser/ai_proxy/server/utils"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"github.com/google/uuid"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
type PgsqlInitHandler struct{}
func NewPgsqlInitHandler() *PgsqlInitHandler {
return &PgsqlInitHandler{}
}
// WriteConfig pgsql 回写配置
func (h PgsqlInitHandler) WriteConfig(ctx context.Context) error {
c, ok := ctx.Value("config").(config.Pgsql)
if !ok {
return errors.New("postgresql config invalid")
}
global.GVA_CONFIG.System.DbType = "pgsql"
global.GVA_CONFIG.Pgsql = c
global.GVA_CONFIG.JWT.SigningKey = uuid.New().String()
cs := utils.StructToMap(global.GVA_CONFIG)
for k, v := range cs {
global.GVA_VP.Set(k, v)
}
global.GVA_ACTIVE_DBNAME = &c.Dbname
return global.GVA_VP.WriteConfig()
}
// EnsureDB 创建数据库并初始化 pg
func (h PgsqlInitHandler) EnsureDB(ctx context.Context, conf *request.InitDB) (next context.Context, err error) {
if s, ok := ctx.Value("dbtype").(string); !ok || s != "pgsql" {
return ctx, ErrDBTypeMismatch
}
c := conf.ToPgsqlConfig()
next = context.WithValue(ctx, "config", c)
if c.Dbname == "" {
return ctx, nil
} // 如果没有数据库名, 则跳出初始化数据
dsn := conf.PgsqlEmptyDsn()
var createSql string
if conf.Template != "" {
createSql = fmt.Sprintf("CREATE DATABASE %s WITH TEMPLATE %s;", c.Dbname, conf.Template)
} else {
createSql = fmt.Sprintf("CREATE DATABASE %s;", c.Dbname)
}
if err = createDatabase(dsn, "pgx", createSql); err != nil {
return nil, err
} // 创建数据库
var db *gorm.DB
if db, err = gorm.Open(postgres.New(postgres.Config{
DSN: c.Dsn(), // DSN data source name
PreferSimpleProtocol: false,
}), &gorm.Config{DisableForeignKeyConstraintWhenMigrating: true}); err != nil {
return ctx, err
}
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
next = context.WithValue(next, "db", db)
return next, err
}
func (h PgsqlInitHandler) InitTables(ctx context.Context, inits initSlice) error {
return createTables(ctx, inits)
}
func (h PgsqlInitHandler) InitData(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for i := 0; i < len(inits); i++ {
if inits[i].DataInserted(next) {
color.Info.Printf(InitDataExist, Pgsql, inits[i].InitializerName())
continue
}
if n, err := inits[i].InitializeData(next); err != nil {
color.Info.Printf(InitDataFailed, Pgsql, inits[i].InitializerName(), err)
return err
} else {
next = n
color.Info.Printf(InitDataSuccess, Pgsql, inits[i].InitializerName())
}
}
color.Info.Printf(InitSuccess, Pgsql)
return nil
}

View File

@@ -0,0 +1,88 @@
package system
import (
"context"
"errors"
"github.com/glebarez/sqlite"
"github.com/google/uuid"
"github.com/gookit/color"
"gorm.io/gorm"
"path/filepath"
"git.echol.cn/loser/ai_proxy/server/config"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/utils"
)
type SqliteInitHandler struct{}
func NewSqliteInitHandler() *SqliteInitHandler {
return &SqliteInitHandler{}
}
// WriteConfig mysql回写配置
func (h SqliteInitHandler) WriteConfig(ctx context.Context) error {
c, ok := ctx.Value("config").(config.Sqlite)
if !ok {
return errors.New("sqlite config invalid")
}
global.GVA_CONFIG.System.DbType = "sqlite"
global.GVA_CONFIG.Sqlite = c
global.GVA_CONFIG.JWT.SigningKey = uuid.New().String()
cs := utils.StructToMap(global.GVA_CONFIG)
for k, v := range cs {
global.GVA_VP.Set(k, v)
}
global.GVA_ACTIVE_DBNAME = &c.Dbname
return global.GVA_VP.WriteConfig()
}
// EnsureDB 创建数据库并初始化 sqlite
func (h SqliteInitHandler) EnsureDB(ctx context.Context, conf *request.InitDB) (next context.Context, err error) {
if s, ok := ctx.Value("dbtype").(string); !ok || s != "sqlite" {
return ctx, ErrDBTypeMismatch
}
c := conf.ToSqliteConfig()
next = context.WithValue(ctx, "config", c)
if c.Dbname == "" {
return ctx, nil
} // 如果没有数据库名, 则跳出初始化数据
dsn := conf.SqliteEmptyDsn()
var db *gorm.DB
if db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}); err != nil {
return ctx, err
}
global.GVA_CONFIG.AutoCode.Root, _ = filepath.Abs("..")
next = context.WithValue(next, "db", db)
return next, err
}
func (h SqliteInitHandler) InitTables(ctx context.Context, inits initSlice) error {
return createTables(ctx, inits)
}
func (h SqliteInitHandler) InitData(ctx context.Context, inits initSlice) error {
next, cancel := context.WithCancel(ctx)
defer cancel()
for _, init := range inits {
if init.DataInserted(next) {
color.Info.Printf(InitDataExist, Sqlite, init.InitializerName())
continue
}
if n, err := init.InitializeData(next); err != nil {
color.Info.Printf(InitDataFailed, Sqlite, init.InitializerName(), err)
return err
} else {
next = n
color.Info.Printf(InitDataSuccess, Sqlite, init.InitializerName())
}
}
color.Info.Printf(InitSuccess, Sqlite)
return nil
}

View File

@@ -0,0 +1,53 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
)
type LoginLogService struct{}
var LoginLogServiceApp = new(LoginLogService)
func (loginLogService *LoginLogService) CreateLoginLog(loginLog system.SysLoginLog) (err error) {
err = global.GVA_DB.Create(&loginLog).Error
return err
}
func (loginLogService *LoginLogService) DeleteLoginLogByIds(ids request.IdsReq) (err error) {
err = global.GVA_DB.Delete(&[]system.SysLoginLog{}, "id in (?)", ids.Ids).Error
return err
}
func (loginLogService *LoginLogService) DeleteLoginLog(loginLog system.SysLoginLog) (err error) {
err = global.GVA_DB.Delete(&loginLog).Error
return err
}
func (loginLogService *LoginLogService) GetLoginLog(id uint) (loginLog system.SysLoginLog, err error) {
err = global.GVA_DB.Where("id = ?", id).First(&loginLog).Error
return
}
func (loginLogService *LoginLogService) GetLoginLogInfoList(info systemReq.SysLoginLogSearch) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysLoginLog{})
var loginLogs []system.SysLoginLog
// 如果有条件搜索 下方会自动创建搜索语句
if info.Username != "" {
db = db.Where("username LIKE ?", "%"+info.Username+"%")
}
if info.Status != false {
db = db.Where("status = ?", info.Status)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Order("id desc").Preload("User").Find(&loginLogs).Error
return loginLogs, total, err
}

View File

@@ -0,0 +1,331 @@
package system
import (
"errors"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
"gorm.io/gorm"
"strconv"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: getMenuTreeMap
//@description: 获取路由总树map
//@param: authorityId string
//@return: treeMap map[string][]system.SysMenu, err error
type MenuService struct{}
var MenuServiceApp = new(MenuService)
func (menuService *MenuService) getMenuTreeMap(authorityId uint) (treeMap map[uint][]system.SysMenu, err error) {
var allMenus []system.SysMenu
var baseMenu []system.SysBaseMenu
var btns []system.SysAuthorityBtn
treeMap = make(map[uint][]system.SysMenu)
var SysAuthorityMenus []system.SysAuthorityMenu
err = global.GVA_DB.Where("sys_authority_authority_id = ?", authorityId).Find(&SysAuthorityMenus).Error
if err != nil {
return
}
var MenuIds []string
for i := range SysAuthorityMenus {
MenuIds = append(MenuIds, SysAuthorityMenus[i].MenuId)
}
err = global.GVA_DB.Where("id in (?)", MenuIds).Order("sort").Preload("Parameters").Find(&baseMenu).Error
if err != nil {
return
}
for i := range baseMenu {
allMenus = append(allMenus, system.SysMenu{
SysBaseMenu: baseMenu[i],
AuthorityId: authorityId,
MenuId: baseMenu[i].ID,
Parameters: baseMenu[i].Parameters,
})
}
err = global.GVA_DB.Where("authority_id = ?", authorityId).Preload("SysBaseMenuBtn").Find(&btns).Error
if err != nil {
return
}
var btnMap = make(map[uint]map[string]uint)
for _, v := range btns {
if btnMap[v.SysMenuID] == nil {
btnMap[v.SysMenuID] = make(map[string]uint)
}
btnMap[v.SysMenuID][v.SysBaseMenuBtn.Name] = authorityId
}
for _, v := range allMenus {
v.Btns = btnMap[v.SysBaseMenu.ID]
treeMap[v.ParentId] = append(treeMap[v.ParentId], v)
}
return treeMap, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetMenuTree
//@description: 获取动态菜单树
//@param: authorityId string
//@return: menus []system.SysMenu, err error
func (menuService *MenuService) GetMenuTree(authorityId uint) (menus []system.SysMenu, err error) {
menuTree, err := menuService.getMenuTreeMap(authorityId)
menus = menuTree[0]
for i := 0; i < len(menus); i++ {
err = menuService.getChildrenList(&menus[i], menuTree)
}
return menus, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: getChildrenList
//@description: 获取子菜单
//@param: menu *model.SysMenu, treeMap map[string][]model.SysMenu
//@return: err error
func (menuService *MenuService) getChildrenList(menu *system.SysMenu, treeMap map[uint][]system.SysMenu) (err error) {
menu.Children = treeMap[menu.MenuId]
for i := 0; i < len(menu.Children); i++ {
err = menuService.getChildrenList(&menu.Children[i], treeMap)
}
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetInfoList
//@description: 获取路由分页
//@return: list interface{}, total int64,err error
func (menuService *MenuService) GetInfoList(authorityID uint) (list interface{}, err error) {
var menuList []system.SysBaseMenu
treeMap, err := menuService.getBaseMenuTreeMap(authorityID)
menuList = treeMap[0]
for i := 0; i < len(menuList); i++ {
err = menuService.getBaseChildrenList(&menuList[i], treeMap)
}
return menuList, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: getBaseChildrenList
//@description: 获取菜单的子菜单
//@param: menu *model.SysBaseMenu, treeMap map[string][]model.SysBaseMenu
//@return: err error
func (menuService *MenuService) getBaseChildrenList(menu *system.SysBaseMenu, treeMap map[uint][]system.SysBaseMenu) (err error) {
menu.Children = treeMap[menu.ID]
for i := 0; i < len(menu.Children); i++ {
err = menuService.getBaseChildrenList(&menu.Children[i], treeMap)
}
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: AddBaseMenu
//@description: 添加基础路由
//@param: menu model.SysBaseMenu
//@return: error
func (menuService *MenuService) AddBaseMenu(menu system.SysBaseMenu) error {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 检查name是否重复
if !errors.Is(tx.Where("name = ?", menu.Name).First(&system.SysBaseMenu{}).Error, gorm.ErrRecordNotFound) {
return errors.New("存在重复name请修改name")
}
if menu.ParentId != 0 {
// 检查父菜单是否存在
var parentMenu system.SysBaseMenu
if err := tx.First(&parentMenu, menu.ParentId).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("父菜单不存在")
}
return err
}
// 检查父菜单下现有子菜单数量
var existingChildrenCount int64
err := tx.Model(&system.SysBaseMenu{}).Where("parent_id = ?", menu.ParentId).Count(&existingChildrenCount).Error
if err != nil {
return err
}
// 如果父菜单原本是叶子菜单(没有子菜单),现在要变成枝干菜单,需要清空其权限分配
if existingChildrenCount == 0 {
// 检查父菜单是否被其他角色设置为首页
var defaultRouterCount int64
err := tx.Model(&system.SysAuthority{}).Where("default_router = ?", parentMenu.Name).Count(&defaultRouterCount).Error
if err != nil {
return err
}
if defaultRouterCount > 0 {
return errors.New("父菜单已被其他角色的首页占用,请先释放父菜单的首页权限")
}
// 清空父菜单的所有权限分配
err = tx.Where("sys_base_menu_id = ?", menu.ParentId).Delete(&system.SysAuthorityMenu{}).Error
if err != nil {
return err
}
}
}
// 创建菜单
return tx.Create(&menu).Error
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: getBaseMenuTreeMap
//@description: 获取路由总树map
//@return: treeMap map[string][]system.SysBaseMenu, err error
func (menuService *MenuService) getBaseMenuTreeMap(authorityID uint) (treeMap map[uint][]system.SysBaseMenu, err error) {
parentAuthorityID, err := AuthorityServiceApp.GetParentAuthorityID(authorityID)
if err != nil {
return nil, err
}
var allMenus []system.SysBaseMenu
treeMap = make(map[uint][]system.SysBaseMenu)
db := global.GVA_DB.Order("sort").Preload("MenuBtn").Preload("Parameters")
// 当开启了严格的树角色并且父角色不为0时需要进行菜单筛选
if global.GVA_CONFIG.System.UseStrictAuth && parentAuthorityID != 0 {
var authorityMenus []system.SysAuthorityMenu
err = global.GVA_DB.Where("sys_authority_authority_id = ?", authorityID).Find(&authorityMenus).Error
if err != nil {
return nil, err
}
var menuIds []string
for i := range authorityMenus {
menuIds = append(menuIds, authorityMenus[i].MenuId)
}
db = db.Where("id in (?)", menuIds)
}
err = db.Find(&allMenus).Error
for _, v := range allMenus {
treeMap[v.ParentId] = append(treeMap[v.ParentId], v)
}
return treeMap, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetBaseMenuTree
//@description: 获取基础路由树
//@return: menus []system.SysBaseMenu, err error
func (menuService *MenuService) GetBaseMenuTree(authorityID uint) (menus []system.SysBaseMenu, err error) {
treeMap, err := menuService.getBaseMenuTreeMap(authorityID)
menus = treeMap[0]
for i := 0; i < len(menus); i++ {
err = menuService.getBaseChildrenList(&menus[i], treeMap)
}
return menus, err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: AddMenuAuthority
//@description: 为角色增加menu树
//@param: menus []model.SysBaseMenu, authorityId string
//@return: err error
func (menuService *MenuService) AddMenuAuthority(menus []system.SysBaseMenu, adminAuthorityID, authorityId uint) (err error) {
var auth system.SysAuthority
auth.AuthorityId = authorityId
auth.SysBaseMenus = menus
err = AuthorityServiceApp.CheckAuthorityIDAuth(adminAuthorityID, authorityId)
if err != nil {
return err
}
var authority system.SysAuthority
_ = global.GVA_DB.First(&authority, "authority_id = ?", adminAuthorityID).Error
var menuIds []string
// 当开启了严格的树角色并且父角色不为0时需要进行菜单筛选
if global.GVA_CONFIG.System.UseStrictAuth && *authority.ParentId != 0 {
var authorityMenus []system.SysAuthorityMenu
err = global.GVA_DB.Where("sys_authority_authority_id = ?", adminAuthorityID).Find(&authorityMenus).Error
if err != nil {
return err
}
for i := range authorityMenus {
menuIds = append(menuIds, authorityMenus[i].MenuId)
}
for i := range menus {
hasMenu := false
for j := range menuIds {
idStr := strconv.Itoa(int(menus[i].ID))
if idStr == menuIds[j] {
hasMenu = true
}
}
if !hasMenu {
return errors.New("添加失败,请勿跨级操作")
}
}
}
err = AuthorityServiceApp.SetMenuAuthority(&auth)
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetMenuAuthority
//@description: 查看当前角色树
//@param: info *request.GetAuthorityId
//@return: menus []system.SysMenu, err error
func (menuService *MenuService) GetMenuAuthority(info *request.GetAuthorityId) (menus []system.SysMenu, err error) {
var baseMenu []system.SysBaseMenu
var SysAuthorityMenus []system.SysAuthorityMenu
err = global.GVA_DB.Where("sys_authority_authority_id = ?", info.AuthorityId).Find(&SysAuthorityMenus).Error
if err != nil {
return
}
var MenuIds []string
for i := range SysAuthorityMenus {
MenuIds = append(MenuIds, SysAuthorityMenus[i].MenuId)
}
err = global.GVA_DB.Where("id in (?) ", MenuIds).Order("sort").Find(&baseMenu).Error
for i := range baseMenu {
menus = append(menus, system.SysMenu{
SysBaseMenu: baseMenu[i],
AuthorityId: info.AuthorityId,
MenuId: baseMenu[i].ID,
Parameters: baseMenu[i].Parameters,
})
}
return menus, err
}
// UserAuthorityDefaultRouter 用户角色默认路由检查
//
// Author [SliverHorn](https://github.com/SliverHorn)
func (menuService *MenuService) UserAuthorityDefaultRouter(user *system.SysUser) {
var menuIds []string
err := global.GVA_DB.Model(&system.SysAuthorityMenu{}).Where("sys_authority_authority_id = ?", user.AuthorityId).Pluck("sys_base_menu_id", &menuIds).Error
if err != nil {
return
}
var am system.SysBaseMenu
err = global.GVA_DB.First(&am, "name = ? and id in (?)", user.Authority.DefaultRouter, menuIds).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
user.Authority.DefaultRouter = "404"
}
}

View File

@@ -0,0 +1,83 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/common/request"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
)
//@author: [granty1](https://github.com/granty1)
//@function: CreateSysOperationRecord
//@description: 创建记录
//@param: sysOperationRecord model.SysOperationRecord
//@return: err error
type OperationRecordService struct{}
var OperationRecordServiceApp = new(OperationRecordService)
//@author: [granty1](https://github.com/granty1)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteSysOperationRecordByIds
//@description: 批量删除记录
//@param: ids request.IdsReq
//@return: err error
func (operationRecordService *OperationRecordService) DeleteSysOperationRecordByIds(ids request.IdsReq) (err error) {
err = global.GVA_DB.Delete(&[]system.SysOperationRecord{}, "id in (?)", ids.Ids).Error
return err
}
//@author: [granty1](https://github.com/granty1)
//@function: DeleteSysOperationRecord
//@description: 删除操作记录
//@param: sysOperationRecord model.SysOperationRecord
//@return: err error
func (operationRecordService *OperationRecordService) DeleteSysOperationRecord(sysOperationRecord system.SysOperationRecord) (err error) {
err = global.GVA_DB.Delete(&sysOperationRecord).Error
return err
}
//@author: [granty1](https://github.com/granty1)
//@function: GetSysOperationRecord
//@description: 根据id获取单条操作记录
//@param: id uint
//@return: sysOperationRecord system.SysOperationRecord, err error
func (operationRecordService *OperationRecordService) GetSysOperationRecord(id uint) (sysOperationRecord system.SysOperationRecord, err error) {
err = global.GVA_DB.Where("id = ?", id).First(&sysOperationRecord).Error
return
}
//@author: [granty1](https://github.com/granty1)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetSysOperationRecordInfoList
//@description: 分页获取操作记录列表
//@param: info systemReq.SysOperationRecordSearch
//@return: list interface{}, total int64, err error
func (operationRecordService *OperationRecordService) GetSysOperationRecordInfoList(info systemReq.SysOperationRecordSearch) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysOperationRecord{})
var sysOperationRecords []system.SysOperationRecord
// 如果有条件搜索 下方会自动创建搜索语句
if info.Method != "" {
db = db.Where("method = ?", info.Method)
}
if info.Path != "" {
db = db.Where("path LIKE ?", "%"+info.Path+"%")
}
if info.Status != 0 {
db = db.Where("status = ?", info.Status)
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Order("id desc").Limit(limit).Offset(offset).Preload("User").Find(&sysOperationRecords).Error
return sysOperationRecords, total, err
}

View File

@@ -0,0 +1,82 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
)
type SysParamsService struct{}
// CreateSysParams 创建参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) CreateSysParams(sysParams *system.SysParams) (err error) {
err = global.GVA_DB.Create(sysParams).Error
return err
}
// DeleteSysParams 删除参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) DeleteSysParams(ID string) (err error) {
err = global.GVA_DB.Delete(&system.SysParams{}, "id = ?", ID).Error
return err
}
// DeleteSysParamsByIds 批量删除参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) DeleteSysParamsByIds(IDs []string) (err error) {
err = global.GVA_DB.Delete(&[]system.SysParams{}, "id in ?", IDs).Error
return err
}
// UpdateSysParams 更新参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) UpdateSysParams(sysParams system.SysParams) (err error) {
err = global.GVA_DB.Model(&system.SysParams{}).Where("id = ?", sysParams.ID).Updates(&sysParams).Error
return err
}
// GetSysParams 根据ID获取参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) GetSysParams(ID string) (sysParams system.SysParams, err error) {
err = global.GVA_DB.Where("id = ?", ID).First(&sysParams).Error
return
}
// GetSysParamsInfoList 分页获取参数记录
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) GetSysParamsInfoList(info systemReq.SysParamsSearch) (list []system.SysParams, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysParams{})
var sysParamss []system.SysParams
// 如果有条件搜索 下方会自动创建搜索语句
if info.StartCreatedAt != nil && info.EndCreatedAt != nil {
db = db.Where("created_at BETWEEN ? AND ?", info.StartCreatedAt, info.EndCreatedAt)
}
if info.Name != "" {
db = db.Where("name LIKE ?", "%"+info.Name+"%")
}
if info.Key != "" {
db = db.Where("key LIKE ?", "%"+info.Key+"%")
}
err = db.Count(&total).Error
if err != nil {
return
}
if limit != 0 {
db = db.Limit(limit).Offset(offset)
}
err = db.Find(&sysParamss).Error
return sysParamss, total, err
}
// GetSysParam 根据key获取参数value
// Author [Mr.奇淼](https://github.com/pixelmaxQm)
func (sysParamsService *SysParamsService) GetSysParam(key string) (param system.SysParams, err error) {
err = global.GVA_DB.Where(system.SysParams{Key: key}).First(&param).Error
return
}

View File

@@ -0,0 +1,549 @@
package system
import (
"context"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"sort"
"strings"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"gopkg.in/yaml.v3"
)
const (
skillFileName = "SKILL.md"
globalConstraintFileName = "README.md"
)
var skillToolOrder = []string{"copilot", "claude", "cursor", "trae", "codex"}
var skillToolDirs = map[string]string{
"copilot": ".aone_copilot",
"claude": ".claude",
"trae": ".trae",
"codex": ".codex",
"cursor": ".cursor",
}
var skillToolLabels = map[string]string{
"copilot": "Copilot",
"claude": "Claude",
"trae": "Trae",
"codex": "Codex",
"cursor": "Cursor",
}
const defaultSkillMarkdown = "## 技能用途\n请在这里描述技能的目标、适用场景与限制条件。\n\n## 输入\n- 请补充输入格式与示例。\n\n## 输出\n- 请补充输出格式与示例。\n\n## 关键步骤\n1. 第一步\n2. 第二步\n\n## 示例\n在此补充一到两个典型示例。\n"
const defaultResourceMarkdown = "# 资源说明\n请在这里补充资源内容。\n"
const defaultReferenceMarkdown = "# 参考资料\n请在这里补充参考资料内容。\n"
const defaultTemplateMarkdown = "# 模板\n请在这里补充模板内容。\n"
const defaultGlobalConstraintMarkdown = "# 全局约束\n请在这里补充该工具的统一约束与使用规范。\n"
type SkillsService struct{}
func (s *SkillsService) Tools(_ context.Context) ([]system.SkillTool, error) {
tools := make([]system.SkillTool, 0, len(skillToolOrder))
for _, key := range skillToolOrder {
if _, err := s.toolSkillsDir(key); err != nil {
return nil, err
}
tools = append(tools, system.SkillTool{Key: key, Label: skillToolLabels[key]})
}
return tools, nil
}
func (s *SkillsService) List(_ context.Context, tool string) ([]string, error) {
skillsDir, err := s.toolSkillsDir(tool)
if err != nil {
return nil, err
}
entries, err := os.ReadDir(skillsDir)
if err != nil {
return nil, err
}
var skills []string
for _, entry := range entries {
if entry.IsDir() {
skills = append(skills, entry.Name())
}
}
sort.Strings(skills)
return skills, nil
}
func (s *SkillsService) Detail(_ context.Context, tool, skill string) (system.SkillDetail, error) {
var detail system.SkillDetail
if !isSafeName(skill) {
return detail, errors.New("技能名称不合法")
}
detail.Tool = tool
detail.Skill = skill
skillDir, err := s.skillDir(tool, skill)
if err != nil {
return detail, err
}
skillFilePath := filepath.Join(skillDir, skillFileName)
content, err := os.ReadFile(skillFilePath)
if err != nil {
if !os.IsNotExist(err) {
return detail, err
}
detail.Meta = system.SkillMeta{Name: skill}
detail.Markdown = defaultSkillMarkdown
} else {
meta, body, parseErr := parseSkillContent(string(content))
if parseErr != nil {
meta = system.SkillMeta{Name: skill}
body = string(content)
}
if meta.Name == "" {
meta.Name = skill
}
detail.Meta = meta
detail.Markdown = body
}
detail.Scripts = listFiles(filepath.Join(skillDir, "scripts"))
detail.Resources = listFiles(filepath.Join(skillDir, "resources"))
detail.References = listFiles(filepath.Join(skillDir, "references"))
detail.Templates = listFiles(filepath.Join(skillDir, "templates"))
return detail, nil
}
func (s *SkillsService) Save(_ context.Context, req request.SkillSaveRequest) error {
if !isSafeName(req.Skill) {
return errors.New("技能名称不合法")
}
skillDir, err := s.ensureSkillDir(req.Tool, req.Skill)
if err != nil {
return err
}
if req.Meta.Name == "" {
req.Meta.Name = req.Skill
}
content, err := buildSkillContent(req.Meta, req.Markdown)
if err != nil {
return err
}
if err := os.WriteFile(filepath.Join(skillDir, skillFileName), []byte(content), 0644); err != nil {
return err
}
if len(req.SyncTools) > 0 {
for _, tool := range req.SyncTools {
if tool == req.Tool {
continue
}
targetDir, err := s.ensureSkillDir(tool, req.Skill)
if err != nil {
return err
}
if err := copySkillDir(skillDir, targetDir); err != nil {
return err
}
}
}
return nil
}
func (s *SkillsService) CreateScript(_ context.Context, req request.SkillScriptCreateRequest) (string, string, error) {
if !isSafeName(req.Skill) {
return "", "", errors.New("技能名称不合法")
}
fileName, lang, err := buildScriptFileName(req.FileName, req.ScriptType)
if err != nil {
return "", "", err
}
if lang == "" {
return "", "", errors.New("脚本类型不支持")
}
skillDir, err := s.ensureSkillDir(req.Tool, req.Skill)
if err != nil {
return "", "", err
}
filePath := filepath.Join(skillDir, "scripts", fileName)
if _, err := os.Stat(filePath); err == nil {
return "", "", errors.New("脚本已存在")
}
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", "", err
}
content := scriptTemplate(lang)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return "", "", err
}
return fileName, content, nil
}
func (s *SkillsService) GetScript(_ context.Context, req request.SkillFileRequest) (string, error) {
return s.readSkillFile(req.Tool, req.Skill, "scripts", req.FileName)
}
func (s *SkillsService) SaveScript(_ context.Context, req request.SkillFileSaveRequest) error {
return s.writeSkillFile(req.Tool, req.Skill, "scripts", req.FileName, req.Content)
}
func (s *SkillsService) CreateResource(_ context.Context, req request.SkillResourceCreateRequest) (string, string, error) {
return s.createMarkdownFile(req.Tool, req.Skill, "resources", req.FileName, defaultResourceMarkdown, "资源")
}
func (s *SkillsService) GetResource(_ context.Context, req request.SkillFileRequest) (string, error) {
return s.readSkillFile(req.Tool, req.Skill, "resources", req.FileName)
}
func (s *SkillsService) SaveResource(_ context.Context, req request.SkillFileSaveRequest) error {
return s.writeSkillFile(req.Tool, req.Skill, "resources", req.FileName, req.Content)
}
func (s *SkillsService) CreateReference(_ context.Context, req request.SkillReferenceCreateRequest) (string, string, error) {
return s.createMarkdownFile(req.Tool, req.Skill, "references", req.FileName, defaultReferenceMarkdown, "参考")
}
func (s *SkillsService) GetReference(_ context.Context, req request.SkillFileRequest) (string, error) {
return s.readSkillFile(req.Tool, req.Skill, "references", req.FileName)
}
func (s *SkillsService) SaveReference(_ context.Context, req request.SkillFileSaveRequest) error {
return s.writeSkillFile(req.Tool, req.Skill, "references", req.FileName, req.Content)
}
func (s *SkillsService) CreateTemplate(_ context.Context, req request.SkillTemplateCreateRequest) (string, string, error) {
return s.createMarkdownFile(req.Tool, req.Skill, "templates", req.FileName, defaultTemplateMarkdown, "模板")
}
func (s *SkillsService) GetTemplate(_ context.Context, req request.SkillFileRequest) (string, error) {
return s.readSkillFile(req.Tool, req.Skill, "templates", req.FileName)
}
func (s *SkillsService) SaveTemplate(_ context.Context, req request.SkillFileSaveRequest) error {
return s.writeSkillFile(req.Tool, req.Skill, "templates", req.FileName, req.Content)
}
func (s *SkillsService) GetGlobalConstraint(_ context.Context, tool string) (string, bool, error) {
skillsDir, err := s.toolSkillsDir(tool)
if err != nil {
return "", false, err
}
filePath := filepath.Join(skillsDir, globalConstraintFileName)
content, err := os.ReadFile(filePath)
if err != nil {
if os.IsNotExist(err) {
return defaultGlobalConstraintMarkdown, false, nil
}
return "", false, err
}
return string(content), true, nil
}
func (s *SkillsService) SaveGlobalConstraint(_ context.Context, req request.SkillGlobalConstraintSaveRequest) error {
if strings.TrimSpace(req.Tool) == "" {
return errors.New("工具类型不能为空")
}
writeConstraint := func(tool, content string) error {
skillsDir, err := s.toolSkillsDir(tool)
if err != nil {
return err
}
filePath := filepath.Join(skillsDir, globalConstraintFileName)
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return err
}
return os.WriteFile(filePath, []byte(content), 0644)
}
if err := writeConstraint(req.Tool, req.Content); err != nil {
return err
}
if len(req.SyncTools) == 0 {
return nil
}
for _, tool := range req.SyncTools {
if tool == "" || tool == req.Tool {
continue
}
if err := writeConstraint(tool, req.Content); err != nil {
return err
}
}
return nil
}
func (s *SkillsService) toolSkillsDir(tool string) (string, error) {
toolDir, ok := skillToolDirs[tool]
if !ok {
return "", errors.New("工具类型不支持")
}
root := strings.TrimSpace(global.GVA_CONFIG.AutoCode.Root)
if root == "" {
root = "."
}
skillsDir := filepath.Join(root, toolDir, "skills")
if err := os.MkdirAll(skillsDir, os.ModePerm); err != nil {
return "", err
}
return skillsDir, nil
}
func (s *SkillsService) skillDir(tool, skill string) (string, error) {
skillsDir, err := s.toolSkillsDir(tool)
if err != nil {
return "", err
}
return filepath.Join(skillsDir, skill), nil
}
func (s *SkillsService) ensureSkillDir(tool, skill string) (string, error) {
if !isSafeName(skill) {
return "", errors.New("技能名称不合法")
}
skillDir, err := s.skillDir(tool, skill)
if err != nil {
return "", err
}
if err := os.MkdirAll(skillDir, os.ModePerm); err != nil {
return "", err
}
return skillDir, nil
}
func (s *SkillsService) createMarkdownFile(tool, skill, subDir, fileName, defaultContent, label string) (string, string, error) {
if !isSafeName(skill) {
return "", "", errors.New("技能名称不合法")
}
cleanName, err := buildResourceFileName(fileName)
if err != nil {
return "", "", err
}
skillDir, err := s.ensureSkillDir(tool, skill)
if err != nil {
return "", "", err
}
filePath := filepath.Join(skillDir, subDir, cleanName)
if _, err := os.Stat(filePath); err == nil {
if label == "" {
label = "文件"
}
return "", "", fmt.Errorf("%s已存在", label)
}
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", "", err
}
content := defaultContent
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return "", "", err
}
return cleanName, content, nil
}
func (s *SkillsService) readSkillFile(tool, skill, subDir, fileName string) (string, error) {
if !isSafeName(skill) {
return "", errors.New("技能名称不合法")
}
if !isSafeFileName(fileName) {
return "", errors.New("文件名不合法")
}
skillDir, err := s.skillDir(tool, skill)
if err != nil {
return "", err
}
filePath := filepath.Join(skillDir, subDir, fileName)
content, err := os.ReadFile(filePath)
if err != nil {
return "", err
}
return string(content), nil
}
func (s *SkillsService) writeSkillFile(tool, skill, subDir, fileName, content string) error {
if !isSafeName(skill) {
return errors.New("技能名称不合法")
}
if !isSafeFileName(fileName) {
return errors.New("文件名不合法")
}
skillDir, err := s.ensureSkillDir(tool, skill)
if err != nil {
return err
}
filePath := filepath.Join(skillDir, subDir, fileName)
if err := os.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return err
}
return os.WriteFile(filePath, []byte(content), 0644)
}
func parseSkillContent(content string) (system.SkillMeta, string, error) {
clean := strings.TrimPrefix(content, "\ufeff")
lines := strings.Split(clean, "\n")
if len(lines) == 0 || strings.TrimSpace(lines[0]) != "---" {
return system.SkillMeta{}, clean, nil
}
end := -1
for i := 1; i < len(lines); i++ {
if strings.TrimSpace(lines[i]) == "---" {
end = i
break
}
}
if end == -1 {
return system.SkillMeta{}, clean, nil
}
yamlText := strings.Join(lines[1:end], "\n")
body := strings.Join(lines[end+1:], "\n")
var meta system.SkillMeta
if err := yaml.Unmarshal([]byte(yamlText), &meta); err != nil {
return system.SkillMeta{}, body, err
}
return meta, body, nil
}
func buildSkillContent(meta system.SkillMeta, markdown string) (string, error) {
if meta.Name == "" {
return "", errors.New("name不能为空")
}
data, err := yaml.Marshal(meta)
if err != nil {
return "", err
}
yamlText := strings.TrimRight(string(data), "\n")
body := strings.TrimLeft(markdown, "\n")
if body != "" {
body = body + "\n"
}
return fmt.Sprintf("---\n%s\n---\n%s", yamlText, body), nil
}
func listFiles(dir string) []string {
entries, err := os.ReadDir(dir)
if err != nil {
return []string{}
}
files := make([]string, 0, len(entries))
for _, entry := range entries {
if entry.Type().IsRegular() {
files = append(files, entry.Name())
}
}
sort.Strings(files)
return files
}
func isSafeName(name string) bool {
if strings.TrimSpace(name) == "" {
return false
}
if strings.Contains(name, "..") {
return false
}
if strings.ContainsAny(name, "/\\") {
return false
}
return name == filepath.Base(name)
}
func isSafeFileName(name string) bool {
if strings.TrimSpace(name) == "" {
return false
}
if strings.Contains(name, "..") {
return false
}
if strings.ContainsAny(name, "/\\") {
return false
}
return name == filepath.Base(name)
}
func buildScriptFileName(fileName, scriptType string) (string, string, error) {
clean := strings.TrimSpace(fileName)
if clean == "" {
return "", "", errors.New("文件名不能为空")
}
if !isSafeFileName(clean) {
return "", "", errors.New("文件名不合法")
}
base := strings.TrimSuffix(clean, filepath.Ext(clean))
if base == "" {
return "", "", errors.New("文件名不合法")
}
switch strings.ToLower(scriptType) {
case "py", "python":
return base + ".py", "python", nil
case "js", "javascript", "script":
return base + ".js", "javascript", nil
case "sh", "shell", "bash":
return base + ".sh", "sh", nil
default:
return "", "", errors.New("脚本类型不支持")
}
}
func buildResourceFileName(fileName string) (string, error) {
clean := strings.TrimSpace(fileName)
if clean == "" {
return "", errors.New("文件名不能为空")
}
if !isSafeFileName(clean) {
return "", errors.New("文件名不合法")
}
base := strings.TrimSuffix(clean, filepath.Ext(clean))
if base == "" {
return "", errors.New("文件名不合法")
}
return base + ".md", nil
}
func scriptTemplate(lang string) string {
switch lang {
case "python":
return "# -*- coding: utf-8 -*-\n# TODO: 在这里实现脚本逻辑\n"
case "javascript":
return "// TODO: 在这里实现脚本逻辑\n"
case "sh":
return "#!/usr/bin/env bash\nset -euo pipefail\n\n# TODO: 在这里实现脚本逻辑\n"
default:
return ""
}
}
func copySkillDir(src, dst string) error {
return filepath.WalkDir(src, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(src, path)
if err != nil {
return err
}
if rel == "." {
return nil
}
target := filepath.Join(dst, rel)
if d.IsDir() {
return os.MkdirAll(target, os.ModePerm)
}
if !d.Type().IsRegular() {
return nil
}
data, err := os.ReadFile(path)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(target), os.ModePerm); err != nil {
return err
}
return os.WriteFile(target, data, 0644)
})
}

View File

@@ -0,0 +1,62 @@
package system
import (
"git.echol.cn/loser/ai_proxy/server/config"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/utils"
"go.uber.org/zap"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetSystemConfig
//@description: 读取配置文件
//@return: conf config.Server, err error
type SystemConfigService struct{}
var SystemConfigServiceApp = new(SystemConfigService)
func (systemConfigService *SystemConfigService) GetSystemConfig() (conf config.Server, err error) {
return global.GVA_CONFIG, nil
}
// @description set system config,
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetSystemConfig
//@description: 设置配置文件
//@param: system model.System
//@return: err error
func (systemConfigService *SystemConfigService) SetSystemConfig(system system.System) (err error) {
cs := utils.StructToMap(system.Config)
for k, v := range cs {
global.GVA_VP.Set(k, v)
}
err = global.GVA_VP.WriteConfig()
return err
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: GetServerInfo
//@description: 获取服务器信息
//@return: server *utils.Server, err error
func (systemConfigService *SystemConfigService) GetServerInfo() (server *utils.Server, err error) {
var s utils.Server
s.Os = utils.InitOS()
if s.Cpu, err = utils.InitCPU(); err != nil {
global.GVA_LOG.Error("func utils.InitCPU() Failed", zap.String("err", err.Error()))
return &s, err
}
if s.Ram, err = utils.InitRAM(); err != nil {
global.GVA_LOG.Error("func utils.InitRAM() Failed", zap.String("err", err.Error()))
return &s, err
}
if s.Disk, err = utils.InitDisk(); err != nil {
global.GVA_LOG.Error("func utils.InitDisk() Failed", zap.String("err", err.Error()))
return &s, err
}
return &s, nil
}

View File

@@ -2,206 +2,317 @@ package system
import (
"errors"
"fmt"
"time"
"git.echol.cn/loser/ai_proxy/server/model/common"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
"git.echol.cn/loser/ai_proxy/server/model/system/request"
"git.echol.cn/loser/ai_proxy/server/model/system/response"
"git.echol.cn/loser/ai_proxy/server/utils"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"
"github.com/google/uuid"
"gorm.io/gorm"
)
//@author: [piexlmax](https://github.com/piexlmax)
//@function: Register
//@description: 用户注册
//@param: u model.SysUser
//@return: userInter system.SysUser, err error
type UserService struct{}
// Login 用户登录
func (s *UserService) Login(req *request.LoginRequest) (resp response.LoginResponse, err error) {
var UserServiceApp = new(UserService)
func (userService *UserService) Register(u system.SysUser) (userInter system.SysUser, err error) {
var user system.SysUser
err = global.GVA_DB.Where("username = ?", req.Username).First(&user).Error
if err != nil {
return resp, errors.New("用户名或密码错误")
if !errors.Is(global.GVA_DB.Where("username = ?", u.Username).First(&user).Error, gorm.ErrRecordNotFound) { // 判断用户名是否注册
return userInter, errors.New("用户名已注册")
}
// 验证密码
err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password))
if err != nil {
return resp, errors.New("用户名或密码错误")
}
// 检查用户状态
if user.Status != "active" {
return resp, errors.New("用户已被禁用")
}
// 生成 JWT Token
token, err := s.generateToken(user.ID, user.Username, user.Role)
if err != nil {
return resp, errors.New("生成Token失败")
}
resp.Token = token
resp.User = response.UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
Avatar: user.Avatar,
Role: user.Role,
Status: user.Status,
}
return resp, nil
// 否则 附加uuid 密码hash加密 注册
u.Password = utils.BcryptHash(u.Password)
u.UUID = uuid.New()
err = global.GVA_DB.Create(&u).Error
return u, err
}
// Register 用户注册
func (s *UserService) Register(req *request.RegisterRequest) (user system.SysUser, err error) {
// 检查用户名是否存在
var count int64
global.GVA_DB.Model(&system.SysUser{}).Where("username = ?", req.Username).Count(&count)
if count > 0 {
return user, errors.New("用户名已存在")
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: Login
//@description: 用户登录
//@param: u *model.SysUser
//@return: err error, userInter *model.SysUser
func (userService *UserService) Login(u *system.SysUser) (userInter *system.SysUser, err error) {
if nil == global.GVA_DB {
return nil, fmt.Errorf("db not init")
}
// 加密密码
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
return user, errors.New("密码加密失败")
}
// 生成 API Key
apiKey, err := utils.GenerateRandomString(32)
if err != nil {
return user, errors.New("生成API Key失败")
}
apiKey = "ak-" + apiKey
user = system.SysUser{
Username: req.Username,
Password: string(hashedPassword),
Email: req.Email,
Role: "user",
Status: "active",
APIKey: apiKey,
}
err = global.GVA_DB.Create(&user).Error
return user, err
}
// GetUserInfo 获取用户信息
func (s *UserService) GetUserInfo(userID uint) (info response.UserInfo, err error) {
var user system.SysUser
err = global.GVA_DB.First(&user, userID).Error
if err != nil {
return info, errors.New("用户不存在")
err = global.GVA_DB.Where("username = ?", u.Username).Preload("Authorities").Preload("Authority").First(&user).Error
if err == nil {
if ok := utils.BcryptCheck(u.Password, user.Password); !ok {
return nil, errors.New("密码错误")
}
MenuServiceApp.UserAuthorityDefaultRouter(&user)
}
info = response.UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
Avatar: user.Avatar,
Role: user.Role,
Status: user.Status,
}
return info, nil
return &user, err
}
// GetUserList 获取用户列表
func (s *UserService) GetUserList(page, pageSize int) (list []response.UserInfo, total int64, err error) {
var users []system.SysUser
offset := (page - 1) * pageSize
//@author: [piexlmax](https://github.com/piexlmax)
//@function: ChangePassword
//@description: 修改用户密码
//@param: u *model.SysUser, newPassword string
//@return: err error
func (userService *UserService) ChangePassword(u *system.SysUser, newPassword string) (err error) {
var user system.SysUser
err = global.GVA_DB.Select("id, password").Where("id = ?", u.ID).First(&user).Error
if err != nil {
return err
}
if ok := utils.BcryptCheck(u.Password, user.Password); !ok {
return errors.New("原密码错误")
}
pwd := utils.BcryptHash(newPassword)
err = global.GVA_DB.Model(&user).Update("password", pwd).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: GetUserInfoList
//@description: 分页获取数据
//@param: info request.PageInfo
//@return: err error, list interface{}, total int64
func (userService *UserService) GetUserInfoList(info systemReq.GetUserList) (list interface{}, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
db := global.GVA_DB.Model(&system.SysUser{})
var userList []system.SysUser
if info.NickName != "" {
db = db.Where("nick_name LIKE ?", "%"+info.NickName+"%")
}
if info.Phone != "" {
db = db.Where("phone LIKE ?", "%"+info.Phone+"%")
}
if info.Username != "" {
db = db.Where("username LIKE ?", "%"+info.Username+"%")
}
if info.Email != "" {
db = db.Where("email LIKE ?", "%"+info.Email+"%")
}
err = db.Count(&total).Error
if err != nil {
return
}
err = db.Limit(limit).Offset(offset).Preload("Authorities").Preload("Authority").Find(&userList).Error
return userList, total, err
}
err = db.Limit(pageSize).Offset(offset).Order("created_at DESC").Find(&users).Error
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetUserAuthority
//@description: 设置一个用户的权限
//@param: uuid uuid.UUID, authorityId string
//@return: err error
func (userService *UserService) SetUserAuthority(id uint, authorityId uint) (err error) {
assignErr := global.GVA_DB.Where("sys_user_id = ? AND sys_authority_authority_id = ?", id, authorityId).First(&system.SysUserAuthority{}).Error
if errors.Is(assignErr, gorm.ErrRecordNotFound) {
return errors.New("该用户无此角色")
}
var authority system.SysAuthority
err = global.GVA_DB.Where("authority_id = ?", authorityId).First(&authority).Error
if err != nil {
return
return err
}
for _, user := range users {
list = append(list, response.UserInfo{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
Email: user.Email,
Phone: user.Phone,
Avatar: user.Avatar,
Role: user.Role,
Status: user.Status,
})
}
return list, total, nil
}
// UpdateUser 更新用户
func (s *UserService) UpdateUser(req *request.UpdateUserRequest) error {
updates := map[string]interface{}{}
if req.Nickname != "" {
updates["nickname"] = req.Nickname
}
if req.Email != "" {
updates["email"] = req.Email
}
if req.Phone != "" {
updates["phone"] = req.Phone
}
if req.Avatar != "" {
updates["avatar"] = req.Avatar
}
if req.Role != "" {
updates["role"] = req.Role
}
if req.Status != "" {
updates["status"] = req.Status
}
return global.GVA_DB.Model(&system.SysUser{}).Where("id = ?", req.ID).Updates(updates).Error
}
// DeleteUser 删除用户
func (s *UserService) DeleteUser(userID uint) error {
return global.GVA_DB.Delete(&system.SysUser{}, userID).Error
}
// GetAPIKey 获取用户的 API Key
func (s *UserService) GetAPIKey(userID uint) (string, error) {
var user system.SysUser
err := global.GVA_DB.Select("api_key").First(&user, userID).Error
return user.APIKey, err
}
// RegenerateAPIKey 重新生成 API Key
func (s *UserService) RegenerateAPIKey(userID uint) (string, error) {
randomStr, err := utils.GenerateRandomString(32)
var authorityMenu []system.SysAuthorityMenu
var authorityMenuIDs []string
err = global.GVA_DB.Where("sys_authority_authority_id = ?", authorityId).Find(&authorityMenu).Error
if err != nil {
return "", errors.New("生成API Key失败")
}
apiKey := "ak-" + randomStr
err = global.GVA_DB.Model(&system.SysUser{}).Where("id = ?", userID).Update("api_key", apiKey).Error
return apiKey, err
}
// generateToken 生成 JWT Token
func (s *UserService) generateToken(userID uint, username, role string) (string, error) {
claims := jwt.MapClaims{
"user_id": userID,
"username": username,
"role": role,
"exp": time.Now().Add(24 * time.Hour).Unix(),
return err
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(global.GVA_CONFIG.JWT.SigningKey))
for i := range authorityMenu {
authorityMenuIDs = append(authorityMenuIDs, authorityMenu[i].MenuId)
}
var authorityMenus []system.SysBaseMenu
err = global.GVA_DB.Preload("Parameters").Where("id in (?)", authorityMenuIDs).Find(&authorityMenus).Error
if err != nil {
return err
}
hasMenu := false
for i := range authorityMenus {
if authorityMenus[i].Name == authority.DefaultRouter {
hasMenu = true
break
}
}
if !hasMenu {
return errors.New("找不到默认路由,无法切换本角色")
}
err = global.GVA_DB.Model(&system.SysUser{}).Where("id = ?", id).Update("authority_id", authorityId).Error
return err
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetUserAuthorities
//@description: 设置一个用户的权限
//@param: id uint, authorityIds []string
//@return: err error
func (userService *UserService) SetUserAuthorities(adminAuthorityID, id uint, authorityIds []uint) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
var user system.SysUser
TxErr := tx.Where("id = ?", id).First(&user).Error
if TxErr != nil {
global.GVA_LOG.Debug(TxErr.Error())
return errors.New("查询用户数据失败")
}
TxErr = tx.Delete(&[]system.SysUserAuthority{}, "sys_user_id = ?", id).Error
if TxErr != nil {
return TxErr
}
var useAuthority []system.SysUserAuthority
for _, v := range authorityIds {
e := AuthorityServiceApp.CheckAuthorityIDAuth(adminAuthorityID, v)
if e != nil {
return e
}
useAuthority = append(useAuthority, system.SysUserAuthority{
SysUserId: id, SysAuthorityAuthorityId: v,
})
}
TxErr = tx.Create(&useAuthority).Error
if TxErr != nil {
return TxErr
}
TxErr = tx.Model(&user).Update("authority_id", authorityIds[0]).Error
if TxErr != nil {
return TxErr
}
// 返回 nil 提交事务
return nil
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: DeleteUser
//@description: 删除用户
//@param: id float64
//@return: err error
func (userService *UserService) DeleteUser(id int) (err error) {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Where("id = ?", id).Delete(&system.SysUser{}).Error; err != nil {
return err
}
if err := tx.Delete(&[]system.SysUserAuthority{}, "sys_user_id = ?", id).Error; err != nil {
return err
}
return nil
})
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetUserInfo
//@description: 设置用户信息
//@param: reqUser model.SysUser
//@return: err error, user model.SysUser
func (userService *UserService) SetUserInfo(req system.SysUser) error {
return global.GVA_DB.Model(&system.SysUser{}).
Select("updated_at", "nick_name", "header_img", "phone", "email", "enable").
Where("id=?", req.ID).
Updates(map[string]interface{}{
"updated_at": time.Now(),
"nick_name": req.NickName,
"header_img": req.HeaderImg,
"phone": req.Phone,
"email": req.Email,
"enable": req.Enable,
}).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetSelfInfo
//@description: 设置用户信息
//@param: reqUser model.SysUser
//@return: err error, user model.SysUser
func (userService *UserService) SetSelfInfo(req system.SysUser) error {
return global.GVA_DB.Model(&system.SysUser{}).
Where("id=?", req.ID).
Updates(req).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: SetSelfSetting
//@description: 设置用户配置
//@param: req datatypes.JSON, uid uint
//@return: err error
func (userService *UserService) SetSelfSetting(req common.JSONMap, uid uint) error {
return global.GVA_DB.Model(&system.SysUser{}).Where("id = ?", uid).Update("origin_setting", req).Error
}
//@author: [piexlmax](https://github.com/piexlmax)
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: GetUserInfo
//@description: 获取用户信息
//@param: uuid uuid.UUID
//@return: err error, user system.SysUser
func (userService *UserService) GetUserInfo(uuid uuid.UUID) (user system.SysUser, err error) {
var reqUser system.SysUser
err = global.GVA_DB.Preload("Authorities").Preload("Authority").First(&reqUser, "uuid = ?", uuid).Error
if err != nil {
return reqUser, err
}
MenuServiceApp.UserAuthorityDefaultRouter(&reqUser)
return reqUser, err
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: FindUserById
//@description: 通过id获取用户信息
//@param: id int
//@return: err error, user *model.SysUser
func (userService *UserService) FindUserById(id int) (user *system.SysUser, err error) {
var u system.SysUser
err = global.GVA_DB.Where("id = ?", id).First(&u).Error
return &u, err
}
//@author: [SliverHorn](https://github.com/SliverHorn)
//@function: FindUserByUuid
//@description: 通过uuid获取用户信息
//@param: uuid string
//@return: err error, user *model.SysUser
func (userService *UserService) FindUserByUuid(uuid string) (user *system.SysUser, err error) {
var u system.SysUser
if err = global.GVA_DB.Where("uuid = ?", uuid).First(&u).Error; err != nil {
return &u, errors.New("用户不存在")
}
return &u, nil
}
//@author: [piexlmax](https://github.com/piexlmax)
//@function: ResetPassword
//@description: 修改用户密码
//@param: ID uint
//@return: err error
func (userService *UserService) ResetPassword(ID uint, password string) (err error) {
err = global.GVA_DB.Model(&system.SysUser{}).Where("id = ?", ID).Update("password", utils.BcryptHash(password)).Error
return err
}

View File

@@ -0,0 +1,230 @@
package system
import (
"context"
"git.echol.cn/loser/ai_proxy/server/global"
"git.echol.cn/loser/ai_proxy/server/model/system"
systemReq "git.echol.cn/loser/ai_proxy/server/model/system/request"
"gorm.io/gorm"
)
type SysVersionService struct{}
// CreateSysVersion 创建版本管理记录
// Author [yourname](https://github.com/yourname)
func (sysVersionService *SysVersionService) CreateSysVersion(ctx context.Context, sysVersion *system.SysVersion) (err error) {
err = global.GVA_DB.Create(sysVersion).Error
return err
}
// DeleteSysVersion 删除版本管理记录
// Author [yourname](https://github.com/yourname)
func (sysVersionService *SysVersionService) DeleteSysVersion(ctx context.Context, ID string) (err error) {
err = global.GVA_DB.Delete(&system.SysVersion{}, "id = ?", ID).Error
return err
}
// DeleteSysVersionByIds 批量删除版本管理记录
// Author [yourname](https://github.com/yourname)
func (sysVersionService *SysVersionService) DeleteSysVersionByIds(ctx context.Context, IDs []string) (err error) {
err = global.GVA_DB.Where("id in ?", IDs).Delete(&system.SysVersion{}).Error
return err
}
// GetSysVersion 根据ID获取版本管理记录
// Author [yourname](https://github.com/yourname)
func (sysVersionService *SysVersionService) GetSysVersion(ctx context.Context, ID string) (sysVersion system.SysVersion, err error) {
err = global.GVA_DB.Where("id = ?", ID).First(&sysVersion).Error
return
}
// GetSysVersionInfoList 分页获取版本管理记录
// Author [yourname](https://github.com/yourname)
func (sysVersionService *SysVersionService) GetSysVersionInfoList(ctx context.Context, info systemReq.SysVersionSearch) (list []system.SysVersion, total int64, err error) {
limit := info.PageSize
offset := info.PageSize * (info.Page - 1)
// 创建db
db := global.GVA_DB.Model(&system.SysVersion{})
var sysVersions []system.SysVersion
// 如果有条件搜索 下方会自动创建搜索语句
if len(info.CreatedAtRange) == 2 {
db = db.Where("created_at BETWEEN ? AND ?", info.CreatedAtRange[0], info.CreatedAtRange[1])
}
if info.VersionName != nil && *info.VersionName != "" {
db = db.Where("version_name LIKE ?", "%"+*info.VersionName+"%")
}
if info.VersionCode != nil && *info.VersionCode != "" {
db = db.Where("version_code = ?", *info.VersionCode)
}
err = db.Count(&total).Error
if err != nil {
return
}
if limit != 0 {
db = db.Limit(limit).Offset(offset)
}
err = db.Find(&sysVersions).Error
return sysVersions, total, err
}
func (sysVersionService *SysVersionService) GetSysVersionPublic(ctx context.Context) {
// 此方法为获取数据源定义的数据
// 请自行实现
}
// GetMenusByIds 根据ID列表获取菜单数据
func (sysVersionService *SysVersionService) GetMenusByIds(ctx context.Context, ids []uint) (menus []system.SysBaseMenu, err error) {
err = global.GVA_DB.Where("id in ?", ids).Preload("Parameters").Preload("MenuBtn").Find(&menus).Error
return
}
// GetApisByIds 根据ID列表获取API数据
func (sysVersionService *SysVersionService) GetApisByIds(ctx context.Context, ids []uint) (apis []system.SysApi, err error) {
err = global.GVA_DB.Where("id in ?", ids).Find(&apis).Error
return
}
// GetDictionariesByIds 根据ID列表获取字典数据
func (sysVersionService *SysVersionService) GetDictionariesByIds(ctx context.Context, ids []uint) (dictionaries []system.SysDictionary, err error) {
err = global.GVA_DB.Where("id in ?", ids).Preload("SysDictionaryDetails").Find(&dictionaries).Error
return
}
// ImportMenus 导入菜单数据
func (sysVersionService *SysVersionService) ImportMenus(ctx context.Context, menus []system.SysBaseMenu) error {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
// 递归创建菜单
return sysVersionService.createMenusRecursively(tx, menus, 0)
})
}
// createMenusRecursively 递归创建菜单
func (sysVersionService *SysVersionService) createMenusRecursively(tx *gorm.DB, menus []system.SysBaseMenu, parentId uint) error {
for _, menu := range menus {
// 检查菜单是否已存在
var existingMenu system.SysBaseMenu
if err := tx.Where("name = ? AND path = ?", menu.Name, menu.Path).First(&existingMenu).Error; err == nil {
// 菜单已存在使用现有菜单ID继续处理子菜单
if len(menu.Children) > 0 {
if err := sysVersionService.createMenusRecursively(tx, menu.Children, existingMenu.ID); err != nil {
return err
}
}
continue
}
// 保存参数和按钮数据,稍后处理
parameters := menu.Parameters
menuBtns := menu.MenuBtn
children := menu.Children
// 创建新菜单(不包含关联数据)
newMenu := system.SysBaseMenu{
ParentId: parentId,
Path: menu.Path,
Name: menu.Name,
Hidden: menu.Hidden,
Component: menu.Component,
Sort: menu.Sort,
Meta: menu.Meta,
}
if err := tx.Create(&newMenu).Error; err != nil {
return err
}
// 创建参数
if len(parameters) > 0 {
for _, param := range parameters {
newParam := system.SysBaseMenuParameter{
SysBaseMenuID: newMenu.ID,
Type: param.Type,
Key: param.Key,
Value: param.Value,
}
if err := tx.Create(&newParam).Error; err != nil {
return err
}
}
}
// 创建菜单按钮
if len(menuBtns) > 0 {
for _, btn := range menuBtns {
newBtn := system.SysBaseMenuBtn{
SysBaseMenuID: newMenu.ID,
Name: btn.Name,
Desc: btn.Desc,
}
if err := tx.Create(&newBtn).Error; err != nil {
return err
}
}
}
// 递归处理子菜单
if len(children) > 0 {
if err := sysVersionService.createMenusRecursively(tx, children, newMenu.ID); err != nil {
return err
}
}
}
return nil
}
// ImportApis 导入API数据
func (sysVersionService *SysVersionService) ImportApis(apis []system.SysApi) error {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
for _, api := range apis {
// 检查API是否已存在
var existingApi system.SysApi
if err := tx.Where("path = ? AND method = ?", api.Path, api.Method).First(&existingApi).Error; err == nil {
// API已存在跳过
continue
}
// 创建新API
newApi := system.SysApi{
Path: api.Path,
Description: api.Description,
ApiGroup: api.ApiGroup,
Method: api.Method,
}
if err := tx.Create(&newApi).Error; err != nil {
return err
}
}
return nil
})
}
// ImportDictionaries 导入字典数据
func (sysVersionService *SysVersionService) ImportDictionaries(dictionaries []system.SysDictionary) error {
return global.GVA_DB.Transaction(func(tx *gorm.DB) error {
for _, dict := range dictionaries {
// 检查字典是否已存在
var existingDict system.SysDictionary
if err := tx.Where("type = ?", dict.Type).First(&existingDict).Error; err == nil {
// 字典已存在,跳过
continue
}
// 创建新字典
newDict := system.SysDictionary{
Name: dict.Name,
Type: dict.Type,
Status: dict.Status,
Desc: dict.Desc,
SysDictionaryDetails: dict.SysDictionaryDetails,
}
if err := tx.Create(&newDict).Error; err != nil {
return err
}
}
return nil
})
}