diff --git a/experimental/air/cmd/compute.go b/experimental/air/cmd/compute.go new file mode 100644 index 0000000000..d45a376f22 --- /dev/null +++ b/experimental/air/cmd/compute.go @@ -0,0 +1,79 @@ +package aircmd + +import ( + "fmt" + "strings" +) + +// gpuType is a wire-facing accelerator type submitted to the training service. +// The number in the name is the partition count (e.g. GPU_8xH100 is 8 GPUs). +type gpuType string + +const ( + gpuType1xA10 gpuType = "GPU_1xA10" + gpuType8xH100 gpuType = "GPU_8xH100" + gpuType1xH100 gpuType = "GPU_1xH100" +) + +// gpuTypes lists every valid type. Used for validation error messages. +var gpuTypes = []gpuType{gpuType1xA10, gpuType1xH100, gpuType8xH100} + +func validGPUTypesHint() string { + names := make([]string, len(gpuTypes)) + for i, g := range gpuTypes { + names[i] = string(g) + } + return "valid types are: " + strings.Join(names, ", ") +} + +// parseGPUType resolves a YAML accelerator_type string to a gpuType. The match is +// exact: the server's lookup is case-sensitive. +func parseGPUType(value string) (gpuType, error) { + switch gpuType(value) { + case gpuType1xA10, gpuType8xH100, gpuType1xH100: + return gpuType(value), nil + } + return "", fmt.Errorf("invalid GPU type %q: %s", value, validGPUTypesHint()) +} + +// gpusPerNode returns the per-node GPU count, which is the partition count from +// the name (GPU_1xH100 -> 1, GPU_8xH100 -> 8). num_accelerators must be a +// round multiple of this since accelerators are allocated in whole nodes. +func gpusPerNode(g gpuType) (int, error) { + switch g { + case gpuType1xA10, gpuType1xH100: + return 1, nil + case gpuType8xH100: + return 8, nil + } + return 0, fmt.Errorf("invalid GPU type %q", string(g)) +} + +// computeConfig is the `compute` block of the run YAML: which accelerators to +// use and how many. +type computeConfig struct { + NumAccelerators int `yaml:"num_accelerators"` + AcceleratorType string `yaml:"accelerator_type"` +} + +// validate checks the compute block against the backend's constraints. +func (c computeConfig) validate() error { + g, err := parseGPUType(c.AcceleratorType) + if err != nil { + return fmt.Errorf("compute.accelerator_type: %w", err) + } + + if c.NumAccelerators <= 0 { + return fmt.Errorf("compute.num_accelerators must be positive, got %d", c.NumAccelerators) + } + + perNode, err := gpusPerNode(g) + if err != nil { + return err + } + if c.NumAccelerators%perNode != 0 { + return fmt.Errorf("compute.num_accelerators for %s must be a multiple of %d, got %d", c.AcceleratorType, perNode, c.NumAccelerators) + } + + return nil +} diff --git a/experimental/air/cmd/compute_test.go b/experimental/air/cmd/compute_test.go new file mode 100644 index 0000000000..3464afbe9e --- /dev/null +++ b/experimental/air/cmd/compute_test.go @@ -0,0 +1,86 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseGPUType(t *testing.T) { + tests := []struct { + in string + want gpuType + }{ + {"GPU_1xA10", gpuType1xA10}, + {"GPU_8xH100", gpuType8xH100}, + {"GPU_1xH100", gpuType1xH100}, + } + for _, tt := range tests { + t.Run(tt.in, func(t *testing.T) { + got, err := parseGPUType(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseGPUTypeInvalid(t *testing.T) { + // Wrong casing is rejected rather than fixed up; legacy types (h100_80gb, a10) + // can no longer be submitted; unknown types are rejected. + for _, in := range []string{"gpu_1xa10", "GPU_1XA10", "GPU_2xH100", "h100_80gb", "a10", "b200", ""} { + t.Run(in, func(t *testing.T) { + _, err := parseGPUType(in) + require.Error(t, err) + assert.Contains(t, err.Error(), "valid types are") + }) + } +} + +func TestGPUsPerNode(t *testing.T) { + tests := []struct { + in gpuType + want int + }{ + {gpuType1xA10, 1}, + {gpuType1xH100, 1}, + {gpuType8xH100, 8}, + } + for _, tt := range tests { + t.Run(string(tt.in), func(t *testing.T) { + got, err := gpusPerNode(tt.in) + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } + + _, err := gpusPerNode(gpuType("nonsense")) + require.Error(t, err) +} + +func TestComputeConfigValidate(t *testing.T) { + tests := []struct { + name string + cfg computeConfig + wantErr string // substring; empty means the config is valid + }{ + {"single node", computeConfig{NumAccelerators: 8, AcceleratorType: "GPU_8xH100"}, ""}, + {"multiple nodes", computeConfig{NumAccelerators: 16, AcceleratorType: "GPU_8xH100"}, ""}, + {"single-gpu partitions", computeConfig{NumAccelerators: 3, AcceleratorType: "GPU_1xH100"}, ""}, + {"unknown type", computeConfig{NumAccelerators: 8, AcceleratorType: "b200"}, "accelerator_type"}, + {"legacy type rejected", computeConfig{NumAccelerators: 8, AcceleratorType: "h100_80gb"}, "accelerator_type"}, + {"non-positive count", computeConfig{NumAccelerators: 0, AcceleratorType: "GPU_1xH100"}, "must be positive"}, + {"count not a multiple", computeConfig{NumAccelerators: 4, AcceleratorType: "GPU_8xH100"}, "multiple of 8"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.validate() + if tt.wantErr == "" { + require.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +}