package middleware import ( "errors" "strings" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/oklog/ulid/v2" "github.com/timetracker/backend/internal/api/utils" "github.com/timetracker/backend/internal/models" ) // JWT configuration const ( // This should be moved to environment variables in production jwtSecret = "your-secret-key-change-in-production" tokenDuration = 24 * time.Hour ) // Claims represents the JWT claims type Claims struct { UserID string `json:"userId"` Email string `json:"email"` Role string `json:"role"` CompanyID string `json:"companyId"` jwt.RegisteredClaims } // AuthMiddleware checks if the user is authenticated func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Get the Authorization header authHeader := c.GetHeader("Authorization") if authHeader == "" { utils.UnauthorizedResponse(c, "Authorization header is required") c.Abort() return } // Check if the header has the Bearer prefix parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { utils.UnauthorizedResponse(c, "Invalid authorization format, expected 'Bearer TOKEN'") c.Abort() return } tokenString := parts[1] claims, err := validateToken(tokenString) if err != nil { utils.UnauthorizedResponse(c, "Invalid or expired token") c.Abort() return } // Store user information in the context c.Set("userID", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Set("companyID", claims.CompanyID) c.Next() } } // RoleMiddleware checks if the user has the required role func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("role") if !exists { utils.UnauthorizedResponse(c, "User role not found in context") c.Abort() return } // Check if the user's role is in the allowed roles roleStr, ok := userRole.(string) if !ok { utils.InternalErrorResponse(c, "Invalid role type in context") c.Abort() return } allowed := false for _, role := range roles { if roleStr == role { allowed = true break } } if !allowed { utils.ForbiddenResponse(c, "Insufficient permissions") c.Abort() return } c.Next() } } // GenerateToken creates a new JWT token for a user func GenerateToken(user *models.User) (string, error) { // Create the claims claims := Claims{ UserID: user.ID.String(), Email: user.Email, Role: user.Role, CompanyID: user.CompanyID.String(), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(tokenDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, } // Create the token token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign the token tokenString, err := token.SignedString([]byte(jwtSecret)) if err != nil { return "", err } return tokenString, nil } // validateToken validates a JWT token and returns the claims func validateToken(tokenString string) (*Claims, error) { // Parse the token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } return []byte(jwtSecret), nil }) if err != nil { return nil, err } // Check if the token is valid if !token.Valid { return nil, errors.New("invalid token") } // Get the claims claims, ok := token.Claims.(*Claims) if !ok { return nil, errors.New("invalid claims") } return claims, nil } // GetUserIDFromContext extracts the user ID from the context func GetUserIDFromContext(c *gin.Context) (ulid.ULID, error) { userID, exists := c.Get("userID") if !exists { return ulid.ULID{}, errors.New("user ID not found in context") } userIDStr, ok := userID.(string) if !ok { return ulid.ULID{}, errors.New("invalid user ID type in context") } id, err := ulid.Parse(userIDStr) if err != nil { return ulid.ULID{}, err } return id, nil } // GetCompanyIDFromContext extracts the company ID from the context func GetCompanyIDFromContext(c *gin.Context) (ulid.ULID, error) { companyID, exists := c.Get("companyID") if !exists { return ulid.ULID{}, errors.New("company ID not found in context") } companyIDStr, ok := companyID.(string) if !ok { return ulid.ULID{}, errors.New("invalid company ID type in context") } id, err := ulid.Parse(companyIDStr) if err != nil { return ulid.ULID{}, err } return id, nil }