199 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			199 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | |
| 
 | |
| import (
 | |
| 	"errors"
 | |
| 	"strings"
 | |
| 	"time"
 | |
| 
 | |
| 	"github.com/gin-gonic/gin"
 | |
| 	"github.com/golang-jwt/jwt/v5"
 | |
| 	"github.com/oklog/ulid/v2"
 | |
| 	"github.com/timetracker/backend/internal/api/utils"
 | |
| 	"github.com/timetracker/backend/internal/models"
 | |
| )
 | |
| 
 | |
| // JWT configuration
 | |
| const (
 | |
| 	// This should be moved to environment variables in production
 | |
| 	jwtSecret     = "your-secret-key-change-in-production"
 | |
| 	tokenDuration = 24 * time.Hour
 | |
| )
 | |
| 
 | |
| // Claims represents the JWT claims
 | |
| type Claims struct {
 | |
| 	UserID    string `json:"userId"`
 | |
| 	Email     string `json:"email"`
 | |
| 	Role      string `json:"role"`
 | |
| 	CompanyID string `json:"companyId"`
 | |
| 	jwt.RegisteredClaims
 | |
| }
 | |
| 
 | |
| // AuthMiddleware checks if the user is authenticated
 | |
| func AuthMiddleware() gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		// Get the Authorization header
 | |
| 		authHeader := c.GetHeader("Authorization")
 | |
| 		if authHeader == "" {
 | |
| 			utils.UnauthorizedResponse(c, "Authorization header is required")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// Check if the header has the Bearer prefix
 | |
| 		parts := strings.Split(authHeader, " ")
 | |
| 		if len(parts) != 2 || parts[0] != "Bearer" {
 | |
| 			utils.UnauthorizedResponse(c, "Invalid authorization format, expected 'Bearer TOKEN'")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		tokenString := parts[1]
 | |
| 		claims, err := validateToken(tokenString)
 | |
| 		if err != nil {
 | |
| 			utils.UnauthorizedResponse(c, "Invalid or expired token")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// Store user information in the context
 | |
| 		c.Set("userID", claims.UserID)
 | |
| 		c.Set("email", claims.Email)
 | |
| 		c.Set("role", claims.Role)
 | |
| 		c.Set("companyID", claims.CompanyID)
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // RoleMiddleware checks if the user has the required role
 | |
| func RoleMiddleware(roles ...string) gin.HandlerFunc {
 | |
| 	return func(c *gin.Context) {
 | |
| 		userRole, exists := c.Get("role")
 | |
| 		if !exists {
 | |
| 			utils.UnauthorizedResponse(c, "User role not found in context")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		// Check if the user's role is in the allowed roles
 | |
| 		roleStr, ok := userRole.(string)
 | |
| 		if !ok {
 | |
| 			utils.InternalErrorResponse(c, "Invalid role type in context")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		allowed := false
 | |
| 		for _, role := range roles {
 | |
| 			if roleStr == role {
 | |
| 				allowed = true
 | |
| 				break
 | |
| 			}
 | |
| 		}
 | |
| 
 | |
| 		if !allowed {
 | |
| 			utils.ForbiddenResponse(c, "Insufficient permissions")
 | |
| 			c.Abort()
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		c.Next()
 | |
| 	}
 | |
| }
 | |
| 
 | |
| // GenerateToken creates a new JWT token for a user
 | |
| func GenerateToken(user *models.User) (string, error) {
 | |
| 	// Create the claims
 | |
| 	claims := Claims{
 | |
| 		UserID:    user.ID.String(),
 | |
| 		Email:     user.Email,
 | |
| 		Role:      user.Role,
 | |
| 		CompanyID: user.CompanyID.String(),
 | |
| 		RegisteredClaims: jwt.RegisteredClaims{
 | |
| 			ExpiresAt: jwt.NewNumericDate(time.Now().Add(tokenDuration)),
 | |
| 			IssuedAt:  jwt.NewNumericDate(time.Now()),
 | |
| 			NotBefore: jwt.NewNumericDate(time.Now()),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	// Create the token
 | |
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 | |
| 
 | |
| 	// Sign the token
 | |
| 	tokenString, err := token.SignedString([]byte(jwtSecret))
 | |
| 	if err != nil {
 | |
| 		return "", err
 | |
| 	}
 | |
| 
 | |
| 	return tokenString, nil
 | |
| }
 | |
| 
 | |
| // validateToken validates a JWT token and returns the claims
 | |
| func validateToken(tokenString string) (*Claims, error) {
 | |
| 	// Parse the token
 | |
| 	token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
 | |
| 		// Validate the signing method
 | |
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | |
| 			return nil, errors.New("unexpected signing method")
 | |
| 		}
 | |
| 		return []byte(jwtSecret), nil
 | |
| 	})
 | |
| 
 | |
| 	if err != nil {
 | |
| 		return nil, err
 | |
| 	}
 | |
| 
 | |
| 	// Check if the token is valid
 | |
| 	if !token.Valid {
 | |
| 		return nil, errors.New("invalid token")
 | |
| 	}
 | |
| 
 | |
| 	// Get the claims
 | |
| 	claims, ok := token.Claims.(*Claims)
 | |
| 	if !ok {
 | |
| 		return nil, errors.New("invalid claims")
 | |
| 	}
 | |
| 
 | |
| 	return claims, nil
 | |
| }
 | |
| 
 | |
| // GetUserIDFromContext extracts the user ID from the context
 | |
| func GetUserIDFromContext(c *gin.Context) (ulid.ULID, error) {
 | |
| 	userID, exists := c.Get("userID")
 | |
| 	if !exists {
 | |
| 		return ulid.ULID{}, errors.New("user ID not found in context")
 | |
| 	}
 | |
| 
 | |
| 	userIDStr, ok := userID.(string)
 | |
| 	if !ok {
 | |
| 		return ulid.ULID{}, errors.New("invalid user ID type in context")
 | |
| 	}
 | |
| 
 | |
| 	id, err := ulid.Parse(userIDStr)
 | |
| 	if err != nil {
 | |
| 		return ulid.ULID{}, err
 | |
| 	}
 | |
| 
 | |
| 	return id, nil
 | |
| }
 | |
| 
 | |
| // GetCompanyIDFromContext extracts the company ID from the context
 | |
| func GetCompanyIDFromContext(c *gin.Context) (ulid.ULID, error) {
 | |
| 	companyID, exists := c.Get("companyID")
 | |
| 	if !exists {
 | |
| 		return ulid.ULID{}, errors.New("company ID not found in context")
 | |
| 	}
 | |
| 
 | |
| 	companyIDStr, ok := companyID.(string)
 | |
| 	if !ok {
 | |
| 		return ulid.ULID{}, errors.New("invalid company ID type in context")
 | |
| 	}
 | |
| 
 | |
| 	id, err := ulid.Parse(companyIDStr)
 | |
| 	if err != nil {
 | |
| 		return ulid.ULID{}, err
 | |
| 	}
 | |
| 
 | |
| 	return id, nil
 | |
| }
 |