refactor: move JWTConfig to config package and update database initialization methods

This commit is contained in:
Jean Jacques Avril 2025-03-31 19:11:38 +00:00
parent fcdeedf7e9
commit bcc3aadb85
3 changed files with 41 additions and 72 deletions

View File

@ -73,7 +73,7 @@ func fileExists(path string) bool {
} }
// generateRSAKeys generates RSA keys and saves them to disk // generateRSAKeys generates RSA keys and saves them to disk
func generateRSAKeys(cfg models.JWTConfig) error { func generateRSAKeys(cfg config.JWTConfig) error {
// Create key directory if it doesn't exist // Create key directory if it doesn't exist
if err := os.MkdirAll(cfg.KeyDir, 0700); err != nil { if err := os.MkdirAll(cfg.KeyDir, 0700); err != nil {
return fmt.Errorf("failed to create key directory: %w", err) return fmt.Errorf("failed to create key directory: %w", err)

View File

@ -9,54 +9,25 @@ import (
"strings" "strings"
"time" "time"
"github.com/timetracker/backend/internal/permissions" // For PostgreSQL "github.com/timetracker/backend/internal/config"
"github.com/timetracker/backend/internal/db"
"github.com/timetracker/backend/internal/permissions"
"gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
) )
// Global variable for the DB connection
var defaultDB *gorm.DB
// DatabaseConfig contains the configuration data for the database connection
type DatabaseConfig struct {
Host string
Port int
User string
Password string
DBName string
SSLMode string
MaxIdleConns int // Maximum number of idle connections
MaxOpenConns int // Maximum number of open connections
MaxLifetime time.Duration // Maximum lifetime of a connection
LogLevel logger.LogLevel
}
// DefaultDatabaseConfig returns a default configuration with sensible values
func DefaultDatabaseConfig() DatabaseConfig {
return DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "timetracker",
Password: "password",
DBName: "timetracker",
SSLMode: "disable",
MaxIdleConns: 10,
MaxOpenConns: 100,
MaxLifetime: time.Hour,
LogLevel: logger.Info,
}
}
// MigrateDB performs database migrations for all models // MigrateDB performs database migrations for all models
func MigrateDB() error { func MigrateDB() error {
if defaultDB == nil { gormDB := db.GetEngine(context.Background())
if gormDB == nil {
return errors.New("database not initialized") return errors.New("database not initialized")
} }
log.Println("Starting database migration...") log.Println("Starting database migration...")
// Add all models that should be migrated here // Add all models that should be migrated here
err := defaultDB.AutoMigrate( err := gormDB.AutoMigrate(
&Company{}, &Company{},
&User{}, &User{},
&Customer{}, &Customer{},
@ -75,33 +46,31 @@ func MigrateDB() error {
return nil return nil
} }
// GetGormDB is no longer needed, as we use db.InitDB and db.GetEngine // GetGormDB is used for special cases like database creation
/* func GetGormDB(dbConfig config.DatabaseConfig, dbName string) (*gorm.DB, error) {
func GetGormDB(dbConfig DatabaseConfig, dbName string) (*gorm.DB, error) { dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode)
dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode)
// Configure GORM logger // Configure GORM logger
gormLogger := logger.New( gormLogger := logger.New(
log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer
logger.Config{ logger.Config{
SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold
LogLevel: dbConfig.LogLevel, // Log level LogLevel: dbConfig.LogLevel, // Log level
IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger
Colorful: true, // Enable color Colorful: true, // Enable color
}, },
) )
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: gormLogger, Logger: gormLogger,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error connecting to the database: %w", err) return nil, fmt.Errorf("error connecting to the database: %w", err)
} }
return db, nil return db, nil
} }
*/
// UpdateModel updates a model based on the set pointer fields // UpdateModel updates a model based on the set pointer fields
func UpdateModel(ctx context.Context, model any, updates any) error { func UpdateModel(ctx context.Context, model any, updates any) error {
@ -160,5 +129,14 @@ func UpdateModel(ctx context.Context, model any, updates any) error {
return nil // Nothing to update return nil // Nothing to update
} }
return defaultDB.WithContext(ctx).Model(model).Updates(updateMap).Error return db.GetEngine(ctx).Model(model).Updates(updateMap).Error
}
// InitDB and CloseDB are forwarded to the db package for backward compatibility
func InitDB(config config.DatabaseConfig) error {
return db.InitDB(config)
}
func CloseDB() error {
return db.CloseDB()
} }

View File

@ -1,13 +1,4 @@
package models package models
import "time" // This file is intentionally left empty.
// The JWTConfig struct has been moved to the config package.
type JWTConfig struct {
Secret string `env:"JWT_SECRET" default:""`
TokenDuration time.Duration `env:"JWT_TOKEN_DURATION" default:"24h"`
KeyGenerate bool `env:"JWT_KEY_GENERATE" default:"true"`
KeyDir string `env:"JWT_KEY_DIR" default:"./keys"`
PrivKeyFile string `env:"JWT_PRIV_KEY_FILE" default:"jwt.key"`
PubKeyFile string `env:"JWT_PUB_KEY_FILE" default:"jwt.key.pub"`
KeyBits int `env:"JWT_KEY_BITS" default:"2048"`
}