Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cmd/slurm-tracker/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,18 @@ func run(cmd *cobra.Command, args []string) error {
Msg("Starting Slurm usage event collector")

// Initialize state driver
stateDriver, err := state.NewStateDriver(cfg.StateFile)
stateDriver, err := state.NewDriver(cfg.StateFile)
if err != nil {
return fmt.Errorf("failed to initialize state driver: %w", err)
}
defer func() {
if err := stateDriver.Shutdown(); err != nil {
log.Error().Err(err).Msg("Error shutting down state driver")
if shutdownErr := stateDriver.Shutdown(); shutdownErr != nil {
log.Error().Err(shutdownErr).Msg("Error shutting down state driver")
}
}()

// Get Slurm jobs
jobs, err := slurm.GetSlurmJobs(cfg.LookbackMinutes)
jobs, err := slurm.GetJobs(cfg.LookbackMinutes)
if err != nil {
return fmt.Errorf("failed to get Slurm jobs: %w", err)
}
Expand All @@ -124,20 +124,20 @@ func run(cmd *cobra.Command, args []string) error {
semaphore := make(chan struct{}, 10)
waitGroup := sync.WaitGroup{}
// Process each job and create usage events
for _, job := range jobs {
for i := range jobs {
waitGroup.Add(1)
semaphore <- struct{}{}
go func(job slurm.SlurmJob) {
go func(job *slurm.Job) {
defer waitGroup.Done()
defer func() { <-semaphore }()
if err := tracker.ProcessJob(cfg, job, stateDriver, pwClient, cfg.DryRun); err != nil {
if err := tracker.ProcessJob(&cfg, job, stateDriver, pwClient, cfg.DryRun); err != nil {
log.Error().
Err(err).
Int("job_id", job.JobID).
Str("job_name", job.Name).
Msg("Failed to process job")
}
}(job)
}(&jobs[i])
}
waitGroup.Wait()

Expand Down
8 changes: 4 additions & 4 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ type PartitionMapping struct {
SKU string `json:"sku"`
}

// ConfigFile represents the JSON configuration file structure
type ConfigFile struct {
// File represents the JSON configuration file structure
type File struct {
DefaultSku string `json:"defaultSku"`
DefaultAllocation string `json:"defaultAllocation"`
Partition []PartitionMapping `json:"partition"`
Expand Down Expand Up @@ -49,10 +49,10 @@ func LoadConfigFile(path string, cfg *Config) error {
if os.IsNotExist(err) {
return fmt.Errorf("config file %s not found: please create the config file or specify the correct path", path)
}
return fmt.Errorf("failed to read config file %s: %w. Please check file permissions or path.", path, err)
return fmt.Errorf("failed to read config file %s: %w", path, err)
}

var configFile ConfigFile
var configFile File
if err := json.Unmarshal(data, &configFile); err != nil {
return fmt.Errorf("failed to parse config file: %w", err)
}
Expand Down
16 changes: 8 additions & 8 deletions internal/slurm/slurm.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type SacctOutput struct {
} `json:"version"`
} `json:"Slurm"`
} `json:"meta"`
Jobs []SlurmJob `json:"jobs"`
Jobs []Job `json:"jobs"`
Errors []any `json:"errors"`
Warnings []any `json:"warnings"`
}
Expand All @@ -36,8 +36,8 @@ type NumberValue struct {
Number int `json:"number"`
}

// SlurmJob represents a single job from sacct output
type SlurmJob struct {
// Job represents a single job from sacct output
type Job struct {
JobID int `json:"job_id"`
Name string `json:"name"`
User string `json:"user"`
Expand Down Expand Up @@ -107,8 +107,8 @@ type TresAlloc struct {
Count int `json:"count"`
}

// GetSlurmJobs queries sacct and returns parsed job data
func GetSlurmJobs(lookbackMinutes int) ([]SlurmJob, error) {
// GetJobs queries sacct and returns parsed job data
func GetJobs(lookbackMinutes int) ([]Job, error) {
// Calculate the start time
startTime := time.Now().Add(-time.Duration(lookbackMinutes) * time.Minute)
startTimeStr := startTime.Format("2006-01-02T15:04:05")
Expand Down Expand Up @@ -156,12 +156,12 @@ func GetSlurmJobs(lookbackMinutes int) ([]SlurmJob, error) {
}

// IsJobRunning returns true if the job is currently running
func IsJobRunning(job SlurmJob) bool {
func IsJobRunning(job *Job) bool {
return job.State.Current == "RUNNING"
}

// IsJobCompleted returns true if the job has reached a terminal state
func IsJobCompleted(job SlurmJob) bool {
func IsJobCompleted(job *Job) bool {
completedStates := map[string]bool{
"COMPLETED": true,
"FAILED": true,
Expand All @@ -175,7 +175,7 @@ func IsJobCompleted(job SlurmJob) bool {
}

// CalculateCoreHoursForElapsed calculates core hours for a given elapsed time
func CalculateCoreHoursForElapsed(job SlurmJob, elapsedSeconds int) float64 {
func CalculateCoreHoursForElapsed(job *Job, elapsedSeconds int) float64 {
// Elapsed time is in seconds
elapsedHours := float64(elapsedSeconds) / 3600.0

Expand Down
28 changes: 14 additions & 14 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"sync"
"time"

_ "modernc.org/sqlite"
_ "modernc.org/sqlite" // register sqlite driver

"github.com/rs/zerolog/log"
)
Expand All @@ -20,15 +20,15 @@ type JobState struct {
CompletedAt int64 `json:"completed_at,omitempty"` // unix timestamp when job completed (0 if still running)
}

// StateDriver manages concurrent access to job states with SQLite persistence
type StateDriver struct {
// Driver manages concurrent access to job states with SQLite persistence
type Driver struct {
db *sql.DB
dbPath string
mutex sync.RWMutex
}

// NewStateDriver creates a new state driver with SQLite backend
func NewStateDriver(dbPath string) (*StateDriver, error) {
// NewDriver creates a new state driver with SQLite backend
func NewDriver(dbPath string) (*Driver, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
Expand All @@ -38,7 +38,7 @@ func NewStateDriver(dbPath string) (*StateDriver, error) {
db.SetMaxOpenConns(1) // SQLite works best with single writer
db.SetMaxIdleConns(1)

driver := &StateDriver{
driver := &Driver{
db: db,
dbPath: dbPath,
}
Expand All @@ -57,7 +57,7 @@ func NewStateDriver(dbPath string) (*StateDriver, error) {
}

// initSchema creates the job_states table if it doesn't exist
func (d *StateDriver) initSchema() error {
func (d *Driver) initSchema() error {
query := `
CREATE TABLE IF NOT EXISTS job_states (
job_id INTEGER PRIMARY KEY,
Expand All @@ -73,7 +73,7 @@ func (d *StateDriver) initSchema() error {
}

// GetState retrieves a job state by ID
func (d *StateDriver) GetState(jobID int) (JobState, bool) {
func (d *Driver) GetState(jobID int) (JobState, bool) {
var state JobState
query := `
SELECT job_id, last_reported_elapsed, last_reported_at, total_core_hours, completed_at
Expand Down Expand Up @@ -103,7 +103,7 @@ func (d *StateDriver) GetState(jobID int) (JobState, bool) {
}

// UpdateState updates a job state immediately
func (d *StateDriver) UpdateState(state JobState) {
func (d *Driver) UpdateState(state JobState) {
query := `
INSERT INTO job_states (job_id, last_reported_elapsed, last_reported_at, total_core_hours, completed_at)
VALUES (?, ?, ?, ?, ?)
Expand All @@ -130,7 +130,7 @@ func (d *StateDriver) UpdateState(state JobState) {
}

// DeleteState removes a job state
func (d *StateDriver) DeleteState(jobID int) {
func (d *Driver) DeleteState(jobID int) {
d.mutex.Lock()
_, err := d.db.Exec("DELETE FROM job_states WHERE job_id = ?", jobID)
d.mutex.Unlock()
Expand All @@ -141,7 +141,7 @@ func (d *StateDriver) DeleteState(jobID int) {
}

// GetAllStates returns all job states
func (d *StateDriver) GetAllStates() map[int]JobState {
func (d *Driver) GetAllStates() map[int]JobState {
query := `
SELECT job_id, last_reported_elapsed, last_reported_at, total_core_hours, completed_at
FROM job_states
Expand Down Expand Up @@ -177,14 +177,14 @@ func (d *StateDriver) GetAllStates() map[int]JobState {
}

// getJobCount returns the number of tracked jobs
func (d *StateDriver) getJobCount() (int, error) {
func (d *Driver) getJobCount() (int, error) {
var count int
err := d.db.QueryRow("SELECT COUNT(*) FROM job_states").Scan(&count)
return count, err
}

// Shutdown gracefully shuts down the driver
func (d *StateDriver) Shutdown() error {
func (d *Driver) Shutdown() error {
if err := d.db.Close(); err != nil {
log.Error().Err(err).Msg("Error closing database")
return err
Expand All @@ -195,7 +195,7 @@ func (d *StateDriver) Shutdown() error {
}

// CleanupOldStates removes states for jobs completed more than the specified duration ago
func (d *StateDriver) CleanupOldStates(olderThan time.Duration) int {
func (d *Driver) CleanupOldStates(olderThan time.Duration) int {
cutoffTime := time.Now().Add(-olderThan).Unix()

d.mutex.Lock()
Expand Down
4 changes: 2 additions & 2 deletions internal/tracker/tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ type UsageEventResponse struct {
}

// ProcessJob processes a single Slurm job, calculating usage and reporting it
func ProcessJob(cfg config.Config, job slurm.SlurmJob, stateDriver *state.StateDriver, pwClient *parallelworks.ClientWithResponses, dryRun bool) error {
func ProcessJob(cfg *config.Config, job *slurm.Job, stateDriver *state.Driver, pwClient *parallelworks.ClientWithResponses, dryRun bool) error {
isRunning := slurm.IsJobRunning(job)
isCompleted := slurm.IsJobCompleted(job)

// Skip jobs that are neither running nor completed
if !isRunning && !isCompleted {
log.Debug().
Int("job_id", job.JobID).
Str("state", fmt.Sprintf("%v", job.State.Current)).
Str("state", job.State.Current).
Msg("Skipping job (not running or completed)")
return nil
}
Expand Down
Loading