Files
go-web-gin/database/mysql.go

143 lines
3.5 KiB
Go

package database
import (
"context"
"fmt"
"sync"
"time"
"git.hujye.com/infrastructure/go-web-gin/config"
"git.hujye.com/infrastructure/go-web-gin/logger"
"gorm.io/driver/mysql"
"gorm.io/gorm"
gormlogger "gorm.io/gorm/logger"
)
var (
db *gorm.DB
once sync.Once
)
// gormLogger implements gorm/logger.Interface using project's zap logger
type gormLogger struct {
logLevel gormlogger.LogLevel
}
// newGormLogger creates a new gorm logger
func newGormLogger(level gormlogger.LogLevel) *gormLogger {
return &gormLogger{logLevel: level}
}
// LogMode sets log level
func (l *gormLogger) LogMode(level gormlogger.LogLevel) gormlogger.Interface {
newLogger := *l
newLogger.logLevel = level
return &newLogger
}
// Info logs info messages
func (l *gormLogger) Info(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= gormlogger.Info {
logger.GetLogger().Info(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
}
}
// Warn logs warn messages
func (l *gormLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= gormlogger.Warn {
logger.GetLogger().Warn(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
}
}
// Error logs error messages
func (l *gormLogger) Error(ctx context.Context, msg string, data ...interface{}) {
if l.logLevel >= gormlogger.Error {
logger.GetLogger().Error(ctx, fmt.Sprintf("gorm: %s", fmt.Sprintf(msg, data...)))
}
}
// Trace logs sql query with execution time
func (l *gormLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
if l.logLevel <= gormlogger.Silent {
return
}
elapsed := time.Since(begin)
sql, rows := fc()
// 忽略 record not found 错误,不记录日志
if err == gorm.ErrRecordNotFound {
return
}
switch {
case err != nil && l.logLevel >= gormlogger.Error:
logger.GetLogger().Error(ctx, "gorm query error", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows, "error", err.Error())
case elapsed > 200*time.Millisecond && l.logLevel >= gormlogger.Warn:
logger.GetLogger().Warn(ctx, "gorm slow query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows)
case l.logLevel >= gormlogger.Info:
logger.GetLogger().Debug(ctx, "gorm query", "duration", elapsed.Milliseconds(), "sql", sql, "rows", rows)
}
}
// GetDB returns the gorm.DB singleton instance
// It will initialize the connection on first call
func GetDB() *gorm.DB {
once.Do(func() {
db = initDB()
})
return db
}
// initDB initializes and returns a new gorm.DB connection
func initDB() *gorm.DB {
cfg := config.Get().MySQL
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=%t&loc=Local",
cfg.User,
cfg.Password,
cfg.Host,
cfg.Port,
cfg.DBName,
cfg.Charset,
cfg.ParseTime,
)
var logLevel gormlogger.LogLevel
if config.Get().IsDebug() {
logLevel = gormlogger.Info
} else {
logLevel = gormlogger.Silent
}
db, err := gorm.Open(mysql.Open(dsn), &gorm.Config{
Logger: newGormLogger(logLevel),
})
if err != nil {
panic(fmt.Sprintf("failed to connect to database: %v", err))
}
sqlDB, err := db.DB()
if err != nil {
panic(fmt.Sprintf("failed to get database instance: %v", err))
}
sqlDB.SetMaxIdleConns(cfg.MaxIdleConns)
sqlDB.SetMaxOpenConns(cfg.MaxOpenConns)
sqlDB.SetConnMaxLifetime(time.Duration(cfg.ConnMaxLifetime) * time.Second)
return db
}
// Close closes the database connection
func Close() error {
if db != nil {
sqlDB, err := db.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
return nil
}