214 lines
5.2 KiB
Go
214 lines
5.2 KiB
Go
package svr
|
||
|
||
import (
|
||
"context"
|
||
"math/rand"
|
||
"net"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"git.hujye.com/infrastructure/go-web-gin/config"
|
||
"git.hujye.com/infrastructure/go-web-gin/utils"
|
||
"git.hujye.com/infrastructure/go-web-gin/web"
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
const (
|
||
// traceIDChars 用于生成 trace ID 的字符集
|
||
traceIDChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||
|
||
// locationCacheTTL 地理位置缓存有效期(3小时)
|
||
locationCacheTTL = 3 * time.Hour
|
||
)
|
||
|
||
// locationCacheItem 地理位置缓存项
|
||
type locationCacheItem struct {
|
||
province string
|
||
city string
|
||
expiresAt time.Time
|
||
}
|
||
|
||
// locationCache 地理位置缓存
|
||
var locationCache sync.Map
|
||
|
||
// getLocationWithCache 获取地理位置(带缓存)
|
||
func getLocationWithCache(ctx context.Context, ip string) (province, city string) {
|
||
// 先从缓存中查找
|
||
if val, ok := locationCache.Load(ip); ok {
|
||
if item, ok := val.(*locationCacheItem); ok {
|
||
// 检查是否过期
|
||
if time.Now().Before(item.expiresAt) {
|
||
return item.province, item.city
|
||
}
|
||
// 已过期,删除缓存
|
||
locationCache.Delete(ip)
|
||
}
|
||
}
|
||
|
||
// 调用 API 获取地理位置
|
||
province, city, err := utils.GetLocation(ctx, ip)
|
||
if err != nil {
|
||
return "", ""
|
||
}
|
||
|
||
// 存入缓存
|
||
locationCache.Store(ip, &locationCacheItem{
|
||
province: province,
|
||
city: city,
|
||
expiresAt: time.Now().Add(locationCacheTTL),
|
||
})
|
||
|
||
return province, city
|
||
}
|
||
|
||
// generateTraceID 生成长度为7的随机字符串作为 trace ID
|
||
func generateTraceID() string {
|
||
r := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||
b := make([]byte, 7)
|
||
for i := range b {
|
||
b[i] = traceIDChars[r.Intn(len(traceIDChars))]
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
// VisitorMiddleware 创建一个中间件,从请求中提取访问者信息并设置到上下文中
|
||
func VisitorMiddleware() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// 从请求中提取访问者信息
|
||
clientIP := getClientIP(c)
|
||
userInfo := web.NewUserInfo().
|
||
WithFromIP(clientIP).
|
||
WithVisitTime(time.Now()).
|
||
WithTrace(generateTraceID()) // 生成7位随机 trace ID
|
||
|
||
// 从 JWT token 中解析用户信息
|
||
if token := extractToken(c); token != "" {
|
||
jwtConfig := config.Get().JWT
|
||
if jwtConfig.Secret != "" {
|
||
jwtUtil := utils.NewJWT(utils.JWTConfig{
|
||
SecretKey: jwtConfig.Secret,
|
||
})
|
||
if claims, err := jwtUtil.ParseToken(token); err == nil {
|
||
userInfo.WithUserID(claims.UserID).WithUserName(claims.UserName).WithRole(claims.Role)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果 IP 不是保留地址,则通过 IP 获取地理位置(带缓存)
|
||
if clientIP != "" && !isPrivateIP(clientIP) {
|
||
if province, city := getLocationWithCache(ctx, clientIP); province != "" {
|
||
userInfo.WithProvince(province).WithCity(city)
|
||
}
|
||
}
|
||
|
||
// 将用户信息设置到上下文中
|
||
ctx = web.ToContext(ctx, userInfo)
|
||
c.Request = c.Request.WithContext(ctx)
|
||
|
||
c.Next()
|
||
|
||
// 通过响应头返回地理位置和 TRACE 信息(中文需要 URL 编码)
|
||
c.Header("X-Trace-ID", userInfo.Trace)
|
||
if userInfo.Province != "" {
|
||
c.Header("X-Province", url.QueryEscape(userInfo.Province))
|
||
}
|
||
if userInfo.City != "" {
|
||
c.Header("X-City", url.QueryEscape(userInfo.City))
|
||
}
|
||
}
|
||
}
|
||
|
||
// extractToken 从请求中提取 JWT token
|
||
// 支持 Authorization: Bearer <token> 和 Authorization: <token> 两种格式
|
||
func extractToken(c *gin.Context) string {
|
||
auth := c.GetHeader("Authorization")
|
||
if auth == "" {
|
||
return ""
|
||
}
|
||
|
||
// Bearer token 格式
|
||
if strings.HasPrefix(auth, "Bearer ") {
|
||
return strings.TrimPrefix(auth, "Bearer ")
|
||
}
|
||
|
||
// 直接 token 格式
|
||
return auth
|
||
}
|
||
|
||
// VisitorMiddlewareWithExtractor 创建一个支持自定义提取函数的中间件
|
||
// 允许应用程序定义自己的用户信息提取逻辑
|
||
func VisitorMiddlewareWithExtractor(extractor func(c *gin.Context) *web.UserInfo) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if extractor == nil {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
userInfo := extractor(c)
|
||
if userInfo != nil {
|
||
ctx := c.Request.Context()
|
||
ctx = web.ToContext(ctx, userInfo)
|
||
c.Request = c.Request.WithContext(ctx)
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// getClientIP 从请求中提取真实客户端 IP
|
||
func getClientIP(c *gin.Context) string {
|
||
// 优先尝试 X-Forwarded-For 请求头
|
||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||
// X-Forwarded-For 可能包含多个 IP,取第一个
|
||
for i := 0; i < len(xff); i++ {
|
||
if xff[i] == ',' {
|
||
return xff[:i]
|
||
}
|
||
}
|
||
return xff
|
||
}
|
||
|
||
// 尝试 X-Real-IP 请求头
|
||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||
return xri
|
||
}
|
||
|
||
// 回退到 ClientIP
|
||
return c.ClientIP()
|
||
}
|
||
|
||
// isPrivateIP 检查 IP 地址是否为私有/保留地址
|
||
// 环回地址、私有网络、链路本地地址都返回 true
|
||
func isPrivateIP(ipStr string) bool {
|
||
ip := net.ParseIP(ipStr)
|
||
if ip == nil {
|
||
return true // 无效 IP 视为私有地址
|
||
}
|
||
|
||
// 检查是否为环回地址
|
||
if ip.IsLoopback() {
|
||
return true
|
||
}
|
||
|
||
// 检查是否为私有地址 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
|
||
if ip.IsPrivate() {
|
||
return true
|
||
}
|
||
|
||
// 检查是否为链路本地地址 (169.254.0.0/16 或 fe80::/10)
|
||
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
|
||
return true
|
||
}
|
||
|
||
// 检查是否为未指定地址 (0.0.0.0 或 ::)
|
||
if ip.IsUnspecified() {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|