Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions cosmos_framework/data/vfm/action/action_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,14 @@
from cosmos_framework.utils import log


def load_action_stats(stats_path: str, stats_key: str = "global") -> dict[str, np.ndarray]:
def load_action_stats(stats_path: str) -> dict[str, np.ndarray]:
"""Load pre-computed action normalization stats from a JSON file."""
path = Path(stats_path)
if not path.exists():
raise FileNotFoundError(f"Action normalization stats not found at {stats_path}.")
log.info(f"Loading action normalization stats from {stats_path}")
with path.open("r") as f:
raw = json.load(f)
if stats_key in raw:
raw = raw[stats_key]
if not isinstance(raw, dict):
raise TypeError(f"Action normalization stats block {stats_key!r} in {stats_path} must be a dict.")
elif stats_key != "global":
raise KeyError(f"Action normalization stats block {stats_key!r} not found in {stats_path}.")
stat_keys = {"mean", "std", "min", "max", "q01", "q99"}
return {key: np.array(value, dtype=np.float32) for key, value in raw.items() if key in stat_keys}

Expand All @@ -39,11 +33,28 @@ def normalize_action(
if method == "quantile":
q01, q99 = stats["q01"], stats["q99"]
denom = (q99 - q01).clamp(min=1e-8)
return (2.0 * (action - q01) / denom - 1.0).clamp(-1.0, 1.0)
return 2.0 * (action - q01) / denom - 1.0
if method == "meanstd":
return (action - stats["mean"]) / stats["std"].clamp(min=1e-8)
if method == "minmax":
lo, hi = stats["min"], stats["max"]
denom = (hi - lo).clamp(min=1e-8)
return (2.0 * (action - lo) / denom - 1.0).clamp(-1.0, 1.0)
return 2.0 * (action - lo) / denom - 1.0
raise ValueError(f"Unknown normalization method: {method!r}")


def denormalize_action(
action: torch.Tensor,
method: str,
stats: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Denormalize action tensor."""
if method == "quantile":
q01, q99 = stats["q01"], stats["q99"]
return 0.5 * (action + 1.0) * (q99 - q01) + q01
if method == "meanstd":
return action * stats["std"] + stats["mean"]
if method == "minmax":
lo, hi = stats["min"], stats["max"]
return 0.5 * (action + 1.0) * (hi - lo) + lo
raise ValueError(f"Unknown normalization method: {method!r}")
145 changes: 145 additions & 0 deletions cosmos_framework/data/vfm/action/action_normalization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: OpenMDW-1.1

import json

import numpy as np
import pytest
import torch

from cosmos_framework.data.vfm.action.action_normalization import (
denormalize_action,
load_action_stats,
normalize_action,
)

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

_RAW_STATS = {
"mean": [0.0, 1.0, -1.0],
"std": [1.0, 2.0, 0.5],
"min": [-2.0, -1.0, -3.0],
"max": [2.0, 3.0, 1.0],
"q01": [-1.0, 0.0, -2.0],
"q99": [1.0, 2.0, 0.0],
}


def _tensor_stats(raw=_RAW_STATS) -> dict[str, torch.Tensor]:
return {k: torch.tensor(v, dtype=torch.float32) for k, v in raw.items()}


def _action() -> torch.Tensor:
return torch.tensor([[0.0, 1.0, -1.0], [1.0, 2.0, 0.0]], dtype=torch.float32)


# ---------------------------------------------------------------------------
# load_action_stats
# ---------------------------------------------------------------------------


def test_load_action_stats_flat(tmp_path):
p = tmp_path / "stats.json"
p.write_text(json.dumps(_RAW_STATS))
result = load_action_stats(str(p))
assert set(result) == set(_RAW_STATS)
for key, value in result.items():
assert isinstance(value, np.ndarray)
assert value.dtype == np.float32
np.testing.assert_array_equal(value, np.array(_RAW_STATS[key], dtype=np.float32))


def test_load_action_stats_filters_unknown_keys(tmp_path):
raw = {**_RAW_STATS, "extra_field": [1.0, 2.0]}
p = tmp_path / "stats.json"
p.write_text(json.dumps(raw))
result = load_action_stats(str(p))
assert "extra_field" not in result


def test_load_action_stats_missing_file():
with pytest.raises(FileNotFoundError):
load_action_stats("/nonexistent/path/stats.json")


# ---------------------------------------------------------------------------
# normalize_action / denormalize_action — round-trip identity
# ---------------------------------------------------------------------------


@pytest.mark.parametrize("method", ["quantile", "meanstd", "minmax"])
def test_round_trip(method):
action = _action()
stats = _tensor_stats()
normalized = normalize_action(action, method, stats)
recovered = denormalize_action(normalized, method, stats)
torch.testing.assert_close(recovered, action, atol=1e-5, rtol=1e-5)


# ---------------------------------------------------------------------------
# normalize_action — endpoint correctness
# ---------------------------------------------------------------------------


def test_normalize_quantile_endpoints():
stats = _tensor_stats()
q01, q99 = stats["q01"], stats["q99"]
assert torch.allclose(normalize_action(q01.unsqueeze(0), "quantile", stats), torch.full((1, 3), -1.0))
assert torch.allclose(normalize_action(q99.unsqueeze(0), "quantile", stats), torch.full((1, 3), 1.0))


def test_normalize_minmax_endpoints():
stats = _tensor_stats()
lo, hi = stats["min"], stats["max"]
assert torch.allclose(normalize_action(lo.unsqueeze(0), "minmax", stats), torch.full((1, 3), -1.0))
assert torch.allclose(normalize_action(hi.unsqueeze(0), "minmax", stats), torch.full((1, 3), 1.0))


def test_normalize_meanstd_zero_mean():
stats = _tensor_stats()
result = normalize_action(stats["mean"].unsqueeze(0), "meanstd", stats)
assert torch.allclose(result, torch.zeros(1, 3))


# ---------------------------------------------------------------------------
# denormalize_action — endpoint correctness
# ---------------------------------------------------------------------------


def test_denormalize_quantile_endpoints():
stats = _tensor_stats()
q01, q99 = stats["q01"], stats["q99"]
assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "quantile", stats), q01.unsqueeze(0))
assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "quantile", stats), q99.unsqueeze(0))


def test_denormalize_minmax_endpoints():
stats = _tensor_stats()
lo, hi = stats["min"], stats["max"]
assert torch.allclose(denormalize_action(torch.full((1, 3), -1.0), "minmax", stats), lo.unsqueeze(0))
assert torch.allclose(denormalize_action(torch.full((1, 3), 1.0), "minmax", stats), hi.unsqueeze(0))


# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------


def test_normalize_zero_range_no_nan():
stats = {k: torch.zeros(3) for k in ("q01", "q99", "mean", "std", "min", "max")}
action = torch.ones(1, 3)
for method in ("quantile", "meanstd", "minmax"):
result = normalize_action(action, method, stats)
assert torch.isfinite(result).all(), f"{method} produced non-finite output with zero range"


def test_normalize_unknown_method_raises():
with pytest.raises(ValueError, match="Unknown normalization method"):
normalize_action(_action(), "unknown_method", _tensor_stats())


def test_denormalize_unknown_method_raises():
with pytest.raises(ValueError, match="Unknown normalization method"):
denormalize_action(_action(), "unknown_method", _tensor_stats())
Loading