2025-03-31 19:07:30 +00:00

165 lines
4.4 KiB
Go

package models
import (
"context"
"errors"
"fmt"
"log"
"reflect"
"strings"
"time"
"github.com/timetracker/backend/internal/permissions" // For PostgreSQL
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// Global variable for the DB connection
var defaultDB *gorm.DB
// DatabaseConfig contains the configuration data for the database connection
type DatabaseConfig struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
MaxIdleConns int // Maximum number of idle connections
MaxOpenConns int // Maximum number of open connections
MaxLifetime time.Duration // Maximum lifetime of a connection
LogLevel logger.LogLevel
}
// DefaultDatabaseConfig returns a default configuration with sensible values
func DefaultDatabaseConfig() DatabaseConfig {
return DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "timetracker",
Password: "password",
DBName: "timetracker",
SSLMode: "disable",
MaxIdleConns: 10,
MaxOpenConns: 100,
MaxLifetime: time.Hour,
LogLevel: logger.Info,
}
}
// MigrateDB performs database migrations for all models
func MigrateDB() error {
if defaultDB == nil {
return errors.New("database not initialized")
}
log.Println("Starting database migration...")
// Add all models that should be migrated here
err := defaultDB.AutoMigrate(
&Company{},
&User{},
&Customer{},
&Project{},
&Activity{},
&TimeEntry{},
&permissions.Role{},
&permissions.Policy{},
)
if err != nil {
return fmt.Errorf("error migrating database: %w", err)
}
log.Println("Database migration completed successfully")
return nil
}
// GetGormDB is no longer needed, as we use db.InitDB and db.GetEngine
/*
func GetGormDB(dbConfig DatabaseConfig, dbName string) (*gorm.DB, error) {
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode)
// Configure GORM logger
gormLogger := logger.New(
log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer
logger.Config{
SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold
LogLevel: dbConfig.LogLevel, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Enable color
},
)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger,
})
if err != nil {
return nil, fmt.Errorf("error connecting to the database: %w", err)
}
return db, nil
}
*/
// UpdateModel updates a model based on the set pointer fields
func UpdateModel(ctx context.Context, model any, updates any) error {
updateValue := reflect.ValueOf(updates)
// If updates is a pointer, use the value behind it
if updateValue.Kind() == reflect.Ptr {
updateValue = updateValue.Elem()
}
// Make sure updates is a struct
if updateValue.Kind() != reflect.Struct {
return errors.New("updates must be a struct")
}
updateType := updateValue.Type()
updateMap := make(map[string]any)
// Iterate through all fields
for i := 0; i < updateValue.NumField(); i++ {
field := updateValue.Field(i)
fieldType := updateType.Field(i)
// Skip unexported fields
if !fieldType.IsExported() {
continue
}
// Special case: Skip ID field (use only for updates)
if fieldType.Name == "ID" {
continue
}
// For pointer types, check if they are not nil
if field.Kind() == reflect.Ptr && !field.IsNil() {
// Extract field name from GORM tag or use default field name
fieldName := fieldType.Name
if tag, ok := fieldType.Tag.Lookup("gorm"); ok {
// Separate tag options
options := strings.Split(tag, ";")
for _, option := range options {
if strings.HasPrefix(option, "column:") {
fieldName = strings.TrimPrefix(option, "column:")
break
}
}
}
// Use the value behind the pointer
updateMap[fieldName] = field.Elem().Interface()
}
}
if len(updateMap) == 0 {
return nil // Nothing to update
}
return defaultDB.WithContext(ctx).Model(model).Updates(updateMap).Error
}