175 lines
4.5 KiB
Go
175 lines
4.5 KiB
Go
package utils
|
||
|
||
import (
|
||
"errors"
|
||
"time"
|
||
|
||
"github.com/golang-jwt/jwt/v5"
|
||
)
|
||
|
||
var (
|
||
// ErrTokenInvalid token 无效
|
||
ErrTokenInvalid = errors.New("token 无效")
|
||
// ErrTokenExpired token 已过期
|
||
ErrTokenExpired = errors.New("token 已过期")
|
||
// ErrTokenNotValidYet token 尚未生效
|
||
ErrTokenNotValidYet = errors.New("token 尚未生效")
|
||
// ErrTokenMalformed token 格式错误
|
||
ErrTokenMalformed = errors.New("token 格式错误")
|
||
// ErrTokenSignatureInvalid 签名无效
|
||
ErrTokenSignatureInvalid = errors.New("签名无效")
|
||
)
|
||
|
||
// JWTConfig JWT 配置
|
||
type JWTConfig struct {
|
||
SecretKey string // 密钥
|
||
ExpiresTime time.Duration // 过期时间
|
||
Issuer string // 签发者
|
||
}
|
||
|
||
// JWT jwt 工具
|
||
type JWT struct {
|
||
config JWTConfig
|
||
}
|
||
|
||
// CustomClaims 自定义 claims
|
||
type CustomClaims struct {
|
||
UserID string `json:"user_id"`
|
||
UserName string `json:"user_name"`
|
||
Role int `json:"role"`
|
||
Data map[string]interface{} `json:"data,omitempty"`
|
||
jwt.RegisteredClaims
|
||
}
|
||
|
||
// NewJWT 创建 JWT 实例
|
||
func NewJWT(config JWTConfig) *JWT {
|
||
return &JWT{config: config}
|
||
}
|
||
|
||
// GenerateToken 生成 token
|
||
func (j *JWT) GenerateToken(claims *CustomClaims) (string, error) {
|
||
// 设置默认值
|
||
if claims.Issuer == "" {
|
||
claims.Issuer = j.config.Issuer
|
||
}
|
||
if claims.ExpiresAt == nil && j.config.ExpiresTime > 0 {
|
||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.config.ExpiresTime))
|
||
}
|
||
if claims.IssuedAt == nil {
|
||
claims.IssuedAt = jwt.NewNumericDate(time.Now())
|
||
}
|
||
if claims.NotBefore == nil {
|
||
claims.NotBefore = jwt.NewNumericDate(time.Now())
|
||
}
|
||
|
||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||
return token.SignedString([]byte(j.config.SecretKey))
|
||
}
|
||
|
||
// ParseToken 解析 token
|
||
func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) {
|
||
token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||
return []byte(j.config.SecretKey), nil
|
||
})
|
||
|
||
if err != nil {
|
||
switch {
|
||
case errors.Is(err, jwt.ErrTokenMalformed):
|
||
return nil, ErrTokenMalformed
|
||
case errors.Is(err, jwt.ErrTokenExpired):
|
||
return nil, ErrTokenExpired
|
||
case errors.Is(err, jwt.ErrTokenNotValidYet):
|
||
return nil, ErrTokenNotValidYet
|
||
case errors.Is(err, jwt.ErrTokenSignatureInvalid):
|
||
return nil, ErrTokenSignatureInvalid
|
||
default:
|
||
return nil, ErrTokenInvalid
|
||
}
|
||
}
|
||
|
||
if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid {
|
||
return claims, nil
|
||
}
|
||
|
||
return nil, ErrTokenInvalid
|
||
}
|
||
|
||
// RefreshToken 刷新 token
|
||
func (j *JWT) RefreshToken(tokenString string) (string, error) {
|
||
claims, err := j.ParseToken(tokenString)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 重置过期时间
|
||
claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.config.ExpiresTime))
|
||
claims.IssuedAt = jwt.NewNumericDate(time.Now())
|
||
|
||
return j.GenerateToken(claims)
|
||
}
|
||
|
||
// GetUserIDFromToken 从 token 中获取用户 ID
|
||
func (j *JWT) GetUserIDFromToken(tokenString string) (string, error) {
|
||
claims, err := j.ParseToken(tokenString)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return claims.UserID, nil
|
||
}
|
||
|
||
// GetUserNameFromToken 从 token 中获取用户名
|
||
func (j *JWT) GetUserNameFromToken(tokenString string) (string, error) {
|
||
claims, err := j.ParseToken(tokenString)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return claims.UserName, nil
|
||
}
|
||
|
||
// GetClaimFromToken 从 token 中获取指定字段
|
||
func (j *JWT) GetClaimFromToken(tokenString string, key string) (interface{}, error) {
|
||
claims, err := j.ParseToken(tokenString)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if claims.Data == nil {
|
||
return nil, errors.New("未找到指定字段")
|
||
}
|
||
val, ok := claims.Data[key]
|
||
if !ok {
|
||
return nil, errors.New("未找到指定字段")
|
||
}
|
||
return val, nil
|
||
}
|
||
|
||
// ValidateToken 验证 token 是否有效
|
||
func (j *JWT) ValidateToken(tokenString string) bool {
|
||
_, err := j.ParseToken(tokenString)
|
||
return err == nil
|
||
}
|
||
|
||
// CreateToken 快速创建 token(简化方法)
|
||
// secretKey: 密钥
|
||
// userID: 用户 ID
|
||
// userName: 用户名
|
||
// expires: 过期时间(小时)
|
||
func CreateToken(secretKey, userID, userName string, expires int) (string, error) {
|
||
jwtUtil := NewJWT(JWTConfig{
|
||
SecretKey: secretKey,
|
||
ExpiresTime: time.Duration(expires) * time.Hour,
|
||
})
|
||
|
||
claims := &CustomClaims{
|
||
UserID: userID,
|
||
UserName: userName,
|
||
}
|
||
|
||
return jwtUtil.GenerateToken(claims)
|
||
}
|
||
|
||
// ParseTokenString 快速解析 token(简化方法)
|
||
func ParseTokenString(secretKey, tokenString string) (*CustomClaims, error) {
|
||
jwtUtil := NewJWT(JWTConfig{SecretKey: secretKey})
|
||
return jwtUtil.ParseToken(tokenString)
|
||
}
|