389 lines
12 KiB
Go
389 lines
12 KiB
Go
package app
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"git.echol.cn/loser/st/server/global"
|
|
"git.echol.cn/loser/st/server/model/app/request"
|
|
"git.echol.cn/loser/st/server/model/common"
|
|
commonResponse "git.echol.cn/loser/st/server/model/common/response"
|
|
"git.echol.cn/loser/st/server/service"
|
|
"github.com/gin-gonic/gin"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type ConversationApi struct{}
|
|
|
|
// CreateConversation
|
|
// @Tags AppConversation
|
|
// @Summary 创建对话
|
|
// @Produce application/json
|
|
// @Param data body request.CreateConversationRequest true "对话信息"
|
|
// @Success 200 {object} commonResponse.Response{data=response.ConversationResponse} "创建成功"
|
|
// @Router /app/conversation [post]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) CreateConversation(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
|
|
var req request.CreateConversationRequest
|
|
err := c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.CreateConversation(userID, &req)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("创建对话失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithData(resp, c)
|
|
}
|
|
|
|
// GetConversationList
|
|
// @Tags AppConversation
|
|
// @Summary 获取对话列表
|
|
// @Produce application/json
|
|
// @Param page query int false "页码"
|
|
// @Param pageSize query int false "每页数量"
|
|
// @Success 200 {object} commonResponse.Response{data=response.ConversationListResponse} "获取成功"
|
|
// @Router /app/conversation [get]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) GetConversationList(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
|
|
var req request.GetConversationListRequest
|
|
req.Page, _ = strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
req.PageSize, _ = strconv.Atoi(c.DefaultQuery("pageSize", "20"))
|
|
|
|
if req.Page < 1 {
|
|
req.Page = 1
|
|
}
|
|
if req.PageSize < 1 || req.PageSize > 100 {
|
|
req.PageSize = 20
|
|
}
|
|
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetConversationList(userID, &req)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("获取对话列表失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithDetailed(commonResponse.PageResult{
|
|
List: resp.List,
|
|
Total: resp.Total,
|
|
Page: resp.Page,
|
|
PageSize: resp.PageSize,
|
|
}, "获取成功", c)
|
|
}
|
|
|
|
// GetConversationByID
|
|
// @Tags AppConversation
|
|
// @Summary 获取对话详情
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Success 200 {object} commonResponse.Response{data=response.ConversationResponse} "获取成功"
|
|
// @Router /app/conversation/:id [get]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) GetConversationByID(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetConversationByID(userID, uint(conversationID))
|
|
if err != nil {
|
|
global.GVA_LOG.Error("获取对话详情失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithData(resp, c)
|
|
}
|
|
|
|
// UpdateConversationSettings
|
|
// @Tags AppConversation
|
|
// @Summary 更新对话设置
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Param data body request.UpdateConversationSettingsRequest true "设置信息"
|
|
// @Success 200 {object} commonResponse.Response{msg=string} "更新成功"
|
|
// @Router /app/conversation/:id/settings [put]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) UpdateConversationSettings(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
var req request.UpdateConversationSettingsRequest
|
|
err = c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
err = service.ServiceGroupApp.AppServiceGroup.ConversationService.UpdateConversationSettings(userID, uint(conversationID), &req)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("更新对话设置失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithMessage("更新成功", c)
|
|
}
|
|
|
|
// DeleteConversation
|
|
// @Tags AppConversation
|
|
// @Summary 删除对话
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Success 200 {object} commonResponse.Response{msg=string} "删除成功"
|
|
// @Router /app/conversation/:id [delete]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) DeleteConversation(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
err = service.ServiceGroupApp.AppServiceGroup.ConversationService.DeleteConversation(userID, uint(conversationID))
|
|
if err != nil {
|
|
global.GVA_LOG.Error("删除对话失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithMessage("删除成功", c)
|
|
}
|
|
|
|
// GetMessageList
|
|
// @Tags AppConversation
|
|
// @Summary 获取消息列表
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Param page query int false "页码"
|
|
// @Param pageSize query int false "每页数量"
|
|
// @Success 200 {object} commonResponse.Response{data=response.MessageListResponse} "获取成功"
|
|
// @Router /app/conversation/:id/messages [get]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) GetMessageList(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
var req request.GetMessageListRequest
|
|
req.Page, _ = strconv.Atoi(c.DefaultQuery("page", "1"))
|
|
req.PageSize, _ = strconv.Atoi(c.DefaultQuery("pageSize", "50"))
|
|
|
|
if req.Page < 1 {
|
|
req.Page = 1
|
|
}
|
|
if req.PageSize < 1 || req.PageSize > 100 {
|
|
req.PageSize = 50
|
|
}
|
|
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetMessageList(userID, uint(conversationID), &req)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("获取消息列表失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
commonResponse.OkWithDetailed(commonResponse.PageResult{
|
|
List: resp.List,
|
|
Total: resp.Total,
|
|
Page: resp.Page,
|
|
PageSize: resp.PageSize,
|
|
}, "获取成功", c)
|
|
}
|
|
|
|
// RegenerateMessage
|
|
// @Tags AppConversation
|
|
// @Summary 重新生成最后一条 AI 回复
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Param stream query bool false "是否流式传输"
|
|
// @Success 200 {object} commonResponse.Response{data=response.MessageResponse} "重新生成成功"
|
|
// @Router /app/conversation/:id/regenerate [post]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) RegenerateMessage(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
if c.Query("stream") == "true" {
|
|
a.regenerateMessageStream(c, userID, uint(conversationID))
|
|
return
|
|
}
|
|
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.RegenerateMessage(userID, uint(conversationID))
|
|
if err != nil {
|
|
global.GVA_LOG.Error("重新生成消息失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
commonResponse.OkWithData(resp, c)
|
|
}
|
|
|
|
func (a *ConversationApi) regenerateMessageStream(c *gin.Context, userID, conversationID uint) {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("X-Accel-Buffering", "no")
|
|
|
|
streamChan := make(chan string, 100)
|
|
errorChan := make(chan error, 1)
|
|
doneChan := make(chan bool, 1)
|
|
|
|
go func() {
|
|
if err := service.ServiceGroupApp.AppServiceGroup.ConversationService.RegenerateMessageStream(
|
|
userID, conversationID, streamChan, doneChan,
|
|
); err != nil {
|
|
errorChan <- err
|
|
}
|
|
}()
|
|
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
commonResponse.FailWithMessage("不支持流式传输", c)
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case chunk := <-streamChan:
|
|
c.Writer.Write([]byte("event: message\n"))
|
|
c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", chunk)))
|
|
flusher.Flush()
|
|
case err := <-errorChan:
|
|
c.Writer.Write([]byte("event: error\n"))
|
|
c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", err.Error())))
|
|
flusher.Flush()
|
|
return
|
|
case <-doneChan:
|
|
c.Writer.Write([]byte("event: done\n"))
|
|
c.Writer.Write([]byte("data: \n\n"))
|
|
flusher.Flush()
|
|
return
|
|
case <-c.Request.Context().Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// SendMessage
|
|
// @Tags AppConversation
|
|
// @Summary 发送消息
|
|
// @Produce application/json
|
|
// @Param id path int true "对话ID"
|
|
// @Param data body request.SendMessageRequest true "消息内容"
|
|
// @Success 200 {object} commonResponse.Response{data=response.MessageResponse} "发送成功"
|
|
// @Router /app/conversation/:id/message [post]
|
|
// @Security ApiKeyAuth
|
|
func (a *ConversationApi) SendMessage(c *gin.Context) {
|
|
userID := common.GetAppUserID(c)
|
|
conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage("无效的对话ID", c)
|
|
return
|
|
}
|
|
|
|
var req request.SendMessageRequest
|
|
err = c.ShouldBindJSON(&req)
|
|
if err != nil {
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
|
|
// 检查是否启用流式传输
|
|
stream := c.Query("stream") == "true"
|
|
|
|
if stream {
|
|
// 流式传输
|
|
a.SendMessageStream(c, userID, uint(conversationID), &req)
|
|
} else {
|
|
// 普通传输
|
|
resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.SendMessage(userID, uint(conversationID), &req)
|
|
if err != nil {
|
|
global.GVA_LOG.Error("发送消息失败", zap.Error(err))
|
|
commonResponse.FailWithMessage(err.Error(), c)
|
|
return
|
|
}
|
|
commonResponse.OkWithData(resp, c)
|
|
}
|
|
}
|
|
|
|
// SendMessageStream 流式传输消息
|
|
func (a *ConversationApi) SendMessageStream(c *gin.Context, userID, conversationID uint, req *request.SendMessageRequest) {
|
|
// 设置SSE响应头
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("X-Accel-Buffering", "no")
|
|
|
|
// 创建流式传输通道
|
|
streamChan := make(chan string, 100)
|
|
errorChan := make(chan error, 1)
|
|
doneChan := make(chan bool, 1)
|
|
|
|
// 启动流式传输
|
|
go func() {
|
|
err := service.ServiceGroupApp.AppServiceGroup.ConversationService.SendMessageStream(
|
|
userID, conversationID, req, streamChan, doneChan,
|
|
)
|
|
if err != nil {
|
|
errorChan <- err
|
|
}
|
|
}()
|
|
|
|
// 发送流式数据
|
|
flusher, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
commonResponse.FailWithMessage("不支持流式传输", c)
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case chunk := <-streamChan:
|
|
// 手动写入 SSE 格式,避免 Gin 的 SSEvent 进行 JSON 序列化
|
|
c.Writer.Write([]byte("event: message\n"))
|
|
c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", chunk)))
|
|
flusher.Flush()
|
|
case err := <-errorChan:
|
|
// 发送错误
|
|
c.Writer.Write([]byte("event: error\n"))
|
|
c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", err.Error())))
|
|
flusher.Flush()
|
|
return
|
|
case <-doneChan:
|
|
// 发送完成信号
|
|
c.Writer.Write([]byte("event: done\n"))
|
|
c.Writer.Write([]byte("data: \n\n"))
|
|
flusher.Flush()
|
|
return
|
|
case <-c.Request.Context().Done():
|
|
// 客户端断开连接
|
|
return
|
|
}
|
|
}
|
|
}
|