package models import ( "context" "errors" "fmt" "log" "reflect" "strings" "time" "gorm.io/driver/postgres" // For PostgreSQL "gorm.io/gorm" "gorm.io/gorm/logger" ) // Global variable for the DB connection var defaultDB *gorm.DB // DatabaseConfig contains the configuration data for the database connection type DatabaseConfig struct { Host string Port int User string Password string DBName string SSLMode string MaxIdleConns int // Maximum number of idle connections MaxOpenConns int // Maximum number of open connections MaxLifetime time.Duration // Maximum lifetime of a connection LogLevel logger.LogLevel } // DefaultDatabaseConfig returns a default configuration with sensible values func DefaultDatabaseConfig() DatabaseConfig { return DatabaseConfig{ Host: "localhost", Port: 5432, User: "timetracker", Password: "password", DBName: "timetracker", SSLMode: "disable", MaxIdleConns: 10, MaxOpenConns: 100, MaxLifetime: time.Hour, LogLevel: logger.Info, } } // InitDB initializes the database connection (once at startup) // with the provided configuration func InitDB(config DatabaseConfig) error { // Create DSN (Data Source Name) dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode) // Configure GORM logger gormLogger := logger.New( log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer logger.Config{ SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold LogLevel: config.LogLevel, // Log level IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger Colorful: true, // Enable color }, ) // Establish database connection with custom logger db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ Logger: gormLogger, }) if err != nil { return fmt.Errorf("error connecting to the database: %w", err) } // Configure connection pool sqlDB, err := db.DB() if err != nil { return fmt.Errorf("error getting database connection: %w", err) } // Set connection pool parameters sqlDB.SetMaxIdleConns(config.MaxIdleConns) sqlDB.SetMaxOpenConns(config.MaxOpenConns) sqlDB.SetConnMaxLifetime(config.MaxLifetime) defaultDB = db return nil } // MigrateDB performs database migrations for all models func MigrateDB() error { if defaultDB == nil { return errors.New("database not initialized") } log.Println("Starting database migration...") // Add all models that should be migrated here err := defaultDB.AutoMigrate( &Company{}, &User{}, &Customer{}, &Project{}, &Activity{}, &TimeEntry{}, ) if err != nil { return fmt.Errorf("error migrating database: %w", err) } log.Println("Database migration completed successfully") return nil } // SeedDB seeds the database with initial data if needed func SeedDB(ctx context.Context) error { if defaultDB == nil { return errors.New("database not initialized") } log.Println("Checking if database seeding is needed...") // Check if we need to seed (e.g., no companies exist) var count int64 if err := defaultDB.Model(&Company{}).Count(&count).Error; err != nil { return fmt.Errorf("error checking if seeding is needed: %w", err) } // If data already exists, skip seeding if count > 0 { log.Println("Database already contains data, skipping seeding") return nil } log.Println("Seeding database with initial data...") // Start a transaction for all seed operations return defaultDB.Transaction(func(tx *gorm.DB) error { // Create a default company defaultCompany := Company{ Name: "Default Company", } if err := tx.Create(&defaultCompany).Error; err != nil { return fmt.Errorf("error creating default company: %w", err) } // Create an admin user adminUser := User{ Email: "admin@example.com", Role: RoleAdmin, CompanyID: defaultCompany.ID, HourlyRate: 100.0, } // Hash a default password pwData, err := HashPassword("Admin@123456") if err != nil { return fmt.Errorf("error hashing password: %w", err) } adminUser.Salt = pwData.Salt adminUser.Hash = pwData.Hash if err := tx.Create(&adminUser).Error; err != nil { return fmt.Errorf("error creating admin user: %w", err) } log.Println("Database seeding completed successfully") return nil }) } // GetEngine returns the DB instance, possibly with context func GetEngine(ctx context.Context) *gorm.DB { if defaultDB == nil { panic("database not initialized") } // If a special transaction is in ctx, you could check it here return defaultDB.WithContext(ctx) } // CloseDB closes the database connection func CloseDB() error { if defaultDB == nil { return nil } sqlDB, err := defaultDB.DB() if err != nil { return fmt.Errorf("error getting database connection: %w", err) } if err := sqlDB.Close(); err != nil { return fmt.Errorf("error closing database connection: %w", err) } return nil } // UpdateModel updates a model based on the set pointer fields func UpdateModel(ctx context.Context, model any, updates any) error { updateValue := reflect.ValueOf(updates) // If updates is a pointer, use the value behind it if updateValue.Kind() == reflect.Ptr { updateValue = updateValue.Elem() } // Make sure updates is a struct if updateValue.Kind() != reflect.Struct { return errors.New("updates must be a struct") } updateType := updateValue.Type() updateMap := make(map[string]any) // Iterate through all fields for i := 0; i < updateValue.NumField(); i++ { field := updateValue.Field(i) fieldType := updateType.Field(i) // Skip unexported fields if !fieldType.IsExported() { continue } // Special case: Skip ID field (use only for updates) if fieldType.Name == "ID" { continue } // For pointer types, check if they are not nil if field.Kind() == reflect.Ptr && !field.IsNil() { // Extract field name from GORM tag or use default field name fieldName := fieldType.Name if tag, ok := fieldType.Tag.Lookup("gorm"); ok { // Separate tag options options := strings.Split(tag, ";") for _, option := range options { if strings.HasPrefix(option, "column:") { fieldName = strings.TrimPrefix(option, "column:") break } } } // Use the value behind the pointer updateMap[fieldName] = field.Elem().Interface() } } if len(updateMap) == 0 { return nil // Nothing to update } return GetEngine(ctx).Model(model).Updates(updateMap).Error }