Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions cosmos_framework/inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,8 +1272,7 @@ def _get_device_memory_bytes() -> int:
try:
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
pynvml.nvmlShutdown()
info = _get_nvml_device_memory_info(handle)
return info.total
except Exception:
# Fallback for unified memory architectures (e.g., GB10) where
Expand All @@ -1282,3 +1281,17 @@ def _get_device_memory_bytes() -> int:
if torch.cuda.is_available():
return int(torch.cuda.get_device_properties(0).total_memory)
return 128 * 1024**3 # Default 128GB
finally:
try:
pynvml.nvmlShutdown()
except Exception:
pass


def _get_nvml_device_memory_info(handle: Any) -> Any:
try:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a bit mis-understanding: we split the exception capture into two parts: one is in inside _get_nvml_device_memory_info, the other is outside this func. Should we keep the exception capture unified in one-place?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this is trivial, this pr is good for me.

return pynvml.nvmlDeviceGetMemoryInfo_v2(handle)
except AttributeError:
return pynvml.nvmlDeviceGetMemoryInfo(handle)
except pynvml.NVMLError_NotSupported:
return pynvml.nvmlDeviceGetMemoryInfo(handle)
26 changes: 26 additions & 0 deletions cosmos_framework/inference/args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OmniSampleOverrides,
OmniSetupOverrides,
SoundDataOverrides,
_get_nvml_device_memory_info,
)
from cosmos_framework.inference.common.config import structure_config

Expand Down Expand Up @@ -95,6 +96,31 @@ def test_build_parallelism(monkeypatch: pytest.MonkeyPatch):
assert parallelism_args.compile_dynamic is False


def test_get_nvml_device_memory_info_prefers_v2(monkeypatch: pytest.MonkeyPatch):
from cosmos_framework.inference import args

expected_info = types.SimpleNamespace(total=96 * 1024**3)

def fail_v1(_handle):
raise args.pynvml.NVMLError_NotSupported()

monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", lambda _handle: expected_info, raising=False)
monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo", fail_v1)

assert _get_nvml_device_memory_info(object()) is expected_info


def test_get_nvml_device_memory_info_falls_back_when_v2_unavailable(monkeypatch: pytest.MonkeyPatch):
from cosmos_framework.inference import args

expected_info = types.SimpleNamespace(total=80 * 1024**3)

monkeypatch.delattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", raising=False)
monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo", lambda _handle: expected_info)

assert _get_nvml_device_memory_info(object()) is expected_info


def test_checkpoints():
for name, ckpt in OmniSetupOverrides.CHECKPOINTS.items():
assert ckpt.hf.repository.split("/")[0] == "nvidia"
Expand Down