From b47c29cf5a190ba96e4e58ee33036ea4d0e7ec2b Mon Sep 17 00:00:00 2001 From: Jean Jacques Avril Date: Wed, 12 Mar 2025 11:03:48 +0000 Subject: [PATCH] feat: Introduce Undefined function for Nullable type and refactor DTOs to use Nullable directly --- backend/cmd/modeltest/main.go | 5 +- .../internal/api/handlers/company_handler.go | 12 ++--- .../internal/api/handlers/customer_handler.go | 16 +++--- .../internal/api/handlers/project_handler.go | 53 +++++++++---------- .../api/handlers/timeentry_handler.go | 16 ++---- backend/internal/api/handlers/user_handler.go | 19 +++---- backend/internal/dtos/company_dto.go | 8 ++- backend/internal/dtos/customer_dto.go | 14 ++--- backend/internal/dtos/project_dto.go | 18 +++---- backend/internal/dtos/timeentry_dto.go | 20 ++++--- backend/internal/dtos/user_dto.go | 17 +++--- backend/internal/models/base.go | 2 - backend/internal/models/project.go | 20 +++---- backend/internal/models/user.go | 37 +++++++------ backend/internal/types/nullable.go | 4 ++ 15 files changed, 126 insertions(+), 135 deletions(-) diff --git a/backend/cmd/modeltest/main.go b/backend/cmd/modeltest/main.go index 1c0384f..e529275 100644 --- a/backend/cmd/modeltest/main.go +++ b/backend/cmd/modeltest/main.go @@ -8,6 +8,7 @@ import ( "github.com/oklog/ulid/v2" "github.com/timetracker/backend/internal/models" + "github.com/timetracker/backend/internal/types" ) func main() { @@ -101,7 +102,7 @@ func testUserModel(ctx context.Context) { Email: "test@example.com", Password: "Test@123456", Role: models.RoleUser, - CompanyID: companyID, + CompanyID: &companyID, HourlyRate: 50.0, } @@ -196,7 +197,7 @@ func testRelationships(ctx context.Context) { // Test invalid ID invalidID := ulid.MustNew(ulid.Timestamp(time.Now()), ulid.DefaultEntropy()) - invalidUser, err := models.GetUserByID(ctx, models.FromULID(invalidID)) + invalidUser, err := models.GetUserByID(ctx, types.FromULID(invalidID)) if err != nil { log.Fatalf("Error getting user with invalid ID: %v", err) } diff --git a/backend/internal/api/handlers/company_handler.go b/backend/internal/api/handlers/company_handler.go index d55758c..da036ff 100644 --- a/backend/internal/api/handlers/company_handler.go +++ b/backend/internal/api/handlers/company_handler.go @@ -147,7 +147,7 @@ func (h *CompanyHandler) CreateCompany(c *gin.Context) { func (h *CompanyHandler) UpdateCompany(c *gin.Context) { // Parse ID from URL idStr := c.Param("id") - id, err := ulid.Parse(idStr) + id, err := types.ULIDFromString(idStr) if err != nil { utils.BadRequestResponse(c, "Invalid company ID format") return @@ -160,11 +160,8 @@ func (h *CompanyHandler) UpdateCompany(c *gin.Context) { return } - // Set ID from URL - companyUpdateDTO.ID = id.String() - // Convert DTO to model - companyUpdate := convertUpdateCompanyDTOToModel(companyUpdateDTO) + companyUpdate := convertUpdateCompanyDTOToModel(companyUpdateDTO, id) // Update company in the database company, err := models.UpdateCompany(c.Request.Context(), companyUpdate) @@ -234,10 +231,9 @@ func convertCreateCompanyDTOToModel(dto dto.CompanyCreateDto) models.CompanyCrea } } -func convertUpdateCompanyDTOToModel(dto dto.CompanyUpdateDto) models.CompanyUpdate { - id, _ := ulid.Parse(dto.ID) +func convertUpdateCompanyDTOToModel(dto dto.CompanyUpdateDto, id types.ULID) models.CompanyUpdate { update := models.CompanyUpdate{ - ID: types.FromULID(id), + ID: id, } if dto.Name != nil { diff --git a/backend/internal/api/handlers/customer_handler.go b/backend/internal/api/handlers/customer_handler.go index 36bf4d1..7d1fa58 100644 --- a/backend/internal/api/handlers/customer_handler.go +++ b/backend/internal/api/handlers/customer_handler.go @@ -319,16 +319,14 @@ func convertUpdateCustomerDTOToModel(dto dto.CustomerUpdateDto) (models.Customer update.Name = dto.Name } - if dto.CompanyID != nil { - if dto.CompanyID.Valid { - companyID, err := types.ULIDFromString(*dto.CompanyID.Value) - if err != nil { - return models.CustomerUpdate{}, fmt.Errorf("invalid company ID: %w", err) - } - update.CompanyID = &companyID - } else { - update.CompanyID = nil + if dto.CompanyID.Valid { + companyID, err := types.ULIDFromString(*dto.CompanyID.Value) + if err != nil { + return models.CustomerUpdate{}, fmt.Errorf("invalid company ID: %w", err) } + update.CompanyID = &companyID + } else { + update.CompanyID = nil } return update, nil diff --git a/backend/internal/api/handlers/project_handler.go b/backend/internal/api/handlers/project_handler.go index e98e41e..63dda6d 100644 --- a/backend/internal/api/handlers/project_handler.go +++ b/backend/internal/api/handlers/project_handler.go @@ -220,7 +220,7 @@ func (h *ProjectHandler) CreateProject(c *gin.Context) { func (h *ProjectHandler) UpdateProject(c *gin.Context) { // Parse ID from URL idStr := c.Param("id") - id, err := ulid.Parse(idStr) + id, err := types.ULIDFromString(idStr) if err != nil { utils.BadRequestResponse(c, "Invalid project ID format") return @@ -233,11 +233,8 @@ func (h *ProjectHandler) UpdateProject(c *gin.Context) { return } - // Set ID from URL - projectUpdateDTO.ID = id.String() - // Convert DTO to model - projectUpdate, err := convertUpdateProjectDTOToModel(projectUpdateDTO) + projectUpdate, err := convertUpdateProjectDTOToModel(projectUpdateDTO, id) if err != nil { utils.BadRequestResponse(c, err.Error()) return @@ -297,49 +294,51 @@ func (h *ProjectHandler) DeleteProject(c *gin.Context) { // Helper functions for DTO conversion func convertProjectToDTO(project *models.Project) dto.ProjectDto { - + customerId := project.CustomerID.String() return dto.ProjectDto{ ID: project.ID.String(), CreatedAt: project.CreatedAt, UpdatedAt: project.UpdatedAt, Name: project.Name, - CustomerID: project.CustomerID.String(), + CustomerID: &customerId, } } func convertCreateProjectDTOToModel(dto dto.ProjectCreateDto) (models.ProjectCreate, error) { + create := models.ProjectCreate{Name: dto.Name} // Convert CustomerID from int to ULID (this is a simplification, adjust as needed) - customerID, err := types.ULIDFromString(dto.CustomerID) - if err != nil { - return models.ProjectCreate{}, fmt.Errorf("invalid customer ID: %w", err) - } + if dto.CustomerID != nil { - return models.ProjectCreate{ - Name: dto.Name, - CustomerID: customerID, - }, nil + customerID, err := types.ULIDFromString(*dto.CustomerID) + if err != nil { + return models.ProjectCreate{}, fmt.Errorf("invalid customer ID: %w", err) + } + create.CustomerID = &customerID + } + return create, nil } -func convertUpdateProjectDTOToModel(dto dto.ProjectUpdateDto) (models.ProjectUpdate, error) { - id, err := ulid.Parse(dto.ID) - if err != nil { - return models.ProjectUpdate{}, fmt.Errorf("invalid project ID: %w", err) - } +func convertUpdateProjectDTOToModel(dto dto.ProjectUpdateDto, id types.ULID) (models.ProjectUpdate, error) { update := models.ProjectUpdate{ - ID: types.FromULID(id), + ID: id, } if dto.Name != nil { update.Name = dto.Name } - if dto.CustomerID != nil { - // Convert CustomerID from int to ULID (this is a simplification, adjust as needed) - customerID, err := types.ULIDFromString(*dto.CustomerID) - if err != nil { - return models.ProjectUpdate{}, fmt.Errorf("invalid customer ID: %w", err) + if dto.CustomerID.Valid { + if dto.CustomerID.Value == nil { + update.CustomerID = nil + + } else { + // Convert CustomerID from int to ULID (this is a simplification, adjust as needed) + customerID, err := types.ULIDFromString(*dto.CustomerID.Value) + if err != nil { + return models.ProjectUpdate{}, fmt.Errorf("invalid customer ID: %w", err) + } + update.CustomerID = &customerID } - update.CustomerID = &customerID } return update, nil diff --git a/backend/internal/api/handlers/timeentry_handler.go b/backend/internal/api/handlers/timeentry_handler.go index 2f4ae1d..aa2663a 100644 --- a/backend/internal/api/handlers/timeentry_handler.go +++ b/backend/internal/api/handlers/timeentry_handler.go @@ -326,7 +326,7 @@ func (h *TimeEntryHandler) CreateTimeEntry(c *gin.Context) { func (h *TimeEntryHandler) UpdateTimeEntry(c *gin.Context) { // Parse ID from URL idStr := c.Param("id") - id, err := ulid.Parse(idStr) + id, err := types.ULIDFromString(idStr) if err != nil { utils.BadRequestResponse(c, "Invalid time entry ID format") return @@ -339,11 +339,8 @@ func (h *TimeEntryHandler) UpdateTimeEntry(c *gin.Context) { return } - // Set ID from URL - timeEntryUpdateDTO.ID = id.String() - // Convert DTO to model - timeEntryUpdate, err := convertUpdateTimeEntryDTOToModel(timeEntryUpdateDTO) + timeEntryUpdate, err := convertUpdateTimeEntryDTOToModel(timeEntryUpdateDTO, id) if err != nil { utils.BadRequestResponse(c, err.Error()) return @@ -445,13 +442,10 @@ func convertCreateTimeEntryDTOToModel(dto dto.TimeEntryCreateDto) (models.TimeEn }, nil } -func convertUpdateTimeEntryDTOToModel(dto dto.TimeEntryUpdateDto) (models.TimeEntryUpdate, error) { - id, err := ulid.Parse(dto.ID) - if err != nil { - return models.TimeEntryUpdate{}, fmt.Errorf("invalid time entry ID: %w", err) - } +func convertUpdateTimeEntryDTOToModel(dto dto.TimeEntryUpdateDto, id types.ULID) (models.TimeEntryUpdate, error) { + update := models.TimeEntryUpdate{ - ID: types.FromULID(id), + ID: id, } if dto.UserID != nil { diff --git a/backend/internal/api/handlers/user_handler.go b/backend/internal/api/handlers/user_handler.go index 6a5625a..ce845e2 100644 --- a/backend/internal/api/handlers/user_handler.go +++ b/backend/internal/api/handlers/user_handler.go @@ -148,7 +148,7 @@ func (h *UserHandler) CreateUser(c *gin.Context) { func (h *UserHandler) UpdateUser(c *gin.Context) { // Parse ID from URL idStr := c.Param("id") - id, err := ulid.Parse(idStr) + id, err := types.ULIDFromString(idStr) if err != nil { utils.BadRequestResponse(c, "Invalid user ID format") return @@ -161,13 +161,9 @@ func (h *UserHandler) UpdateUser(c *gin.Context) { return } - // Set ID from URL - userUpdateDTO.ID = id.String() - // Convert DTO to Model - idWrapper := types.FromULID(id) update := models.UserUpdate{ - ID: idWrapper, + ID: id, } if userUpdateDTO.Email != nil { @@ -179,22 +175,23 @@ func (h *UserHandler) UpdateUser(c *gin.Context) { if userUpdateDTO.Role != nil { update.Role = userUpdateDTO.Role } - if userUpdateDTO.CompanyID != nil { - if userUpdateDTO.CompanyID.Valid { + + if userUpdateDTO.CompanyID.Valid { + if userUpdateDTO.CompanyID.Value != nil { companyID, err := types.ULIDFromString(*userUpdateDTO.CompanyID.Value) if err != nil { utils.BadRequestResponse(c, "Invalid company ID format") return } - update.CompanyID = &companyID + update.CompanyID = types.NewNullable(companyID) } else { - update.CompanyID = nil + update.CompanyID = types.Null[types.ULID]() } } + if userUpdateDTO.HourlyRate != nil { update.HourlyRate = userUpdateDTO.HourlyRate } - // Update user in the database user, err := models.UpdateUser(c.Request.Context(), update) if err != nil { diff --git a/backend/internal/dtos/company_dto.go b/backend/internal/dtos/company_dto.go index e55746b..884a5b7 100644 --- a/backend/internal/dtos/company_dto.go +++ b/backend/internal/dtos/company_dto.go @@ -17,9 +17,7 @@ type CompanyCreateDto struct { } type CompanyUpdateDto struct { - ID string `json:"id" example:"01HGW2BBG0000000000000000"` - CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` - UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` - LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` - Name *string `json:"name" example:"Acme Corp"` + CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` + UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` + Name *string `json:"name" example:"Acme Corp"` } diff --git a/backend/internal/dtos/customer_dto.go b/backend/internal/dtos/customer_dto.go index b0bb45a..471f409 100644 --- a/backend/internal/dtos/customer_dto.go +++ b/backend/internal/dtos/customer_dto.go @@ -22,11 +22,11 @@ type CustomerCreateDto struct { } type CustomerUpdateDto struct { - ID string `json:"id" example:"01HGW2BBG0000000000000000"` - CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` - UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` - LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` - Name *string `json:"name" example:"John Doe"` - CompanyID *types.Nullable[string] `json:"companyId" example:"01HGW2BBG0000000000000000"` - OwnerUserID *types.Nullable[string] `json:"owningUserID" example:"01HGW2BBG0000000000000000"` + ID string `json:"id" example:"01HGW2BBG0000000000000000"` + CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` + UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` + LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` + Name *string `json:"name" example:"John Doe"` + CompanyID types.Nullable[string] `json:"companyId" example:"01HGW2BBG0000000000000000"` + OwnerUserID types.Nullable[string] `json:"owningUserID" example:"01HGW2BBG0000000000000000"` } diff --git a/backend/internal/dtos/project_dto.go b/backend/internal/dtos/project_dto.go index f88f7f2..6ff6f58 100644 --- a/backend/internal/dtos/project_dto.go +++ b/backend/internal/dtos/project_dto.go @@ -2,6 +2,8 @@ package dto import ( "time" + + "github.com/timetracker/backend/internal/types" ) type ProjectDto struct { @@ -10,19 +12,17 @@ type ProjectDto struct { UpdatedAt time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` LastEditorID string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` Name string `json:"name" example:"Time Tracking App"` - CustomerID string `json:"customerId" example:"01HGW2BBG0000000000000000"` + CustomerID *string `json:"customerId" example:"01HGW2BBG0000000000000000"` } type ProjectCreateDto struct { - Name string `json:"name" example:"Time Tracking App"` - CustomerID string `json:"customerId" example:"01HGW2BBG0000000000000000"` + Name string `json:"name" example:"Time Tracking App"` + CustomerID *string `json:"customerId" example:"01HGW2BBG0000000000000000"` } type ProjectUpdateDto struct { - ID string `json:"id" example:"01HGW2BBG0000000000000000"` - CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` - UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` - LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` - Name *string `json:"name" example:"Time Tracking App"` - CustomerID *string `json:"customerId" example:"01HGW2BBG0000000000000000"` + CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` + UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` + Name *string `json:"name" example:"Time Tracking App"` + CustomerID types.Nullable[string] `json:"customerId" example:"01HGW2BBG0000000000000000"` } diff --git a/backend/internal/dtos/timeentry_dto.go b/backend/internal/dtos/timeentry_dto.go index 915a921..27a100d 100644 --- a/backend/internal/dtos/timeentry_dto.go +++ b/backend/internal/dtos/timeentry_dto.go @@ -29,15 +29,13 @@ type TimeEntryCreateDto struct { } type TimeEntryUpdateDto struct { - ID string `json:"id" example:"01HGW2BBG0000000000000000"` - CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` - UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` - LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` - UserID *string `json:"userId" example:"01HGW2BBG0000000000000000"` - ProjectID *string `json:"projectId" example:"01HGW2BBG0000000000000000"` - ActivityID *string `json:"activityId" example:"01HGW2BBG0000000000000000"` - Start *time.Time `json:"start" example:"2024-01-01T08:00:00Z"` - End *time.Time `json:"end" example:"2024-01-01T17:00:00Z"` - Description *string `json:"description" example:"Working on the Time Tracking App"` - Billable *int `json:"billable" example:"100"` // Percentage (0-100) + CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` + UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` + UserID *string `json:"userId" example:"01HGW2BBG0000000000000000"` + ProjectID *string `json:"projectId" example:"01HGW2BBG0000000000000000"` + ActivityID *string `json:"activityId" example:"01HGW2BBG0000000000000000"` + Start *time.Time `json:"start" example:"2024-01-01T08:00:00Z"` + End *time.Time `json:"end" example:"2024-01-01T17:00:00Z"` + Description *string `json:"description" example:"Working on the Time Tracking App"` + Billable *int `json:"billable" example:"100"` // Percentage (0-100) } diff --git a/backend/internal/dtos/user_dto.go b/backend/internal/dtos/user_dto.go index d5dbae8..89d49fd 100644 --- a/backend/internal/dtos/user_dto.go +++ b/backend/internal/dtos/user_dto.go @@ -26,13 +26,12 @@ type UserCreateDto struct { } type UserUpdateDto struct { - ID string `json:"id" example:"01HGW2BBG0000000000000000"` - CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` - UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` - LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` - Email *string `json:"email" example:"test@example.com"` - Password *string `json:"password" example:"password123"` - Role *string `json:"role" example:"admin"` - CompanyID *types.Nullable[string] `json:"companyId" example:"01HGW2BBG0000000000000000"` - HourlyRate *float64 `json:"hourlyRate" example:"50.00"` + CreatedAt *time.Time `json:"createdAt" example:"2024-01-01T00:00:00Z"` + UpdatedAt *time.Time `json:"updatedAt" example:"2024-01-01T00:00:00Z"` + LastEditorID *string `json:"lastEditorID" example:"01HGW2BBG0000000000000000"` + Email *string `json:"email" example:"test@example.com"` + Password *string `json:"password" example:"password123"` + Role *string `json:"role" example:"admin"` + CompanyID types.Nullable[string] `json:"companyId" example:"01HGW2BBG0000000000000000"` + HourlyRate *float64 `json:"hourlyRate" example:"50.00"` } diff --git a/backend/internal/models/base.go b/backend/internal/models/base.go index 2017cb8..9d888ed 100644 --- a/backend/internal/models/base.go +++ b/backend/internal/models/base.go @@ -1,7 +1,6 @@ package models import ( - "fmt" "math/rand" "time" @@ -24,7 +23,6 @@ func (eb *EntityBase) BeforeCreate(tx *gorm.DB) error { entropy := ulid.Monotonic(rand.New(rand.NewSource(time.Now().UnixNano())), 0) newID := ulid.MustNew(ulid.Timestamp(time.Now()), entropy) eb.ID = types.ULID{ULID: newID} - fmt.Println("Generated ID:", eb.ID) } return nil } diff --git a/backend/internal/models/project.go b/backend/internal/models/project.go index 1f6e65a..2247098 100644 --- a/backend/internal/models/project.go +++ b/backend/internal/models/project.go @@ -13,8 +13,8 @@ import ( // Project represents a project in the system type Project struct { EntityBase - Name string `gorm:"column:name;not null"` - CustomerID types.ULID `gorm:"column:customer_id;type:bytea;not null"` + Name string `gorm:"column:name;not null"` + CustomerID *types.ULID `gorm:"column:customer_id;type:bytea;not null"` // Relationships (for Eager Loading) Customer *Customer `gorm:"foreignKey:CustomerID"` @@ -28,7 +28,7 @@ func (Project) TableName() string { // ProjectCreate contains the fields for creating a new project type ProjectCreate struct { Name string - CustomerID types.ULID + CustomerID *types.ULID } // ProjectUpdate contains the updatable fields of a project @@ -122,12 +122,14 @@ func CreateProject(ctx context.Context, create ProjectCreate) (*Project, error) } // Check if the customer exists - customer, err := GetCustomerByID(ctx, create.CustomerID) - if err != nil { - return nil, fmt.Errorf("error checking the customer: %w", err) - } - if customer == nil { - return nil, errors.New("the specified customer does not exist") + if create.CustomerID == nil { + customer, err := GetCustomerByID(ctx, *create.CustomerID) + if err != nil { + return nil, fmt.Errorf("error checking the customer: %w", err) + } + if customer == nil { + return nil, errors.New("the specified customer does not exist") + } } project := Project{ diff --git a/backend/internal/models/user.go b/backend/internal/models/user.go index e34f8f2..2972dfc 100644 --- a/backend/internal/models/user.go +++ b/backend/internal/models/user.go @@ -63,12 +63,12 @@ type UserCreate struct { // UserUpdate contains the updatable fields of a user type UserUpdate struct { - ID types.ULID `gorm:"-"` // Exclude from updates - Email *string `gorm:"column:email"` - Password *string `gorm:"-"` // Not stored directly in DB - Role *string `gorm:"column:role"` - CompanyID *types.ULID `gorm:"column:company_id"` - HourlyRate *float64 `gorm:"column:hourly_rate"` + ID types.ULID `gorm:"-"` // Exclude from updates + Email *string `gorm:"column:email"` + Password *string `gorm:"-"` // Not stored directly in DB + Role *string `gorm:"column:role"` + CompanyID types.Nullable[types.ULID] `gorm:"column:company_id"` + HourlyRate *float64 `gorm:"column:hourly_rate"` } // PasswordData contains the data for password hash and salt @@ -448,13 +448,15 @@ func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) { } // If CompanyID is updated, check if it exists - if update.CompanyID != nil && (user.CompanyID == nil || update.CompanyID.Compare(*user.CompanyID) != 0) { - var companyCount int64 - if err := tx.Model(&Company{}).Where("id = ?", *update.CompanyID).Count(&companyCount).Error; err != nil { - return fmt.Errorf("error checking company: %w", err) - } - if companyCount == 0 { - return errors.New("the specified company does not exist") + if update.CompanyID.Valid && update.CompanyID.Value != nil { + if user.CompanyID == nil || *update.CompanyID.Value != *user.CompanyID { + var companyCount int64 + if err := tx.Model(&Company{}).Where("id = ?", *update.CompanyID.Value).Count(&companyCount).Error; err != nil { + return fmt.Errorf("error checking company: %w", err) + } + if companyCount == 0 { + return errors.New("the specified company does not exist") + } } } @@ -484,8 +486,13 @@ func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) { if update.Role != nil { updates["role"] = *update.Role } - if update.CompanyID != nil { - updates["company_id"] = *update.CompanyID + if update.CompanyID.Valid { + if update.CompanyID.Value == nil { + updates["company_id"] = nil + } else { + updates["company_id"] = *update.CompanyID.Value + } + } if update.HourlyRate != nil { updates["hourly_rate"] = *update.HourlyRate diff --git a/backend/internal/types/nullable.go b/backend/internal/types/nullable.go index 74f131b..664405d 100644 --- a/backend/internal/types/nullable.go +++ b/backend/internal/types/nullable.go @@ -18,6 +18,10 @@ func NewNullable[T any](value T) Nullable[T] { // Null erstellt eine leere Nullable-Instanz (ungesetzt) func Null[T any]() Nullable[T] { + return Nullable[T]{Valid: true} +} + +func Undefined[T any]() Nullable[T] { return Nullable[T]{Valid: false} }