138 lines
3.4 KiB
Go
138 lines
3.4 KiB
Go
package database
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"git.hujye.com/infrastructure/go-web-gin/config"
|
|
"git.hujye.com/infrastructure/go-web-gin/logger"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/gorm"
|
|
gormlogger "gorm.io/gorm/logger"
|
|
)
|
|
|
|
var (
|
|
db *gorm.DB
|
|
once sync.Once
|
|
)
|
|
|
|
// gormLogger implements gorm/logger.Interface using project's zap logger
|
|
type gormLogger struct {
|
|
logLevel gormlogger.LogLevel
|
|
}
|
|
|
|
// newGormLogger creates a new gorm logger
|
|
func newGormLogger(level gormlogger.LogLevel) *gormLogger {
|
|
return &gormLogger{logLevel: level}
|
|
}
|
|
|
|
// LogMode sets log level
|
|
func (l *gormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
|
|
newLogger := *l
|
|
newLogger.logLevel = level
|
|
return &newLogger
|
|
}
|
|
|
|
// Info logs info messages
|
|
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
|
|
if l.logLevel >= gormlogger.Info {
|
|
logger.GetLogger().Info(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
|
|
}
|
|
}
|
|
|
|
// Warn logs warn messages
|
|
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
|
|
if l.logLevel >= gormlogger.Warn {
|
|
logger.GetLogger().Warn(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
|
|
}
|
|
}
|
|
|
|
// Error logs error messages
|
|
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
|
|
if l.logLevel >= gormlogger.Error {
|
|
logger.GetLogger().Error(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
|
|
}
|
|
}
|
|
|
|
// Trace logs sql query with execution time
|
|
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
|
|
if l.logLevel <= gormlogger.Silent {
|
|
return
|
|
}
|
|
|
|
elapsed := time.Since(begin)
|
|
sql, rows := fc()
|
|
|
|
switch {
|
|
case err != nil && l.logLevel >= gormlogger.Error:
|
|
logger.GetLogger().Error(ctx, "gorm query error", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows, "error", err.Error())
|
|
case elapsed > 200*time.Millisecond && l.logLevel >= gormlogger.Warn:
|
|
logger.GetLogger().Warn(ctx, "gorm slow query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows)
|
|
case l.logLevel >= gormlogger.Info:
|
|
logger.GetLogger().Debug(ctx, "gorm query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows)
|
|
}
|
|
}
|
|
|
|
// GetDB returns the gorm.DB singleton instance
|
|
// It will initialize the connection on first call
|
|
func GetDB() *gorm.DB {
|
|
once.Do(func() {
|
|
db = initDB()
|
|
})
|
|
return db
|
|
}
|
|
|
|
// initDB initializes and returns a new gorm.DB connection
|
|
func initDB() *gorm.DB {
|
|
cfg := config.Get().MySQL
|
|
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=Local",
|
|
cfg.User,
|
|
cfg.Password,
|
|
cfg.Host,
|
|
cfg.Port,
|
|
cfg.DBName,
|
|
cfg.Charset,
|
|
cfg.ParseTime,
|
|
)
|
|
|
|
var logLevel gormlogger.LogLevel
|
|
if config.Get().IsDebug() {
|
|
logLevel = gormlogger.Info
|
|
} else {
|
|
logLevel = gormlogger.Silent
|
|
}
|
|
|
|
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
|
|
Logger: newGormLogger(logLevel),
|
|
})
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to connect to database: %v", err))
|
|
}
|
|
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to get database instance: %v", err))
|
|
}
|
|
|
|
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
|
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
|
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
|
|
|
|
return db
|
|
}
|
|
|
|
// Close closes the database connection
|
|
func Close() error {
|
|
if db != nil {
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return sqlDB.Close()
|
|
}
|
|
return nil
|
|
}
|