From 2e13d775fa70da940348cf3585621cc123082f14 Mon Sep 17 00:00:00 2001 From: Jean Jacques Avril Date: Tue, 11 Mar 2025 23:54:29 +0000 Subject: [PATCH] feat: Implement RSA key generation and initialization for JWT authentication --- backend/cmd/api/main.go | 8 +- backend/internal/api/middleware/jwt_auth.go | 191 +++++++++++++++++++- backend/internal/config/config.go | 13 +- backend/internal/models/jwt.go | 5 +- 4 files changed, 203 insertions(+), 14 deletions(-) diff --git a/backend/cmd/api/main.go b/backend/cmd/api/main.go index d7ae432..caf6d27 100644 --- a/backend/cmd/api/main.go +++ b/backend/cmd/api/main.go @@ -13,7 +13,8 @@ import ( "github.com/gin-gonic/gin" swaggerFiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" - _ "github.com/timetracker/backend/docs" // This line is important for swag to work + _ "github.com/timetracker/backend/docs" + "github.com/timetracker/backend/internal/api/middleware" "github.com/timetracker/backend/internal/api/routes" "github.com/timetracker/backend/internal/config" "github.com/timetracker/backend/internal/models" @@ -60,6 +61,11 @@ func main() { log.Fatalf("Error migrating database: %v", err) } + // Initialize JWT keys + if err := middleware.InitJWTKeys(); err != nil { + log.Fatalf("Error initializing JWT keys: %v", err) + } + // Create Gin router r := gin.Default() diff --git a/backend/internal/api/middleware/jwt_auth.go b/backend/internal/api/middleware/jwt_auth.go index 89cb5ea..9a0b589 100644 --- a/backend/internal/api/middleware/jwt_auth.go +++ b/backend/internal/api/middleware/jwt_auth.go @@ -1,7 +1,14 @@ package middleware import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "errors" + "fmt" + "os" + "path/filepath" "time" "github.com/gin-gonic/gin" @@ -12,6 +19,148 @@ import ( "github.com/timetracker/backend/internal/models" ) +var ( + signKey *rsa.PrivateKey + verifyKey *rsa.PublicKey +) + +// InitJWTKeys initializes the JWT keys +func InitJWTKeys() error { + cfg := config.MustLoadConfig() + + // If a secret is provided, we'll use HMAC-SHA256, so no need for certificates + if cfg.JWTConfig.Secret != "" { + println("Using HMAC-SHA256 for JWT") + return nil + } + + // Check if keys exist + privKeyPath := filepath.Join(cfg.JWTConfig.KeyDir, cfg.JWTConfig.PrivKeyFile) + pubKeyPath := filepath.Join(cfg.JWTConfig.KeyDir, cfg.JWTConfig.PubKeyFile) + + keysExist := fileExists(privKeyPath) && fileExists(pubKeyPath) + + // Generate keys if they don't exist and KeyGenerate is true + if !keysExist && cfg.JWTConfig.KeyGenerate { + println("Generating RSA keys") + if err := generateRSAKeys(cfg.JWTConfig); err != nil { + return fmt.Errorf("failed to generate RSA keys: %w", err) + } + } else if !keysExist { + return errors.New("JWT keys not found and key generation is disabled") + } + + // Load keys + var err error + signKey, err = loadPrivateKey(privKeyPath) + if err != nil { + return fmt.Errorf("failed to load private key: %w", err) + } + + verifyKey, err = loadPublicKey(pubKeyPath) + if err != nil { + return fmt.Errorf("failed to load public key: %w", err) + } + + return nil +} + +// fileExists checks if a file exists +func fileExists(path string) bool { + _, err := os.Stat(path) + return !os.IsNotExist(err) +} + +// generateRSAKeys generates RSA keys and saves them to disk +func generateRSAKeys(cfg models.JWTConfig) error { + // Create key directory if it doesn't exist + if err := os.MkdirAll(cfg.KeyDir, 0700); err != nil { + return fmt.Errorf("failed to create key directory: %w", err) + } + + // Generate private key + privateKey, err := rsa.GenerateKey(rand.Reader, cfg.KeyBits) + if err != nil { + return fmt.Errorf("failed to generate private key: %w", err) + } + + // Save private key + privKeyPath := filepath.Join(cfg.KeyDir, cfg.PrivKeyFile) + privKeyFile, err := os.OpenFile(privKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + return fmt.Errorf("failed to create private key file: %w", err) + } + defer privKeyFile.Close() + + privKeyPEM := &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + } + + if err := pem.Encode(privKeyFile, privKeyPEM); err != nil { + return fmt.Errorf("failed to encode private key: %w", err) + } + + // Save public key + pubKeyPath := filepath.Join(cfg.KeyDir, cfg.PubKeyFile) + pubKeyFile, err := os.OpenFile(pubKeyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("failed to create public key file: %w", err) + } + defer pubKeyFile.Close() + + pubKeyPEM := &pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: x509.MarshalPKCS1PublicKey(&privateKey.PublicKey), + } + + if err := pem.Encode(pubKeyFile, pubKeyPEM); err != nil { + return fmt.Errorf("failed to encode public key: %w", err) + } + + return nil +} + +// loadPrivateKey loads a private key from a file +func loadPrivateKey(path string) (*rsa.PrivateKey, error) { + keyData, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read private key file: %w", err) + } + + block, _ := pem.Decode(keyData) + if block == nil { + return nil, errors.New("failed to parse PEM block containing the private key") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return privateKey, nil +} + +// loadPublicKey loads a public key from a file +func loadPublicKey(path string) (*rsa.PublicKey, error) { + keyData, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read public key file: %w", err) + } + + block, _ := pem.Decode(keyData) + if block == nil { + return nil, errors.New("failed to parse PEM block containing the public key") + } + + publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + return publicKey, nil +} + // Claims represents the JWT claims type Claims struct { UserID string `json:"userId"` @@ -100,12 +249,22 @@ func GenerateToken(user *models.User, c *gin.Context) (string, error) { }, } - // Create the token - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - cfg := config.MustLoadConfig() - // Sign the token - tokenString, err := token.SignedString([]byte(cfg.JWTConfig.Secret)) + var token *jwt.Token + var tokenString string + var err error + + // Choose signing method based on configuration + if cfg.JWTConfig.Secret != "" { + // Use HMAC-SHA256 if a secret is provided + token = jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err = token.SignedString([]byte(cfg.JWTConfig.Secret)) + } else { + // Use RSA if no secret is provided + token = jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + tokenString, err = token.SignedString(signKey) + } + if err != nil { return "", err } @@ -118,14 +277,26 @@ func GenerateToken(user *models.User, c *gin.Context) (string, error) { // validateToken validates a JWT token and returns the claims func validateToken(tokenString string) (*Claims, error) { + cfg := config.MustLoadConfig() + // Parse the token token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) { - // Validate the signing method - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, errors.New("unexpected signing method") + // Check which signing method was used + if _, ok := token.Method.(*jwt.SigningMethodHMAC); ok { + // HMAC method was used, validate with secret + if cfg.JWTConfig.Secret == "" { + return nil, errors.New("HMAC signing method used but no secret configured") + } + return []byte(cfg.JWTConfig.Secret), nil + } else if _, ok := token.Method.(*jwt.SigningMethodRSA); ok { + // RSA method was used, validate with public key + if verifyKey == nil { + return nil, errors.New("RSA signing method used but no public key loaded") + } + return verifyKey, nil } - cfg := config.MustLoadConfig() - return []byte(cfg.JWTConfig.Secret), nil + + return nil, errors.New("unexpected signing method") }) if err != nil { diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 83300a6..fe3e4a2 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -48,7 +48,7 @@ func LoadConfig() (*Config, error) { // loadJWTConfig loads JWT configuration from environment func loadJWTConfig(cfg *Config) error { - cfg.JWTConfig.Secret = getEnv("JWT_SECRET", "default-secret") + cfg.JWTConfig.Secret = getEnv("JWT_SECRET", "") defaultDuration := 24 * time.Hour durationStr := getEnv("JWT_TOKEN_DURATION", defaultDuration.String()) @@ -58,7 +58,7 @@ func loadJWTConfig(cfg *Config) error { } cfg.JWTConfig.TokenDuration = duration - keyGenerateStr := getEnv("JWT_KEY_GENERATE", "false") + keyGenerateStr := getEnv("JWT_KEY_GENERATE", "true") keyGenerate, err := strconv.ParseBool(keyGenerateStr) if err != nil { return fmt.Errorf("invalid JWT_KEY_GENERATE: %w", err) @@ -66,6 +66,15 @@ func loadJWTConfig(cfg *Config) error { cfg.JWTConfig.KeyGenerate = keyGenerate cfg.JWTConfig.KeyDir = getEnv("JWT_KEY_DIR", "./keys") + cfg.JWTConfig.PrivKeyFile = getEnv("JWT_PRIV_KEY_FILE", "jwt.key") + cfg.JWTConfig.PubKeyFile = getEnv("JWT_PUB_KEY_FILE", "jwt.key.pub") + + keyBitsStr := getEnv("JWT_KEY_BITS", "2048") + keyBits, err := strconv.Atoi(keyBitsStr) + if err != nil { + return fmt.Errorf("invalid JWT_KEY_BITS: %w", err) + } + cfg.JWTConfig.KeyBits = keyBits return nil } diff --git a/backend/internal/models/jwt.go b/backend/internal/models/jwt.go index d68edb8..239a854 100644 --- a/backend/internal/models/jwt.go +++ b/backend/internal/models/jwt.go @@ -5,6 +5,9 @@ import "time" type JWTConfig struct { Secret string `env:"JWT_SECRET" default:""` TokenDuration time.Duration `env:"JWT_TOKEN_DURATION" default:"24h"` - KeyGenerate bool `env:"JWT_KEY_GENERATE" default:"false"` + KeyGenerate bool `env:"JWT_KEY_GENERATE" default:"true"` KeyDir string `env:"JWT_KEY_DIR" default:"./keys"` + PrivKeyFile string `env:"JWT_PRIV_KEY_FILE" default:"jwt.key"` + PubKeyFile string `env:"JWT_PUB_KEY_FILE" default:"jwt.key.pub"` + KeyBits int `env:"JWT_KEY_BITS" default:"2048"` }