From 00a0a01b1725fbfbe4cebf195c5b7f14a39a5864 Mon Sep 17 00:00:00 2001 From: hujie Date: Tue, 3 Mar 2026 00:58:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E8=AE=BF=E5=AE=A2?= =?UTF-8?q?=E6=8B=A6=E6=88=AA=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/config.go | 8 ++ go.mod | 1 + logger/logger.go | 6 ++ svr/visitor.go | 213 ++++++++++++++++++++++++++++++++++++++++++++++ utils/jwt.go | 173 +++++++++++++++++++++++++++++++++++++ utils/location.go | 101 ++++++++++++++++++++++ web/user.go | 64 ++++++++++++++ 7 files changed, 566 insertions(+) create mode 100644 svr/visitor.go create mode 100644 utils/jwt.go create mode 100644 utils/location.go diff --git a/config/config.go b/config/config.go index eca380d..d9364cc 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,7 @@ type Config struct { Log LogConfig `mapstructure:"log"` MySQL MySQLConfig `mapstructure:"mysql"` Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` } // ServerConfig represents server configuration @@ -73,6 +74,13 @@ type RedisConfig struct { WriteTimeout int `mapstructure:"write_timeout"` // seconds } +// JWTConfig represents jwt configuration +type JWTConfig struct { + Secret string `mapstructure:"secret"` + Issuer string `mapstructure:"issuer"` + Expires int `mapstructure:"expires"` // hours +} + // Load loads configuration from environment variable CFG_PATH // Default path is config/config.yml func Load() (*Config, error) { diff --git a/go.mod b/go.mod index 76b2b25..57dfa24 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/logger/logger.go b/logger/logger.go index 7569742..7805866 100644 --- a/logger/logger.go +++ b/logger/logger.go @@ -116,6 +116,12 @@ func getUserFields(ctx context.Context) []zap.Field { if fromIP := web.GetFromIP(ctx); fromIP != "" { fields = append(fields, zap.String("from_ip", fromIP)) } + if province := web.GetProvince(ctx); province != "" { + fields = append(fields, zap.String("province", province)) + } + if city := web.GetCity(ctx); city != "" { + fields = append(fields, zap.String("city", city)) + } if len(fields) == 0 { return nil diff --git a/svr/visitor.go b/svr/visitor.go new file mode 100644 index 0000000..9d9aec6 --- /dev/null +++ b/svr/visitor.go @@ -0,0 +1,213 @@ +package svr + +import ( + "context" + "math/rand" + "net" + "net/url" + "strings" + "sync" + "time" + + "git.hujye.com/infrastructure/go-web-gin/config" + "git.hujye.com/infrastructure/go-web-gin/utils" + "git.hujye.com/infrastructure/go-web-gin/web" + "github.com/gin-gonic/gin" +) + +const ( + // traceIDChars 用于生成 trace ID 的字符集 + traceIDChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + // locationCacheTTL 地理位置缓存有效期(3小时) + locationCacheTTL = 3 * time.Hour +) + +// locationCacheItem 地理位置缓存项 +type locationCacheItem struct { + province string + city string + expiresAt time.Time +} + +// locationCache 地理位置缓存 +var locationCache sync.Map + +// getLocationWithCache 获取地理位置(带缓存) +func getLocationWithCache(ctx context.Context, ip string) (province, city string) { + // 先从缓存中查找 + if val, ok := locationCache.Load(ip); ok { + if item, ok := val.(*locationCacheItem); ok { + // 检查是否过期 + if time.Now().Before(item.expiresAt) { + return item.province, item.city + } + // 已过期,删除缓存 + locationCache.Delete(ip) + } + } + + // 调用 API 获取地理位置 + province, city, err := utils.GetLocation(ctx, ip) + if err != nil { + return "", "" + } + + // 存入缓存 + locationCache.Store(ip, &locationCacheItem{ + province: province, + city: city, + expiresAt: time.Now().Add(locationCacheTTL), + }) + + return province, city +} + +// generateTraceID 生成长度为7的随机字符串作为 trace ID +func generateTraceID() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, 7) + for i := range b { + b[i] = traceIDChars[r.Intn(len(traceIDChars))] + } + return string(b) +} + +// VisitorMiddleware 创建一个中间件,从请求中提取访问者信息并设置到上下文中 +func VisitorMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + ctx := c.Request.Context() + + // 从请求中提取访问者信息 + clientIP := getClientIP(c) + userInfo := web.NewUserInfo(). + WithFromIP(clientIP). + WithVisitTime(time.Now()). + WithTrace(generateTraceID()) // 生成7位随机 trace ID + + // 从 JWT token 中解析用户信息 + if token := extractToken(c); token != "" { + jwtConfig := config.Get().JWT + if jwtConfig.Secret != "" { + jwtUtil := utils.NewJWT(utils.JWTConfig{ + SecretKey: jwtConfig.Secret, + }) + if claims, err := jwtUtil.ParseToken(token); err == nil { + userInfo.WithUserID(claims.UserID).WithUserName(claims.UserName) + } + } + } + + // 如果 IP 不是保留地址,则通过 IP 获取地理位置(带缓存) + if clientIP != "" && !isPrivateIP(clientIP) { + if province, city := getLocationWithCache(ctx, clientIP); province != "" { + userInfo.WithProvince(province).WithCity(city) + } + } + + // 将用户信息设置到上下文中 + ctx = web.ToContext(ctx, userInfo) + c.Request = c.Request.WithContext(ctx) + + c.Next() + + // 通过响应头返回地理位置和 TRACE 信息(中文需要 URL 编码) + c.Header("X-Trace-ID", userInfo.Trace) + if userInfo.Province != "" { + c.Header("X-Province", url.QueryEscape(userInfo.Province)) + } + if userInfo.City != "" { + c.Header("X-City", url.QueryEscape(userInfo.City)) + } + } +} + +// extractToken 从请求中提取 JWT token +// 支持 Authorization: Bearer 和 Authorization: 两种格式 +func extractToken(c *gin.Context) string { + auth := c.GetHeader("Authorization") + if auth == "" { + return "" + } + + // Bearer token 格式 + if strings.HasPrefix(auth, "Bearer ") { + return strings.TrimPrefix(auth, "Bearer ") + } + + // 直接 token 格式 + return auth +} + +// VisitorMiddlewareWithExtractor 创建一个支持自定义提取函数的中间件 +// 允许应用程序定义自己的用户信息提取逻辑 +func VisitorMiddlewareWithExtractor(extractor func(c *gin.Context) *web.UserInfo) gin.HandlerFunc { + return func(c *gin.Context) { + if extractor == nil { + c.Next() + return + } + + userInfo := extractor(c) + if userInfo != nil { + ctx := c.Request.Context() + ctx = web.ToContext(ctx, userInfo) + c.Request = c.Request.WithContext(ctx) + } + + c.Next() + } +} + +// getClientIP 从请求中提取真实客户端 IP +func getClientIP(c *gin.Context) string { + // 优先尝试 X-Forwarded-For 请求头 + if xff := c.GetHeader("X-Forwarded-For"); xff != "" { + // X-Forwarded-For 可能包含多个 IP,取第一个 + for i := 0; i < len(xff); i++ { + if xff[i] == ',' { + return xff[:i] + } + } + return xff + } + + // 尝试 X-Real-IP 请求头 + if xri := c.GetHeader("X-Real-IP"); xri != "" { + return xri + } + + // 回退到 ClientIP + return c.ClientIP() +} + +// isPrivateIP 检查 IP 地址是否为私有/保留地址 +// 环回地址、私有网络、链路本地地址都返回 true +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return true // 无效 IP 视为私有地址 + } + + // 检查是否为环回地址 + if ip.IsLoopback() { + return true + } + + // 检查是否为私有地址 (10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16) + if ip.IsPrivate() { + return true + } + + // 检查是否为链路本地地址 (169.254.0.0/16 或 fe80::/10) + if ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + + // 检查是否为未指定地址 (0.0.0.0 或 ::) + if ip.IsUnspecified() { + return true + } + + return false +} diff --git a/utils/jwt.go b/utils/jwt.go new file mode 100644 index 0000000..baf0865 --- /dev/null +++ b/utils/jwt.go @@ -0,0 +1,173 @@ +package utils + +import ( + "errors" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + // ErrTokenInvalid token 无效 + ErrTokenInvalid = errors.New("token 无效") + // ErrTokenExpired token 已过期 + ErrTokenExpired = errors.New("token 已过期") + // ErrTokenNotValidYet token 尚未生效 + ErrTokenNotValidYet = errors.New("token 尚未生效") + // ErrTokenMalformed token 格式错误 + ErrTokenMalformed = errors.New("token 格式错误") + // ErrTokenSignatureInvalid 签名无效 + ErrTokenSignatureInvalid = errors.New("签名无效") +) + +// JWTConfig JWT 配置 +type JWTConfig struct { + SecretKey string // 密钥 + ExpiresTime time.Duration // 过期时间 + Issuer string // 签发者 +} + +// JWT jwt 工具 +type JWT struct { + config JWTConfig +} + +// CustomClaims 自定义 claims +type CustomClaims struct { + UserID string `json:"user_id"` + UserName string `json:"user_name"` + Data map[string]interface{} `json:"data,omitempty"` + jwt.RegisteredClaims +} + +// NewJWT 创建 JWT 实例 +func NewJWT(config JWTConfig) *JWT { + return &JWT{config: config} +} + +// GenerateToken 生成 token +func (j *JWT) GenerateToken(claims *CustomClaims) (string, error) { + // 设置默认值 + if claims.Issuer == "" { + claims.Issuer = j.config.Issuer + } + if claims.ExpiresAt == nil && j.config.ExpiresTime > 0 { + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.config.ExpiresTime)) + } + if claims.IssuedAt == nil { + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + } + if claims.NotBefore == nil { + claims.NotBefore = jwt.NewNumericDate(time.Now()) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(j.config.SecretKey)) +} + +// ParseToken 解析 token +func (j *JWT) ParseToken(tokenString string) (*CustomClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &CustomClaims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(j.config.SecretKey), nil + }) + + if err != nil { + switch { + case errors.Is(err, jwt.ErrTokenMalformed): + return nil, ErrTokenMalformed + case errors.Is(err, jwt.ErrTokenExpired): + return nil, ErrTokenExpired + case errors.Is(err, jwt.ErrTokenNotValidYet): + return nil, ErrTokenNotValidYet + case errors.Is(err, jwt.ErrTokenSignatureInvalid): + return nil, ErrTokenSignatureInvalid + default: + return nil, ErrTokenInvalid + } + } + + if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid { + return claims, nil + } + + return nil, ErrTokenInvalid +} + +// RefreshToken 刷新 token +func (j *JWT) RefreshToken(tokenString string) (string, error) { + claims, err := j.ParseToken(tokenString) + if err != nil { + return "", err + } + + // 重置过期时间 + claims.ExpiresAt = jwt.NewNumericDate(time.Now().Add(j.config.ExpiresTime)) + claims.IssuedAt = jwt.NewNumericDate(time.Now()) + + return j.GenerateToken(claims) +} + +// GetUserIDFromToken 从 token 中获取用户 ID +func (j *JWT) GetUserIDFromToken(tokenString string) (string, error) { + claims, err := j.ParseToken(tokenString) + if err != nil { + return "", err + } + return claims.UserID, nil +} + +// GetUserNameFromToken 从 token 中获取用户名 +func (j *JWT) GetUserNameFromToken(tokenString string) (string, error) { + claims, err := j.ParseToken(tokenString) + if err != nil { + return "", err + } + return claims.UserName, nil +} + +// GetClaimFromToken 从 token 中获取指定字段 +func (j *JWT) GetClaimFromToken(tokenString string, key string) (interface{}, error) { + claims, err := j.ParseToken(tokenString) + if err != nil { + return nil, err + } + if claims.Data == nil { + return nil, errors.New("未找到指定字段") + } + val, ok := claims.Data[key] + if !ok { + return nil, errors.New("未找到指定字段") + } + return val, nil +} + +// ValidateToken 验证 token 是否有效 +func (j *JWT) ValidateToken(tokenString string) bool { + _, err := j.ParseToken(tokenString) + return err == nil +} + +// CreateToken 快速创建 token(简化方法) +// secretKey: 密钥 +// userID: 用户 ID +// userName: 用户名 +// expires: 过期时间(小时) +func CreateToken(secretKey, userID, userName string, expires int) (string, error) { + jwtUtil := NewJWT(JWTConfig{ + SecretKey: secretKey, + ExpiresTime: time.Duration(expires) * time.Hour, + }) + + claims := &CustomClaims{ + UserID: userID, + UserName: userName, + } + + return jwtUtil.GenerateToken(claims) +} + +// ParseTokenString 快速解析 token(简化方法) +func ParseTokenString(secretKey, tokenString string) (*CustomClaims, error) { + jwtUtil := NewJWT(JWTConfig{SecretKey: secretKey}) + return jwtUtil.ParseToken(tokenString) +} diff --git a/utils/location.go b/utils/location.go new file mode 100644 index 0000000..6575055 --- /dev/null +++ b/utils/location.go @@ -0,0 +1,101 @@ +package utils + +import ( + "context" + "encoding/json" + "fmt" + "io" + "math/rand" + "net/http" + "time" +) + +// LocationResponse represents the response from IP location API +type LocationResponse struct { + IP string `json:"ip"` + Pro string `json:"pro"` + ProCode string `json:"proCode"` + City string `json:"city"` + CityCode string `json:"cityCode"` + Region string `json:"region"` + RegionCode string `json:"regionCode"` + Addr string `json:"addr"` + RegionNames string `json:"regionNames"` + Err string `json:"err"` +} + +// userAgents is a list of common browser user agents for random selection +var userAgents = []string{ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Safari/605.1.15", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:121.0) Gecko/20100101 Firefox/121.0", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Mozilla/5.0 (iPhone; CPU iPhone OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1", + "Mozilla/5.0 (iPad; CPU OS 17_0 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.0 Mobile/15E148 Safari/604.1", + "Mozilla/5.0 (Linux; Android 14; SM-S918B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Mobile Safari/537.36", +} + +// randomUserAgent returns a random user agent string +func randomUserAgent() string { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return userAgents[r.Intn(len(userAgents))] +} + +// GetLocation 获取IP的地理位置信息 +// 调用太平洋网络IP地址查询API,返回省份和城市 +func GetLocation(ctx context.Context, ip string) (province string, city string, err error) { + url := fmt.Sprintf("https://whois.pconline.com.cn/ipJson.jsp?ip=%s&json=true", ip) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", "", fmt.Errorf("create request failed: %w", err) + } + + // Set random user agent to simulate browser + req.Header.Set("User-Agent", randomUserAgent()) + req.Header.Set("Accept", "application/json, text/plain, */*") + req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8") + req.Header.Set("Referer", "https://whois.pconline.com.cn/") + + client := &http.Client{ + Timeout: 5 * time.Second, + } + + resp, err := client.Do(req) + if err != nil { + return "", "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("request failed with status: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("read response failed: %w", err) + } + + var location LocationResponse + if err := json.Unmarshal(body, &location); err != nil { + return "", "", fmt.Errorf("parse response failed: %w", err) + } + + // Check for error in response + if location.Err != "" { + return "", "", fmt.Errorf("API error: %s", location.Err) + } + + return location.Pro, location.City, nil +} + +// GetLocationWithCache 获取IP的地理位置信息(带缓存) +// 如果缓存中有则直接返回,否则调用API并缓存结果 +func GetLocationWithCache(ctx context.Context, ip string, cache func(key string, fetch func() (string, string, error)) (string, string, error)) (province string, city string, err error) { + return cache(ip, func() (string, string, error) { + return GetLocation(ctx, ip) + }) +} diff --git a/web/user.go b/web/user.go index a667a9f..faf8659 100644 --- a/web/user.go +++ b/web/user.go @@ -12,6 +12,8 @@ type UserInfo struct { Trace string `json:"trace"` VisitTime time.Time `json:"visit_time"` FromIP string `json:"from_ip"` + Province string `json:"province"` + City string `json:"city"` } // contextKey is the type for context keys to prevent collisions @@ -24,6 +26,8 @@ const ( TraceKey contextKey = "trace" VisitTimeKey contextKey = "visitTime" FromIPKey contextKey = "fromIP" + ProvinceKey contextKey = "province" + CityKey contextKey = "city" ) // GetUserID returns user ID from context @@ -121,6 +125,44 @@ func SetFromIP(ctx context.Context, fromIP string) context.Context { return context.WithValue(ctx, FromIPKey, fromIP) } +// GetProvince returns province from context +func GetProvince(ctx context.Context) string { + if ctx == nil { + return "" + } + if val, ok := ctx.Value(ProvinceKey).(string); ok { + return val + } + return "" +} + +// SetProvince sets province in context and returns new context +func SetProvince(ctx context.Context, province string) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, ProvinceKey, province) +} + +// GetCity returns city from context +func GetCity(ctx context.Context) string { + if ctx == nil { + return "" + } + if val, ok := ctx.Value(CityKey).(string); ok { + return val + } + return "" +} + +// SetCity sets city in context and returns new context +func SetCity(ctx context.Context, city string) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, CityKey, city) +} + // FromContext builds UserInfo from individual context fields // Returns nil if no fields are found func FromContext(ctx context.Context) *UserInfo { @@ -150,6 +192,14 @@ func FromContext(ctx context.Context) *UserInfo { info.FromIP = fromIP hasData = true } + if province := GetProvince(ctx); province != "" { + info.Province = province + hasData = true + } + if city := GetCity(ctx); city != "" { + info.City = city + hasData = true + } if !hasData { return nil @@ -171,6 +221,8 @@ func ToContext(ctx context.Context, userInfo *UserInfo) context.Context { ctx = SetTrace(ctx, userInfo.Trace) ctx = SetVisitTime(ctx, userInfo.VisitTime) ctx = SetFromIP(ctx, userInfo.FromIP) + ctx = SetProvince(ctx, userInfo.Province) + ctx = SetCity(ctx, userInfo.City) return ctx } @@ -211,3 +263,15 @@ func (u *UserInfo) WithVisitTime(visitTime time.Time) *UserInfo { u.VisitTime = visitTime return u } + +// WithProvince sets the province +func (u *UserInfo) WithProvince(province string) *UserInfo { + u.Province = province + return u +} + +// WithCity sets the city +func (u *UserInfo) WithCity(city string) *UserInfo { + u.City = city + return u +}