package database import ( "context" "fmt" "sync" "time" "git.hujye.com/infrastructure/go-web-gin/config" "git.hujye.com/infrastructure/go-web-gin/logger" "gorm.io/driver/mysql" "gorm.io/gorm" gormlogger "gorm.io/gorm/logger" ) var ( db *gorm.DB once sync.Once ) // gormLogger implements gorm/logger.Interface using project's zap logger type gormLogger struct { logLevel gormlogger.LogLevel } // newGormLogger creates a new gorm logger func newGormLogger(level gormlogger.LogLevel) *gormLogger { return &gormLogger{logLevel: level} } // LogMode sets log level func (l *gormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface { newLogger := *l newLogger.logLevel = level return &newLogger } // Info logs info messages func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= gormlogger.Info { logger.GetLogger().Info(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...))) } } // Warn logs warn messages func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= gormlogger.Warn { logger.GetLogger().Warn(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...))) } } // Error logs error messages func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.logLevel >= gormlogger.Error { logger.GetLogger().Error(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...))) } } // Trace logs sql query with execution time func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.logLevel <= gormlogger.Silent { return } elapsed := time.Since(begin) sql, rows := fc() switch { case err != nil && l.logLevel >= gormlogger.Error: logger.GetLogger().Error(ctx, "gorm query error", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows, "error", err.Error()) case elapsed > 200*time.Millisecond && l.logLevel >= gormlogger.Warn: logger.GetLogger().Warn(ctx, "gorm slow query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows) case l.logLevel >= gormlogger.Info: logger.GetLogger().Debug(ctx, "gorm query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows) } } // GetDB returns the gorm.DB singleton instance // It will initialize the connection on first call func GetDB() *gorm.DB { once.Do(func() { db = initDB() }) return db } // initDB initializes and returns a new gorm.DB connection func initDB() *gorm.DB { cfg := config.Get().MySQL dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=Local", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.DBName, cfg.Charset, cfg.ParseTime, ) var logLevel gormlogger.LogLevel if config.Get().IsDebug() { logLevel = gormlogger.Info } else { logLevel = gormlogger.Silent } db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{ Logger: newGormLogger(logLevel), }) if err != nil { panic(fmt.Sprintf("failed to connect to database: %v", err)) } sqlDB, err := db.DB() if err != nil { panic(fmt.Sprintf("failed to get database instance: %v", err)) } sqlDB.SetMaxIdleConns(cfg.MaxIdleConns) sqlDB.SetMaxOpenConns(cfg.MaxOpenConns) sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second) return db } // Close closes the database connection func Close() error { if db != nil { sqlDB, err := db.DB() if err != nil { return err } return sqlDB.Close() } return nil }