228 lines
5.9 KiB
Go
228 lines
5.9 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"gorm.io/driver/postgres" // For PostgreSQL
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// Global variable for the DB connection
|
|
var defaultDB *gorm.DB
|
|
|
|
// DatabaseConfig contains the configuration data for the database connection
|
|
type DatabaseConfig struct {
|
|
Host string
|
|
Port int
|
|
User string
|
|
Password string
|
|
DBName string
|
|
SSLMode string
|
|
MaxIdleConns int // Maximum number of idle connections
|
|
MaxOpenConns int // Maximum number of open connections
|
|
MaxLifetime time.Duration // Maximum lifetime of a connection
|
|
LogLevel logger.LogLevel
|
|
}
|
|
|
|
// DefaultDatabaseConfig returns a default configuration with sensible values
|
|
func DefaultDatabaseConfig() DatabaseConfig {
|
|
return DatabaseConfig{
|
|
Host: "localhost",
|
|
Port: 5432,
|
|
User: "timetracker",
|
|
Password: "password",
|
|
DBName: "timetracker",
|
|
SSLMode: "disable",
|
|
MaxIdleConns: 10,
|
|
MaxOpenConns: 100,
|
|
MaxLifetime: time.Hour,
|
|
LogLevel: logger.Info,
|
|
}
|
|
}
|
|
|
|
// InitDB initializes the database connection (once at startup)
|
|
// with the provided configuration
|
|
func InitDB(config DatabaseConfig) error {
|
|
// Create DSN (Data Source Name)
|
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
|
config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode)
|
|
|
|
// Configure GORM logger
|
|
gormLogger := logger.New(
|
|
log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer
|
|
logger.Config{
|
|
SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold
|
|
LogLevel: config.LogLevel, // Log level
|
|
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
|
Colorful: true, // Enable color
|
|
},
|
|
)
|
|
|
|
// Establish database connection with custom logger
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
|
Logger: gormLogger,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error connecting to the database: %w", err)
|
|
}
|
|
|
|
// Configure connection pool
|
|
sqlDB, err := db.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting database connection: %w", err)
|
|
}
|
|
|
|
// Set connection pool parameters
|
|
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
|
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
|
sqlDB.SetConnMaxLifetime(config.MaxLifetime)
|
|
|
|
defaultDB = db
|
|
return nil
|
|
}
|
|
|
|
// MigrateDB performs database migrations for all models
|
|
func MigrateDB() error {
|
|
if defaultDB == nil {
|
|
return errors.New("database not initialized")
|
|
}
|
|
|
|
log.Println("Starting database migration...")
|
|
|
|
// Add all models that should be migrated here
|
|
err := defaultDB.AutoMigrate(
|
|
&Company{},
|
|
&User{},
|
|
&Customer{},
|
|
&Project{},
|
|
&Activity{},
|
|
&TimeEntry{},
|
|
)
|
|
|
|
if err != nil {
|
|
return fmt.Errorf("error migrating database: %w", err)
|
|
}
|
|
|
|
log.Println("Database migration completed successfully")
|
|
return nil
|
|
}
|
|
|
|
// GetEngine returns the DB instance, possibly with context
|
|
func GetEngine(ctx context.Context) *gorm.DB {
|
|
if defaultDB == nil {
|
|
panic("database not initialized")
|
|
}
|
|
// If a special transaction is in ctx, you could check it here
|
|
return defaultDB.WithContext(ctx)
|
|
}
|
|
|
|
// CloseDB closes the database connection
|
|
func CloseDB() error {
|
|
if defaultDB == nil {
|
|
return nil
|
|
}
|
|
|
|
sqlDB, err := defaultDB.DB()
|
|
if err != nil {
|
|
return fmt.Errorf("error getting database connection: %w", err)
|
|
}
|
|
|
|
if err := sqlDB.Close(); err != nil {
|
|
return fmt.Errorf("error closing database connection: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func GetGormDB(dbConfig DatabaseConfig, dbName string) (*gorm.DB, error) {
|
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
|
dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode)
|
|
|
|
// Configure GORM logger
|
|
gormLogger := logger.New(
|
|
log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer
|
|
logger.Config{
|
|
SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold
|
|
LogLevel: dbConfig.LogLevel, // Log level
|
|
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
|
Colorful: true, // Enable color
|
|
},
|
|
)
|
|
|
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
|
Logger: gormLogger,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error connecting to the database: %w", err)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
|
|
// UpdateModel updates a model based on the set pointer fields
|
|
func UpdateModel(ctx context.Context, model any, updates any) error {
|
|
updateValue := reflect.ValueOf(updates)
|
|
|
|
// If updates is a pointer, use the value behind it
|
|
if updateValue.Kind() == reflect.Ptr {
|
|
updateValue = updateValue.Elem()
|
|
}
|
|
|
|
// Make sure updates is a struct
|
|
if updateValue.Kind() != reflect.Struct {
|
|
return errors.New("updates must be a struct")
|
|
}
|
|
|
|
updateType := updateValue.Type()
|
|
updateMap := make(map[string]any)
|
|
|
|
// Iterate through all fields
|
|
for i := 0; i < updateValue.NumField(); i++ {
|
|
field := updateValue.Field(i)
|
|
fieldType := updateType.Field(i)
|
|
|
|
// Skip unexported fields
|
|
if !fieldType.IsExported() {
|
|
continue
|
|
}
|
|
|
|
// Special case: Skip ID field (use only for updates)
|
|
if fieldType.Name == "ID" {
|
|
continue
|
|
}
|
|
|
|
// For pointer types, check if they are not nil
|
|
if field.Kind() == reflect.Ptr && !field.IsNil() {
|
|
// Extract field name from GORM tag or use default field name
|
|
fieldName := fieldType.Name
|
|
|
|
if tag, ok := fieldType.Tag.Lookup("gorm"); ok {
|
|
// Separate tag options
|
|
options := strings.Split(tag, ";")
|
|
for _, option := range options {
|
|
if strings.HasPrefix(option, "column:") {
|
|
fieldName = strings.TrimPrefix(option, "column:")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use the value behind the pointer
|
|
updateMap[fieldName] = field.Elem().Interface()
|
|
}
|
|
}
|
|
|
|
if len(updateMap) == 0 {
|
|
return nil // Nothing to update
|
|
}
|
|
|
|
return GetEngine(ctx).Model(model).Updates(updateMap).Error
|
|
}
|