Files
go-web-gin/svr/visitor.go
2026-03-05 16:53:06 +08:00

214 lines
5.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package svr
import (
"context"
"math/rand"
"net"
"net/url"
"strings"
"sync"
"time"
"git.hujye.com/infrastructure/go-web-gin/config"
"git.hujye.com/infrastructure/go-web-gin/utils"
"git.hujye.com/infrastructure/go-web-gin/web"
"github.com/gin-gonic/gin"
)
const (
// traceIDChars 用于生成 trace ID 的字符集
traceIDChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
// locationCacheTTL 地理位置缓存有效期3小时
locationCacheTTL = 3 * time.Hour
)
// locationCacheItem 地理位置缓存项
type locationCacheItem struct {
province string
city string
expiresAt time.Time
}
// locationCache 地理位置缓存
var locationCache sync.Map
// getLocationWithCache 获取地理位置(带缓存)
func getLocationWithCache(ctx context.Context, ip string) (province, city string) {
// 先从缓存中查找
if val, ok := locationCache.Load(ip); ok {
if item, ok := val.(*locationCacheItem); ok {
// 检查是否过期
if time.Now().Before(item.expiresAt) {
return item.province, item.city
}
// 已过期,删除缓存
locationCache.Delete(ip)
}
}
// 调用 API 获取地理位置
province, city, err := utils.GetLocation(ctx, ip)
if err != nil {
return "", ""
}
// 存入缓存
locationCache.Store(ip, &locationCacheItem{
province: province,
city: city,
expiresAt: time.Now().Add(locationCacheTTL),
})
return province, city
}
// generateTraceID 生成长度为7的随机字符串作为 trace ID
func generateTraceID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 7)
for i := range b {
b[i] = traceIDChars[r.Intn(len(traceIDChars))]
}
return string(b)
}
// VisitorMiddleware 创建一个中间件,从请求中提取访问者信息并设置到上下文中
func VisitorMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 从请求中提取访问者信息
clientIP := getClientIP(c)
userInfo := web.NewUserInfo().
WithFromIP(clientIP).
WithVisitTime(time.Now()).
WithTrace(generateTraceID()) // 生成7位随机 trace ID
// 从 JWT token 中解析用户信息
if token := extractToken(c); token != "" {
jwtConfig := config.Get().JWT
if jwtConfig.Secret != "" {
jwtUtil := utils.NewJWT(utils.JWTConfig{
SecretKey: jwtConfig.Secret,
})
if claims, err := jwtUtil.ParseToken(token); err == nil {
userInfo.WithUserID(claims.UserID).WithUserName(claims.UserName).WithRole(claims.Role)
}
}
}
// 如果 IP 不是保留地址,则通过 IP 获取地理位置(带缓存)
if clientIP != "" && !isPrivateIP(clientIP) {
if province, city := getLocationWithCache(ctx, clientIP); province != "" {
userInfo.WithProvince(province).WithCity(city)
}
}
// 将用户信息设置到上下文中
ctx = web.ToContext(ctx, userInfo)
c.Request = c.Request.WithContext(ctx)
c.Next()
// 通过响应头返回地理位置和 TRACE 信息(中文需要 URL 编码)
c.Header("X-Trace-ID", userInfo.Trace)
if userInfo.Province != "" {
c.Header("X-Province", url.QueryEscape(userInfo.Province))
}
if userInfo.City != "" {
c.Header("X-City", url.QueryEscape(userInfo.City))
}
}
}
// extractToken 从请求中提取 JWT token
// 支持 Authorization: Bearer <token> 和 Authorization: <token> 两种格式
func extractToken(c *gin.Context) string {
auth := c.GetHeader("Authorization")
if auth == "" {
return ""
}
// Bearer token 格式
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
// 直接 token 格式
return auth
}
// VisitorMiddlewareWithExtractor 创建一个支持自定义提取函数的中间件
// 允许应用程序定义自己的用户信息提取逻辑
func VisitorMiddlewareWithExtractor(extractor func(c *gin.Context) *web.UserInfo) gin.HandlerFunc {
return func(c *gin.Context) {
if extractor == nil {
c.Next()
return
}
userInfo := extractor(c)
if userInfo != nil {
ctx := c.Request.Context()
ctx = web.ToContext(ctx, userInfo)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}
// getClientIP 从请求中提取真实客户端 IP
func getClientIP(c *gin.Context) string {
// 优先尝试 X-Forwarded-For 请求头
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
// X-Forwarded-For 可能包含多个 IP取第一个
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
return xff[:i]
}
}
return xff
}
// 尝试 X-Real-IP 请求头
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 回退到 ClientIP
return c.ClientIP()
}
// isPrivateIP 检查 IP 地址是否为私有/保留地址
// 环回地址、私有网络、链路本地地址都返回 true
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return true // 无效 IP 视为私有地址
}
// 检查是否为环回地址
if ip.IsLoopback() {
return true
}
// 检查是否为私有地址 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
if ip.IsPrivate() {
return true
}
// 检查是否为链路本地地址 (169.254.0.0/16 或 fe80::/10)
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// 检查是否为未指定地址 (0.0.0.0 或 ::)
if ip.IsUnspecified() {
return true
}
return false
}