You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

167 lines
4.7 KiB
Go

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package handle
import (
"Lee-WineList/client"
"Lee-WineList/common/constant"
"Lee-WineList/model/cache"
"Lee-WineList/model/entity"
"Lee-WineList/repository"
"context"
"encoding/json"
"fmt"
"git.echol.cn/loser/logger/log"
"github.com/go-oauth2/oauth2/v4"
"github.com/go-oauth2/oauth2/v4/errors"
"net/http"
"strconv"
"strings"
"time"
)
// UserAuthorizationHandler 获取用户Id
func UserAuthorizationHandler(w http.ResponseWriter, r *http.Request) (userId string, err error) {
loginType := constant.LoginType(r.FormValue("type")) // 登录类型
userIdentity := constant.UserIdentity(r.FormValue("identity")) // 身份类型
account := r.FormValue("username") // 用户传入账号(或者授权Code)
nikeName := r.FormValue("nickName") // 昵称
avatarUrl := r.FormValue("avatarUrl") // 头像
log.Debugf("预处理用户登录请求,身份类型: %v => 登录类型: %s => 账号: %v", userIdentity, loginType, account)
// 普通用户
userId, err = getUser(account, loginType, nikeName, avatarUrl)
if err != nil {
return
}
atoi, _ := strconv.Atoi(userId)
// 组装缓存用户信息
m := cache.UserInfo{
UserId: atoi,
UserType: userIdentity.String(),
}
userInfo, err := m.String()
if err != nil {
err = errors.New("登录失败,请联系管理员")
return
}
if err = client.Redis.Set(context.Background(), fmt.Sprintf("%s%v", constant.OAuth2UserCacheKey, userId), userInfo, time.Hour*24*7).Err(); err != nil {
log.Errorf("缓存用户信息失败用户ID%v错误信息%s", userId, err.Error())
err = errors.New("登录失败,请联系管理员")
return
}
return
}
// LoginWithPassword 账号密码登录模式
func LoginWithPassword(ctx context.Context, clientId, userId, password string) (userID string, err error) {
log.Debugf("[%v]处理登录请求用户Id%s --> %s", clientId, userId, password)
userID = userId
return
}
// CheckClient 检查是否允许该客户端通过该授权模式请求令牌
func CheckClient(clientID string, grant oauth2.GrantType) (allowed bool, err error) {
// 解出租户Id和传入的客户端Id
c := entity.OAuth2Client{ClientId: clientID}
// 查询客户端配置信息
if err = repository.OAuth2Client().FindOne(&c); err != nil {
log.Errorf("客户端信息查询失败: %v", err.Error())
err = errors.New("客户端信息查询失败: " + err.Error())
allowed = false
return
}
// 判断是否包含授权范围
allowed = strings.Contains(c.Grant, string(grant))
if !allowed {
err = errors.New("不受允许的grant_type")
}
return
}
// ExtensionFields 自定义响应Token的扩展字段
func ExtensionFields(ti oauth2.TokenInfo) (fieldsValue map[string]any) {
fieldsValue = map[string]any{}
fieldsValue["license"] = "Made By Lee"
// 取出用户信息
var userInfo entity.User
userInfo.Id, _ = strconv.Atoi(ti.GetUserID())
//if err := repository.User().GetUser(&userInfo); err != nil {
// return
//}
repository.User().GetUser(&userInfo)
fieldsValue["newUser"] = time.Now().Sub(userInfo.CreatedAt.ToTime()).Minutes() <= 1
fieldsValue["nickname"] = userInfo.Nickname
fieldsValue["phone"] = userInfo.Phone
fieldsValue["userId"] = userInfo.Id
return
}
// ResponseToken 返回Token生成结果
func ResponseToken(w http.ResponseWriter, data map[string]any, header http.Header, statusCode ...int) error {
log.Debugf("返回Token原始数据: %+v", data)
type response struct {
Code int `json:"code"`
Data map[string]any `json:"data"`
Msg string `json:"message"`
}
status := http.StatusOK
msg := "login success"
if len(statusCode) > 0 && statusCode[0] > 0 {
status = statusCode[0]
msg = fmt.Sprintf("%v", data["error_description"])
// 处理特殊返回 - 刷新Token到期了
switch data["error"] {
case "invalid_grant":
msg = "登录已过期,请重新授权登录"
case "invalid_request":
msg = "登录参数错误"
default:
log.Errorf("收到未定义的登录错误: %v", data["error_description"])
}
data = nil
}
res := response{
Code: status,
Msg: msg,
Data: data,
}
jsonBytes, err := json.Marshal(res)
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json;charset=UTF-8")
w.Header().Set("Cache-Control", "no-store")
w.Header().Set("Pragma", "no-cache")
for key := range header {
w.Header().Set(key, header.Get(key))
}
w.WriteHeader(status)
_, err = w.Write(jsonBytes)
if err != nil {
log.Errorf("返回Token失败: %v", err.Error())
return err
}
return err
}
// InternalErrorHandler 自定义内部错误处理
func InternalErrorHandler(err error) (re *errors.Response) {
re = errors.NewResponse(err, http.StatusUnauthorized)
re.Description = err.Error()
return
}