287 lines
7.8 KiB
Go
287 lines
7.8 KiB
Go
package utils
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/timetracker/backend/internal/api/responses"
|
|
"github.com/timetracker/backend/internal/types"
|
|
)
|
|
|
|
// ParseID parses an ID from the URL parameter and converts it to a types.ULID
|
|
func ParseID(c *gin.Context, paramName string) (types.ULID, error) {
|
|
idStr := c.Param(paramName)
|
|
return types.ULIDFromString(idStr)
|
|
}
|
|
|
|
// BindJSON binds the request body to the provided struct
|
|
func BindJSON(c *gin.Context, obj interface{}) error {
|
|
if err := c.ShouldBindJSON(obj); err != nil {
|
|
return fmt.Errorf("invalid request body: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ConvertToDTO converts a slice of models to a slice of DTOs using the provided conversion function
|
|
func ConvertToDTO[M any, D any](models []M, convertFn func(*M) D) []D {
|
|
dtos := make([]D, len(models))
|
|
for i, model := range models {
|
|
// Create a copy of the model to avoid issues with loop variable capture
|
|
modelCopy := model
|
|
dtos[i] = convertFn(&modelCopy)
|
|
}
|
|
return dtos
|
|
}
|
|
|
|
// HandleGetAll is a generic function to handle GET all entities endpoints
|
|
func HandleGetAll[M any, D any](
|
|
c *gin.Context,
|
|
getAllFn func(ctx context.Context) ([]M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
) {
|
|
// Get entities from the database
|
|
entities, err := getAllFn(c.Request.Context())
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error retrieving %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Convert to DTOs
|
|
dtos := ConvertToDTO(entities, convertFn)
|
|
|
|
responses.SuccessResponse(c, 200, dtos)
|
|
}
|
|
|
|
// HandleGetByID is a generic function to handle GET entity by ID endpoints
|
|
func HandleGetByID[M any, D any](
|
|
c *gin.Context,
|
|
getByIDFn func(ctx context.Context, id types.ULID) (*M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
) {
|
|
// Parse ID from URL
|
|
id, err := ParseID(c, "id")
|
|
if err != nil {
|
|
responses.BadRequestResponse(c, fmt.Sprintf("Invalid %s ID format", entityName))
|
|
return
|
|
}
|
|
|
|
// Get entity from the database
|
|
entity, err := getByIDFn(c.Request.Context(), id)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error retrieving %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
if entity == nil {
|
|
responses.NotFoundResponse(c, fmt.Sprintf("%s not found", entityName))
|
|
return
|
|
}
|
|
|
|
// Convert to DTO
|
|
dto := convertFn(entity)
|
|
|
|
responses.SuccessResponse(c, 200, dto)
|
|
}
|
|
|
|
// HandleCreate is a generic function to handle POST entity endpoints
|
|
func HandleCreate[C any, M any, D any](
|
|
c *gin.Context,
|
|
createFn func(ctx context.Context, create C) (*M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
) {
|
|
// Parse request body
|
|
var createDTO C
|
|
if err := BindJSON(c, &createDTO); err != nil {
|
|
responses.BadRequestResponse(c, err.Error())
|
|
return
|
|
}
|
|
|
|
// Create entity in the database
|
|
entity, err := createFn(c.Request.Context(), createDTO)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error creating %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Convert to DTO
|
|
dto := convertFn(entity)
|
|
|
|
responses.SuccessResponse(c, 201, dto)
|
|
}
|
|
|
|
// HandleDelete is a generic function to handle DELETE entity endpoints
|
|
func HandleDelete(
|
|
c *gin.Context,
|
|
deleteFn func(ctx context.Context, id types.ULID) error,
|
|
entityName string,
|
|
) {
|
|
// Parse ID from URL
|
|
id, err := ParseID(c, "id")
|
|
if err != nil {
|
|
responses.BadRequestResponse(c, fmt.Sprintf("Invalid %s ID format", entityName))
|
|
return
|
|
}
|
|
|
|
// Delete entity from the database
|
|
err = deleteFn(c.Request.Context(), id)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error deleting %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
responses.SuccessResponse(c, 204, nil)
|
|
}
|
|
|
|
// HandleUpdate is a generic function to handle PUT entity endpoints
|
|
// It takes a prepareUpdateFn that handles parsing the ID, binding the JSON, and converting the DTO to a model update object
|
|
func HandleUpdate[U any, M any, D any](
|
|
c *gin.Context,
|
|
updateFn func(ctx context.Context, update U) (*M, error),
|
|
convertFn func(*M) D,
|
|
prepareUpdateFn func(*gin.Context) (U, error),
|
|
entityName string,
|
|
) {
|
|
// Prepare the update object (parse ID, bind JSON, convert DTO to model)
|
|
update, err := prepareUpdateFn(c)
|
|
if err != nil {
|
|
// The prepareUpdateFn should handle setting the appropriate error response
|
|
return
|
|
}
|
|
|
|
// Update entity in the database
|
|
entity, err := updateFn(c.Request.Context(), update)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error updating %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
if entity == nil {
|
|
responses.NotFoundResponse(c, fmt.Sprintf("%s not found", entityName))
|
|
return
|
|
}
|
|
|
|
// Convert to DTO
|
|
dto := convertFn(entity)
|
|
|
|
responses.SuccessResponse(c, http.StatusOK, dto)
|
|
}
|
|
|
|
// HandleGetByFilter is a generic function to handle GET entities by a filter parameter
|
|
func HandleGetByFilter[M any, D any](
|
|
c *gin.Context,
|
|
getByFilterFn func(ctx context.Context, filterID types.ULID) ([]M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
paramName string,
|
|
) {
|
|
// Parse filter ID from URL
|
|
filterID, err := ParseID(c, paramName)
|
|
if err != nil {
|
|
responses.BadRequestResponse(c, fmt.Sprintf("Invalid %s ID format", paramName))
|
|
return
|
|
}
|
|
|
|
// Get entities from the database
|
|
entities, err := getByFilterFn(c.Request.Context(), filterID)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error retrieving %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Convert to DTOs
|
|
dtos := ConvertToDTO(entities, convertFn)
|
|
|
|
responses.SuccessResponse(c, http.StatusOK, dtos)
|
|
}
|
|
|
|
// HandleGetByUserID is a specialized function to handle GET entities by user ID
|
|
func HandleGetByUserID[M any, D any](
|
|
c *gin.Context,
|
|
getByUserIDFn func(ctx context.Context, userID types.ULID) ([]M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
) {
|
|
// Get user ID from context (set by AuthMiddleware)
|
|
userID, exists := c.Get("userID")
|
|
if !exists {
|
|
responses.UnauthorizedResponse(c, "User not authenticated")
|
|
return
|
|
}
|
|
|
|
userIDStr, ok := userID.(string)
|
|
if !ok {
|
|
responses.InternalErrorResponse(c, "Invalid user ID type in context")
|
|
return
|
|
}
|
|
|
|
parsedUserID, err := types.ULIDFromString(userIDStr)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error parsing user ID: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// Get entities from the database
|
|
entities, err := getByUserIDFn(c.Request.Context(), parsedUserID)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error retrieving %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Convert to DTOs
|
|
dtos := ConvertToDTO(entities, convertFn)
|
|
|
|
responses.SuccessResponse(c, http.StatusOK, dtos)
|
|
}
|
|
|
|
// HandleGetByDateRange is a specialized function to handle GET entities by date range
|
|
func HandleGetByDateRange[M any, D any](
|
|
c *gin.Context,
|
|
getByDateRangeFn func(ctx context.Context, start, end time.Time) ([]M, error),
|
|
convertFn func(*M) D,
|
|
entityName string,
|
|
) {
|
|
// Parse date range from query parameters
|
|
startStr := c.Query("start")
|
|
endStr := c.Query("end")
|
|
|
|
if startStr == "" || endStr == "" {
|
|
responses.BadRequestResponse(c, "Start and end dates are required")
|
|
return
|
|
}
|
|
|
|
start, err := time.Parse(time.RFC3339, startStr)
|
|
if err != nil {
|
|
responses.BadRequestResponse(c, "Invalid start date format. Use ISO 8601 format (e.g., 2023-01-01T00:00:00Z)")
|
|
return
|
|
}
|
|
|
|
end, err := time.Parse(time.RFC3339, endStr)
|
|
if err != nil {
|
|
responses.BadRequestResponse(c, "Invalid end date format. Use ISO 8601 format (e.g., 2023-01-01T00:00:00Z)")
|
|
return
|
|
}
|
|
|
|
if end.Before(start) {
|
|
responses.BadRequestResponse(c, "End date cannot be before start date")
|
|
return
|
|
}
|
|
|
|
// Get entities from the database
|
|
entities, err := getByDateRangeFn(c.Request.Context(), start, end)
|
|
if err != nil {
|
|
responses.InternalErrorResponse(c, fmt.Sprintf("Error retrieving %s: %s", entityName, err.Error()))
|
|
return
|
|
}
|
|
|
|
// Convert to DTOs
|
|
dtos := ConvertToDTO(entities, convertFn)
|
|
|
|
responses.SuccessResponse(c, http.StatusOK, dtos)
|
|
}
|