package middleware import ( "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "os" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/joho/godotenv" "github.com/oklog/ulid/v2" "github.com/timetracker/backend/internal/api/utils" "github.com/timetracker/backend/internal/models" ) var ( jwtSecret string tokenDuration = 24 * time.Hour ) func init() { // Load .env file _ = godotenv.Load() // Get JWT secret from environment jwtSecret = os.Getenv("JWT_SECRET") // Generate a random secret if none is provided if jwtSecret == "" { randomBytes := make([]byte, 32) _, err := rand.Read(randomBytes) if err != nil { panic("failed to generate JWT secret: " + err.Error()) } jwtSecret = string(randomBytes) } // Generate and store RSA keys if configured if os.Getenv("JWT_KEY_GENERATE") == "true" { keyDir := os.Getenv("JWT_KEY_DIR") if keyDir == "" { keyDir = "./keys" } // Create directory if it doesn't exist if err := os.MkdirAll(keyDir, 0755); err != nil { panic("failed to create key directory: " + err.Error()) } // Generate RSA key pair privateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { panic("failed to generate RSA key pair: " + err.Error()) } // Save private key privateKeyFile, err := os.Create(fmt.Sprintf("%s/private.pem", keyDir)) if err != nil { panic("failed to create private key file: " + err.Error()) } defer privateKeyFile.Close() privateKeyPEM := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), } if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil { panic("failed to encode private key: " + err.Error()) } // Save public key publicKeyFile, err := os.Create(fmt.Sprintf("%s/public.pem", keyDir)) if err != nil { panic("failed to create public key file: " + err.Error()) } defer publicKeyFile.Close() publicKeyPEM := &pem.Block{ Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&privateKey.PublicKey), } if err := pem.Encode(publicKeyFile, publicKeyPEM); err != nil { panic("failed to encode public key: " + err.Error()) } } } // Claims represents the JWT claims type Claims struct { UserID string `json:"userId"` Email string `json:"email"` Role string `json:"role"` CompanyID string `json:"companyId"` jwt.RegisteredClaims } // AuthMiddleware checks if the user is authenticated func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Get the token from cookie tokenString, err := c.Cookie("jwt") if err != nil { utils.UnauthorizedResponse(c, "Authentication cookie is required") c.Abort() return } claims, err := validateToken(tokenString) if err != nil { utils.UnauthorizedResponse(c, "Invalid or expired token") c.Abort() return } // Store user information in the context c.Set("userID", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Set("companyID", claims.CompanyID) c.Next() } } // RoleMiddleware checks if the user has the required role func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("role") if !exists { utils.UnauthorizedResponse(c, "User role not found in context") c.Abort() return } // Check if the user's role is in the allowed roles roleStr, ok := userRole.(string) if !ok { utils.InternalErrorResponse(c, "Invalid role type in context") c.Abort() return } allowed := false for _, role := range roles { if roleStr == role { allowed = true break } } if !allowed { utils.ForbiddenResponse(c, "Insufficient permissions") c.Abort() return } c.Next() } } // GenerateToken creates a new JWT token for a user func GenerateToken(user *models.User, c *gin.Context) (string, error) { // Create the claims claims := Claims{ UserID: user.ID.String(), Email: user.Email, Role: user.Role, CompanyID: user.CompanyID.String(), RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(tokenDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, } // Create the token token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) // Sign the token tokenString, err := token.SignedString([]byte(jwtSecret)) if err != nil { return "", err } // Set the cookie c.SetCookie("jwt", tokenString, int(tokenDuration.Seconds()), "/", "", true, true) return tokenString, nil } // validateToken validates a JWT token and returns the claims func validateToken(tokenString string) (*Claims, error) { // Parse the token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { // Validate the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, errors.New("unexpected signing method") } return []byte(jwtSecret), nil }) if err != nil { return nil, err } // Check if the token is valid if !token.Valid { return nil, errors.New("invalid token") } // Get the claims claims, ok := token.Claims.(*Claims) if !ok { return nil, errors.New("invalid claims") } return claims, nil } // GetUserIDFromContext extracts the user ID from the context func GetUserIDFromContext(c *gin.Context) (ulid.ULID, error) { userID, exists := c.Get("userID") if !exists { return ulid.ULID{}, errors.New("user ID not found in context") } userIDStr, ok := userID.(string) if !ok { return ulid.ULID{}, errors.New("invalid user ID type in context") } id, err := ulid.Parse(userIDStr) if err != nil { return ulid.ULID{}, err } return id, nil } // GetCompanyIDFromContext extracts the company ID from the context func GetCompanyIDFromContext(c *gin.Context) (ulid.ULID, error) { companyID, exists := c.Get("companyID") if !exists { return ulid.ULID{}, errors.New("company ID not found in context") } companyIDStr, ok := companyID.(string) if !ok { return ulid.ULID{}, errors.New("invalid company ID type in context") } id, err := ulid.Parse(companyIDStr) if err != nil { return ulid.ULID{}, err } return id, nil }