package svr import ( "context" "math/rand" "net" "net/url" "strings" "sync" "time" "git.hujye.com/infrastructure/go-web-gin/config" "git.hujye.com/infrastructure/go-web-gin/utils" "git.hujye.com/infrastructure/go-web-gin/web" "github.com/gin-gonic/gin" ) const ( // traceIDChars 用于生成 trace ID 的字符集 traceIDChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" // locationCacheTTL 地理位置缓存有效期(3小时) locationCacheTTL = 3 * time.Hour ) // locationCacheItem 地理位置缓存项 type locationCacheItem struct { province string city string expiresAt time.Time } // locationCache 地理位置缓存 var locationCache sync.Map // getLocationWithCache 获取地理位置(带缓存) func getLocationWithCache(ctx context.Context, ip string) (province, city string) { // 先从缓存中查找 if val, ok := locationCache.Load(ip); ok { if item, ok := val.(*locationCacheItem); ok { // 检查是否过期 if time.Now().Before(item.expiresAt) { return item.province, item.city } // 已过期,删除缓存 locationCache.Delete(ip) } } // 调用 API 获取地理位置 province, city, err := utils.GetLocation(ctx, ip) if err != nil { return "", "" } // 存入缓存 locationCache.Store(ip, &locationCacheItem{ province: province, city: city, expiresAt: time.Now().Add(locationCacheTTL), }) return province, city } // generateTraceID 生成长度为7的随机字符串作为 trace ID func generateTraceID() string { r := rand.New(rand.NewSource(time.Now().UnixNano())) b := make([]byte, 7) for i := range b { b[i] = traceIDChars[r.Intn(len(traceIDChars))] } return string(b) } // VisitorMiddleware 创建一个中间件,从请求中提取访问者信息并设置到上下文中 func VisitorMiddleware() gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context() // 从请求中提取访问者信息 clientIP := getClientIP(c) userInfo := web.NewUserInfo(). WithFromIP(clientIP). WithVisitTime(time.Now()). WithTrace(generateTraceID()) // 生成7位随机 trace ID // 从 JWT token 中解析用户信息 if token := extractToken(c); token != "" { jwtConfig := config.Get().JWT if jwtConfig.Secret != "" { jwtUtil := utils.NewJWT(utils.JWTConfig{ SecretKey: jwtConfig.Secret, }) if claims, err := jwtUtil.ParseToken(token); err == nil { userInfo.WithUserID(claims.UserID).WithUserName(claims.UserName).WithRole(claims.Role) } } } // 如果 IP 不是保留地址,则通过 IP 获取地理位置(带缓存) if clientIP != "" && !isPrivateIP(clientIP) { if province, city := getLocationWithCache(ctx, clientIP); province != "" { userInfo.WithProvince(province).WithCity(city) } } // 将用户信息设置到上下文中 ctx = web.ToContext(ctx, userInfo) c.Request = c.Request.WithContext(ctx) c.Next() // 通过响应头返回地理位置和 TRACE 信息(中文需要 URL 编码) c.Header("X-Trace-ID", userInfo.Trace) if userInfo.Province != "" { c.Header("X-Province", url.QueryEscape(userInfo.Province)) } if userInfo.City != "" { c.Header("X-City", url.QueryEscape(userInfo.City)) } } } // extractToken 从请求中提取 JWT token // 支持 Authorization: Bearer 和 Authorization: 两种格式 func extractToken(c *gin.Context) string { auth := c.GetHeader("Authorization") if auth == "" { return "" } // Bearer token 格式 if strings.HasPrefix(auth, "Bearer ") { return strings.TrimPrefix(auth, "Bearer ") } // 直接 token 格式 return auth } // VisitorMiddlewareWithExtractor 创建一个支持自定义提取函数的中间件 // 允许应用程序定义自己的用户信息提取逻辑 func VisitorMiddlewareWithExtractor(extractor func(c *gin.Context) *web.UserInfo) gin.HandlerFunc { return func(c *gin.Context) { if extractor == nil { c.Next() return } userInfo := extractor(c) if userInfo != nil { ctx := c.Request.Context() ctx = web.ToContext(ctx, userInfo) c.Request = c.Request.WithContext(ctx) } c.Next() } } // getClientIP 从请求中提取真实客户端 IP func getClientIP(c *gin.Context) string { // 优先尝试 X-Forwarded-For 请求头 if xff := c.GetHeader("X-Forwarded-For"); xff != "" { // X-Forwarded-For 可能包含多个 IP,取第一个 for i := 0; i < len(xff); i++ { if xff[i] == ',' { return xff[:i] } } return xff } // 尝试 X-Real-IP 请求头 if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } // 回退到 ClientIP return c.ClientIP() } // isPrivateIP 检查 IP 地址是否为私有/保留地址 // 环回地址、私有网络、链路本地地址都返回 true func isPrivateIP(ipStr string) bool { ip := net.ParseIP(ipStr) if ip == nil { return true // 无效 IP 视为私有地址 } // 检查是否为环回地址 if ip.IsLoopback() { return true } // 检查是否为私有地址 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) if ip.IsPrivate() { return true } // 检查是否为链路本地地址 (169.254.0.0/16 或 fe80::/10) if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { return true } // 检查是否为未指定地址 (0.0.0.0 或 ::) if ip.IsUnspecified() { return true } return false }