Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/runpod_flash/cli/utils/skeleton_template/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ QB workers process jobs from a queue. Each call to `/runsync` sends a job and wa
for the result. Use QB for compute-heavy tasks that may take seconds to minutes.

**gpu_worker.py** — GPU serverless function:

```python
from runpod_flash import Endpoint, GpuType

Expand All @@ -90,6 +91,7 @@ async def gpu_hello(input_data: dict) -> dict:
```

**cpu_worker.py** — CPU serverless function:

```python
from runpod_flash import Endpoint

Expand All @@ -104,6 +106,7 @@ LB workers expose standard HTTP endpoints (GET, POST, etc.) behind a load balanc
Use LB for low-latency API endpoints that need horizontal scaling.

**lb_worker.py** — HTTP endpoints on a load-balanced container:

```python
from runpod_flash import Endpoint

Expand Down Expand Up @@ -156,17 +159,18 @@ Then run `flash run` -- the new worker appears automatically.

## GPU Types

| Config | Hardware | VRAM |
|--------|----------|------|
| `GpuType.ANY` | Any available GPU | varies |
| `GpuType.NVIDIA_GEFORCE_RTX_4090` | RTX 4090 | 24 GB |
| `GpuType.NVIDIA_GEFORCE_RTX_5090` | RTX 5090 | 32 GB |
| `GpuType.NVIDIA_RTX_6000_ADA_GENERATION` | RTX 6000 Ada | 48 GB |
| `GpuType.NVIDIA_L4` | L4 | 24 GB |
| `GpuType.NVIDIA_A100_80GB_PCIe` | A100 PCIe | 80 GB |
| `GpuType.NVIDIA_A100_SXM4_80GB` | A100 SXM4 | 80 GB |
| `GpuType.NVIDIA_H100_80GB_HBM3` | H100 | 80 GB |
| `GpuType.NVIDIA_H200` | H200 | 141 GB |
| Config | Hardware | VRAM |
| ----------------------------------------- | ----------------- | ------ |
| `GpuType.ANY` | Any available GPU | varies |
| `GpuType.NVIDIA_GEFORCE_RTX_4090` | RTX 4090 | 24 GB |
| `GpuType.NVIDIA_GEFORCE_RTX_5090` | RTX 5090 | 32 GB |
| `GpuType.NVIDIA_RTX_6000_ADA_GENERATION` | RTX 6000 Ada | 48 GB |
| `GpuType.NVIDIA_L4` | L4 | 24 GB |
| `GpuType.NVIDIA_A100_80GB_PCIe` | A100 PCIe | 80 GB |
| `GpuType.NVIDIA_A100_SXM4_80GB` | A100 SXM4 | 80 GB |
| `GpuType.NVIDIA_H100_80GB_HBM3` | H100 | 80 GB |
| `GpuType.NVIDIA_H200` | H200 | 141 GB |
| `GpuType.NVIDIA_B200` | B200 | 180 GB |

## CPU Types

Expand Down
30 changes: 27 additions & 3 deletions src/runpod_flash/core/resources/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ class GpuGroup(Enum):
HOPPER_141 = "HOPPER_141"
"""NVIDIA H200"""

BLACKWELL_96 = "BLACKWELL_96"
"""NVIDIA RTX PRO 6000 Blackwell Server Edition, NVIDIA RTX PRO 6000 Blackwell Workstation Edition, NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"""

BLACKWELL_180 = "BLACKWELL_180"
"""NVIDIA B200"""

@classmethod
def all(cls) -> List["GpuGroup"]:
"""Returns all GPU groups."""
Expand Down Expand Up @@ -163,6 +169,15 @@ class GpuType(Enum):
NVIDIA_GEFORCE_RTX_4090 = "NVIDIA GeForce RTX 4090"
NVIDIA_GEFORCE_RTX_5090 = "NVIDIA GeForce RTX 5090"
NVIDIA_RTX_6000_ADA_GENERATION = "NVIDIA RTX 6000 Ada Generation"
NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION = (
"NVIDIA RTX PRO 6000 Blackwell Server Edition"
)
NVIDIA_RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION = (
"NVIDIA RTX PRO 6000 Blackwell Workstation Edition"
)
NVIDIA_RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION = (
"NVIDIA RTX PRO 6000 Blackwell Max-Q Workstation Edition"
)
NVIDIA_H100_80GB_HBM3 = "NVIDIA H100 80GB HBM3"
NVIDIA_RTX_A4000 = "NVIDIA RTX A4000"
NVIDIA_RTX_A4500 = "NVIDIA RTX A4500"
Expand All @@ -176,6 +191,7 @@ class GpuType(Enum):
NVIDIA_A100_80GB_PCIe = "NVIDIA A100 80GB PCIe"
NVIDIA_A100_SXM4_80GB = "NVIDIA A100-SXM4-80GB"
NVIDIA_H200 = "NVIDIA H200"
NVIDIA_B200 = "NVIDIA B200"

