feat: 新增访客拦截器

This commit is contained in:
2026-03-03 00:58:22 +08:00
parent 52cfe1b911
commit 00a0a01b17
7 changed files with 566 additions and 0 deletions

213
svr/visitor.go Normal file
View File

@@ -0,0 +1,213 @@
package svr
import (
"context"
"math/rand"
"net"
"net/url"
"strings"
"sync"
"time"
"git.hujye.com/infrastructure/go-web-gin/config"
"git.hujye.com/infrastructure/go-web-gin/utils"
"git.hujye.com/infrastructure/go-web-gin/web"
"github.com/gin-gonic/gin"
)
const (
// traceIDChars 用于生成 trace ID 的字符集
traceIDChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
// locationCacheTTL 地理位置缓存有效期3小时
locationCacheTTL = 3 * time.Hour
)
// locationCacheItem 地理位置缓存项
type locationCacheItem struct {
province string
city string
expiresAt time.Time
}
// locationCache 地理位置缓存
var locationCache sync.Map
// getLocationWithCache 获取地理位置(带缓存)
func getLocationWithCache(ctx context.Context, ip string) (province, city string) {
// 先从缓存中查找
if val, ok := locationCache.Load(ip); ok {
if item, ok := val.(*locationCacheItem); ok {
// 检查是否过期
if time.Now().Before(item.expiresAt) {
return item.province, item.city
}
// 已过期,删除缓存
locationCache.Delete(ip)
}
}
// 调用 API 获取地理位置
province, city, err := utils.GetLocation(ctx, ip)
if err != nil {
return "", ""
}
// 存入缓存
locationCache.Store(ip, &locationCacheItem{
province: province,
city: city,
expiresAt: time.Now().Add(locationCacheTTL),
})
return province, city
}
// generateTraceID 生成长度为7的随机字符串作为 trace ID
func generateTraceID() string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, 7)
for i := range b {
b[i] = traceIDChars[r.Intn(len(traceIDChars))]
}
return string(b)
}
// VisitorMiddleware 创建一个中间件,从请求中提取访问者信息并设置到上下文中
func VisitorMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 从请求中提取访问者信息
clientIP := getClientIP(c)
userInfo := web.NewUserInfo().
WithFromIP(clientIP).
WithVisitTime(time.Now()).
WithTrace(generateTraceID()) // 生成7位随机 trace ID
// 从 JWT token 中解析用户信息
if token := extractToken(c); token != "" {
jwtConfig := config.Get().JWT
if jwtConfig.Secret != "" {
jwtUtil := utils.NewJWT(utils.JWTConfig{
SecretKey: jwtConfig.Secret,
})
if claims, err := jwtUtil.ParseToken(token); err == nil {
userInfo.WithUserID(claims.UserID).WithUserName(claims.UserName)
}
}
}
// 如果 IP 不是保留地址,则通过 IP 获取地理位置(带缓存)
if clientIP != "" && !isPrivateIP(clientIP) {
if province, city := getLocationWithCache(ctx, clientIP); province != "" {
userInfo.WithProvince(province).WithCity(city)
}
}
// 将用户信息设置到上下文中
ctx = web.ToContext(ctx, userInfo)
c.Request = c.Request.WithContext(ctx)
c.Next()
// 通过响应头返回地理位置和 TRACE 信息(中文需要 URL 编码)
c.Header("X-Trace-ID", userInfo.Trace)
if userInfo.Province != "" {
c.Header("X-Province", url.QueryEscape(userInfo.Province))
}
if userInfo.City != "" {
c.Header("X-City", url.QueryEscape(userInfo.City))
}
}
}
// extractToken 从请求中提取 JWT token
// 支持 Authorization: Bearer <token> 和 Authorization: <token> 两种格式
func extractToken(c *gin.Context) string {
auth := c.GetHeader("Authorization")
if auth == "" {
return ""
}
// Bearer token 格式
if strings.HasPrefix(auth, "Bearer ") {
return strings.TrimPrefix(auth, "Bearer ")
}
// 直接 token 格式
return auth
}
// VisitorMiddlewareWithExtractor 创建一个支持自定义提取函数的中间件
// 允许应用程序定义自己的用户信息提取逻辑
func VisitorMiddlewareWithExtractor(extractor func(c *gin.Context) *web.UserInfo) gin.HandlerFunc {
return func(c *gin.Context) {
if extractor == nil {
c.Next()
return
}
userInfo := extractor(c)
if userInfo != nil {
ctx := c.Request.Context()
ctx = web.ToContext(ctx, userInfo)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
}
}
// getClientIP 从请求中提取真实客户端 IP
func getClientIP(c *gin.Context) string {
// 优先尝试 X-Forwarded-For 请求头
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
// X-Forwarded-For 可能包含多个 IP取第一个
for i := 0; i < len(xff); i++ {
if xff[i] == ',' {
return xff[:i]
}
}
return xff
}
// 尝试 X-Real-IP 请求头
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 回退到 ClientIP
return c.ClientIP()
}
// isPrivateIP 检查 IP 地址是否为私有/保留地址
// 环回地址、私有网络、链路本地地址都返回 true
func isPrivateIP(ipStr string) bool {
ip := net.ParseIP(ipStr)
if ip == nil {
return true // 无效 IP 视为私有地址
}
// 检查是否为环回地址
if ip.IsLoopback() {
return true
}
// 检查是否为私有地址 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16)
if ip.IsPrivate() {
return true
}
// 检查是否为链路本地地址 (169.254.0.0/16 或 fe80::/10)
if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
return true
}
// 检查是否为未指定地址 (0.0.0.0 或 ::)
if ip.IsUnspecified() {
return true
}
return false
}