diff --git a/backend/cmd/dbtest/main.go b/backend/cmd/dbtest/main.go index bbdc974..c32ef85 100644 --- a/backend/cmd/dbtest/main.go +++ b/backend/cmd/dbtest/main.go @@ -7,6 +7,8 @@ import ( "log" "time" + "github.com/timetracker/backend/internal/config" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/models" ) @@ -15,15 +17,15 @@ func main() { flag.Parse() // Get database configuration with sensible defaults - dbConfig := models.DefaultDatabaseConfig() + dbConfig := config.DefaultDatabaseConfig() // Initialize database fmt.Println("Connecting to database...") - if err := models.InitDB(dbConfig); err != nil { + if err := db.InitDB(dbConfig); err != nil { log.Fatalf("Error initializing database: %v", err) } defer func() { - if err := models.CloseDB(); err != nil { + if err := db.CloseDB(); err != nil { log.Printf("Error closing database connection: %v", err) } }() @@ -34,7 +36,7 @@ func main() { defer cancel() // Get the database engine - db := models.GetEngine(ctx) + db := db.GetEngine(ctx) // Test database connection with a simple query var result int diff --git a/backend/cmd/migrate/main.go b/backend/cmd/migrate/main.go index 727b7c4..47ce009 100644 --- a/backend/cmd/migrate/main.go +++ b/backend/cmd/migrate/main.go @@ -6,7 +6,8 @@ import ( "log" "os" - "github.com/timetracker/backend/internal/models" + "github.com/timetracker/backend/internal/config" + "github.com/timetracker/backend/internal/db" "gorm.io/gorm/logger" ) @@ -29,7 +30,7 @@ func main() { } // Get database configuration with sensible defaults - dbConfig := models.DefaultDatabaseConfig() + dbConfig := config.DefaultDatabaseConfig() // Override with environment variables if provided if host := os.Getenv("DB_HOST"); host != "" { @@ -62,7 +63,7 @@ func main() { var err error - gormDB, err := models.GetGormDB(dbConfig, "postgres") + gormDB, err := db.GetGormDB(dbConfig, "postgres") if err != nil { log.Fatalf("Error getting gorm DB: %v", err) } @@ -89,11 +90,11 @@ func main() { fmt.Printf("✓ Database %s created successfully\n", dbConfig.DBName) } - if err = models.InitDB(dbConfig); err != nil { + if err = db.InitDB(dbConfig); err != nil { log.Fatalf("Error initializing database: %v", err) } defer func() { - if err := models.CloseDB(); err != nil { + if err := db.CloseDB(); err != nil { log.Printf("Error closing database connection: %v", err) } }() @@ -101,7 +102,7 @@ func main() { // Run migrations fmt.Println("Running database migrations...") - if err = models.MigrateDB(); err != nil { + if err = db.MigrateDB(); err != nil { log.Fatalf("Error migrating database: %v", err) } fmt.Println("✓ Database migrations completed successfully") diff --git a/backend/cmd/modeltest/main.go b/backend/cmd/modeltest/main.go index e529275..3835f06 100644 --- a/backend/cmd/modeltest/main.go +++ b/backend/cmd/modeltest/main.go @@ -7,21 +7,23 @@ import ( "time" "github.com/oklog/ulid/v2" + "github.com/timetracker/backend/internal/config" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/models" "github.com/timetracker/backend/internal/types" ) func main() { // Get database configuration with sensible defaults - dbConfig := models.DefaultDatabaseConfig() + dbConfig := config.DefaultDatabaseConfig() // Initialize database fmt.Println("Connecting to database...") - if err := models.InitDB(dbConfig); err != nil { + if err := db.InitDB(dbConfig); err != nil { log.Fatalf("Error initializing database: %v", err) } defer func() { - if err := models.CloseDB(); err != nil { + if err := db.CloseDB(); err != nil { log.Printf("Error closing database connection: %v", err) } }() diff --git a/backend/cmd/seed/main.go b/backend/cmd/seed/main.go index 2c1c93d..9bf8e3a 100644 --- a/backend/cmd/seed/main.go +++ b/backend/cmd/seed/main.go @@ -7,6 +7,7 @@ import ( "log" "github.com/timetracker/backend/internal/config" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/models" "gorm.io/gorm" ) @@ -23,11 +24,11 @@ func main() { } // Initialize database - if err := models.InitDB(cfg.Database); err != nil { + if err := db.InitDB(cfg.Database); err != nil { log.Fatalf("Error initializing database: %v", err) } defer func() { - if err := models.CloseDB(); err != nil { + if err := db.CloseDB(); err != nil { log.Printf("Error closing database connection: %v", err) } }() @@ -44,7 +45,7 @@ func main() { func seedDatabase(ctx context.Context) error { // Check if seeding is needed var count int64 - if err := models.GetEngine(ctx).Model(&models.Company{}).Count(&count).Error; err != nil { + if err := db.GetEngine(ctx).Model(&models.Company{}).Count(&count).Error; err != nil { return fmt.Errorf("error checking if seeding is needed: %w", err) } @@ -57,7 +58,7 @@ func seedDatabase(ctx context.Context) error { log.Println("Seeding database with initial data...") // Start transaction - return models.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + return db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // Create default company defaultCompany := models.Company{ Name: "Default Company", diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index fe3e4a2..fef856a 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -9,14 +9,54 @@ import ( "time" "github.com/joho/godotenv" - "github.com/timetracker/backend/internal/models" "gorm.io/gorm/logger" ) +// DatabaseConfig contains the configuration data for the database connection +type DatabaseConfig struct { + Host string + Port int + User string + Password string + DBName string + SSLMode string + MaxIdleConns int // Maximum number of idle connections + MaxOpenConns int // Maximum number of open connections + MaxLifetime time.Duration // Maximum lifetime of a connection + LogLevel logger.LogLevel +} + +// DefaultDatabaseConfig returns a default configuration with sensible values +func DefaultDatabaseConfig() DatabaseConfig { + return DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "timetracker", + Password: "password", + DBName: "timetracker", + SSLMode: "disable", + MaxIdleConns: 10, + MaxOpenConns: 100, + MaxLifetime: time.Hour, + LogLevel: logger.Info, + } +} + +// JWTConfig represents the configuration for JWT authentication +type JWTConfig struct { + Secret string + TokenDuration time.Duration + KeyGenerate bool + KeyDir string + PrivKeyFile string + PubKeyFile string + KeyBits int +} + // Config represents the application configuration type Config struct { - Database models.DatabaseConfig - JWTConfig models.JWTConfig + Database DatabaseConfig + JWTConfig JWTConfig APIKey string } @@ -26,8 +66,8 @@ func LoadConfig() (*Config, error) { _ = godotenv.Load() cfg := &Config{ - Database: models.DefaultDatabaseConfig(), - JWTConfig: models.JWTConfig{}, + Database: DefaultDatabaseConfig(), + JWTConfig: JWTConfig{}, } // Load database configuration diff --git a/backend/internal/db/db.go b/backend/internal/db/db.go new file mode 100644 index 0000000..27b9b3a --- /dev/null +++ b/backend/internal/db/db.go @@ -0,0 +1,84 @@ +package db + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/timetracker/backend/internal/config" + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// Global variable for the DB connection +var db *gorm.DB + +// InitDB initializes the database connection (once at startup) +// with the provided configuration +func InitDB(config config.DatabaseConfig) error { + // Create DSN (Data Source Name) + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode) + + // Configure GORM logger + gormLogger := logger.New( + log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer + logger.Config{ + SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold + LogLevel: config.LogLevel, // Log level + IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger + Colorful: true, // Enable color + }, + ) + + // Establish database connection with custom logger + var err error + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: gormLogger, + }) + if err != nil { + return fmt.Errorf("error connecting to the database: %w", err) + } + + // Configure connection pool + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("error getting database connection: %w", err) + } + + // Set connection pool parameters + sqlDB.SetMaxIdleConns(config.MaxIdleConns) + sqlDB.SetMaxOpenConns(config.MaxOpenConns) + sqlDB.SetConnMaxLifetime(config.MaxLifetime) + + return nil +} + +// GetEngine returns the DB instance, possibly with context +func GetEngine(ctx context.Context) *gorm.DB { + if db == nil { + panic("database not initialized") + } + // If a special transaction is in ctx, you could check it here + return db.WithContext(ctx) +} + +// CloseDB closes the database connection +func CloseDB() error { + if db == nil { + return nil + } + + sqlDB, err := db.DB() + if err != nil { + return fmt.Errorf("error getting database connection: %w", err) + } + + if err := sqlDB.Close(); err != nil { + return fmt.Errorf("error closing database connection: %w", err) + } + + return nil +} diff --git a/backend/internal/models/activity.go b/backend/internal/models/activity.go index 2a7b1be..2230956 100644 --- a/backend/internal/models/activity.go +++ b/backend/internal/models/activity.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "gorm.io/gorm" ) @@ -36,7 +37,7 @@ type ActivityCreate struct { // GetActivityByID finds an Activity by its ID func GetActivityByID(ctx context.Context, id types.ULID) (*Activity, error) { var activity Activity - result := GetEngine(ctx).Where("id = ?", id).First(&activity) + result := db.GetEngine(ctx).Where("id = ?", id).First(&activity) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -49,7 +50,7 @@ func GetActivityByID(ctx context.Context, id types.ULID) (*Activity, error) { // GetAllActivities returns all Activities func GetAllActivities(ctx context.Context) ([]Activity, error) { var activities []Activity - result := GetEngine(ctx).Find(&activities) + result := db.GetEngine(ctx).Find(&activities) if result.Error != nil { return nil, result.Error } @@ -63,7 +64,7 @@ func CreateActivity(ctx context.Context, create ActivityCreate) (*Activity, erro BillingRate: create.BillingRate, } - result := GetEngine(ctx).Create(&activity) + result := db.GetEngine(ctx).Create(&activity) if result.Error != nil { return nil, result.Error } @@ -91,6 +92,6 @@ func UpdateActivity(ctx context.Context, update ActivityUpdate) (*Activity, erro // DeleteActivity deletes an Activity by its ID func DeleteActivity(ctx context.Context, id types.ULID) error { - result := GetEngine(ctx).Delete(&Activity{}, id) + result := db.GetEngine(ctx).Delete(&Activity{}, id) return result.Error } diff --git a/backend/internal/models/company.go b/backend/internal/models/company.go index 3a22264..f9c2ee3 100644 --- a/backend/internal/models/company.go +++ b/backend/internal/models/company.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "gorm.io/gorm" ) @@ -33,7 +34,7 @@ type CompanyUpdate struct { // GetCompanyByID finds a company by its ID func GetCompanyByID(ctx context.Context, id types.ULID) (*Company, error) { var company Company - result := GetEngine(ctx).Where("id = ?", id).First(&company) + result := db.GetEngine(ctx).Where("id = ?", id).First(&company) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -46,7 +47,7 @@ func GetCompanyByID(ctx context.Context, id types.ULID) (*Company, error) { // GetAllCompanies returns all companies func GetAllCompanies(ctx context.Context) ([]Company, error) { var companies []Company - result := GetEngine(ctx).Find(&companies) + result := db.GetEngine(ctx).Find(&companies) if result.Error != nil { return nil, result.Error } @@ -55,7 +56,7 @@ func GetAllCompanies(ctx context.Context) ([]Company, error) { func GetCustomersByCompanyID(ctx context.Context, companyID int) ([]Customer, error) { var customers []Customer - result := GetEngine(ctx).Where("company_id = ?", companyID).Find(&customers) + result := db.GetEngine(ctx).Where("company_id = ?", companyID).Find(&customers) if result.Error != nil { return nil, result.Error } @@ -68,7 +69,7 @@ func CreateCompany(ctx context.Context, create CompanyCreate) (*Company, error) Name: create.Name, } - result := GetEngine(ctx).Create(&company) + result := db.GetEngine(ctx).Create(&company) if result.Error != nil { return nil, result.Error } @@ -96,6 +97,6 @@ func UpdateCompany(ctx context.Context, update CompanyUpdate) (*Company, error) // DeleteCompany deletes a company by its ID func DeleteCompany(ctx context.Context, id types.ULID) error { - result := GetEngine(ctx).Delete(&Company{}, id) + result := db.GetEngine(ctx).Delete(&Company{}, id) return result.Error } diff --git a/backend/internal/models/customer.go b/backend/internal/models/customer.go index 1ec224b..0c98bb4 100644 --- a/backend/internal/models/customer.go +++ b/backend/internal/models/customer.go @@ -4,6 +4,7 @@ import ( "context" "errors" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "gorm.io/gorm" ) @@ -39,7 +40,7 @@ type CustomerUpdate struct { // GetCustomerByID finds a customer by its ID func GetCustomerByID(ctx context.Context, id types.ULID) (*Customer, error) { var customer Customer - result := GetEngine(ctx).Where("id = ?", id).First(&customer) + result := db.GetEngine(ctx).Where("id = ?", id).First(&customer) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -52,7 +53,7 @@ func GetCustomerByID(ctx context.Context, id types.ULID) (*Customer, error) { // GetAllCustomers returns all customers func GetAllCustomers(ctx context.Context) ([]Customer, error) { var customers []Customer - result := GetEngine(ctx).Find(&customers) + result := db.GetEngine(ctx).Find(&customers) if result.Error != nil { return nil, result.Error } @@ -66,7 +67,7 @@ func CreateCustomer(ctx context.Context, create CustomerCreate) (*Customer, erro CompanyID: create.CompanyID, } - result := GetEngine(ctx).Create(&customer) + result := db.GetEngine(ctx).Create(&customer) if result.Error != nil { return nil, result.Error } @@ -94,6 +95,6 @@ func UpdateCustomer(ctx context.Context, update CustomerUpdate) (*Customer, erro // DeleteCustomer deletes a customer by its ID func DeleteCustomer(ctx context.Context, id types.ULID) error { - result := GetEngine(ctx).Delete(&Customer{}, id) + result := db.GetEngine(ctx).Delete(&Customer{}, id) return result.Error } diff --git a/backend/internal/models/db.go b/backend/internal/models/db.go index 0709e89..c72aa68 100644 --- a/backend/internal/models/db.go +++ b/backend/internal/models/db.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "gorm.io/driver/postgres" // For PostgreSQL + "github.com/timetracker/backend/internal/permissions" // For PostgreSQL "gorm.io/gorm" "gorm.io/gorm/logger" ) @@ -47,47 +47,6 @@ func DefaultDatabaseConfig() DatabaseConfig { } } -// InitDB initializes the database connection (once at startup) -// with the provided configuration -func InitDB(config DatabaseConfig) error { - // Create DSN (Data Source Name) - dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", - config.Host, config.Port, config.User, config.Password, config.DBName, config.SSLMode) - - // Configure GORM logger - gormLogger := logger.New( - log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer - logger.Config{ - SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold - LogLevel: config.LogLevel, // Log level - IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger - Colorful: true, // Enable color - }, - ) - - // Establish database connection with custom logger - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ - Logger: gormLogger, - }) - if err != nil { - return fmt.Errorf("error connecting to the database: %w", err) - } - - // Configure connection pool - sqlDB, err := db.DB() - if err != nil { - return fmt.Errorf("error getting database connection: %w", err) - } - - // Set connection pool parameters - sqlDB.SetMaxIdleConns(config.MaxIdleConns) - sqlDB.SetMaxOpenConns(config.MaxOpenConns) - sqlDB.SetConnMaxLifetime(config.MaxLifetime) - - defaultDB = db - return nil -} - // MigrateDB performs database migrations for all models func MigrateDB() error { if defaultDB == nil { @@ -104,6 +63,8 @@ func MigrateDB() error { &Project{}, &Activity{}, &TimeEntry{}, + &permissions.Role{}, + &permissions.Policy{}, ) if err != nil { @@ -114,57 +75,33 @@ func MigrateDB() error { return nil } -// GetEngine returns the DB instance, possibly with context -func GetEngine(ctx context.Context) *gorm.DB { - if defaultDB == nil { - panic("database not initialized") - } - // If a special transaction is in ctx, you could check it here - return defaultDB.WithContext(ctx) -} - -// CloseDB closes the database connection -func CloseDB() error { - if defaultDB == nil { - return nil - } - - sqlDB, err := defaultDB.DB() - if err != nil { - return fmt.Errorf("error getting database connection: %w", err) - } - - if err := sqlDB.Close(); err != nil { - return fmt.Errorf("error closing database connection: %w", err) - } - - return nil -} - +// GetGormDB is no longer needed, as we use db.InitDB and db.GetEngine +/* func GetGormDB(dbConfig DatabaseConfig, dbName string) (*gorm.DB, error) { - dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", - dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode) + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + dbConfig.Host, dbConfig.Port, dbConfig.User, dbConfig.Password, dbName, dbConfig.SSLMode) - // Configure GORM logger - gormLogger := logger.New( - log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer - logger.Config{ - SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold - LogLevel: dbConfig.LogLevel, // Log level - IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger - Colorful: true, // Enable color - }, - ) + // Configure GORM logger + gormLogger := logger.New( + log.New(log.Writer(), "\r\n", log.LstdFlags), // io writer + logger.Config{ + SlowThreshold: 200 * time.Millisecond, // Slow SQL threshold + LogLevel: dbConfig.LogLevel, // Log level + IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger + Colorful: true, // Enable color + }, + ) - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ - Logger: gormLogger, - }) - if err != nil { - return nil, fmt.Errorf("error connecting to the database: %w", err) - } + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{ + Logger: gormLogger, + }) + if err != nil { + return nil, fmt.Errorf("error connecting to the database: %w", err) + } - return db, nil + return db, nil } +*/ // UpdateModel updates a model based on the set pointer fields func UpdateModel(ctx context.Context, model any, updates any) error { @@ -223,5 +160,5 @@ func UpdateModel(ctx context.Context, model any, updates any) error { return nil // Nothing to update } - return GetEngine(ctx).Model(model).Updates(updateMap).Error + return defaultDB.WithContext(ctx).Model(model).Updates(updateMap).Error } diff --git a/backend/internal/models/project.go b/backend/internal/models/project.go index 78b2591..d90536f 100644 --- a/backend/internal/models/project.go +++ b/backend/internal/models/project.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "gorm.io/gorm" ) @@ -60,7 +61,7 @@ func (pu *ProjectUpdate) Validate() error { // GetProjectByID finds a project by its ID func GetProjectByID(ctx context.Context, id types.ULID) (*Project, error) { var project Project - result := GetEngine(ctx).Where("id = ?", id).First(&project) + result := db.GetEngine(ctx).Where("id = ?", id).First(&project) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -73,7 +74,7 @@ func GetProjectByID(ctx context.Context, id types.ULID) (*Project, error) { // GetProjectWithCustomer loads a project with the associated customer information func GetProjectWithCustomer(ctx context.Context, id types.ULID) (*Project, error) { var project Project - result := GetEngine(ctx).Preload("Customer").Where("id = ?", id).First(&project) + result := db.GetEngine(ctx).Preload("Customer").Where("id = ?", id).First(&project) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -86,7 +87,7 @@ func GetProjectWithCustomer(ctx context.Context, id types.ULID) (*Project, error // GetAllProjects returns all projects func GetAllProjects(ctx context.Context) ([]Project, error) { var projects []Project - result := GetEngine(ctx).Find(&projects) + result := db.GetEngine(ctx).Find(&projects) if result.Error != nil { return nil, result.Error } @@ -96,7 +97,7 @@ func GetAllProjects(ctx context.Context) ([]Project, error) { // GetAllProjectsWithCustomers returns all projects with customer information func GetAllProjectsWithCustomers(ctx context.Context) ([]Project, error) { var projects []Project - result := GetEngine(ctx).Preload("Customer").Find(&projects) + result := db.GetEngine(ctx).Preload("Customer").Find(&projects) if result.Error != nil { return nil, result.Error } @@ -106,7 +107,7 @@ func GetAllProjectsWithCustomers(ctx context.Context) ([]Project, error) { // GetProjectsByCustomerID returns all projects of a specific customer func GetProjectsByCustomerID(ctx context.Context, customerId types.ULID) ([]Project, error) { var projects []Project - result := GetEngine(ctx).Where("customer_id = ?", customerId.ULID).Find(&projects) + result := db.GetEngine(ctx).Where("customer_id = ?", customerId.ULID).Find(&projects) if result.Error != nil { return nil, result.Error } @@ -136,7 +137,7 @@ func CreateProject(ctx context.Context, create ProjectCreate) (*Project, error) CustomerID: create.CustomerID, } - result := GetEngine(ctx).Create(&project) + result := db.GetEngine(ctx).Create(&project) if result.Error != nil { return nil, fmt.Errorf("error creating the project: %w", result.Error) } @@ -181,7 +182,7 @@ func UpdateProject(ctx context.Context, update ProjectUpdate) (*Project, error) // DeleteProject deletes a project by its ID func DeleteProject(ctx context.Context, id types.ULID) error { // Here you could check if dependent entities exist - result := GetEngine(ctx).Delete(&Project{}, id) + result := db.GetEngine(ctx).Delete(&Project{}, id) if result.Error != nil { return fmt.Errorf("error deleting the project: %w", result.Error) } @@ -198,7 +199,7 @@ func CreateProjectWithTransaction(ctx context.Context, create ProjectCreate) (*P var project *Project // Start transaction - err := GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + err := db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // Customer check within the transaction var customer Customer if err := tx.Where("id = ?", create.CustomerID).First(&customer).Error; err != nil { diff --git a/backend/internal/models/timeentry.go b/backend/internal/models/timeentry.go index 6b32016..afd6f28 100644 --- a/backend/internal/models/timeentry.go +++ b/backend/internal/models/timeentry.go @@ -6,6 +6,7 @@ import ( "fmt" "time" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "gorm.io/gorm" ) @@ -105,7 +106,7 @@ func (tu *TimeEntryUpdate) Validate() error { // GetTimeEntryByID finds a time entry by its ID func GetTimeEntryByID(ctx context.Context, id types.ULID) (*TimeEntry, error) { var timeEntry TimeEntry - result := GetEngine(ctx).Where("id = ?", id).First(&timeEntry) + result := db.GetEngine(ctx).Where("id = ?", id).First(&timeEntry) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -118,7 +119,7 @@ func GetTimeEntryByID(ctx context.Context, id types.ULID) (*TimeEntry, error) { // GetTimeEntryWithRelations loads a time entry with all associated data func GetTimeEntryWithRelations(ctx context.Context, id types.ULID) (*TimeEntry, error) { var timeEntry TimeEntry - result := GetEngine(ctx). + result := db.GetEngine(ctx). Preload("User"). Preload("Project"). Preload("Project.Customer"). // Nested relationship @@ -138,7 +139,7 @@ func GetTimeEntryWithRelations(ctx context.Context, id types.ULID) (*TimeEntry, // GetAllTimeEntries returns all time entries func GetAllTimeEntries(ctx context.Context) ([]TimeEntry, error) { var timeEntries []TimeEntry - result := GetEngine(ctx).Find(&timeEntries) + result := db.GetEngine(ctx).Find(&timeEntries) if result.Error != nil { return nil, result.Error } @@ -148,7 +149,7 @@ func GetAllTimeEntries(ctx context.Context) ([]TimeEntry, error) { // GetTimeEntriesByUserID returns all time entries of a user func GetTimeEntriesByUserID(ctx context.Context, userID types.ULID) ([]TimeEntry, error) { var timeEntries []TimeEntry - result := GetEngine(ctx).Where("user_id = ?", userID).Find(&timeEntries) + result := db.GetEngine(ctx).Where("user_id = ?", userID).Find(&timeEntries) if result.Error != nil { return nil, result.Error } @@ -158,7 +159,7 @@ func GetTimeEntriesByUserID(ctx context.Context, userID types.ULID) ([]TimeEntry // GetTimeEntriesByProjectID returns all time entries of a project func GetTimeEntriesByProjectID(ctx context.Context, projectID types.ULID) ([]TimeEntry, error) { var timeEntries []TimeEntry - result := GetEngine(ctx).Where("project_id = ?", projectID).Find(&timeEntries) + result := db.GetEngine(ctx).Where("project_id = ?", projectID).Find(&timeEntries) if result.Error != nil { return nil, result.Error } @@ -169,7 +170,7 @@ func GetTimeEntriesByProjectID(ctx context.Context, projectID types.ULID) ([]Tim func GetTimeEntriesByDateRange(ctx context.Context, start, end time.Time) ([]TimeEntry, error) { var timeEntries []TimeEntry // Search for overlaps in the time range - result := GetEngine(ctx). + result := db.GetEngine(ctx). Where("(start BETWEEN ? AND ?) OR (end BETWEEN ? AND ?)", start, end, start, end). Find(&timeEntries) @@ -189,7 +190,7 @@ func SumBillableHoursByProject(ctx context.Context, projectID types.ULID) (float var result Result // SQL calculation of weighted hours - err := GetEngine(ctx).Raw(` + err := db.GetEngine(ctx).Raw(` SELECT SUM( EXTRACT(EPOCH FROM (end - start)) / 3600 * (billable / 100.0) ) as total_hours @@ -214,7 +215,7 @@ func CreateTimeEntry(ctx context.Context, create TimeEntryCreate) (*TimeEntry, e // Start a transaction var timeEntry *TimeEntry - err := GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + err := db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // Check references if err := validateReferences(tx, create.UserID, create.ProjectID, create.ActivityID); err != nil { return err @@ -295,7 +296,7 @@ func UpdateTimeEntry(ctx context.Context, update TimeEntryUpdate) (*TimeEntry, e } // Start a transaction for the update - err = GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + err = db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // Check references if they are updated if update.UserID != nil || update.ProjectID != nil || update.ActivityID != nil { // Use current values if not updated @@ -352,7 +353,7 @@ func UpdateTimeEntry(ctx context.Context, update TimeEntryUpdate) (*TimeEntry, e // DeleteTimeEntry deletes a time entry by its ID func DeleteTimeEntry(ctx context.Context, id types.ULID) error { - result := GetEngine(ctx).Delete(&TimeEntry{}, id) + result := db.GetEngine(ctx).Delete(&TimeEntry{}, id) if result.Error != nil { return fmt.Errorf("error deleting the time entry: %w", result.Error) } diff --git a/backend/internal/models/user.go b/backend/internal/models/user.go index 2972dfc..3a079d6 100644 --- a/backend/internal/models/user.go +++ b/backend/internal/models/user.go @@ -11,28 +11,12 @@ import ( "slices" + "github.com/timetracker/backend/internal/db" "github.com/timetracker/backend/internal/types" "golang.org/x/crypto/argon2" "gorm.io/gorm" ) -// Argon2 Parameters -const ( - // Recommended values for Argon2id - ArgonTime = 1 - ArgonMemory = 64 * 1024 // 64MB - ArgonThreads = 4 - ArgonKeyLen = 32 - SaltLength = 16 -) - -// Role Constants -const ( - RoleAdmin = "admin" - RoleUser = "user" - RoleViewer = "viewer" -) - // User represents a user in the system type User struct { EntityBase @@ -42,6 +26,7 @@ type User struct { Role string `gorm:"column:role;not null;default:'user'"` CompanyID *types.ULID `gorm:"column:company_id;type:bytea;index"` HourlyRate float64 `gorm:"column:hourly_rate;not null;default:0"` + Companies []string `gorm:"type:text[]"` // Relationship for Eager Loading Company *Company `gorm:"foreignKey:CompanyID"` @@ -290,7 +275,7 @@ func (uu *UserUpdate) Validate() error { // GetUserByID finds a user by their ID func GetUserByID(ctx context.Context, id types.ULID) (*User, error) { var user User - result := GetEngine(ctx).Where("id = ?", id).First(&user) + result := db.GetEngine(ctx).Where("id = ?", id).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -303,7 +288,7 @@ func GetUserByID(ctx context.Context, id types.ULID) (*User, error) { // GetUserByEmail finds a user by their email func GetUserByEmail(ctx context.Context, email string) (*User, error) { var user User - result := GetEngine(ctx).Where("email = ?", email).First(&user) + result := db.GetEngine(ctx).Where("email = ?", email).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -316,7 +301,7 @@ func GetUserByEmail(ctx context.Context, email string) (*User, error) { // GetUserWithCompany loads a user with their company func GetUserWithCompany(ctx context.Context, id types.ULID) (*User, error) { var user User - result := GetEngine(ctx).Preload("Company").Where("id = ?", id).First(&user) + result := db.GetEngine(ctx).Preload("Company").Where("id = ?", id).First(&user) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, nil @@ -329,7 +314,7 @@ func GetUserWithCompany(ctx context.Context, id types.ULID) (*User, error) { // GetAllUsers returns all users func GetAllUsers(ctx context.Context) ([]User, error) { var users []User - result := GetEngine(ctx).Find(&users) + result := db.GetEngine(ctx).Find(&users) if result.Error != nil { return nil, result.Error } @@ -351,7 +336,7 @@ func GetUsersByCompanyID(ctx context.Context, companyID types.ULID) ([]User, err var users []User // Apply the dynamic company condition condition := getCompanyCondition(&companyID) - result := GetEngine(ctx).Scopes(condition).Find(&users) + result := db.GetEngine(ctx).Scopes(condition).Find(&users) if result.Error != nil { return nil, result.Error } @@ -368,7 +353,7 @@ func CreateUser(ctx context.Context, create UserCreate) (*User, error) { // Start a transaction var user *User - err := GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + err := db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // Check if email already exists var count int64 if err := tx.Model(&User{}).Where("email = ?", create.Email).Count(&count).Error; err != nil { @@ -435,7 +420,7 @@ func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) { } // Start a transaction for the update - err = GetEngine(ctx).Transaction(func(tx *gorm.DB) error { + err = db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error { // If email is updated, check if it's already in use if update.Email != nil && *update.Email != user.Email { var count int64 @@ -492,7 +477,6 @@ func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) { } else { updates["company_id"] = *update.CompanyID.Value } - } if update.HourlyRate != nil { updates["hourly_rate"] = *update.HourlyRate @@ -521,7 +505,7 @@ func DeleteUser(ctx context.Context, id types.ULID) error { // Here one could check if dependent entities exist // e.g., don't delete if time entries still exist - result := GetEngine(ctx).Delete(&User{}, id) + result := db.GetEngine(ctx).Delete(&User{}, id) if result.Error != nil { return fmt.Errorf("error deleting user: %w", result.Error) } @@ -551,3 +535,20 @@ func AuthenticateUser(ctx context.Context, email, password string) (*User, error return user, nil } + +// Argon2 Parameters +const ( + // Recommended values for Argon2id + ArgonTime = 1 + ArgonMemory = 64 * 1024 // 64MB + ArgonThreads = 4 + ArgonKeyLen = 32 + SaltLength = 16 +) + +// Role Constants +const ( + RoleAdmin = "admin" + RoleUser = "user" + RoleViewer = "viewer" +) diff --git a/backend/internal/permissions/evaluator.go b/backend/internal/permissions/evaluator.go new file mode 100644 index 0000000..a9fb251 --- /dev/null +++ b/backend/internal/permissions/evaluator.go @@ -0,0 +1,35 @@ +package permissions + +import ( + "context" +) + +func (u *User) EffectivePermissions(ctx context.Context, scope string) (Permission, error) { + if u.ActiveRole == nil { + return 0, nil + } + + // Load the role and its associated policies using the helper function. + role, err := LoadRoleWithPolicies(ctx, u.ActiveRole.ID) + if err != nil { + return 0, err + } + + var perm Permission + for _, policy := range role.Policies { + for pat, p := range policy.Scopes { + if MatchScope(pat, scope) { + perm |= p + } + } + } + return perm, nil +} + +func (u *User) HasPermission(ctx context.Context, scope string, requiredPerm Permission) (bool, error) { + effective, err := u.EffectivePermissions(ctx, scope) + if err != nil { + return false, err + } + return (effective & requiredPerm) == requiredPerm, nil +} diff --git a/backend/internal/permissions/helpers.go b/backend/internal/permissions/helpers.go new file mode 100644 index 0000000..e533e21 --- /dev/null +++ b/backend/internal/permissions/helpers.go @@ -0,0 +1,23 @@ +package permissions + +import ( + "context" + "fmt" + + "github.com/oklog/ulid/v2" + "github.com/timetracker/backend/internal/db" + "gorm.io/gorm" +) + +// LoadRoleWithPolicies loads a role with its associated policies from the database. +func LoadRoleWithPolicies(ctx context.Context, roleID ulid.ULID) (*Role, error) { + var role Role + err := db.GetEngine(ctx).Preload("Policies").First(&role, "id = ?", roleID).Error + if err != nil { + if err == gorm.ErrRecordNotFound { + return nil, fmt.Errorf("role with ID %s not found", roleID) + } + return nil, fmt.Errorf("failed to load role: %w", err) + } + return &role, nil +} diff --git a/backend/internal/permissions/matching.go b/backend/internal/permissions/matching.go new file mode 100644 index 0000000..19be802 --- /dev/null +++ b/backend/internal/permissions/matching.go @@ -0,0 +1,11 @@ +package permissions + +import "strings" + +func MatchScope(pattern, scope string) bool { + if strings.HasSuffix(pattern, "/*") { + prefix := strings.TrimSuffix(pattern, "/*") + return strings.HasPrefix(scope, prefix) + } + return pattern == scope +} diff --git a/backend/internal/permissions/permissions.go b/backend/internal/permissions/permissions.go new file mode 100644 index 0000000..49f3620 --- /dev/null +++ b/backend/internal/permissions/permissions.go @@ -0,0 +1,13 @@ +package permissions + +type Permission uint64 + +const ( + PermRead Permission = 1 << iota // 1 + PermWrite // 2 + PermCreate // 4 + PermList // 8 + PermDelete // 16 + PermModerate // 32 + PermSuperadmin // 64 +) \ No newline at end of file diff --git a/backend/internal/permissions/policy.go b/backend/internal/permissions/policy.go new file mode 100644 index 0000000..1bbe489 --- /dev/null +++ b/backend/internal/permissions/policy.go @@ -0,0 +1,40 @@ +package permissions + +import ( + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + + "github.com/oklog/ulid/v2" +) + +type Policy struct { + ID ulid.ULID `gorm:"primaryKey;type:bytea"` + Name string `gorm:"not null"` + RoleID ulid.ULID `gorm:"type:bytea"` //Fremdschlüssel + Scopes Scopes `gorm:"type:jsonb;not null"` // JSONB-Spalte +} + +// Scopes type to handle JSON marshalling +type Scopes map[string]Permission + +// Scan scan value into Jsonb, implements sql.Scanner interface +func (j *Scopes) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) + } + + var scopes map[string]Permission + if err := json.Unmarshal(bytes, &scopes); err != nil { + return err + } + *j = scopes + return nil +} + +// Value return json value, implement driver.Valuer interface +func (j Scopes) Value() (driver.Value, error) { + return json.Marshal(j) +} diff --git a/backend/internal/permissions/role.go b/backend/internal/permissions/role.go new file mode 100644 index 0000000..14c5e3d --- /dev/null +++ b/backend/internal/permissions/role.go @@ -0,0 +1,11 @@ +package permissions + +import ( + "github.com/oklog/ulid/v2" +) + +type Role struct { + ID ulid.ULID `gorm:"primaryKey;type:bytea"` + Name string `gorm:"unique;not null"` + Policies []Policy `gorm:"foreignKey:RoleID"` +} diff --git a/backend/internal/permissions/user.go b/backend/internal/permissions/user.go new file mode 100644 index 0000000..43a488e --- /dev/null +++ b/backend/internal/permissions/user.go @@ -0,0 +1,10 @@ +package permissions + +import ( + "github.com/oklog/ulid/v2" +) + +type User struct { + ActiveRole *Role `gorm:"foreignKey:UserID"` // Beziehung zur aktiven Rolle + UserID ulid.ULID `gorm:"type:bytea"` //Fremdschlüssel +} diff --git a/docu/permissions_plan.md b/docu/permissions_plan.md new file mode 100644 index 0000000..1c9ec45 --- /dev/null +++ b/docu/permissions_plan.md @@ -0,0 +1,250 @@ +# Berechtigungssystem Plan + +Dieser Plan beschreibt die Implementierung eines scope-basierten Berechtigungssystems für das TimeTracker-Projekt. + +## Grundkonzept + +- Ein **Benutzer** kann eine **Rolle** annehmen, aber immer nur eine ist aktiv. +- Eine **Rolle** besteht aus mehreren **Policies**. +- Eine **Policy** hat einen Namen und eine Map, die **Scopes** (z. B. `items/books`) einem **Berechtigungsschlüssel** (Bitflag) zuordnet. +- Berechtigungsschlüssel sind Bitflags, die Permissions wie `read`, `write`, `create`, `list`, `delete`, `moderate`, `superadmin` usw. repräsentieren. +- Scopes können **Wildcards** enthalten, z. B. `items/*`, das auf `items/books` vererbt wird. +- Ziel: Berechtigungen sowohl im Go-Backend (für API-Sicherheit) als auch im TypeScript-Frontend (für UI-Anpassung) evaluieren. + +## Implementierung im Go-Backend + +### 1. Ordnerstruktur + +- Neuer Ordner: `backend/internal/permissions` +- Dateien: + - `permissions.go`: `Permission`-Konstanten (Bitflags). + - `policy.go`: `Policy`-Struktur. + - `role.go`: `Role`-Struktur. + - `user.go`: Erweiterung der `User`-Struktur. + - `matching.go`: `matchScope`-Funktion. + - `evaluator.go`: `EffectivePermissions`- und `HasPermission`-Funktionen. + +### 2. Go-Strukturen + +- `permissions.go`: + +```go +package permissions + +type Permission uint64 + +const ( + PermRead Permission = 1 << iota // 1 + PermWrite // 2 + PermCreate // 4 + PermList // 8 + PermDelete // 16 + PermModerate // 32 + PermSuperadmin // 64 +) +``` + +- `policy.go`: + +```go +package permissions + +type Policy struct { + Name string + Scopes map[string]Permission +} +``` + +- `role.go`: + +```go +package permissions + +type Role struct { + Name string + Policies []Policy +} +``` + +- `user.go`: + +```go +package permissions + +import "github.com/your-org/your-project/backend/internal/models" // Pfad anpassen + +type User struct { + models.User // Einbettung + ActiveRole *Role +} +``` + +### 3. Funktionen + +- `matching.go`: + +```go +package permissions + +import "strings" + +func MatchScope(pattern, scope string) bool { + if strings.HasSuffix(pattern, "/*") { + prefix := strings.TrimSuffix(pattern, "/*") + return strings.HasPrefix(scope, prefix) + } + return pattern == scope +} +``` + +- `evaluator.go`: + +```go +package permissions + +func (u *User) EffectivePermissions(scope string) Permission { + if u.ActiveRole == nil { + return 0 + } + var perm Permission + for _, policy := range u.ActiveRole.Policies { + for pat, p := range policy.Scopes { + if MatchScope(pat, scope) { + perm |= p + } + } + } + return perm +} + +func (u *User) HasPermission(scope string, requiredPerm Permission) bool { + effective := u.EffectivePermissions(scope) + return (effective & requiredPerm) == requiredPerm +} +``` + +### 4. Integration in die API-Handler + +- Anpassung der `jwt_auth.go` Middleware. +- Verwendung von `HasPermission` in den API-Handlern. + +## Persistierung (Datenbank) + +### 1. Datenbankmodell + +- Zwei neue Tabellen: `roles` und `policies`. +- `roles`: + - `id` (ULID, Primärschlüssel) + - `name` (VARCHAR, eindeutig) +- `policies`: + - `id` (ULID, Primärschlüssel) + - `name` (VARCHAR, eindeutig) + - `role_id` (ULID, Fremdschlüssel, der auf `roles.id` verweist) + - `scopes` (JSONB oder TEXT, speichert die `map[string]Permission` als JSON) +- Beziehung: 1:n zwischen `roles` und `policies`. + +### 2. Go-Strukturen (Anpassungen) + +- `role.go`: + +```go +package permissions + +import ( + "github.com/your-org/your-project/backend/internal/types" // Pfad anpassen +) + +type Role struct { + ID types.ULID `gorm:"primaryKey;type:bytea"` + Name string `gorm:"unique;not null"` + Policies []Policy `gorm:"foreignKey:RoleID"` +} +``` + +- `policy.go`: + +```go +package permissions + +import ( + "encoding/json" + "database/sql/driver" + "errors" + "fmt" + + "github.com/your-org/your-project/backend/internal/types" // Pfad anpassen + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +type Policy struct { + ID types.ULID `gorm:"primaryKey;type:bytea"` + Name string `gorm:"not null"` + RoleID types.ULID `gorm:"type:bytea"` //Fremdschlüssel + Scopes Scopes `gorm:"type:jsonb;not null"` // JSONB-Spalte +} + +//Scopes type to handle JSON marshalling +type Scopes map[string]Permission + +// Scan scan value into Jsonb, implements sql.Scanner interface +func (j *Scopes) Scan(value interface{}) error { + bytes, ok := value.([]byte) + if !ok { + return errors.New(fmt.Sprint("Failed to unmarshal JSONB value:", value)) + } + + var scopes map[string]Permission + if err := json.Unmarshal(bytes, &scopes); err != nil { + return err + } + *j = scopes + return nil +} + +// Value return json value, implement driver.Valuer interface +func (j Scopes) Value() (driver.Value, error) { + return json.Marshal(j) +} +``` +### 3. Migration + +- Verwendung des vorhandenen Migrations-Frameworks (`backend/cmd/migrate/main.go`). + +### 4. Seed-Daten + +- Optionale Seed-Daten (`backend/cmd/seed/main.go`). + +### 5. Anpassung der Funktionen + +- Anpassung von `EffectivePermissions` und `HasPermission` in `evaluator.go` für Datenbankzugriff. + +## Mermaid Diagramm +```mermaid +graph LR + subgraph Benutzer + U[User] --> AR(ActiveRole) + end + subgraph Rolle + AR --> R(Role) + R --> P1(Policy 1) + R --> P2(Policy 2) + R --> Pn(Policy n) + end + subgraph Policy + P1 --> S1(Scope 1: Permissions) + P1 --> S2(Scope 2: Permissions) + P2 --> S3(Scope 3: Permissions) + Pn --> Sm(Scope m: Permissions) + end + + S1 -- Permissions --> PR(PermRead) + S1 -- Permissions --> PW(PermWrite) + S2 -- Permissions --> PL(PermList) + Sm -- Permissions --> PD(PermDelete) + + style U fill:#f9f,stroke:#333,stroke-width:2px + style R fill:#ccf,stroke:#333,stroke-width:2px + style P1,P2,Pn fill:#ddf,stroke:#333,stroke-width:2px + style S1,S2,S3,Sm fill:#eef,stroke:#333,stroke-width:1px + style PR,PW,PL,PD fill:#ff9,stroke:#333,stroke-width:1px \ No newline at end of file