package models import ( "context" "errors" "fmt" "reflect" "strings" "gorm.io/driver/postgres" // For PostgreSQL "gorm.io/gorm" ) // Global variable for the DB connection var defaultDB *gorm.DB // DatabaseConfig contains the configuration data for the database connection type DatabaseConfig struct { Host string Port int User string Password string DBName string SSLMode string } // InitDB initializes the database connection (once at startup) // with the provided configuration func InitDB(config DatabaseConfig) error { // Create DSN (Data Source Name) dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode) // Establish database connection db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) if err != nil { return fmt.Errorf("error connecting to the database: %w", err) } defaultDB = db return nil } // GetEngine returns the DB instance, possibly with context func GetEngine(ctx context.Context) *gorm.DB { // If a special transaction is in ctx, you could check it here return defaultDB.WithContext(ctx) } // UpdateModel updates a model based on the set pointer fields func UpdateModel(ctx context.Context, model any, updates any) error { updateValue := reflect.ValueOf(updates) // If updates is a pointer, use the value behind it if updateValue.Kind() == reflect.Ptr { updateValue = updateValue.Elem() } // Make sure updates is a struct if updateValue.Kind() != reflect.Struct { return errors.New("updates must be a struct") } updateType := updateValue.Type() updateMap := make(map[string]any) // Iterate through all fields for i := 0; i < updateValue.NumField(); i++ { field := updateValue.Field(i) fieldType := updateType.Field(i) // Skip unexported fields if !fieldType.IsExported() { continue } // Special case: Skip ID field (use only for updates) if fieldType.Name == "ID" { continue } // For pointer types, check if they are not nil if field.Kind() == reflect.Ptr && !field.IsNil() { // Extract field name from GORM tag or use default field name fieldName := fieldType.Name if tag, ok := fieldType.Tag.Lookup("gorm"); ok { // Separate tag options options := strings.Split(tag, ";") for _, option := range options { if strings.HasPrefix(option, "column:") { fieldName = strings.TrimPrefix(option, "column:") break } } } // Use the value behind the pointer updateMap[fieldName] = field.Elem().Interface() } } if len(updateMap) == 0 { return nil // Nothing to update } return GetEngine(ctx).Model(model).Updates(updateMap).Error }