package db import ( "context" "errors" "fmt" "log" "time" "github.com/timetracker/backend/internal/config" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // Global variable for the DB connection var db *gorm.DB // ErrDBNotInitialized is returned when a database operation is attempted before initialization var ErrDBNotInitialized = errors.New("database not initialized") // InitDB initializes the database connection (once at startup) // with the provided configuration func InitDB(config config.DatabaseConfig) error { // Create connection using the default database name gormDB, err := createConnection(config, config.DBName) if err != nil { return err } // Set the global db instance db = gormDB // Configure connection pool return configureConnectionPool(db, config) } // GetEngine returns the DB instance with context func GetEngine(ctx context.Context) *gorm.DB { if db == nil { panic(ErrDBNotInitialized) } return db.WithContext(ctx) } // CloseDB closes the database connection func CloseDB() error { if db == nil { return nil } sqlDB, err := db.DB() if err != nil { return fmt.Errorf("error getting database connection: %w", err) } if err := sqlDB.Close(); err != nil { return fmt.Errorf("error closing database connection: %w", err) } return nil } // GetGormDB is used for special cases like database creation func GetGormDB(dbConfig config.DatabaseConfig, dbName string) (*gorm.DB, error) { return createConnection(dbConfig, dbName) } // MigrateDB performs database migrations for all models // This is a placeholder that will be called by models.MigrateDB func MigrateDB() error { if db == nil { return ErrDBNotInitialized } // The actual migration is implemented in models.MigrateDB // This is just a placeholder to make the migrate/main.go file work return errors.New("MigrateDB should be called from models package") } // createConnection creates a new database connection with the given configuration func createConnection(dbConfig config.DatabaseConfig, dbName string) (*gorm.DB, error) { // Create DSN (Data Source Name) dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode) // Configure GORM logger gormLogger := createGormLogger(dbConfig) // Establish database connection with custom logger gormDB, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: gormLogger, }) if err != nil { return nil, fmt.Errorf("error connecting to the database: %w", err) } return gormDB, nil } // createGormLogger creates a configured GORM logger instance func createGormLogger(dbConfig config.DatabaseConfig) logger.Interface { return logger.New( log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer logger.Config{ SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold LogLevel: dbConfig.LogLevel, // Log level IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger Colorful: true, // Enable color }, ) } // configureConnectionPool sets up the connection pool parameters func configureConnectionPool(db *gorm.DB, config config.DatabaseConfig) error { sqlDB, err := db.DB() if err != nil { return fmt.Errorf("error getting database connection: %w", err) } // Set connection pool parameters sqlDB.SetMaxIdleConns(config.MaxIdleConns) sqlDB.SetMaxOpenConns(config.MaxOpenConns) sqlDB.SetConnMaxLifetime(config.MaxLifetime) return nil }