2025-03-31 19:07:30 +00:00

103 lines
2.6 KiB
Go

package main
import (
"context"
"flag"
"fmt"
"log"
"time"
"github.com/timetracker/backend/internal/config"
"github.com/timetracker/backend/internal/db"
"github.com/timetracker/backend/internal/models"
)
func main() {
dropTable := flag.String("drop_table", "", "Drop the specified table")
flag.Parse()
// Get database configuration with sensible defaults
dbConfig := config.DefaultDatabaseConfig()
// Initialize database
fmt.Println("Connecting to database...")
if err := db.InitDB(dbConfig); err != nil {
log.Fatalf("Error initializing database: %v", err)
}
defer func() {
if err := db.CloseDB(); err != nil {
log.Printf("Error closing database connection: %v", err)
}
}()
fmt.Println("✓ Database connection successful")
// Test a simple query
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// Get the database engine
db := db.GetEngine(ctx)
// Test database connection with a simple query
var result int
var err error
if *dropTable != "" {
fmt.Printf("Dropping table %s...\n", *dropTable)
dropErr := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s", *dropTable)).Error
if dropErr != nil {
log.Fatalf("Error dropping table %s: %v", *dropTable, dropErr)
}
fmt.Printf("✓ Table %s dropped successfully\n", *dropTable)
return
}
err = db.Raw("SELECT 1").Scan(&result).Error
if err != nil {
log.Fatalf("Error executing test query: %v", err)
}
fmt.Println("✓ Test query executed successfully")
// Check if tables exist
fmt.Println("Checking database tables...")
var tables []string
err = db.Raw("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'").Scan(&tables).Error
if err != nil {
log.Fatalf("Error checking tables: %v", err)
}
if len(tables) == 0 {
fmt.Println("No tables found. You may need to run migrations.")
fmt.Println("Attempting to run migrations...")
// Run migrations
if err := models.MigrateDB(); err != nil {
log.Fatalf("Error migrating database: %v", err)
}
fmt.Println("✓ Migrations completed successfully")
} else {
fmt.Println("Found tables:")
for _, table := range tables {
fmt.Printf(" - %s\n", table)
}
}
// Count users
var userCount int64
err = db.Model(&models.User{}).Count(&userCount).Error
if err != nil {
log.Fatalf("Error counting users: %v", err)
}
fmt.Printf("✓ User count: %d\n", userCount)
// Count companies
var companyCount int64
err = db.Model(&models.Company{}).Count(&companyCount).Error
if err != nil {
log.Fatalf("Error counting companies: %v", err)
}
fmt.Printf("✓ Company count: %d\n", companyCount)
fmt.Println("\nDatabase test completed successfully!")
}