2025-03-31 19:07:30 +00:00

211 lines
5.6 KiB
Go

package main
import (
"context"
"fmt"
"log"
"time"
"github.com/oklog/ulid/v2"
"github.com/timetracker/backend/internal/config"
"github.com/timetracker/backend/internal/db"
"github.com/timetracker/backend/internal/models"
"github.com/timetracker/backend/internal/types"
)
func main() {
// Get database configuration with sensible defaults
dbConfig := config.DefaultDatabaseConfig()
// Initialize database
fmt.Println("Connecting to database...")
if err := db.InitDB(dbConfig); err != nil {
log.Fatalf("Error initializing database: %v", err)
}
defer func() {
if err := db.CloseDB(); err != nil {
log.Printf("Error closing database connection: %v", err)
}
}()
fmt.Println("✓ Database connection successful")
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Test Company model
fmt.Println("\n=== Testing Company Model ===")
testCompanyModel(ctx)
// Test User model
fmt.Println("\n=== Testing User Model ===")
testUserModel(ctx)
// Test relationships
fmt.Println("\n=== Testing Relationships ===")
testRelationships(ctx)
fmt.Println("\nModel tests completed successfully!")
}
func testCompanyModel(ctx context.Context) {
// Create a new company
companyCreate := models.CompanyCreate{
Name: "Test Company",
}
company, err := models.CreateCompany(ctx, companyCreate)
if err != nil {
log.Fatalf("Error creating company: %v", err)
}
fmt.Printf("✓ Created company: %s (ID: %s)\n", company.Name, company.ID)
// Get the company by ID
retrievedCompany, err := models.GetCompanyByID(ctx, company.ID)
if err != nil {
log.Fatalf("Error getting company: %v", err)
}
if retrievedCompany == nil {
log.Fatalf("Company not found")
}
fmt.Printf("✓ Retrieved company: %s\n", retrievedCompany.Name)
// Update the company
newName := "Updated Test Company"
companyUpdate := models.CompanyUpdate{
ID: company.ID,
Name: &newName,
}
updatedCompany, err := models.UpdateCompany(ctx, companyUpdate)
if err != nil {
log.Fatalf("Error updating company: %v", err)
}
fmt.Printf("✓ Updated company name to: %s\n", updatedCompany.Name)
// Get all companies
companies, err := models.GetAllCompanies(ctx)
if err != nil {
log.Fatalf("Error getting all companies: %v", err)
}
fmt.Printf("✓ Retrieved %d companies\n", len(companies))
}
func testUserModel(ctx context.Context) {
// Get a company to associate with the user
companies, err := models.GetAllCompanies(ctx)
if err != nil || len(companies) == 0 {
log.Fatalf("Error getting companies or no companies found: %v", err)
}
companyID := companies[0].ID
// Create a new user
userCreate := models.UserCreate{
Email: "test@example.com",
Password: "Test@123456",
Role: models.RoleUser,
CompanyID: &companyID,
HourlyRate: 50.0,
}
user, err := models.CreateUser(ctx, userCreate)
if err != nil {
log.Fatalf("Error creating user: %v", err)
}
fmt.Printf("✓ Created user: %s (ID: %s)\n", user.Email, user.ID)
// Get the user by ID
retrievedUser, err := models.GetUserByID(ctx, user.ID)
if err != nil {
log.Fatalf("Error getting user: %v", err)
}
if retrievedUser == nil {
log.Fatalf("User not found")
}
fmt.Printf("✓ Retrieved user: %s\n", retrievedUser.Email)
// Get the user by email
emailUser, err := models.GetUserByEmail(ctx, user.Email)
if err != nil {
log.Fatalf("Error getting user by email: %v", err)
}
if emailUser == nil {
log.Fatalf("User not found by email")
}
fmt.Printf("✓ Retrieved user by email: %s\n", emailUser.Email)
// Update the user
newEmail := "updated@example.com"
newRole := models.RoleAdmin
newHourlyRate := 75.0
userUpdate := models.UserUpdate{
ID: user.ID,
Email: &newEmail,
Role: &newRole,
HourlyRate: &newHourlyRate,
}
updatedUser, err := models.UpdateUser(ctx, userUpdate)
if err != nil {
log.Fatalf("Error updating user: %v", err)
}
fmt.Printf("✓ Updated user email to: %s, role to: %s\n", updatedUser.Email, updatedUser.Role)
// Test authentication
authUser, err := models.AuthenticateUser(ctx, updatedUser.Email, "Test@123456")
if err != nil {
log.Fatalf("Error authenticating user: %v", err)
}
if authUser == nil {
log.Fatalf("Authentication failed")
}
fmt.Printf("✓ User authentication successful\n")
// Get all users
users, err := models.GetAllUsers(ctx)
if err != nil {
log.Fatalf("Error getting all users: %v", err)
}
fmt.Printf("✓ Retrieved %d users\n", len(users))
// Get users by company ID
companyUsers, err := models.GetUsersByCompanyID(ctx, companyID)
if err != nil {
log.Fatalf("Error getting users by company ID: %v", err)
}
fmt.Printf("✓ Retrieved %d users for company ID: %s\n", len(companyUsers), companyID)
}
func testRelationships(ctx context.Context) {
// Get a user with company
users, err := models.GetAllUsers(ctx)
if err != nil || len(users) == 0 {
log.Fatalf("Error getting users or no users found: %v", err)
}
userID := users[0].ID
// Get user with company
user, err := models.GetUserWithCompany(ctx, userID)
if err != nil {
log.Fatalf("Error getting user with company: %v", err)
}
if user == nil {
log.Fatalf("User not found")
}
if user.Company == nil {
log.Fatalf("User's company not loaded")
}
fmt.Printf("✓ Retrieved user %s with company %s\n", user.Email, user.Company.Name)
// Test invalid ID
invalidID := ulid.MustNew(ulid.Timestamp(time.Now()), ulid.DefaultEntropy())
invalidUser, err := models.GetUserByID(ctx, types.FromULID(invalidID))
if err != nil {
log.Fatalf("Error getting user with invalid ID: %v", err)
}
if invalidUser != nil {
log.Fatalf("User found with invalid ID")
}
fmt.Printf("✓ Correctly handled invalid user ID\n")
}