From 84bbc33b7b9e0bc59e6cde1e11abe201715418f1 Mon Sep 17 00:00:00 2001 From: Rohithmatham12 Date: Sat, 13 Jun 2026 18:01:42 -0400 Subject: [PATCH 1/2] fix: prefer NVML v2 memory info for inference setup --- cosmos_framework/inference/args.py | 17 ++++++++++++++-- cosmos_framework/inference/args_test.py | 26 +++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index 62c3c89..1524027 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -1272,8 +1272,7 @@ def _get_device_memory_bytes() -> int: try: pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) - info = pynvml.nvmlDeviceGetMemoryInfo(handle) - pynvml.nvmlShutdown() + info = _get_nvml_device_memory_info(handle) return info.total except Exception: # Fallback for unified memory architectures (e.g., GB10) where @@ -1282,3 +1281,17 @@ def _get_device_memory_bytes() -> int: if torch.cuda.is_available(): return int(torch.cuda.get_device_properties(0).total_memory) return 128 * 1024**3 # Default 128GB + finally: + try: + pynvml.nvmlShutdown() + except Exception: + pass + + +def _get_nvml_device_memory_info(handle: Any) -> Any: + try: + return pynvml.nvmlDeviceGetMemoryInfo_v2(handle) + except AttributeError: + return pynvml.nvmlDeviceGetMemoryInfo(handle) + except pynvml.NVMLError_NotSupported: + return pynvml.nvmlDeviceGetMemoryInfo(handle) diff --git a/cosmos_framework/inference/args_test.py b/cosmos_framework/inference/args_test.py index 631f87e..ec33f42 100644 --- a/cosmos_framework/inference/args_test.py +++ b/cosmos_framework/inference/args_test.py @@ -17,6 +17,7 @@ OmniSampleOverrides, OmniSetupOverrides, SoundDataOverrides, + _get_nvml_device_memory_info, ) from cosmos_framework.inference.common.config import structure_config @@ -95,6 +96,31 @@ def test_build_parallelism(monkeypatch: pytest.MonkeyPatch): assert parallelism_args.compile_dynamic is False +def test_get_nvml_device_memory_info_prefers_v2(monkeypatch: pytest.MonkeyPatch): + from cosmos_framework.inference import args + + expected_info = types.SimpleNamespace(total=96 * 1024**3) + + def fail_v1(_handle): + raise args.pynvml.NVMLError_NotSupported() + + monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", lambda _handle: expected_info) + monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo", fail_v1) + + assert _get_nvml_device_memory_info(object()) is expected_info + + +def test_get_nvml_device_memory_info_falls_back_when_v2_unavailable(monkeypatch: pytest.MonkeyPatch): + from cosmos_framework.inference import args + + expected_info = types.SimpleNamespace(total=80 * 1024**3) + + monkeypatch.delattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", raising=False) + monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo", lambda _handle: expected_info) + + assert _get_nvml_device_memory_info(object()) is expected_info + + def test_checkpoints(): for name, ckpt in OmniSetupOverrides.CHECKPOINTS.items(): assert ckpt.hf.repository.split("/")[0] == "nvidia" From 9876dd649fee36458974818421ce72104adc1313 Mon Sep 17 00:00:00 2001 From: Rohithmatham12 Date: Sun, 14 Jun 2026 23:31:14 -0400 Subject: [PATCH 2/2] test: allow missing NVML v2 symbol --- cosmos_framework/inference/args_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cosmos_framework/inference/args_test.py b/cosmos_framework/inference/args_test.py index ec33f42..d373bf2 100644 --- a/cosmos_framework/inference/args_test.py +++ b/cosmos_framework/inference/args_test.py @@ -104,7 +104,7 @@ def test_get_nvml_device_memory_info_prefers_v2(monkeypatch: pytest.MonkeyPatch) def fail_v1(_handle): raise args.pynvml.NVMLError_NotSupported() - monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", lambda _handle: expected_info) + monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo_v2", lambda _handle: expected_info, raising=False) monkeypatch.setattr(args.pynvml, "nvmlDeviceGetMemoryInfo", fail_v1) assert _get_nvml_device_memory_info(object()) is expected_info