feat: Add database migration, seeding, and testing commands with Makefile integration
This commit is contained in:
@@ -4,11 +4,14 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres" // For PostgreSQL
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
// Global variable for the DB connection
|
||||
@@ -16,12 +19,32 @@ var defaultDB *gorm.DB
|
||||
|
||||
// DatabaseConfig contains the configuration data for the database connection
|
||||
type DatabaseConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
SSLMode string
|
||||
MaxIdleConns int // Maximum number of idle connections
|
||||
MaxOpenConns int // Maximum number of open connections
|
||||
MaxLifetime time.Duration // Maximum lifetime of a connection
|
||||
LogLevel logger.LogLevel
|
||||
}
|
||||
|
||||
// DefaultDatabaseConfig returns a default configuration with sensible values
|
||||
func DefaultDatabaseConfig() DatabaseConfig {
|
||||
return DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "timetracker",
|
||||
Password: "password",
|
||||
DBName: "timetracker",
|
||||
SSLMode: "disable",
|
||||
MaxIdleConns: 10,
|
||||
MaxOpenConns: 100,
|
||||
MaxLifetime: time.Hour,
|
||||
LogLevel: logger.Info,
|
||||
}
|
||||
}
|
||||
|
||||
// InitDB initializes the database connection (once at startup)
|
||||
@@ -31,22 +54,151 @@ func InitDB(config DatabaseConfig) error {
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode)
|
||||
|
||||
// Establish database connection
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
// Configure GORM logger
|
||||
gormLogger := logger.New(
|
||||
log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer
|
||||
logger.Config{
|
||||
SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold
|
||||
LogLevel: config.LogLevel, // Log level
|
||||
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
|
||||
Colorful: true, // Enable color
|
||||
},
|
||||
)
|
||||
|
||||
// Establish database connection with custom logger
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: gormLogger,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error connecting to the database: %w", err)
|
||||
}
|
||||
|
||||
// Configure connection pool
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting database connection: %w", err)
|
||||
}
|
||||
|
||||
// Set connection pool parameters
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetConnMaxLifetime(config.MaxLifetime)
|
||||
|
||||
defaultDB = db
|
||||
return nil
|
||||
}
|
||||
|
||||
// MigrateDB performs database migrations for all models
|
||||
func MigrateDB() error {
|
||||
if defaultDB == nil {
|
||||
return errors.New("database not initialized")
|
||||
}
|
||||
|
||||
log.Println("Starting database migration...")
|
||||
|
||||
// Add all models that should be migrated here
|
||||
err := defaultDB.AutoMigrate(
|
||||
&Company{},
|
||||
&User{},
|
||||
&Customer{},
|
||||
&Project{},
|
||||
&Activity{},
|
||||
&TimeEntry{},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error migrating database: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Database migration completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// SeedDB seeds the database with initial data if needed
|
||||
func SeedDB(ctx context.Context) error {
|
||||
if defaultDB == nil {
|
||||
return errors.New("database not initialized")
|
||||
}
|
||||
|
||||
log.Println("Checking if database seeding is needed...")
|
||||
|
||||
// Check if we need to seed (e.g., no companies exist)
|
||||
var count int64
|
||||
if err := defaultDB.Model(&Company{}).Count(&count).Error; err != nil {
|
||||
return fmt.Errorf("error checking if seeding is needed: %w", err)
|
||||
}
|
||||
|
||||
// If data already exists, skip seeding
|
||||
if count > 0 {
|
||||
log.Println("Database already contains data, skipping seeding")
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Println("Seeding database with initial data...")
|
||||
|
||||
// Start a transaction for all seed operations
|
||||
return defaultDB.Transaction(func(tx *gorm.DB) error {
|
||||
// Create a default company
|
||||
defaultCompany := Company{
|
||||
Name: "Default Company",
|
||||
}
|
||||
if err := tx.Create(&defaultCompany).Error; err != nil {
|
||||
return fmt.Errorf("error creating default company: %w", err)
|
||||
}
|
||||
|
||||
// Create an admin user
|
||||
adminUser := User{
|
||||
Email: "admin@example.com",
|
||||
Role: RoleAdmin,
|
||||
CompanyID: defaultCompany.ID,
|
||||
HourlyRate: 100.0,
|
||||
}
|
||||
|
||||
// Hash a default password
|
||||
pwData, err := HashPassword("Admin@123456")
|
||||
if err != nil {
|
||||
return fmt.Errorf("error hashing password: %w", err)
|
||||
}
|
||||
|
||||
adminUser.Salt = pwData.Salt
|
||||
adminUser.Hash = pwData.Hash
|
||||
|
||||
if err := tx.Create(&adminUser).Error; err != nil {
|
||||
return fmt.Errorf("error creating admin user: %w", err)
|
||||
}
|
||||
|
||||
log.Println("Database seeding completed successfully")
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetEngine returns the DB instance, possibly with context
|
||||
func GetEngine(ctx context.Context) *gorm.DB {
|
||||
if defaultDB == nil {
|
||||
panic("database not initialized")
|
||||
}
|
||||
// If a special transaction is in ctx, you could check it here
|
||||
return defaultDB.WithContext(ctx)
|
||||
}
|
||||
|
||||
// CloseDB closes the database connection
|
||||
func CloseDB() error {
|
||||
if defaultDB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sqlDB, err := defaultDB.DB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error getting database connection: %w", err)
|
||||
}
|
||||
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
return fmt.Errorf("error closing database connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateModel updates a model based on the set pointer fields
|
||||
func UpdateModel(ctx context.Context, model any, updates any) error {
|
||||
updateValue := reflect.ValueOf(updates)
|
||||
|
||||
Reference in New Issue
Block a user