feat: Enhance user update handling and introduce NullString type for optional fields

This commit is contained in:
2025-03-12 07:54:00 +00:00
parent 0379ea4ae4
commit da115dc3f6
8 changed files with 111 additions and 54 deletions
+21 -9
View File
@@ -35,12 +35,12 @@ const (
// User represents a user in the system
type User struct {
EntityBase
Email string `gorm:"column:email;unique;not null"`
Salt string `gorm:"column:salt;not null;type:varchar(64)"` // Base64-encoded Salt
Hash string `gorm:"column:hash;not null;type:varchar(128)"` // Base64-encoded Hash
Role string `gorm:"column:role;not null;default:'user'"`
CompanyID ULIDWrapper `gorm:"column:company_id;type:bytea;not null;index"`
HourlyRate float64 `gorm:"column:hourly_rate;not null;default:0"`
Email string `gorm:"column:email;unique;not null"`
Salt string `gorm:"column:salt;not null;type:varchar(64)"` // Base64-encoded Salt
Hash string `gorm:"column:hash;not null;type:varchar(128)"` // Base64-encoded Hash
Role string `gorm:"column:role;not null;default:'user'"`
CompanyID *ULIDWrapper `gorm:"column:company_id;type:bytea;index"`
HourlyRate float64 `gorm:"column:hourly_rate;not null;default:0"`
// Relationship for Eager Loading
Company *Company `gorm:"foreignKey:CompanyID"`
@@ -335,10 +335,22 @@ func GetAllUsers(ctx context.Context) ([]User, error) {
return users, nil
}
// getCompanyCondition builds the company condition for queries
func getCompanyCondition(companyID *ULIDWrapper) func(db *gorm.DB) *gorm.DB {
return func(db *gorm.DB) *gorm.DB {
if companyID == nil {
return db.Where("company_id IS NULL")
}
return db.Where("company_id = ?", *companyID)
}
}
// GetUsersByCompanyID returns all users of a company
func GetUsersByCompanyID(ctx context.Context, companyID ULIDWrapper) ([]User, error) {
var users []User
result := GetEngine(ctx).Where("company_id = ?", companyID).Find(&users)
// Apply the dynamic company condition
condition := getCompanyCondition(&companyID)
result := GetEngine(ctx).Scopes(condition).Find(&users)
if result.Error != nil {
return nil, result.Error
}
@@ -386,7 +398,7 @@ func CreateUser(ctx context.Context, create UserCreate) (*User, error) {
Salt: pwData.Salt,
Hash: pwData.Hash,
Role: create.Role,
CompanyID: create.CompanyID,
CompanyID: &create.CompanyID,
HourlyRate: create.HourlyRate,
}
@@ -435,7 +447,7 @@ func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) {
}
// If CompanyID is updated, check if it exists
if update.CompanyID != nil && update.CompanyID.Compare(user.CompanyID) != 0 {
if update.CompanyID != nil && (user.CompanyID == nil || update.CompanyID.Compare(*user.CompanyID) != 0) {
var companyCount int64
if err := tx.Model(&Company{}).Where("id = ?", *update.CompanyID).Count(&companyCount).Error; err != nil {
return fmt.Errorf("error checking company: %w", err)