package app import ( "fmt" "net/http" "strconv" "git.echol.cn/loser/st/server/global" "git.echol.cn/loser/st/server/model/app/request" "git.echol.cn/loser/st/server/model/common" commonResponse "git.echol.cn/loser/st/server/model/common/response" "git.echol.cn/loser/st/server/service" "github.com/gin-gonic/gin" "go.uber.org/zap" ) type ConversationApi struct{} // CreateConversation // @Tags AppConversation // @Summary 创建对话 // @Produce application/json // @Param data body request.CreateConversationRequest true "对话信息" // @Success 200 {object} commonResponse.Response{data=response.ConversationResponse} "创建成功" // @Router /app/conversation [post] // @Security ApiKeyAuth func (a *ConversationApi) CreateConversation(c *gin.Context) { userID := common.GetAppUserID(c) var req request.CreateConversationRequest err := c.ShouldBindJSON(&req) if err != nil { commonResponse.FailWithMessage(err.Error(), c) return } resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.CreateConversation(userID, &req) if err != nil { global.GVA_LOG.Error("创建对话失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithData(resp, c) } // GetConversationList // @Tags AppConversation // @Summary 获取对话列表 // @Produce application/json // @Param page query int false "页码" // @Param pageSize query int false "每页数量" // @Success 200 {object} commonResponse.Response{data=response.ConversationListResponse} "获取成功" // @Router /app/conversation [get] // @Security ApiKeyAuth func (a *ConversationApi) GetConversationList(c *gin.Context) { userID := common.GetAppUserID(c) var req request.GetConversationListRequest req.Page, _ = strconv.Atoi(c.DefaultQuery("page", "1")) req.PageSize, _ = strconv.Atoi(c.DefaultQuery("pageSize", "20")) if req.Page < 1 { req.Page = 1 } if req.PageSize < 1 || req.PageSize > 100 { req.PageSize = 20 } resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetConversationList(userID, &req) if err != nil { global.GVA_LOG.Error("获取对话列表失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithDetailed(commonResponse.PageResult{ List: resp.List, Total: resp.Total, Page: resp.Page, PageSize: resp.PageSize, }, "获取成功", c) } // GetConversationByID // @Tags AppConversation // @Summary 获取对话详情 // @Produce application/json // @Param id path int true "对话ID" // @Success 200 {object} commonResponse.Response{data=response.ConversationResponse} "获取成功" // @Router /app/conversation/:id [get] // @Security ApiKeyAuth func (a *ConversationApi) GetConversationByID(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetConversationByID(userID, uint(conversationID)) if err != nil { global.GVA_LOG.Error("获取对话详情失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithData(resp, c) } // UpdateConversationSettings // @Tags AppConversation // @Summary 更新对话设置 // @Produce application/json // @Param id path int true "对话ID" // @Param data body request.UpdateConversationSettingsRequest true "设置信息" // @Success 200 {object} commonResponse.Response{msg=string} "更新成功" // @Router /app/conversation/:id/settings [put] // @Security ApiKeyAuth func (a *ConversationApi) UpdateConversationSettings(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } var req request.UpdateConversationSettingsRequest err = c.ShouldBindJSON(&req) if err != nil { commonResponse.FailWithMessage(err.Error(), c) return } err = service.ServiceGroupApp.AppServiceGroup.ConversationService.UpdateConversationSettings(userID, uint(conversationID), &req) if err != nil { global.GVA_LOG.Error("更新对话设置失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithMessage("更新成功", c) } // DeleteConversation // @Tags AppConversation // @Summary 删除对话 // @Produce application/json // @Param id path int true "对话ID" // @Success 200 {object} commonResponse.Response{msg=string} "删除成功" // @Router /app/conversation/:id [delete] // @Security ApiKeyAuth func (a *ConversationApi) DeleteConversation(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } err = service.ServiceGroupApp.AppServiceGroup.ConversationService.DeleteConversation(userID, uint(conversationID)) if err != nil { global.GVA_LOG.Error("删除对话失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithMessage("删除成功", c) } // GetMessageList // @Tags AppConversation // @Summary 获取消息列表 // @Produce application/json // @Param id path int true "对话ID" // @Param page query int false "页码" // @Param pageSize query int false "每页数量" // @Success 200 {object} commonResponse.Response{data=response.MessageListResponse} "获取成功" // @Router /app/conversation/:id/messages [get] // @Security ApiKeyAuth func (a *ConversationApi) GetMessageList(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } var req request.GetMessageListRequest req.Page, _ = strconv.Atoi(c.DefaultQuery("page", "1")) req.PageSize, _ = strconv.Atoi(c.DefaultQuery("pageSize", "50")) if req.Page < 1 { req.Page = 1 } if req.PageSize < 1 || req.PageSize > 100 { req.PageSize = 50 } resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.GetMessageList(userID, uint(conversationID), &req) if err != nil { global.GVA_LOG.Error("获取消息列表失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithDetailed(commonResponse.PageResult{ List: resp.List, Total: resp.Total, Page: resp.Page, PageSize: resp.PageSize, }, "获取成功", c) } // RegenerateMessage // @Tags AppConversation // @Summary 重新生成最后一条 AI 回复 // @Produce application/json // @Param id path int true "对话ID" // @Param stream query bool false "是否流式传输" // @Success 200 {object} commonResponse.Response{data=response.MessageResponse} "重新生成成功" // @Router /app/conversation/:id/regenerate [post] // @Security ApiKeyAuth func (a *ConversationApi) RegenerateMessage(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } if c.Query("stream") == "true" { a.regenerateMessageStream(c, userID, uint(conversationID)) return } resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.RegenerateMessage(userID, uint(conversationID)) if err != nil { global.GVA_LOG.Error("重新生成消息失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithData(resp, c) } func (a *ConversationApi) regenerateMessageStream(c *gin.Context, userID, conversationID uint) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") streamChan := make(chan string, 100) errorChan := make(chan error, 1) doneChan := make(chan bool, 1) go func() { if err := service.ServiceGroupApp.AppServiceGroup.ConversationService.RegenerateMessageStream( userID, conversationID, streamChan, doneChan, ); err != nil { errorChan <- err } }() flusher, ok := c.Writer.(http.Flusher) if !ok { commonResponse.FailWithMessage("不支持流式传输", c) return } for { select { case chunk := <-streamChan: c.Writer.Write([]byte("event: message\n")) c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", chunk))) flusher.Flush() case err := <-errorChan: c.Writer.Write([]byte("event: error\n")) c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", err.Error()))) flusher.Flush() return case <-doneChan: c.Writer.Write([]byte("event: done\n")) c.Writer.Write([]byte("data: \n\n")) flusher.Flush() return case <-c.Request.Context().Done(): return } } } // SendMessage // @Tags AppConversation // @Summary 发送消息 // @Produce application/json // @Param id path int true "对话ID" // @Param data body request.SendMessageRequest true "消息内容" // @Success 200 {object} commonResponse.Response{data=response.MessageResponse} "发送成功" // @Router /app/conversation/:id/message [post] // @Security ApiKeyAuth func (a *ConversationApi) SendMessage(c *gin.Context) { userID := common.GetAppUserID(c) conversationID, err := strconv.ParseUint(c.Param("id"), 10, 32) if err != nil { commonResponse.FailWithMessage("无效的对话ID", c) return } var req request.SendMessageRequest err = c.ShouldBindJSON(&req) if err != nil { commonResponse.FailWithMessage(err.Error(), c) return } // 检查是否启用流式传输 stream := c.Query("stream") == "true" if stream { // 流式传输 a.SendMessageStream(c, userID, uint(conversationID), &req) } else { // 普通传输 resp, err := service.ServiceGroupApp.AppServiceGroup.ConversationService.SendMessage(userID, uint(conversationID), &req) if err != nil { global.GVA_LOG.Error("发送消息失败", zap.Error(err)) commonResponse.FailWithMessage(err.Error(), c) return } commonResponse.OkWithData(resp, c) } } // SendMessageStream 流式传输消息 func (a *ConversationApi) SendMessageStream(c *gin.Context, userID, conversationID uint, req *request.SendMessageRequest) { // 设置SSE响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 创建流式传输通道 streamChan := make(chan string, 100) errorChan := make(chan error, 1) doneChan := make(chan bool, 1) // 启动流式传输 go func() { err := service.ServiceGroupApp.AppServiceGroup.ConversationService.SendMessageStream( userID, conversationID, req, streamChan, doneChan, ) if err != nil { errorChan <- err } }() // 发送流式数据 flusher, ok := c.Writer.(http.Flusher) if !ok { commonResponse.FailWithMessage("不支持流式传输", c) return } for { select { case chunk := <-streamChan: // 手动写入 SSE 格式,避免 Gin 的 SSEvent 进行 JSON 序列化 c.Writer.Write([]byte("event: message\n")) c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", chunk))) flusher.Flush() case err := <-errorChan: // 发送错误 c.Writer.Write([]byte("event: error\n")) c.Writer.Write([]byte(fmt.Sprintf("data: %s\n\n", err.Error()))) flusher.Flush() return case <-doneChan: // 发送完成信号 c.Writer.Write([]byte("event: done\n")) c.Writer.Write([]byte("data: \n\n")) flusher.Flush() return case <-c.Request.Context().Done(): // 客户端断开连接 return } } }