feat: Add JWT configuration to environment and refactor JWT middleware to use new config structure
This commit is contained in:
		
							parent
							
								
									9057adebdd
								
							
						
					
					
						commit
						b545392f27
					
				| @ -10,3 +10,5 @@ API_KEY= | |||||||
| JWT_SECRET=test | JWT_SECRET=test | ||||||
| JWT_KEY_DIR=keys | JWT_KEY_DIR=keys | ||||||
| JWT_KEY_GENERATE=true | JWT_KEY_GENERATE=true | ||||||
|  | JWT_TOKEN_DURATION=24h | ||||||
|  | ENVIRONMENT=production | ||||||
| @ -1,97 +1,17 @@ | |||||||
| package middleware | package middleware | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"crypto/rand" |  | ||||||
| 	"crypto/rsa" |  | ||||||
| 	"crypto/x509" |  | ||||||
| 	"encoding/pem" |  | ||||||
| 	"errors" | 	"errors" | ||||||
| 	"fmt" |  | ||||||
| 	"os" |  | ||||||
| 	"time" | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/gin-gonic/gin" | 	"github.com/gin-gonic/gin" | ||||||
| 	"github.com/golang-jwt/jwt/v5" | 	"github.com/golang-jwt/jwt/v5" | ||||||
| 	"github.com/joho/godotenv" |  | ||||||
| 	"github.com/oklog/ulid/v2" | 	"github.com/oklog/ulid/v2" | ||||||
| 	"github.com/timetracker/backend/internal/api/utils" | 	"github.com/timetracker/backend/internal/api/utils" | ||||||
|  | 	"github.com/timetracker/backend/internal/config" | ||||||
| 	"github.com/timetracker/backend/internal/models" | 	"github.com/timetracker/backend/internal/models" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| var ( |  | ||||||
| 	jwtSecret     string |  | ||||||
| 	tokenDuration = 24 * time.Hour |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| func init() { |  | ||||||
| 	// Load .env file |  | ||||||
| 	_ = godotenv.Load() |  | ||||||
| 
 |  | ||||||
| 	// Get JWT secret from environment |  | ||||||
| 	jwtSecret = os.Getenv("JWT_SECRET") |  | ||||||
| 
 |  | ||||||
| 	// Generate a random secret if none is provided |  | ||||||
| 	if jwtSecret == "" { |  | ||||||
| 		randomBytes := make([]byte, 32) |  | ||||||
| 		_, err := rand.Read(randomBytes) |  | ||||||
| 		if err != nil { |  | ||||||
| 			panic("failed to generate JWT secret: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		jwtSecret = string(randomBytes) |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	// Generate and store RSA keys if configured |  | ||||||
| 	if os.Getenv("JWT_KEY_GENERATE") == "true" { |  | ||||||
| 		keyDir := os.Getenv("JWT_KEY_DIR") |  | ||||||
| 		if keyDir == "" { |  | ||||||
| 			keyDir = "./keys" |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Create directory if it doesn't exist |  | ||||||
| 		if err := os.MkdirAll(keyDir, 0755); err != nil { |  | ||||||
| 			panic("failed to create key directory: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Generate RSA key pair |  | ||||||
| 		privateKey, err := rsa.GenerateKey(rand.Reader, 2048) |  | ||||||
| 		if err != nil { |  | ||||||
| 			panic("failed to generate RSA key pair: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Save private key |  | ||||||
| 		privateKeyFile, err := os.Create(fmt.Sprintf("%s/private.pem", keyDir)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			panic("failed to create private key file: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		defer privateKeyFile.Close() |  | ||||||
| 
 |  | ||||||
| 		privateKeyPEM := &pem.Block{ |  | ||||||
| 			Type:  "RSA PRIVATE KEY", |  | ||||||
| 			Bytes: x509.MarshalPKCS1PrivateKey(privateKey), |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil { |  | ||||||
| 			panic("failed to encode private key: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		// Save public key |  | ||||||
| 		publicKeyFile, err := os.Create(fmt.Sprintf("%s/public.pem", keyDir)) |  | ||||||
| 		if err != nil { |  | ||||||
| 			panic("failed to create public key file: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 		defer publicKeyFile.Close() |  | ||||||
| 
 |  | ||||||
| 		publicKeyPEM := &pem.Block{ |  | ||||||
| 			Type:  "RSA PUBLIC KEY", |  | ||||||
| 			Bytes: x509.MarshalPKCS1PublicKey(&privateKey.PublicKey), |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		if err := pem.Encode(publicKeyFile, publicKeyPEM); err != nil { |  | ||||||
| 			panic("failed to encode public key: " + err.Error()) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Claims represents the JWT claims | // Claims represents the JWT claims | ||||||
| type Claims struct { | type Claims struct { | ||||||
| 	UserID    string `json:"userId"` | 	UserID    string `json:"userId"` | ||||||
| @ -174,7 +94,7 @@ func GenerateToken(user *models.User, c *gin.Context) (string, error) { | |||||||
| 		Role:      user.Role, | 		Role:      user.Role, | ||||||
| 		CompanyID: user.CompanyID.String(), | 		CompanyID: user.CompanyID.String(), | ||||||
| 		RegisteredClaims: jwt.RegisteredClaims{ | 		RegisteredClaims: jwt.RegisteredClaims{ | ||||||
| 			ExpiresAt: jwt.NewNumericDate(time.Now().Add(tokenDuration)), | 			ExpiresAt: jwt.NewNumericDate(time.Now().Add(config.MustLoadConfig().JWTConfig.TokenDuration)), | ||||||
| 			IssuedAt:  jwt.NewNumericDate(time.Now()), | 			IssuedAt:  jwt.NewNumericDate(time.Now()), | ||||||
| 			NotBefore: jwt.NewNumericDate(time.Now()), | 			NotBefore: jwt.NewNumericDate(time.Now()), | ||||||
| 		}, | 		}, | ||||||
| @ -183,14 +103,15 @@ func GenerateToken(user *models.User, c *gin.Context) (string, error) { | |||||||
| 	// Create the token | 	// Create the token | ||||||
| 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | ||||||
| 
 | 
 | ||||||
|  | 	cfg := config.MustLoadConfig() | ||||||
| 	// Sign the token | 	// Sign the token | ||||||
| 	tokenString, err := token.SignedString([]byte(jwtSecret)) | 	tokenString, err := token.SignedString([]byte(cfg.JWTConfig.Secret)) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Set the cookie | 	// Set the cookie | ||||||
| 	c.SetCookie("jwt", tokenString, int(tokenDuration.Seconds()), "/", "", true, true) | 	c.SetCookie("jwt", tokenString, int(cfg.JWTConfig.TokenDuration.Seconds()), "/", "", true, true) | ||||||
| 
 | 
 | ||||||
| 	return tokenString, nil | 	return tokenString, nil | ||||||
| } | } | ||||||
| @ -198,12 +119,13 @@ func GenerateToken(user *models.User, c *gin.Context) (string, error) { | |||||||
| // validateToken validates a JWT token and returns the claims | // validateToken validates a JWT token and returns the claims | ||||||
| func validateToken(tokenString string) (*Claims, error) { | func validateToken(tokenString string) (*Claims, error) { | ||||||
| 	// Parse the token | 	// Parse the token | ||||||
| 	token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { | 	token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) { | ||||||
| 		// Validate the signing method | 		// Validate the signing method | ||||||
| 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | 		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { | ||||||
| 			return nil, errors.New("unexpected signing method") | 			return nil, errors.New("unexpected signing method") | ||||||
| 		} | 		} | ||||||
| 		return []byte(jwtSecret), nil | 		cfg := config.MustLoadConfig() | ||||||
|  | 		return []byte(cfg.JWTConfig.Secret), nil | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
|  | |||||||
| @ -6,6 +6,7 @@ import ( | |||||||
| 	"log" | 	"log" | ||||||
| 	"os" | 	"os" | ||||||
| 	"strconv" | 	"strconv" | ||||||
|  | 	"time" | ||||||
| 
 | 
 | ||||||
| 	"github.com/joho/godotenv" | 	"github.com/joho/godotenv" | ||||||
| 	"github.com/timetracker/backend/internal/models" | 	"github.com/timetracker/backend/internal/models" | ||||||
| @ -15,6 +16,7 @@ import ( | |||||||
| // Config represents the application configuration | // Config represents the application configuration | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Database  models.DatabaseConfig | 	Database  models.DatabaseConfig | ||||||
|  | 	JWTConfig models.JWTConfig | ||||||
| 	APIKey    string | 	APIKey    string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| @ -25,6 +27,7 @@ func LoadConfig() (*Config, error) { | |||||||
| 
 | 
 | ||||||
| 	cfg := &Config{ | 	cfg := &Config{ | ||||||
| 		Database:  models.DefaultDatabaseConfig(), | 		Database:  models.DefaultDatabaseConfig(), | ||||||
|  | 		JWTConfig: models.JWTConfig{}, | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Load database configuration | 	// Load database configuration | ||||||
| @ -32,12 +35,41 @@ func LoadConfig() (*Config, error) { | |||||||
| 		return nil, fmt.Errorf("failed to load database config: %w", err) | 		return nil, fmt.Errorf("failed to load database config: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
|  | 	// Load JWT configuration | ||||||
|  | 	if err := loadJWTConfig(cfg); err != nil { | ||||||
|  | 		return nil, fmt.Errorf("failed to load JWT config: %w", err) | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
| 	// Load API key | 	// Load API key | ||||||
| 	cfg.APIKey = getEnv("API_KEY", "") | 	cfg.APIKey = getEnv("API_KEY", "") | ||||||
| 
 | 
 | ||||||
| 	return cfg, nil | 	return cfg, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | // loadJWTConfig loads JWT configuration from environment | ||||||
|  | func loadJWTConfig(cfg *Config) error { | ||||||
|  | 	cfg.JWTConfig.Secret = getEnv("JWT_SECRET", "default-secret") | ||||||
|  | 	defaultDuration := 24 * time.Hour | ||||||
|  | 	durationStr := getEnv("JWT_TOKEN_DURATION", defaultDuration.String()) | ||||||
|  | 
 | ||||||
|  | 	duration, err := time.ParseDuration(durationStr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("invalid JWT_TOKEN_DURATION: %w", err) | ||||||
|  | 	} | ||||||
|  | 	cfg.JWTConfig.TokenDuration = duration | ||||||
|  | 
 | ||||||
|  | 	keyGenerateStr := getEnv("JWT_KEY_GENERATE", "false") | ||||||
|  | 	keyGenerate, err := strconv.ParseBool(keyGenerateStr) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return fmt.Errorf("invalid JWT_KEY_GENERATE: %w", err) | ||||||
|  | 	} | ||||||
|  | 	cfg.JWTConfig.KeyGenerate = keyGenerate | ||||||
|  | 
 | ||||||
|  | 	cfg.JWTConfig.KeyDir = getEnv("JWT_KEY_DIR", "./keys") | ||||||
|  | 
 | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
| // loadDatabaseConfig loads database configuration from environment | // loadDatabaseConfig loads database configuration from environment | ||||||
| func loadDatabaseConfig(cfg *Config) error { | func loadDatabaseConfig(cfg *Config) error { | ||||||
| 	// Required fields | 	// Required fields | ||||||
|  | |||||||
							
								
								
									
										10
									
								
								backend/internal/models/jwt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								backend/internal/models/jwt.go
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,10 @@ | |||||||
|  | package models | ||||||
|  | 
 | ||||||
|  | import "time" | ||||||
|  | 
 | ||||||
|  | type JWTConfig struct { | ||||||
|  | 	Secret        string        `env:"JWT_SECRET" default:""` | ||||||
|  | 	TokenDuration time.Duration `env:"JWT_TOKEN_DURATION" default:"24h"` | ||||||
|  | 	KeyGenerate   bool          `env:"JWT_KEY_GENERATE" default:"false"` | ||||||
|  | 	KeyDir        string        `env:"JWT_KEY_DIR" default:"./keys"` | ||||||
|  | } | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user