package middleware import ( "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "fmt" "os" "path/filepath" "time" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/oklog/ulid/v2" "github.com/timetracker/backend/internal/api/responses" "github.com/timetracker/backend/internal/config" "github.com/timetracker/backend/internal/models" "github.com/timetracker/backend/internal/types" ) var ( signKey *rsa.PrivateKey verifyKey *rsa.PublicKey ) // InitJWTKeys initializes the JWT keys func InitJWTKeys() error { cfg := config.MustLoadConfig() // If a secret is provided, we'll use HMAC-SHA256, so no need for certificates if cfg.JWTConfig.Secret != "" { println("Using HMAC-SHA256 for JWT") return nil } // Check if keys exist privKeyPath := filepath.Join(cfg.JWTConfig.KeyDir, cfg.JWTConfig.PrivKeyFile) pubKeyPath := filepath.Join(cfg.JWTConfig.KeyDir, cfg.JWTConfig.PubKeyFile) keysExist := fileExists(privKeyPath) && fileExists(pubKeyPath) // Generate keys if they don't exist and KeyGenerate is true if !keysExist && cfg.JWTConfig.KeyGenerate { println("Generating RSA keys") if err := generateRSAKeys(cfg.JWTConfig); err != nil { return fmt.Errorf("failed to generate RSA keys: %w", err) } } else if !keysExist { return errors.New("JWT keys not found and key generation is disabled") } // Load keys var err error signKey, err = loadPrivateKey(privKeyPath) if err != nil { return fmt.Errorf("failed to load private key: %w", err) } verifyKey, err = loadPublicKey(pubKeyPath) if err != nil { return fmt.Errorf("failed to load public key: %w", err) } return nil } // fileExists checks if a file exists func fileExists(path string) bool { _, err := os.Stat(path) return !os.IsNotExist(err) } // generateRSAKeys generates RSA keys and saves them to disk func generateRSAKeys(cfg models.JWTConfig) error { // Create key directory if it doesn't exist if err := os.MkdirAll(cfg.KeyDir, 0700); err != nil { return fmt.Errorf("failed to create key directory: %w", err) } // Generate private key privateKey, err := rsa.GenerateKey(rand.Reader, cfg.KeyBits) if err != nil { return fmt.Errorf("failed to generate private key: %w", err) } // Save private key privKeyPath := filepath.Join(cfg.KeyDir, cfg.PrivKeyFile) privKeyFile, err := os.OpenFile(privKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) if err != nil { return fmt.Errorf("failed to create private key file: %w", err) } defer privKeyFile.Close() privKeyPEM := &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey), } if err := pem.Encode(privKeyFile, privKeyPEM); err != nil { return fmt.Errorf("failed to encode private key: %w", err) } // Save public key pubKeyPath := filepath.Join(cfg.KeyDir, cfg.PubKeyFile) pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) if err != nil { return fmt.Errorf("failed to create public key file: %w", err) } defer pubKeyFile.Close() pubKeyPEM := &pem.Block{ Type: "RSA PUBLIC KEY", Bytes: x509.MarshalPKCS1PublicKey(&privateKey.PublicKey), } if err := pem.Encode(pubKeyFile, pubKeyPEM); err != nil { return fmt.Errorf("failed to encode public key: %w", err) } return nil } // loadPrivateKey loads a private key from a file func loadPrivateKey(path string) (*rsa.PrivateKey, error) { keyData, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read private key file: %w", err) } block, _ := pem.Decode(keyData) if block == nil { return nil, errors.New("failed to parse PEM block containing the private key") } privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse private key: %w", err) } return privateKey, nil } // loadPublicKey loads a public key from a file func loadPublicKey(path string) (*rsa.PublicKey, error) { keyData, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to read public key file: %w", err) } block, _ := pem.Decode(keyData) if block == nil { return nil, errors.New("failed to parse PEM block containing the public key") } publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) if err != nil { return nil, fmt.Errorf("failed to parse public key: %w", err) } return publicKey, nil } // Claims represents the JWT claims type Claims struct { UserID string `json:"userId"` Email string `json:"email"` Role string `json:"role"` CompanyID *string `json:"companyId"` jwt.RegisteredClaims } // AuthMiddleware checks if the user is authenticated func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { // Get the token from cookie tokenString, err := c.Cookie("jwt") if err != nil { responses.UnauthorizedResponse(c, "Authentication cookie is required") c.Abort() return } claims, err := validateToken(tokenString) if err != nil { responses.UnauthorizedResponse(c, "Invalid or expired token") c.Abort() return } // Store user information in the context c.Set("userID", claims.UserID) c.Set("email", claims.Email) c.Set("role", claims.Role) c.Set("companyID", claims.CompanyID) c.Next() } } // RoleMiddleware checks if the user has the required role func RoleMiddleware(roles ...string) gin.HandlerFunc { return func(c *gin.Context) { userRole, exists := c.Get("role") if !exists { responses.UnauthorizedResponse(c, "User role not found in context") c.Abort() return } // Check if the user's role is in the allowed roles roleStr, ok := userRole.(string) if !ok { responses.InternalErrorResponse(c, "Invalid role type in context") c.Abort() return } allowed := false for _, role := range roles { if roleStr == role { allowed = true break } } if !allowed { responses.ForbiddenResponse(c, "Insufficient permissions") c.Abort() return } c.Next() } } // GenerateToken creates a new JWT token for a user func GenerateToken(user *models.User, c *gin.Context) (string, error) { // Create the claims var companyId *string if user.CompanyID != nil { wrapper := user.CompanyID.String() companyId = &wrapper } claims := Claims{ UserID: user.ID.String(), Email: user.Email, Role: user.Role, CompanyID: companyId, RegisteredClaims: jwt.RegisteredClaims{ ExpiresAt: jwt.NewNumericDate(time.Now().Add(config.MustLoadConfig().JWTConfig.TokenDuration)), IssuedAt: jwt.NewNumericDate(time.Now()), NotBefore: jwt.NewNumericDate(time.Now()), }, } cfg := config.MustLoadConfig() var token *jwt.Token var tokenString string var err error // Choose signing method based on configuration if cfg.JWTConfig.Secret != "" { // Use HMAC-SHA256 if a secret is provided token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err = token.SignedString([]byte(cfg.JWTConfig.Secret)) } else { // Use RSA if no secret is provided token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) tokenString, err = token.SignedString(signKey) } if err != nil { return "", err } // Set the cookie c.SetCookie("jwt", tokenString, int(cfg.JWTConfig.TokenDuration.Seconds()), "/", "", true, true) return tokenString, nil } // validateToken validates a JWT token and returns the claims func validateToken(tokenString string) (*Claims, error) { cfg := config.MustLoadConfig() // Parse the token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) { // Check which signing method was used if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok { // HMAC method was used, validate with secret if cfg.JWTConfig.Secret == "" { return nil, errors.New("HMAC signing method used but no secret configured") } return []byte(cfg.JWTConfig.Secret), nil } else if _, ok := token.Method.(*jwt.SigningMethodRSA); ok { // RSA method was used, validate with public key if verifyKey == nil { return nil, errors.New("RSA signing method used but no public key loaded") } return verifyKey, nil } return nil, errors.New("unexpected signing method") }) if err != nil { return nil, err } // Check if the token is valid if !token.Valid { return nil, errors.New("invalid token") } // Get the claims claims, ok := token.Claims.(*Claims) if !ok { return nil, errors.New("invalid claims") } return claims, nil } // GetUserIDFromContext extracts the user ID from the context func GetUserIDFromContext(c *gin.Context) (types.ULID, error) { userID, exists := c.Get("userID") if !exists { return types.ULID{}, errors.New("user ID not found in context") } userIDStr, ok := userID.(string) if !ok { return types.ULID{}, errors.New("invalid user ID type in context") } id, err := ulid.Parse(userIDStr) if err != nil { return types.ULID{}, err } return types.FromULID(id), nil } // GetCompanyIDFromContext extracts the company ID from the context func GetCompanyIDFromContext(c *gin.Context) (types.ULID, error) { companyID, exists := c.Get("companyID") if !exists { return types.ULID{}, errors.New("company ID not found in context") } companyIDStr, ok := companyID.(string) if !ok { return types.ULID{}, errors.New("invalid company ID type in context") } id, err := ulid.Parse(companyIDStr) if err != nil { return types.ULID{}, err } return types.FromULID(id), nil }