78 lines
1.9 KiB
Go
78 lines
1.9 KiB
Go
package models
|
|
|
|
import (
|
|
"context"
|
|
"database/sql/driver"
|
|
"fmt"
|
|
|
|
"github.com/oklog/ulid/v2"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
)
|
|
|
|
// ULIDWrapper wraps ulid.ULID to make it work nicely with GORM
|
|
type ULIDWrapper struct {
|
|
ulid.ULID
|
|
}
|
|
|
|
// NewULIDWrapper creates a new ULIDWrapper with a new ULID
|
|
func NewULIDWrapper() ULIDWrapper {
|
|
return ULIDWrapper{ULID: ulid.Make()}
|
|
}
|
|
|
|
// FromULID creates a ULIDWrapper from a ulid.ULID
|
|
func FromULID(id ulid.ULID) ULIDWrapper {
|
|
return ULIDWrapper{ULID: id}
|
|
}
|
|
|
|
// ULIDWrapperFromString creates a ULIDWrapper from a string
|
|
func ULIDWrapperFromString(id string) (ULIDWrapper, error) {
|
|
parsed, err := ulid.Parse(id)
|
|
if err != nil {
|
|
return ULIDWrapper{}, fmt.Errorf("failed to parse ULID string: %w", err)
|
|
}
|
|
return ULIDWrapper{ULID: parsed}, nil
|
|
}
|
|
|
|
// Scan implements the sql.Scanner interface for ULIDWrapper
|
|
func (u *ULIDWrapper) Scan(src any) error {
|
|
switch v := src.(type) {
|
|
case []byte:
|
|
// If it's exactly 16 bytes, it's the binary representation
|
|
if len(v) == 16 {
|
|
copy(u.ULID[:], v)
|
|
return nil
|
|
}
|
|
// Otherwise, try as string
|
|
return fmt.Errorf("cannot scan []byte of length %d into ULIDWrapper", len(v))
|
|
case string:
|
|
parsed, err := ulid.Parse(v)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to parse ULID: %w", err)
|
|
}
|
|
u.ULID = parsed
|
|
return nil
|
|
default:
|
|
return fmt.Errorf("cannot scan %T into ULIDWrapper", src)
|
|
}
|
|
}
|
|
|
|
// Value implements the driver.Valuer interface for ULIDWrapper
|
|
// Returns the binary representation of the ULID for maximum efficiency
|
|
func (u ULIDWrapper) Value() (driver.Value, error) {
|
|
return u.ULID.Bytes(), nil
|
|
}
|
|
|
|
// GormValue implements the gorm.Valuer interface for ULIDWrapper
|
|
func (u ULIDWrapper) GormValue(ctx context.Context, db *gorm.DB) clause.Expr {
|
|
return clause.Expr{
|
|
SQL: "?",
|
|
Vars: []any{u.Bytes()},
|
|
}
|
|
}
|
|
|
|
// Compare implements comparison for ULIDWrapper
|
|
func (u ULIDWrapper) Compare(other ULIDWrapper) int {
|
|
return u.ULID.Compare(other.ULID)
|
|
}
|