Skip to content

Commit 18e88e0

Browse files
committed
fix: add BLACKWELL_96 pool, RTX PRO 6000 Max-Q type, and deprecated NVIDIA_ aliases
1 parent 961cd7a commit 18e88e0

2 files changed

Lines changed: 73 additions & 1 deletion

File tree

src/runpod_flash/core/resources/gpu.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class GpuGroup(Enum):
6464
HOPPER_141 = "HOPPER_141"
6565
"""NVIDIA H200"""
6666

67+
BLACKWELL_96 = "BLACKWELL_96"
68+
"""NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"""
69+
6770
BLACKWELL_180 = "BLACKWELL_180"
6871
"""NVIDIA B200"""
6972

@@ -172,6 +175,9 @@ class GpuType(Enum):
172175
RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION = (
173176
"NVIDIA RTX PRO 6000 Blackwell Workstation Edition"
174177
)
178+
RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION = (
179+
"NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"
180+
)
175181
H100_80GB_HBM3 = "NVIDIA H100 80GB HBM3"
176182
RTX_A4000 = "NVIDIA RTX A4000"
177183
RTX_A4500 = "NVIDIA RTX A4500"
@@ -200,6 +206,26 @@ def is_gpu_type(cls, gpu_type: str) -> bool:
200206
return gpu_type in {m.value for m in cls}
201207

202208

209+
# deprecated aliases (old NVIDIA_ prefix names) set outside the class
210+
# so they don't interfere with cloudpickle serialization
211+
GpuType.NVIDIA_GEFORCE_RTX_4090 = GpuType.GEFORCE_RTX_4090
212+
GpuType.NVIDIA_GEFORCE_RTX_5090 = GpuType.GEFORCE_RTX_5090
213+
GpuType.NVIDIA_RTX_6000_ADA_GENERATION = GpuType.RTX_6000_ADA_GENERATION
214+
GpuType.NVIDIA_H100_80GB_HBM3 = GpuType.H100_80GB_HBM3
215+
GpuType.NVIDIA_RTX_A4000 = GpuType.RTX_A4000
216+
GpuType.NVIDIA_RTX_A4500 = GpuType.RTX_A4500
217+
GpuType.NVIDIA_RTX_4000_ADA_GENERATION = GpuType.RTX_4000_ADA_GENERATION
218+
GpuType.NVIDIA_RTX_2000_ADA_GENERATION = GpuType.RTX_2000_ADA_GENERATION
219+
GpuType.NVIDIA_RTX_A5000 = GpuType.RTX_A5000
220+
GpuType.NVIDIA_L4 = GpuType.L4
221+
GpuType.NVIDIA_GEFORCE_RTX_3090 = GpuType.GEFORCE_RTX_3090
222+
GpuType.NVIDIA_A40 = GpuType.A40
223+
GpuType.NVIDIA_RTX_A6000 = GpuType.RTX_A6000
224+
GpuType.NVIDIA_A100_80GB_PCIe = GpuType.A100_80GB_PCIe
225+
GpuType.NVIDIA_A100_SXM4_80GB = GpuType.A100_SXM4_80GB
226+
GpuType.NVIDIA_H200 = GpuType.H200
227+
228+
203229
POOLS_TO_TYPES = {
204230
GpuGroup.ADA_24: [GpuType.GEFORCE_RTX_4090],
205231
GpuGroup.ADA_32_PRO: [GpuType.GEFORCE_RTX_5090],
@@ -219,6 +245,11 @@ def is_gpu_type(cls, gpu_type: str) -> bool:
219245
GpuGroup.AMPERE_48: [GpuType.A40, GpuType.RTX_A6000],
220246
GpuGroup.AMPERE_80: [GpuType.A100_80GB_PCIe, GpuType.A100_SXM4_80GB],
221247
GpuGroup.HOPPER_141: [GpuType.H200],
248+
GpuGroup.BLACKWELL_96: [
249+
GpuType.RTX_PRO_6000_BLACKWELL_SERVER_EDITION,
250+
GpuType.RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION,
251+
GpuType.RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION,
252+
],
222253
GpuGroup.BLACKWELL_180: [GpuType.B200],
223254
}
224255

tests/unit/resources/test_gpu_ids.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from runpod_flash.core.resources.gpu import GpuGroup, GpuType
1+
from runpod_flash.core.resources.gpu import GpuGroup, GpuType, POOLS_TO_TYPES
22

33

44
class TestGpuIdsImports:
@@ -23,3 +23,44 @@ def test_from_gpu_ids_str_pool_only_returns_group(self):
2323
def test_gpu_type_is_gpu_type_checks_enum_values(self):
2424
assert GpuType.is_gpu_type("L4") is False
2525
assert GpuType.is_gpu_type("NVIDIA L4") is True
26+
27+
def test_blackwell_groups_round_trip(self):
28+
gpu_ids = GpuGroup.to_gpu_ids_str([GpuGroup.BLACKWELL_96])
29+
assert "BLACKWELL_96" in gpu_ids
30+
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
31+
assert parsed == [GpuGroup.BLACKWELL_96]
32+
33+
gpu_ids = GpuGroup.to_gpu_ids_str([GpuGroup.BLACKWELL_180])
34+
assert "BLACKWELL_180" in gpu_ids
35+
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
36+
assert parsed == [GpuGroup.BLACKWELL_180]
37+
38+
def test_b200_type_maps_to_blackwell_180(self):
39+
gpu_ids = GpuGroup.to_gpu_ids_str([GpuType.B200])
40+
assert "BLACKWELL_180" in gpu_ids
41+
# b200 is the only type in BLACKWELL_180, so no negations needed
42+
# and from_gpu_ids_str returns the group
43+
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
44+
assert parsed == [GpuGroup.BLACKWELL_180]
45+
46+
def test_rtx_pro_6000_type_maps_to_blackwell_96(self):
47+
gpu_ids = GpuGroup.to_gpu_ids_str(
48+
[GpuType.RTX_PRO_6000_BLACKWELL_SERVER_EDITION]
49+
)
50+
assert "BLACKWELL_96" in gpu_ids
51+
# other RTX PRO 6000 variants are negated
52+
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
53+
assert GpuType.RTX_PRO_6000_BLACKWELL_SERVER_EDITION in parsed
54+
55+
def test_every_gpu_type_has_pool_mapping(self):
56+
all_mapped = set()
57+
for types in POOLS_TO_TYPES.values():
58+
all_mapped.update(types)
59+
for gpu_type in GpuType.all():
60+
assert gpu_type in all_mapped, f"{gpu_type.name} has no pool mapping"
61+
62+
def test_deprecated_aliases_resolve_to_canonical(self):
63+
assert GpuType.NVIDIA_L4 is GpuType.L4
64+
assert GpuType.NVIDIA_A40 is GpuType.A40
65+
assert GpuType.NVIDIA_H200 is GpuType.H200
66+
assert GpuType.NVIDIA_GEFORCE_RTX_4090 is GpuType.GEFORCE_RTX_4090

0 commit comments

Comments
 (0)