@@ -64,6 +64,9 @@ class GpuGroup(Enum):
6464 HOPPER_141 = "HOPPER_141"
6565 """NVIDIA H200"""
6666
67+ BLACKWELL_96 = "BLACKWELL_96"
68+ """NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"""
69+
6770 BLACKWELL_180 = "BLACKWELL_180"
6871 """NVIDIA B200"""
6972
@@ -172,6 +175,9 @@ class GpuType(Enum):
172175 RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION = (
173176 "NVIDIA RTX PRO 6000 Blackwell Workstation Edition"
174177 )
178+ RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION = (
179+ "NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"
180+ )
175181 H100_80GB_HBM3 = "NVIDIA H100 80GB HBM3"
176182 RTX_A4000 = "NVIDIA RTX A4000"
177183 RTX_A4500 = "NVIDIA RTX A4500"
@@ -200,6 +206,26 @@ def is_gpu_type(cls, gpu_type: str) -> bool:
200206 return gpu_type in {m .value for m in cls }
201207
202208
209+ # deprecated aliases (old NVIDIA_ prefix names) set outside the class
210+ # so they don't interfere with cloudpickle serialization
211+ GpuType .NVIDIA_GEFORCE_RTX_4090 = GpuType .GEFORCE_RTX_4090
212+ GpuType .NVIDIA_GEFORCE_RTX_5090 = GpuType .GEFORCE_RTX_5090
213+ GpuType .NVIDIA_RTX_6000_ADA_GENERATION = GpuType .RTX_6000_ADA_GENERATION
214+ GpuType .NVIDIA_H100_80GB_HBM3 = GpuType .H100_80GB_HBM3
215+ GpuType .NVIDIA_RTX_A4000 = GpuType .RTX_A4000
216+ GpuType .NVIDIA_RTX_A4500 = GpuType .RTX_A4500
217+ GpuType .NVIDIA_RTX_4000_ADA_GENERATION = GpuType .RTX_4000_ADA_GENERATION
218+ GpuType .NVIDIA_RTX_2000_ADA_GENERATION = GpuType .RTX_2000_ADA_GENERATION
219+ GpuType .NVIDIA_RTX_A5000 = GpuType .RTX_A5000
220+ GpuType .NVIDIA_L4 = GpuType .L4
221+ GpuType .NVIDIA_GEFORCE_RTX_3090 = GpuType .GEFORCE_RTX_3090
222+ GpuType .NVIDIA_A40 = GpuType .A40
223+ GpuType .NVIDIA_RTX_A6000 = GpuType .RTX_A6000
224+ GpuType .NVIDIA_A100_80GB_PCIe = GpuType .A100_80GB_PCIe
225+ GpuType .NVIDIA_A100_SXM4_80GB = GpuType .A100_SXM4_80GB
226+ GpuType .NVIDIA_H200 = GpuType .H200
227+
228+
203229POOLS_TO_TYPES = {
204230 GpuGroup .ADA_24 : [GpuType .GEFORCE_RTX_4090 ],
205231 GpuGroup .ADA_32_PRO : [GpuType .GEFORCE_RTX_5090 ],
@@ -219,6 +245,11 @@ def is_gpu_type(cls, gpu_type: str) -> bool:
219245 GpuGroup .AMPERE_48 : [GpuType .A40 , GpuType .RTX_A6000 ],
220246 GpuGroup .AMPERE_80 : [GpuType .A100_80GB_PCIe , GpuType .A100_SXM4_80GB ],
221247 GpuGroup .HOPPER_141 : [GpuType .H200 ],
248+ GpuGroup .BLACKWELL_96 : [
249+ GpuType .RTX_PRO_6000_BLACKWELL_SERVER_EDITION ,
250+ GpuType .RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION ,
251+ GpuType .RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION ,
252+ ],
222253 GpuGroup .BLACKWELL_180 : [GpuType .B200 ],
223254}
224255
0 commit comments