@classmethod
def all(cls) -> List["GpuType"]:
Expand All @@ -185,9 +201,11 @@ def all(cls) -> List["GpuType"]:
@classmethod
def is_gpu_type(cls, gpu_type: str) -> bool:
"""
Check if a string is a valid GPU type.
Check if a string is a valid GPU type, excluding ANY.
Uses all() to avoid matching the "any" value, which should be
treated as a GpuGroup in from_gpu_ids_str.
"""
return gpu_type in {m.value for m in cls}
return gpu_type in {m.value for m in cls.all()}


POOLS_TO_TYPES = {
Expand All @@ -209,10 +227,16 @@ def is_gpu_type(cls, gpu_type: str) -> bool:
GpuGroup.AMPERE_48: [GpuType.NVIDIA_A40, GpuType.NVIDIA_RTX_A6000],
GpuGroup.AMPERE_80: [GpuType.NVIDIA_A100_80GB_PCIe, GpuType.NVIDIA_A100_SXM4_80GB],
GpuGroup.HOPPER_141: [GpuType.NVIDIA_H200],
GpuGroup.BLACKWELL_96: [
GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION,
GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_WORKSTATION_EDITION,
GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_MAX_Q_WORKSTATION_EDITION,
],
GpuGroup.BLACKWELL_180: [GpuType.NVIDIA_B200],
}


def _pool_from_gpu_type(gpu_type: GpuType) -> str:
def _pool_from_gpu_type(gpu_type: GpuType) -> Optional[GpuGroup]:
for group, types in POOLS_TO_TYPES.items():
if gpu_type in types:
return group
Expand Down
50 changes: 47 additions & 3 deletions tests/unit/resources/test_gpu_ids.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from runpod_flash.core.resources.gpu import GpuGroup, GpuType
from runpod_flash.core.resources.gpu import GpuGroup, GpuType, POOLS_TO_TYPES


class TestGpuIdsImports:
Expand All @@ -20,6 +20,50 @@ def test_from_gpu_ids_str_pool_only_returns_group(self):
parsed = GpuGroup.from_gpu_ids_str("AMPERE_24")
assert parsed == [GpuGroup.AMPERE_24]

def test_gpu_type_is_gpu_type_checks_enum_member_names(self):
assert GpuType.is_gpu_type("NVIDIA_L4") is False
def test_gpu_type_is_gpu_type_checks_enum_values(self):
assert GpuType.is_gpu_type("L4") is False
assert GpuType.is_gpu_type("NVIDIA L4") is True

def test_gpu_type_is_gpu_type_excludes_any(self):
assert GpuType.is_gpu_type("any") is False

def test_any_round_trip(self):
gpu_ids = GpuGroup.to_gpu_ids_str([GpuGroup.ANY])
assert gpu_ids == "any"
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert parsed == [GpuGroup.ANY]

def test_blackwell_groups_round_trip(self):
gpu_ids = GpuGroup.to_gpu_ids_str([GpuGroup.BLACKWELL_96])
assert "BLACKWELL_96" in gpu_ids
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert parsed == [GpuGroup.BLACKWELL_96]

gpu_ids = GpuGroup.to_gpu_ids_str([GpuGroup.BLACKWELL_180])
assert "BLACKWELL_180" in gpu_ids
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert parsed == [GpuGroup.BLACKWELL_180]

def test_b200_type_maps_to_blackwell_180(self):
gpu_ids = GpuGroup.to_gpu_ids_str([GpuType.NVIDIA_B200])
assert "BLACKWELL_180" in gpu_ids
# b200 is the only type in BLACKWELL_180, so no negations needed
# and from_gpu_ids_str returns the group
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert parsed == [GpuGroup.BLACKWELL_180]

def test_rtx_pro_6000_type_maps_to_blackwell_96(self):
gpu_ids = GpuGroup.to_gpu_ids_str(
[GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION]
)
assert "BLACKWELL_96" in gpu_ids
# other RTX PRO 6000 variants are negated
parsed = GpuGroup.from_gpu_ids_str(gpu_ids)
assert GpuType.NVIDIA_RTX_PRO_6000_BLACKWELL_SERVER_EDITION in parsed

def test_every_gpu_type_has_pool_mapping(self):
all_mapped = set()
for types in POOLS_TO_TYPES.values():
all_mapped.update(types)
for gpu_type in GpuType.all():
assert gpu_type in all_mapped, f"{gpu_type.name} has no pool mapping"