Some checks failed
Gitea Actions Demo / Explore-Gitea-Actions (push) Has been cancelled
557 lines
15 KiB
Go
557 lines
15 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"slices"
|
|
|
|
"github.com/timetracker/backend/internal/db"
|
|
"github.com/timetracker/backend/internal/types"
|
|
"golang.org/x/crypto/argon2"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
// User represents a user in the system
|
|
type User struct {
|
|
EntityBase
|
|
Email string `gorm:"column:email;unique;not null"`
|
|
Salt string `gorm:"column:salt;not null;type:varchar(64)"` // Base64-encoded Salt
|
|
Hash string `gorm:"column:hash;not null;type:varchar(128)"` // Base64-encoded Hash
|
|
Role string `gorm:"column:role;not null;default:'user'"`
|
|
CompanyID *types.ULID `gorm:"column:company_id;type:bytea;index"`
|
|
HourlyRate float64 `gorm:"column:hourly_rate;not null;default:0"`
|
|
Companies []string `gorm:"type:text[]"`
|
|
|
|
// Relationship for Eager Loading
|
|
Company *Company `gorm:"foreignKey:CompanyID"`
|
|
}
|
|
|
|
// TableName provides the table name for GORM
|
|
func (User) TableName() string {
|
|
return "users"
|
|
}
|
|
|
|
// UserCreate contains the fields for creating a new user
|
|
type UserCreate struct {
|
|
Email string
|
|
Password string
|
|
Role string
|
|
CompanyID *types.ULID
|
|
HourlyRate float64
|
|
}
|
|
|
|
// UserUpdate contains the updatable fields of a user
|
|
type UserUpdate struct {
|
|
ID types.ULID `gorm:"-"` // Exclude from updates
|
|
Email *string `gorm:"column:email"`
|
|
Password *string `gorm:"-"` // Not stored directly in DB
|
|
Role *string `gorm:"column:role"`
|
|
CompanyID types.Nullable[types.ULID] `gorm:"column:company_id"`
|
|
HourlyRate *float64 `gorm:"column:hourly_rate"`
|
|
}
|
|
|
|
// PasswordData contains the data for password hash and salt
|
|
type PasswordData struct {
|
|
Salt string
|
|
Hash string
|
|
}
|
|
|
|
// GenerateSalt generates a cryptographically secure salt
|
|
func GenerateSalt() (string, error) {
|
|
salt := make([]byte, SaltLength)
|
|
_, err := rand.Read(salt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return base64.StdEncoding.EncodeToString(salt), nil
|
|
}
|
|
|
|
// HashPassword creates a secure password hash with Argon2id and a random salt
|
|
func HashPassword(password string) (PasswordData, error) {
|
|
// Generate a cryptographically secure salt
|
|
saltStr, err := GenerateSalt()
|
|
if err != nil {
|
|
return PasswordData{}, fmt.Errorf("error generating salt: %w", err)
|
|
}
|
|
|
|
salt, err := base64.StdEncoding.DecodeString(saltStr)
|
|
if err != nil {
|
|
return PasswordData{}, fmt.Errorf("error decoding salt: %w", err)
|
|
}
|
|
|
|
// Create hash with Argon2id (modern, secure hash function)
|
|
hash := argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeyLen)
|
|
hashStr := base64.StdEncoding.EncodeToString(hash)
|
|
|
|
return PasswordData{
|
|
Salt: saltStr,
|
|
Hash: hashStr,
|
|
}, nil
|
|
}
|
|
|
|
// VerifyPassword checks if a password matches the hash
|
|
func VerifyPassword(password, saltStr, hashStr string) (bool, error) {
|
|
salt, err := base64.StdEncoding.DecodeString(saltStr)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error decoding salt: %w", err)
|
|
}
|
|
|
|
hash, err := base64.StdEncoding.DecodeString(hashStr)
|
|
if err != nil {
|
|
return false, fmt.Errorf("error decoding hash: %w", err)
|
|
}
|
|
|
|
// Calculate hash with the same salt
|
|
computedHash := argon2.IDKey([]byte(password), salt, ArgonTime, ArgonMemory, ArgonThreads, ArgonKeyLen)
|
|
|
|
// Constant time comparison to prevent timing attacks
|
|
return hmacEqual(hash, computedHash), nil
|
|
}
|
|
|
|
// hmacEqual performs a constant-time comparison (prevents timing attacks)
|
|
func hmacEqual(a, b []byte) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
var result byte
|
|
for i := 0; i < len(a); i++ {
|
|
result |= a[i] ^ b[i]
|
|
}
|
|
|
|
return result == 0
|
|
}
|
|
|
|
// Validate checks if the Create structure contains valid data
|
|
func (uc *UserCreate) Validate() error {
|
|
if uc.Email == "" {
|
|
return errors.New("email cannot be empty")
|
|
}
|
|
|
|
// Check email format
|
|
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
|
if !emailRegex.MatchString(uc.Email) {
|
|
return errors.New("invalid email format")
|
|
}
|
|
|
|
if uc.Password == "" {
|
|
return errors.New("password cannot be empty")
|
|
}
|
|
|
|
// Check password complexity
|
|
if len(uc.Password) < 10 {
|
|
return errors.New("password must be at least 10 characters long")
|
|
}
|
|
|
|
// More complex password validation
|
|
var (
|
|
hasUpper = false
|
|
hasLower = false
|
|
hasNumber = false
|
|
hasSpecial = false
|
|
)
|
|
|
|
for _, char := range uc.Password {
|
|
switch {
|
|
case 'A' <= char && char <= 'Z':
|
|
hasUpper = true
|
|
case 'a' <= char && char <= 'z':
|
|
hasLower = true
|
|
case '0' <= char && char <= '9':
|
|
hasNumber = true
|
|
case char == '!' || char == '@' || char == '#' || char == '$' ||
|
|
char == '%' || char == '^' || char == '&' || char == '*':
|
|
hasSpecial = true
|
|
}
|
|
}
|
|
|
|
if !hasUpper || !hasLower || !hasNumber || !hasSpecial {
|
|
return errors.New("password must contain uppercase letters, lowercase letters, numbers, and special characters")
|
|
}
|
|
|
|
// Check role
|
|
if uc.Role == "" {
|
|
uc.Role = RoleUser // Set default role
|
|
} else {
|
|
validRoles := []string{RoleAdmin, RoleUser, RoleViewer}
|
|
isValid := slices.Contains(validRoles, uc.Role)
|
|
if !isValid {
|
|
return fmt.Errorf("invalid role: %s, allowed are: %s",
|
|
uc.Role, strings.Join(validRoles, ", "))
|
|
}
|
|
}
|
|
|
|
if uc.CompanyID != nil && uc.CompanyID.Compare(types.ULID{}) == 0 {
|
|
return errors.New("companyID cannot be empty")
|
|
}
|
|
|
|
if uc.HourlyRate < 0 {
|
|
return errors.New("hourly rate cannot be negative")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Validate checks if the Update structure contains valid data
|
|
func (uu *UserUpdate) Validate() error {
|
|
if uu.Email != nil && *uu.Email == "" {
|
|
return errors.New("email cannot be empty")
|
|
}
|
|
|
|
// Check email format
|
|
if uu.Email != nil {
|
|
emailRegex := regexp.MustCompile(`^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`)
|
|
if !emailRegex.MatchString(*uu.Email) {
|
|
return errors.New("invalid email format")
|
|
}
|
|
}
|
|
|
|
if uu.Password != nil {
|
|
if *uu.Password == "" {
|
|
return errors.New("password cannot be empty")
|
|
}
|
|
|
|
// Check password complexity
|
|
if len(*uu.Password) < 10 {
|
|
return errors.New("password must be at least 10 characters long")
|
|
}
|
|
|
|
// More complex password validation
|
|
var (
|
|
hasUpper = false
|
|
hasLower = false
|
|
hasNumber = false
|
|
hasSpecial = false
|
|
)
|
|
|
|
for _, char := range *uu.Password {
|
|
switch {
|
|
case 'A' <= char && char <= 'Z':
|
|
hasUpper = true
|
|
case 'a' <= char && char <= 'z':
|
|
hasLower = true
|
|
case '0' <= char && char <= '9':
|
|
hasNumber = true
|
|
case char == '!' || char == '@' || char == '#' || char == '$' ||
|
|
char == '%' || char == '^' || char == '&' || char == '*':
|
|
hasSpecial = true
|
|
}
|
|
}
|
|
|
|
if !hasUpper || !hasLower || !hasNumber || !hasSpecial {
|
|
return errors.New("password must contain uppercase letters, lowercase letters, numbers, and special characters")
|
|
}
|
|
}
|
|
|
|
// Check role
|
|
if uu.Role != nil {
|
|
validRoles := []string{RoleAdmin, RoleUser, RoleViewer}
|
|
isValid := false
|
|
for _, role := range validRoles {
|
|
if *uu.Role == role {
|
|
isValid = true
|
|
break
|
|
}
|
|
}
|
|
if !isValid {
|
|
return fmt.Errorf("invalid role: %s, allowed are: %s",
|
|
*uu.Role, strings.Join(validRoles, ", "))
|
|
}
|
|
}
|
|
|
|
if uu.HourlyRate != nil && *uu.HourlyRate < 0 {
|
|
return errors.New("hourly rate cannot be negative")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetUserByID finds a user by their ID
|
|
func GetUserByID(ctx context.Context, id types.ULID) (*User, error) {
|
|
var user User
|
|
result := db.GetEngine(ctx).Where("id = ?", id).First(&user)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, result.Error
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetUserByEmail finds a user by their email
|
|
func GetUserByEmail(ctx context.Context, email string) (*User, error) {
|
|
var user User
|
|
result := db.GetEngine(ctx).Where("email = ?", email).First(&user)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, result.Error
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetUserWithCompany loads a user with their company
|
|
func GetUserWithCompany(ctx context.Context, id types.ULID) (*User, error) {
|
|
var user User
|
|
result := db.GetEngine(ctx).Preload("Company").Where("id = ?", id).First(&user)
|
|
if result.Error != nil {
|
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
|
return nil, nil
|
|
}
|
|
return nil, result.Error
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
// GetAllUsers returns all users
|
|
func GetAllUsers(ctx context.Context) ([]User, error) {
|
|
var users []User
|
|
result := db.GetEngine(ctx).Find(&users)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
// getCompanyCondition builds the company condition for queries
|
|
func getCompanyCondition(companyID *types.ULID) func(db *gorm.DB) *gorm.DB {
|
|
return func(db *gorm.DB) *gorm.DB {
|
|
if companyID == nil {
|
|
return db.Where("company_id IS NULL")
|
|
}
|
|
return db.Where("company_id = ?", *companyID)
|
|
}
|
|
}
|
|
|
|
// GetUsersByCompanyID returns all users of a company
|
|
func GetUsersByCompanyID(ctx context.Context, companyID types.ULID) ([]User, error) {
|
|
var users []User
|
|
// Apply the dynamic company condition
|
|
condition := getCompanyCondition(&companyID)
|
|
result := db.GetEngine(ctx).Scopes(condition).Find(&users)
|
|
if result.Error != nil {
|
|
return nil, result.Error
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
// CreateUser creates a new user with validation and secure password hashing
|
|
func CreateUser(ctx context.Context, create UserCreate) (*User, error) {
|
|
// Validation
|
|
if err := create.Validate(); err != nil {
|
|
return nil, fmt.Errorf("validation error: %w", err)
|
|
}
|
|
|
|
// Start a transaction
|
|
var user *User
|
|
|
|
err := db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// Check if email already exists
|
|
var count int64
|
|
if err := tx.Model(&User{}).Where("email = ?", create.Email).Count(&count).Error; err != nil {
|
|
return fmt.Errorf("error checking email: %w", err)
|
|
}
|
|
if count > 0 {
|
|
return errors.New("email is already in use")
|
|
}
|
|
|
|
if create.CompanyID != nil {
|
|
// Check if company exists
|
|
var companyCount int64
|
|
if err := tx.Model(&Company{}).Where("id = ?", create.CompanyID).Count(&companyCount).Error; err != nil {
|
|
return fmt.Errorf("error checking company: %w", err)
|
|
}
|
|
if companyCount == 0 {
|
|
return errors.New("the specified company does not exist")
|
|
}
|
|
}
|
|
|
|
// Hash password with unique salt
|
|
pwData, err := HashPassword(create.Password)
|
|
if err != nil {
|
|
return fmt.Errorf("error hashing password: %w", err)
|
|
}
|
|
|
|
// Create user with salt and hash stored separately
|
|
newUser := User{
|
|
Email: create.Email,
|
|
Salt: pwData.Salt,
|
|
Hash: pwData.Hash,
|
|
Role: create.Role,
|
|
CompanyID: create.CompanyID,
|
|
HourlyRate: create.HourlyRate,
|
|
}
|
|
|
|
if err := tx.Create(&newUser).Error; err != nil {
|
|
return fmt.Errorf("error creating user: %w", err)
|
|
}
|
|
|
|
user = &newUser
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// UpdateUser updates an existing user
|
|
func UpdateUser(ctx context.Context, update UserUpdate) (*User, error) {
|
|
// Validation
|
|
if err := update.Validate(); err != nil {
|
|
return nil, fmt.Errorf("validation error: %w", err)
|
|
}
|
|
|
|
// Find user
|
|
user, err := GetUserByID(ctx, update.ID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
return nil, errors.New("user not found")
|
|
}
|
|
|
|
// Start a transaction for the update
|
|
err = db.GetEngine(ctx).Transaction(func(tx *gorm.DB) error {
|
|
// If email is updated, check if it's already in use
|
|
if update.Email != nil && *update.Email != user.Email {
|
|
var count int64
|
|
if err := tx.Model(&User{}).Where("email = ? AND id != ?", *update.Email, update.ID).Count(&count).Error; err != nil {
|
|
return fmt.Errorf("error checking email: %w", err)
|
|
}
|
|
if count > 0 {
|
|
return errors.New("email is already in use")
|
|
}
|
|
}
|
|
|
|
// If CompanyID is updated, check if it exists
|
|
if update.CompanyID.Valid && update.CompanyID.Value != nil {
|
|
if user.CompanyID == nil || *update.CompanyID.Value != *user.CompanyID {
|
|
var companyCount int64
|
|
if err := tx.Model(&Company{}).Where("id = ?", *update.CompanyID.Value).Count(&companyCount).Error; err != nil {
|
|
return fmt.Errorf("error checking company: %w", err)
|
|
}
|
|
if companyCount == 0 {
|
|
return errors.New("the specified company does not exist")
|
|
}
|
|
}
|
|
}
|
|
|
|
// If password is updated, rehash with new salt
|
|
if update.Password != nil {
|
|
pwData, err := HashPassword(*update.Password)
|
|
if err != nil {
|
|
return fmt.Errorf("error hashing password: %w", err)
|
|
}
|
|
|
|
// Update salt and hash directly in the model
|
|
if err := tx.Model(user).Updates(map[string]any{
|
|
"salt": pwData.Salt,
|
|
"hash": pwData.Hash,
|
|
}).Error; err != nil {
|
|
return fmt.Errorf("error updating password: %w", err)
|
|
}
|
|
}
|
|
|
|
// Create map for generic update
|
|
updates := make(map[string]any)
|
|
|
|
// Add only non-password fields to the update
|
|
if update.Email != nil {
|
|
updates["email"] = *update.Email
|
|
}
|
|
if update.Role != nil {
|
|
updates["role"] = *update.Role
|
|
}
|
|
if update.CompanyID.Valid {
|
|
if update.CompanyID.Value == nil {
|
|
updates["company_id"] = nil
|
|
} else {
|
|
updates["company_id"] = *update.CompanyID.Value
|
|
}
|
|
}
|
|
if update.HourlyRate != nil {
|
|
updates["hourly_rate"] = *update.HourlyRate
|
|
}
|
|
|
|
// Only execute generic update if there are changes
|
|
if len(updates) > 0 {
|
|
if err := tx.Model(user).Updates(updates).Error; err != nil {
|
|
return fmt.Errorf("error updating user: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Load updated data from the database
|
|
return GetUserByID(ctx, update.ID)
|
|
}
|
|
|
|
// DeleteUser deletes a user by their ID
|
|
func DeleteUser(ctx context.Context, id types.ULID) error {
|
|
// Here one could check if dependent entities exist
|
|
// e.g., don't delete if time entries still exist
|
|
|
|
result := db.GetEngine(ctx).Delete(&User{}, id)
|
|
if result.Error != nil {
|
|
return fmt.Errorf("error deleting user: %w", result.Error)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// AuthenticateUser authenticates a user with email and password
|
|
func AuthenticateUser(ctx context.Context, email, password string) (*User, error) {
|
|
user, err := GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if user == nil {
|
|
// Same error message to avoid revealing information about existing accounts
|
|
return nil, errors.New("invalid login credentials")
|
|
}
|
|
|
|
// Verify password with the stored salt
|
|
isValid, err := VerifyPassword(password, user.Salt, user.Hash)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error verifying password: %w", err)
|
|
}
|
|
|
|
if !isValid {
|
|
return nil, errors.New("invalid login credentials")
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
// Argon2 Parameters
|
|
const (
|
|
// Recommended values for Argon2id
|
|
ArgonTime = 1
|
|
ArgonMemory = 64 * 1024 // 64MB
|
|
ArgonThreads = 4
|
|
ArgonKeyLen = 32
|
|
SaltLength = 16
|
|
)
|
|
|
|
// Role Constants
|
|
const (
|
|
RoleAdmin = "admin"
|
|
RoleUser = "user"
|
|
RoleViewer = "viewer"
|
|
)
|