Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 97 additions & 11 deletions src/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ package main

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net/http"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"
Expand All @@ -16,34 +18,40 @@ import (
)

type App struct {
redisClient *redis.Client
scheduler *Scheduler
supervisor *Supervisor
httpServer *http.Server
wg sync.WaitGroup
log *slog.Logger
redisClient *redis.Client
scheduler *Scheduler
supervisor *Supervisor
httpServer *http.Server
wg sync.WaitGroup
log *slog.Logger
statusRegistry *StatusRegistry
}

func NewApp(redisAddr, gpuType string, log *slog.Logger) *App {
client := redis.NewClient(&redis.Options{Addr: redisAddr})
scheduler := NewScheduler(redisAddr, log)
statusRegistry := NewStatusRegistry(client, log)

consumerID := fmt.Sprintf("worker_%d", os.Getpid())
supervisor := NewSupervisor(redisAddr, consumerID, gpuType, log)

mux := http.NewServeMux()
a := &App{
redisClient: client,
scheduler: scheduler,
supervisor: supervisor,
httpServer: &http.Server{Addr: ":3000", Handler: mux},
log: log,
redisClient: client,
scheduler: scheduler,
supervisor: supervisor,
httpServer: &http.Server{Addr: ":3000", Handler: mux},
log: log,
statusRegistry: statusRegistry,
}

mux.HandleFunc("/auth/login", a.login)
mux.HandleFunc("/auth/refresh", a.refresh)
mux.HandleFunc("/jobs", a.enqueueJob)
mux.HandleFunc("/jobs/status", a.getJobStatus)
mux.HandleFunc("/supervisors/status", a.getSupervisorStatus)
mux.HandleFunc("/supervisors/status/", a.getSupervisorStatusByID)
mux.HandleFunc("/supervisors", a.getAllSupervisors)

a.log.Info("new app initialized", "redis_address", redisAddr,
"gpu_type", gpuType, "http_address", a.httpServer.Addr)
Expand All @@ -58,6 +66,12 @@ func (a *App) Start() error {
return err
}

// Start supervisor
if err := a.supervisor.Start(); err != nil {
a.log.Error("supervisor start failed", "err", err)
return err
}

// Launch HTTP server
a.wg.Add(1)
go func() {
Expand Down Expand Up @@ -164,3 +178,75 @@ func (a *App) getJobStatus(w http.ResponseWriter, r *http.Request) {
id := r.URL.Query().Get("id")
fmt.Fprintln(w, "job id=", id)
}

func (a *App) getSupervisorStatus(w http.ResponseWriter, r *http.Request) {
supervisors, err := a.statusRegistry.GetAllSupervisors()
if err != nil {
a.log.Error("failed to get supervisor status", "error", err)
http.Error(w, "failed to get supervisor status", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]interface{}{
"supervisors": supervisors,
"count": len(supervisors),
}); err != nil {
a.log.Error("failed to encode supervisor status response", "error", err)
http.Error(w, "failed to encode response", http.StatusInternalServerError)
return
}
}

func (a *App) getSupervisorStatusByID(w http.ResponseWriter, r *http.Request) {
// extract consumer ID from URL path
path := strings.TrimPrefix(r.URL.Path, "/supervisors/status/")
if path == "" {
http.Error(w, "consumer ID required", http.StatusBadRequest)
return
}

supervisor, err := a.statusRegistry.GetSupervisor(path)
if err != nil {
a.log.Error("failed to get supervisor status", "consumer_id", path, "error", err)
http.Error(w, "supervisor not found", http.StatusNotFound)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(supervisor); err != nil {
a.log.Error("failed to encode supervisor status response", "error", err)
http.Error(w, "failed to encode response", http.StatusInternalServerError)
return
}
}

func (a *App) getAllSupervisors(w http.ResponseWriter, r *http.Request) {
activeOnly := r.URL.Query().Get("active") == "true"

var supervisors []SupervisorStatus
var err error

if activeOnly {
supervisors, err = a.statusRegistry.GetActiveSupervisors()
} else {
supervisors, err = a.statusRegistry.GetAllSupervisors()
}

if err != nil {
a.log.Error("failed to get supervisors", "active_only", activeOnly, "error", err)
http.Error(w, "failed to get supervisors", http.StatusInternalServerError)
return
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(map[string]interface{}{
"supervisors": supervisors,
"count": len(supervisors),
"active_only": activeOnly,
}); err != nil {
a.log.Error("failed to encode supervisors response", "error", err)
http.Error(w, "failed to encode response", http.StatusInternalServerError)
return
}
}
138 changes: 138 additions & 0 deletions src/int_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,51 @@ import (
"sync"
"syscall"
"testing"
"time"

"github.com/redis/go-redis/v9"
)

func addDummySupervisors(statusRegistry *StatusRegistry, log *slog.Logger) {
now := time.Now()

dummySupervisors := []SupervisorStatus{
{
ConsumerID: "worker_amd_001",
GPUType: "AMD",
Status: SupervisorStateActive,
LastSeen: now, // now
StartedAt: now.Add(-2 * time.Hour), // 2hours ago
},
{
ConsumerID: "worker_nvidia_002",
GPUType: "NVIDIA",
Status: SupervisorStateActive,
LastSeen: now.Add(-30 * time.Second), // 30 seconds ago
StartedAt: now.Add(-1 * time.Hour), // 1 hour ago
},
{
ConsumerID: "worker_tt_003",
GPUType: "TT",
Status: SupervisorStateInactive,
LastSeen: now.Add(-5 * time.Minute), // seen 5 minutes ago
StartedAt: now.Add(-3 * time.Hour), // 3 hours ago
},
}

for _, supervisor := range dummySupervisors {
if err := statusRegistry.UpdateStatus(supervisor.ConsumerID, supervisor); err != nil {
log.Error("failed to add dummy supervisor", "consumer_id", supervisor.ConsumerID, "error", err)
} else {
log.Info("added dummy supervisor", "consumer_id", supervisor.ConsumerID, "gpu_type", supervisor.GPUType, "status", supervisor.Status)
}
}
}

func TestIntegration(t *testing.T) {
os.Setenv("ENV", "test")
defer os.Unsetenv("ENV")

redisAddr := "localhost:6379"
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))

Expand Down Expand Up @@ -58,3 +98,101 @@ func TestIntegration(t *testing.T) {
wg.Wait()
supervisor.Stop()
}

func TestDummySupervisors(t *testing.T) {
redisAddr := "localhost:6379"
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))

// Test 1: Dummy supervisors should be added in test environment
os.Setenv("ENV", "test")
defer os.Unsetenv("ENV")

// Clean up Redis data before test
client := redis.NewClient(&redis.Options{Addr: redisAddr})
defer client.Close()
client.FlushDB(context.Background())

app := NewApp(redisAddr, "AMD", log)
defer app.redisClient.Close()

// Manually add dummy supervisors for testing
addDummySupervisors(app.statusRegistry, log)

supervisors, err := app.statusRegistry.GetAllSupervisors()
if err != nil {
t.Errorf("Failed to get supervisors: %v", err)
}
// Verify dummy supervisor IDs exist
dummyIDs := []string{"worker_amd_001", "worker_nvidia_002", "worker_tt_003"}
for _, dummyID := range dummyIDs {
found := false
for _, supervisor := range supervisors {
if supervisor.ConsumerID == dummyID {
found = true
break
}
}
if !found {
t.Errorf("Expected dummy supervisor %s not found", dummyID)
}
}
}

// Unit tests for StatusRegistry
func TestStatusRegistry_BasicOperations(t *testing.T) {
redisAddr := "localhost:6379"
log := slog.New(slog.NewJSONHandler(os.Stdout, nil))

client := redis.NewClient(&redis.Options{Addr: redisAddr})
defer client.Close()
client.FlushDB(context.Background())

registry := NewStatusRegistry(client, log)

now := time.Now()

// Test adding and retrieving a supervisor
status := SupervisorStatus{
ConsumerID: "test_worker_001",
GPUType: "AMD",
Status: SupervisorStateActive,
LastSeen: now,
StartedAt: now.Add(-1 * time.Hour),
}

// Add status
err := registry.UpdateStatus(status.ConsumerID, status)
if err != nil {
t.Errorf("UpdateStatus failed: %v", err)
}

// Retrieve status
retrievedStatus, err := registry.GetSupervisor(status.ConsumerID)
if err != nil {
t.Errorf("GetSupervisor failed: %v", err)
}

if retrievedStatus.Status != status.Status {
t.Errorf("Expected Status %s, got %s", status.Status, retrievedStatus.Status)
}

// Test getting all supervisors
allSupervisors, err := registry.GetAllSupervisors()
if err != nil {
t.Errorf("GetAllSupervisors failed: %v", err)
}

if len(allSupervisors) != 1 {
t.Errorf("Expected 1 supervisor, got %d", len(allSupervisors))
}

// Test getting active supervisors
activeSupervisors, err := registry.GetActiveSupervisors()
if err != nil {
t.Errorf("GetActiveSupervisors failed: %v", err)
}

if len(activeSupervisors) != 1 {
t.Errorf("Expected 1 active supervisor, got %d", len(activeSupervisors))
}
}
89 changes: 89 additions & 0 deletions src/status.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package main

import (
"context"
"encoding/json"
"fmt"
"log/slog"

"github.com/redis/go-redis/v9"
)

type StatusRegistry struct {
redisClient *redis.Client
log *slog.Logger
}

func NewStatusRegistry(redisClient *redis.Client, log *slog.Logger) *StatusRegistry {
return &StatusRegistry{
redisClient: redisClient,
log: log,
}
}

func (sr *StatusRegistry) GetAllSupervisors() ([]SupervisorStatus, error) {
ctx := context.Background()
result := sr.redisClient.HGetAll(ctx, SupervisorStatusKey)
if result.Err() != nil {
return nil, fmt.Errorf("failed to get supervisor status: %w", result.Err())
}

var supervisors []SupervisorStatus
for consumerID, statusJSON := range result.Val() {
var status SupervisorStatus
if err := json.Unmarshal([]byte(statusJSON), &status); err != nil {
sr.log.Error("failed to unmarshal supervisor status", "consumer_id", consumerID, "error", err)
continue
}
supervisors = append(supervisors, status)
}

return supervisors, nil
}

func (sr *StatusRegistry) GetSupervisor(consumerID string) (*SupervisorStatus, error) {
ctx := context.Background()
result := sr.redisClient.HGet(ctx, SupervisorStatusKey, consumerID)
if result.Err() != nil {
return nil, fmt.Errorf("failed to get supervisor status: %w", result.Err())
}

var status SupervisorStatus
if err := json.Unmarshal([]byte(result.Val()), &status); err != nil {
return nil, fmt.Errorf("failed to unmarshal supervisor status: %w", err)
}

return &status, nil
}

func (sr *StatusRegistry) GetActiveSupervisors() ([]SupervisorStatus, error) {
allSupervisors, err := sr.GetAllSupervisors()
if err != nil {
return nil, err
}

var activeSupervisors []SupervisorStatus
for _, supervisor := range allSupervisors {
if supervisor.Status == SupervisorStateActive {
activeSupervisors = append(activeSupervisors, supervisor)
}
}

return activeSupervisors, nil
}

func (sr *StatusRegistry) UpdateStatus(consumerID string, status SupervisorStatus) error {
ctx := context.Background()
statusJSON, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to marshal supervisor status: %w", err)
}

result := sr.redisClient.HSet(ctx, SupervisorStatusKey, consumerID, string(statusJSON))
if result.Err() != nil {
return fmt.Errorf("failed to update supervisor status: %w", result.Err())
}

sr.log.Info("supervisor status updated", "consumer_id", consumerID, "status", status.Status)
return nil
}
Loading