diff --git a/.github/unittest/linux_libs/scripts_mujoco/environment.yml b/.github/unittest/linux_libs/scripts_mujoco/environment.yml new file mode 100644 index 00000000000..5fc46369a5a --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/environment.yml @@ -0,0 +1,29 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - hypothesis + - future + - cloudpickle + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-json-report + - pytest-error-for-skips + - pytest-asyncio + - expecttest + - pybind11[global] + - pyyaml + - scipy + - hydra-core + - psutil + # MuJoCo physics backends -- pinned for compatibility with mujoco-torch 0.2.0: + # mujoco>=3.8 removed the mjENBL_MULTICCD enum that mujoco-torch 0.2.0 references. + - mujoco==3.7.0 + - mujoco-mjx==3.7.0 + - mujoco-torch==0.2.0 + - jax[cuda12]>=0.7.0,<0.11 diff --git a/.github/unittest/linux_libs/scripts_mujoco/install.sh b/.github/unittest/linux_libs/scripts_mujoco/install.sh new file mode 100755 index 00000000000..6d518268a70 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/install.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu128" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu128 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +pip install git+https://github.com/pytorch/tensordict.git --progress-bar off + +# smoke test +python -c "import functorch;import tensordict" + +printf "* Installing torchrl\n" +python -m pip install -e . --no-build-isolation + +# smoke test +python -c "import torchrl" +python -c "import mujoco; import mujoco.mjx; import mujoco_torch; print('mujoco', mujoco.__version__, 'mujoco-torch ok')" diff --git a/.github/unittest/linux_libs/scripts_mujoco/post_process.sh b/.github/unittest/linux_libs/scripts_mujoco/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_mujoco/run_all.sh b/.github/unittest/linux_libs/scripts_mujoco/run_all.sh new file mode 100755 index 00000000000..550c6f0bcc5 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/run_all.sh @@ -0,0 +1,13 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +apt update +apt install -y libglfw3 libglfw3-dev libglew-dev libgl1-mesa-glx libgl1-mesa-dev mesa-common-dev libegl1-mesa-dev freeglut3 freeglut3-dev + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +bash ${this_dir}/setup_env.sh +bash ${this_dir}/install.sh +PYTHON=./env/bin/python bash "$(git rev-parse --show-toplevel)/.github/unittest/helpers/assert_torch_version.sh" "$TORCH_VERSION" +bash ${this_dir}/run_test.sh +bash ${this_dir}/post_process.sh diff --git a/.github/unittest/linux_libs/scripts_mujoco/run_test.sh b/.github/unittest/linux_libs/scripts_mujoco/run_test.sh new file mode 100755 index 00000000000..1e5e56f005b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/run_test.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +set -euxo pipefail + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False + +# JAX (mjx backend) GPU initialization: mirrors scripts_brax/run_test.sh. +export XLA_PYTHON_CLIENT_PREALLOCATE=false +export XLA_PYTHON_CLIENT_ALLOCATOR=platform +export TF_FORCE_GPU_ALLOW_GROWTH=true +export CUDA_VISIBLE_DEVICES=0 + +# OpenGL backend for mujoco.Renderer (used by the mjx and mujoco backends' +# from_pixels path; the mujoco-torch backend uses its own torch raycaster +# and doesn't need a GL context). EGL works headless on the GPU runner; +# matches scripts_gym/run_all.sh. +export MUJOCO_GL=egl +export PYOPENGL_PLATFORM=egl + +python -m torch.utils.collect_env +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +export MKL_THREADING_LAYER=GNU +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# Lib smoke checks +python -c "import mujoco; import mujoco.mjx; import mujoco_torch; print('mujoco', mujoco.__version__)" +python -c " +import jax +import os +os.environ.setdefault('XLA_PYTHON_CLIENT_PREALLOCATE', 'false') +os.environ.setdefault('XLA_PYTHON_CLIENT_ALLOCATOR', 'platform') +try: + devices = jax.devices() + print(f'JAX devices: {devices}') +except Exception as e: + print(f'JAX init error: {e}; falling back to CPU') + os.environ['JAX_PLATFORM_NAME'] = 'cpu' + jax.config.update('jax_platform_name', 'cpu') +" +python -c 'import torch;t = torch.ones([2,2], device="cuda:0" if torch.cuda.is_available() else "cpu");print(t);print("tensor device:" + str(t.device))' + +# JSON report for flaky test tracking +json_report_dir="${RUNNER_ARTIFACT_DIR:-${root_dir}}" +json_report_args="--json-report --json-report-file=${json_report_dir}/test-results-mujoco.json --json-report-indent=2" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/libs/test_mujoco.py ${json_report_args} --instafail -v --durations 200 --capture no -k TestMujoco --error-for-skips +coverage combine -q +coverage xml -i + +# Upload test results with metadata for flaky tracking +python .github/unittest/helpers/upload_test_results.py || echo "Warning: Failed to process test results for flaky tracking" diff --git a/.github/unittest/linux_libs/scripts_mujoco/setup_env.sh b/.github/unittest/linux_libs/scripts_mujoco/setup_env.sh new file mode 100755 index 00000000000..a03405276ec --- /dev/null +++ b/.github/unittest/linux_libs/scripts_mujoco/setup_env.sh @@ -0,0 +1,63 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -euxo pipefail + +apt-get update && apt-get upgrade -y && apt-get install -y git cmake +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +apt-get install -y wget \ + gcc \ + g++ \ + unzip \ + curl \ + patchelf \ + libosmesa6-dev \ + libgl1-mesa-glx \ + libglfw3 \ + libglew-dev \ + libglvnd0 \ + libgl1 \ + libglx0 \ + libegl1 \ + libgles2 + +apt-get upgrade -y libstdc++6 + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-mujoco.yml b/.github/workflows/test-linux-mujoco.yml new file mode 100644 index 00000000000..4e6955e84a2 --- /dev/null +++ b/.github/workflows/test-linux-mujoco.yml @@ -0,0 +1,68 @@ +name: MuJoCo Custom Envs Tests on Linux + +on: + pull_request: + paths: + - "torchrl/envs/custom/mujoco/**" + - "test/libs/test_mujoco.py" + - ".github/workflows/test-linux-mujoco.yml" + - ".github/unittest/linux_libs/scripts_mujoco/**" + push: + branches: + - nightly + - main + - release/* + paths: + - "torchrl/envs/custom/mujoco/**" + - "test/libs/test_mujoco.py" + - ".github/workflows/test-linux-mujoco.yml" + - ".github/unittest/linux_libs/scripts_mujoco/**" + workflow_dispatch: + workflow_call: + +concurrency: + # Mirror the convention from test-linux-libs.yml: cancel in-progress runs + # for the same ref, but let main runs all complete (so we can pinpoint + # regressions to a specific commit if something breaks). + group: test-linux-mujoco-${{ github.ref == 'refs/heads/main' && format('ci-master-{0}', github.sha) || format('ci-{0}', github.ref) }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + +jobs: + unittests-mujoco: + strategy: + matrix: + python_version: ["3.11"] + cuda_arch_version: ["12.8"] + uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "12.8" + docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04" + timeout: 90 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.11" + export CU_VERSION="12.8" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + export TD_GET_DEFAULTS_TO_NONE=1 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_mujoco/run_all.sh diff --git a/docs/source/reference/envs_api.rst b/docs/source/reference/envs_api.rst index 8485a24d8d5..a00399a60fe 100644 --- a/docs/source/reference/envs_api.rst +++ b/docs/source/reference/envs_api.rst @@ -244,6 +244,36 @@ TorchRL offers a series of custom built-in environments. PendulumEnv TicTacToeEnv +MuJoCo custom environments +-------------------------- + +A family of MuJoCo-backed envs sharing one base class +(:class:`~torchrl.envs.MujocoEnv`) with a swappable physics backend +(``"mujoco-torch"`` -- the default and ``torch.compile``-friendly +native-torch engine, ``"mjx"`` -- JAX-vectorized, or ``"mujoco"`` -- +official C-bindings). The XML asset can be a local path or an +``http(s)`` URL, so users can point at remote models without vendoring +them. Subclasses describe the *task* by overriding +:meth:`~torchrl.envs.MujocoEnv._compute_reward` and +:meth:`~torchrl.envs.MujocoEnv._compute_done`. + +The locomotion classes mirror the Gymnasium ``-v4`` reward and +termination semantics. :class:`~torchrl.envs.SatelliteEnv` is an +attitude-control task with a 4- or 6-CMG cluster and a +manipulability-based singularity penalty driving the policy away from +internal singular configurations of the gimbal Jacobian. + +.. autosummary:: + :toctree: generated/ + :template: rl_template.rst + + MujocoEnv + AntEnv + HopperEnv + HumanoidEnv + SatelliteEnv + Walker2dEnv + Domain-specific --------------- diff --git a/examples/satellite/__init__.py b/examples/satellite/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/satellite/_utils.py b/examples/satellite/_utils.py new file mode 100644 index 00000000000..ab629ad99bf --- /dev/null +++ b/examples/satellite/_utils.py @@ -0,0 +1,948 @@ +"""Shared utilities for the satellite PPO and SAC training scripts. + +Houses everything both scripts need so the entry-points stay short and the +PPO-vs-SAC comparison is apples-to-apples (same env wrapping, same eval +protocol, same metric definitions, same WandB project / group). + +Public surface: + +* :func:`pick_device` -- one-line GPU/MPS/CPU selection. +* :func:`make_train_env` -- vmapped :class:`SatelliteEnv` + transform stack. +* :func:`make_eval_env` -- same stack plus a stateless :class:`TestSetPrimer` + feeding fixed `(init_bus_quat, target_quat)` rows from a CSV. +* :class:`TestSetPrimer` -- thin :class:`TensorDictPrimer` subclass that + emits a deterministic batch of test-set rows on every reset (with safe + tiling when ``num_envs != len(test_set)``). +* :func:`generate_test_set` / :func:`dump_test_set_csv` / + :func:`load_test_set_csv` -- (re)producible eval set on disk. +* :func:`make_actor` / :func:`make_value_critic` / + :func:`make_qvalue_critic` -- shared TanhNormal actor and MLP critics. +* :func:`eval_metrics_fn` -- per-category metrics for the + :class:`Evaluator`. +* :func:`setup_wandb` -- exports the user's *personal* WandB key from + ``~/.tokens`` before importing wandb. + +Run ``python -m examples.satellite._utils generate-test-set`` to dump the CSV. +""" + +from __future__ import annotations + +import argparse +import math +import os +import sys +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +import torch.nn as nn + +from tensordict import NestedKey, TensorDictBase +from tensordict.nn import ( + AddStateIndependentNormalScale, + NormalParamExtractor, + TensorDictModule, +) + +from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded +from torchrl.envs import ( + CatTensors, + ObservationNorm, + RandomTruncationTransform, + RewardScaling, + RewardSum, + StepCounter, + TensorDictPrimer, + TransformedEnv, +) +from torchrl.envs.custom.mujoco import SatelliteEnv +from torchrl.envs.custom.mujoco._math import ( + cmg_jacobian, + manipulability, + pyramid_4cmg_geometry, + quat_conj, + quat_log, + quat_mul, +) +from torchrl.modules import MLP, ProbabilisticActor, TanhNormal, ValueOperator + + +# --------------------------------------------------------------------------- +# Constants -- a single source of truth for layout / file paths. +# --------------------------------------------------------------------------- + +PACKAGE_DIR = Path(__file__).resolve().parent +DEFAULT_TEST_SET_PATH = PACKAGE_DIR / "test_set.csv" +DEFAULT_OBS_NORM_PATH = PACKAGE_DIR / "obs_norm_stats.pt" + +# Observation sub-keys fed to the policy / critic. Manipulability is +# kept *outside* this list so the network never sees it -- it's a +# logging-only signal. +POLICY_OBS_KEYS: list[str] = [ + "quat_err", + "bus_omega", + "gimbal_angles", + "gimbal_rates", +] + +# Eval / test-set categories. Kept here so the metric reporter and the +# generator stay in sync. +CATEGORIES: tuple[str, ...] = ( + "uniform", + "large_err", + "near_singular", + "precision", + "off_axis", +) + + +# --------------------------------------------------------------------------- +# Device selection +# --------------------------------------------------------------------------- + + +def pick_device(prefer: str | None = None) -> torch.device: + """Return the best available device. + + Order: explicit ``prefer`` argument > ``CUDA`` > ``CPU``. + + .. note:: + + We deliberately skip MPS: the ``mujoco-torch`` backend keeps its + physics model in float64 internally, which MPS doesn't support. + The user said "GPU if available, CPU otherwise" -- on Apple + silicon that means CPU. + """ + if prefer is not None: + return torch.device(prefer) + if torch.cuda.is_available(): + return torch.device("cuda") + return torch.device("cpu") + + +# --------------------------------------------------------------------------- +# Test-set generation +# --------------------------------------------------------------------------- + + +def _normalize_quat(q: torch.Tensor) -> torch.Tensor: + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +def _attitude_error_norm(init_q: torch.Tensor, target_q: torch.Tensor) -> torch.Tensor: + """Magnitude of the quaternion log-map between two attitudes (radians).""" + q_err = quat_mul(quat_conj(init_q), target_q) + return quat_log(q_err).norm(dim=-1) + + +def _manip_at_angles(angles: torch.Tensor, num_cmgs: int = 4) -> torch.Tensor: + """Manipulability at given gimbal angles for the 4-CMG pyramid geometry.""" + g, r0 = pyramid_4cmg_geometry(device=angles.device, dtype=angles.dtype) + jac = cmg_jacobian(angles, g, r0, 100.0) + return manipulability(jac) + + +def generate_test_set( + n: int = 256, + seed: int = 42, + num_cmgs: int = 4, + device: torch.device | str = "cpu", +) -> dict[str, torch.Tensor | list[str]]: + """Generate a deterministic test set covering the four hard categories. + + Returns a dict with keys ``init_bus_quat`` (n, 4), ``target_quat`` + (n, 4) and ``category`` (list[str]). Every category gets ``n // 5`` + rows (the residual is padded with ``uniform``). + + The generation is fully seeded by ``seed`` so re-running this + function with the same arguments yields a byte-identical CSV. + """ + if num_cmgs != 4: + # The generator below only knows the 4-CMG pyramid geometry -- + # extending to 6-CMG is straightforward but unnecessary now. + raise NotImplementedError("Test-set generation currently assumes 4 CMGs.") + + g = torch.Generator(device="cpu").manual_seed(seed) + per_cat = n // len(CATEGORIES) + residual = n - per_cat * len(CATEGORIES) + counts = {c: per_cat for c in CATEGORIES} + counts["uniform"] += residual # absorb any leftover + + inits: list[torch.Tensor] = [] + targets: list[torch.Tensor] = [] + cats: list[str] = [] + + # uniform -- pure random pairs; the dynamics-relevant baseline. + n_u = counts["uniform"] + inits.append(_normalize_quat(torch.randn(n_u, 4, generator=g))) + targets.append(_normalize_quat(torch.randn(n_u, 4, generator=g))) + cats.extend(["uniform"] * n_u) + + # large_err -- reject any pair with attitude error < 120 degrees. + n_l = counts["large_err"] + LARGE_THRESHOLD = math.radians(120.0) + accepted_init: list[torch.Tensor] = [] + accepted_target: list[torch.Tensor] = [] + while len(accepted_init) < n_l: + batch = max(64, n_l * 4) + i = _normalize_quat(torch.randn(batch, 4, generator=g)) + t = _normalize_quat(torch.randn(batch, 4, generator=g)) + err = _attitude_error_norm(i, t) + keep = err >= LARGE_THRESHOLD + for j in range(batch): + if keep[j]: + accepted_init.append(i[j]) + accepted_target.append(t[j]) + if len(accepted_init) == n_l: + break + inits.append(torch.stack(accepted_init)) + targets.append(torch.stack(accepted_target)) + cats.extend(["large_err"] * n_l) + + # near_singular -- pick (init, target) pairs whose initial slew + # direction (the quat-log of the relative attitude) aligns with a + # near-null direction of the gimbal Jacobian at the satellite's + # nominal posture. We can't prescribe ``qpos[7:]`` (gimbal angles) + # via the reset API, so the proxy here is: among many random pairs, + # pick the ones where the slew axis is most poorly conditioned for + # the *nominal* CMG configuration, i.e. has the lowest projection + # onto the column span of the Jacobian at gimbal angles=0. The + # bottom-N by that score is the "hardest to slew" family. + n_s = counts["near_singular"] + pool = max(8 * n_s, 4096) + pool_init = _normalize_quat(torch.randn(pool, 4, generator=g)) + pool_target = _normalize_quat(torch.randn(pool, 4, generator=g)) + # Initial slew direction: the quaternion log of the relative + # attitude error. + pool_qerr = quat_mul(quat_conj(pool_init), pool_target) + pool_log = quat_log(pool_qerr) # (pool, 3) + # Jacobian at nominal (zero) gimbal angles; shape (3, 4). We + # measure how poorly the slew direction aligns with the J column + # span by 1 - (||J^T s|| / ||s|| / sigma_max), where s is the slew + # direction. A small alignment score => the satellite has to push + # mostly through the null space => harder maneuver. + g_axes, r0 = pyramid_4cmg_geometry(device="cpu", dtype=torch.float32) + jac = cmg_jacobian(torch.zeros(1, 4), g_axes, r0, 100.0).squeeze(0) # (3, 4) + sigma_max = torch.linalg.svdvals(jac).max() + s_unit = pool_log / pool_log.norm(dim=-1, keepdim=True).clamp_min(1e-6) + alignment = (jac.T @ s_unit.T).norm(dim=0) / sigma_max # (pool,) + worst_idx = torch.topk(-alignment, k=n_s).indices + inits.append(pool_init.index_select(0, worst_idx)) + targets.append(pool_target.index_select(0, worst_idx)) + cats.extend(["near_singular"] * n_s) + + # precision -- pairs with moderate attitude error so success at the + # 0.05 rad threshold is genuinely informative rather than trivial. + n_p = counts["precision"] + accepted_init = [] + accepted_target = [] + LO, HI = math.radians(30.0), math.radians(90.0) + while len(accepted_init) < n_p: + batch = max(64, n_p * 4) + i = _normalize_quat(torch.randn(batch, 4, generator=g)) + t = _normalize_quat(torch.randn(batch, 4, generator=g)) + err = _attitude_error_norm(i, t) + keep = (err >= LO) & (err <= HI) + for j in range(batch): + if keep[j]: + accepted_init.append(i[j]) + accepted_target.append(t[j]) + if len(accepted_init) == n_p: + break + inits.append(torch.stack(accepted_init)) + targets.append(torch.stack(accepted_target)) + cats.extend(["precision"] * n_p) + + # off_axis -- targets oriented along non-principal axes. We sample + # a random axis with all components > 0.4 in absolute value and a + # random angle, then build the quaternion (cos a/2, sin a/2 * axis). + n_o = counts["off_axis"] + accepted_init = [] + accepted_target = [] + while len(accepted_init) < n_o: + ax = torch.randn(1, 3, generator=g) + if (ax.abs() < 0.4).any(): + continue + ax = ax / ax.norm(dim=-1, keepdim=True).clamp_min(1e-12) + angle = torch.empty(1).uniform_( + math.radians(60.0), math.radians(160.0), generator=g + ) + half = 0.5 * angle + target_q = torch.cat( + [half.cos().unsqueeze(0), half.sin().unsqueeze(0) * ax], dim=-1 + ).squeeze(0) + init_q = _normalize_quat(torch.randn(1, 4, generator=g)).squeeze(0) + accepted_init.append(init_q) + accepted_target.append(_normalize_quat(target_q.unsqueeze(0)).squeeze(0)) + inits.append(torch.stack(accepted_init)) + targets.append(torch.stack(accepted_target)) + cats.extend(["off_axis"] * n_o) + + init_t = torch.cat(inits, dim=0).to(device).float() + target_t = torch.cat(targets, dim=0).to(device).float() + assert init_t.shape == (n, 4) and target_t.shape == (n, 4) + return {"init_bus_quat": init_t, "target_quat": target_t, "category": cats} + + +def dump_test_set_csv(test_set: dict[str, Any], path: str | Path) -> None: + """Write the test set to a CSV (12 floats + 1 category string per row).""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + init = test_set["init_bus_quat"] + target = test_set["target_quat"] + cats = test_set["category"] + n = init.shape[0] + cols = ( + ["init_w", "init_x", "init_y", "init_z"] + + ["target_w", "target_x", "target_y", "target_z"] + + ["category"] + ) + with path.open("w") as f: + f.write(",".join(cols) + "\n") + for k in range(n): + row = ( + [f"{init[k, j].item():.10f}" for j in range(4)] + + [f"{target[k, j].item():.10f}" for j in range(4)] + + [cats[k]] + ) + f.write(",".join(row) + "\n") + + +def load_test_set_csv( + path: str | Path, +) -> tuple[torch.Tensor, torch.Tensor, list[str]]: + """Read the CSV back into ``(init_bus_quat, target_quat, categories)``.""" + path = Path(path) + init_rows: list[list[float]] = [] + target_rows: list[list[float]] = [] + cats: list[str] = [] + with path.open() as f: + header = f.readline().strip().split(",") # noqa: F841 -- documented schema + for line in f: + parts = line.strip().split(",") + init_rows.append([float(x) for x in parts[:4]]) + target_rows.append([float(x) for x in parts[4:8]]) + cats.append(parts[8]) + init = torch.tensor(init_rows, dtype=torch.float32) + target = torch.tensor(target_rows, dtype=torch.float32) + return init, target, cats + + +# --------------------------------------------------------------------------- +# Stateless test-set primer +# --------------------------------------------------------------------------- + + +class TestSetPrimer(TensorDictPrimer): + """Inject fixed ``(init_bus_quat, target_quat)`` pairs at every reset. + + Stateless and vectorized: the primer holds the full test set and + deterministically slices / tiles it to fill a reset of any size. + Designed to pair with the modified :class:`SatelliteEnv` whose + :meth:`_reset` reads these keys from the input tensordict. + + With ``num_envs == len(test_set)`` and full resets only (the standard + eval flow), each env always receives the same row -- so two + consecutive eval rollouts replay the exact same starts. + """ + + def __init__( + self, + init_bus_quat: torch.Tensor, + target_quat: torch.Tensor, + *, + device: torch.device | str = "cpu", + ) -> None: + n = init_bus_quat.shape[0] + if target_quat.shape[0] != n: + raise ValueError( + f"init/target length mismatch: {init_bus_quat.shape[0]} vs " + f"{target_quat.shape[0]}" + ) + device = torch.device(device) + self._init = init_bus_quat.to(device) + self._target = target_quat.to(device) + self._n = n + + # Specs declare the keys to the framework. ``random=False`` and + # ``default_value=callable`` triggers the per-reset closure + # below, which sees the reset mask shape and returns matching + # rows. + primer_specs = Composite( + init_bus_quat=Unbounded(shape=(4,), dtype=torch.float32, device=device), + target_quat=Unbounded(shape=(4,), dtype=torch.float32, device=device), + ) + super().__init__( + primers=primer_specs, + random=False, + default_value=self._sample, + single_default_value=True, + expand_specs=True, + ) + + def _sample( + self, reset: torch.Tensor | None = None, **_: Any + ) -> dict[str, torch.Tensor]: + # ``reset`` is the bool mask, shape ``(B,)`` or ``(B, 1)``; + # may be ``None`` when called from a full-batch reset path. + # Strategy: generate ``arange(B) % n`` over the *full* batch + # then sub-select by the mask. This keeps row k always paired + # with env k for the standard ``B == n`` setup, and gracefully + # tiles / truncates otherwise. + if reset is None: + B = self._n + sel = torch.arange(B, device=self._init.device) % self._n + else: + mask = reset + if mask.ndim > 1: + mask = mask.squeeze(-1) + B = mask.shape[0] + full = torch.arange(B, device=self._init.device) % self._n + sel = full[mask] + return { + "init_bus_quat": self._init.index_select(0, sel), + "target_quat": self._target.index_select(0, sel), + } + + +# --------------------------------------------------------------------------- +# Env factories +# --------------------------------------------------------------------------- + + +def _attach_obs_norm( + env: TransformedEnv, + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None, + device: torch.device, +) -> ObservationNorm: + """Append an :class:`ObservationNorm`, optionally pre-loaded with stats. + + Returns the transform instance so the caller can ``init_stats`` on + it later when ``obs_norm_stats is None``. + """ + if obs_norm_stats is not None: + loc, scale = obs_norm_stats + norm = ObservationNorm( + loc=loc.to(device), + scale=scale.to(device), + in_keys=["observation"], + standard_normal=True, + ) + else: + # Lazy: stats filled in later via ``init_stats``. + norm = ObservationNorm( + in_keys=["observation"], + standard_normal=True, + ) + env.append_transform(norm) + return norm + + +def make_train_env( + *, + num_envs: int, + device: torch.device, + max_steps: int = 1500, + min_random_horizon: int | None = None, + random_horizon_prob: float = 0.0, + compile_step: bool = False, + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None = None, + use_obs_norm: bool = True, + num_cmgs: int = 4, + action_scale: float = 3.0, + singularity_weight: float = 0.5, + singularity_clamp_min: float = 1e-6, + singularity_mode: str = "inverse", + singularity_exp_k: float = 5.0, + omega_weight: float = 0.1, + ctrl_cost_weight: float = 0.01, + frame_skip: int | None = None, + reward_scale: float = 1.0, + seed: int | None = None, +) -> tuple[TransformedEnv, ObservationNorm | None]: + """Vmapped training env. Returns ``(env, observation_norm_transform)``. + + When ``use_obs_norm=False`` the :class:`ObservationNorm` transform + is omitted entirely and the second tuple element is ``None``. The + raw observation is in physically meaningful units already + (``quat_err`` in radians, ``bus_omega`` in rad/s, gimbal sin/cos in + [-1, 1], gimbal_rates in rad/s) so a network with reasonable init + can train on them directly. + + When ``use_obs_norm=True``, the caller is expected to either: + + * pass ``obs_norm_stats=(loc, scale)`` so eval / training share stats; or + * leave ``obs_norm_stats=None``, then call + ``observation_norm_transform.init_stats(num_iter=N, ...)`` once. + """ + base = SatelliteEnv( + num_cmgs=num_cmgs, + num_envs=num_envs, + backend="mujoco-torch", + device=device, + seed=seed, + max_episode_steps=max_steps, + compile_step=compile_step, + compile_kwargs={"dynamic": False} if compile_step else None, + action_scale=action_scale, + singularity_weight=singularity_weight, + singularity_clamp_min=singularity_clamp_min, + singularity_mode=singularity_mode, + singularity_exp_k=singularity_exp_k, + omega_weight=omega_weight, + ctrl_cost_weight=ctrl_cost_weight, + frame_skip=frame_skip, + ) + env = TransformedEnv(base) + # 1) Pack the dynamics-relevant sub-keys into a single ``observation`` + # tensor. Keep ``manipulability`` outside (logging-only). + env.append_transform( + CatTensors( + in_keys=POLICY_OBS_KEYS, + out_key="observation", + dim=-1, + del_keys=False, + sort=False, + ) + ) + # 2) Per-dim normalization, shared across PPO / SAC / eval. Skipped + # when ``use_obs_norm=False`` -- the network sees raw observations. + obs_norm: ObservationNorm | None = None + if use_obs_norm: + obs_norm = _attach_obs_norm(env, obs_norm_stats, device) + # 3) Episodic step counter (logging) + episode-return aggregator + # (consumed by the Evaluator). + env.append_transform(StepCounter(max_steps=max_steps)) + if min_random_horizon is not None: + env.append_transform( + RandomTruncationTransform( + min_horizon=min_random_horizon, + max_horizon=max_steps, + prob=random_horizon_prob, + ) + ) + if reward_scale != 1.0: + # ``reward = reward * scale + 0`` -- scale rewards before + # downstream consumers (RewardSum, replay buffer, loss). With + # raw reward in [-3.5, 0] and scale=1/3 the post-scale range + # is [-1.17, 0], much friendlier for Q-target regression. + env.append_transform( + RewardScaling(loc=0.0, scale=reward_scale, in_keys=["reward"]) + ) + env.append_transform(RewardSum()) + return env, obs_norm + + +def make_eval_env( + *, + device: torch.device, + test_set_csv: str | Path = DEFAULT_TEST_SET_PATH, + max_steps: int = 1500, + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None, + use_obs_norm: bool = True, + num_cmgs: int = 4, + action_scale: float = 3.0, + singularity_weight: float = 0.5, + singularity_clamp_min: float = 1e-6, + singularity_mode: str = "inverse", + singularity_exp_k: float = 5.0, + omega_weight: float = 0.1, + ctrl_cost_weight: float = 0.01, + frame_skip: int | None = None, + compile_step: bool = False, + reward_scale: float = 1.0, + seed: int = 0, +) -> TransformedEnv: + """Eval env: ``num_envs == len(test_set)`` so one rollout = full eval. + + The :class:`TestSetPrimer` injects matching ``init_bus_quat`` / + ``target_quat`` rows on every reset so the eval is byte-stable across + iterations. The ``action_scale`` / ``singularity_weight`` / + ``use_obs_norm`` settings must match the training env so eval + rewards are directly comparable. + """ + init_q, target_q, _cats = load_test_set_csv(test_set_csv) + n = init_q.shape[0] + base = SatelliteEnv( + num_cmgs=num_cmgs, + num_envs=n, + backend="mujoco-torch", + device=device, + seed=seed, + max_episode_steps=max_steps, + action_scale=action_scale, + singularity_weight=singularity_weight, + singularity_clamp_min=singularity_clamp_min, + singularity_mode=singularity_mode, + singularity_exp_k=singularity_exp_k, + omega_weight=omega_weight, + ctrl_cost_weight=ctrl_cost_weight, + frame_skip=frame_skip, + compile_step=compile_step, + compile_kwargs={"dynamic": False} if compile_step else None, + ) + env = TransformedEnv(base) + env.append_transform(TestSetPrimer(init_q, target_q, device=device)) + env.append_transform( + CatTensors( + in_keys=POLICY_OBS_KEYS, + out_key="observation", + dim=-1, + del_keys=False, + sort=False, + ) + ) + if use_obs_norm: + if obs_norm_stats is None: + raise ValueError( + "make_eval_env(use_obs_norm=True) requires obs_norm_stats " + "so the eval env shares normalization with training. " + "Either compute stats once and save them, or load from " + "disk -- or pass use_obs_norm=False to skip the transform." + ) + _attach_obs_norm(env, obs_norm_stats, device) + env.append_transform(StepCounter(max_steps=max_steps)) + if reward_scale != 1.0: + env.append_transform( + RewardScaling(loc=0.0, scale=reward_scale, in_keys=["reward"]) + ) + env.append_transform(RewardSum()) + return env + + +# --------------------------------------------------------------------------- +# Networks +# --------------------------------------------------------------------------- + + +def _obs_dim(obs_spec: TensorSpec) -> int: + """Resolve the policy-input dimension from the (possibly composite) spec.""" + if isinstance(obs_spec, Composite): + return obs_spec["observation"].shape[-1] + return obs_spec.shape[-1] + + +_ACTIVATIONS: dict[str, type[nn.Module]] = { + "relu": nn.ReLU, + "tanh": nn.Tanh, + "elu": nn.ELU, + "gelu": nn.GELU, + "silu": nn.SiLU, +} + + +def _activation_class(name: str | type[nn.Module]) -> type[nn.Module]: + if isinstance(name, str): + try: + return _ACTIVATIONS[name.lower()] + except KeyError as e: + raise ValueError( + f"Unknown activation {name!r}; valid: {sorted(_ACTIVATIONS)}." + ) from e + return name + + +def _small_gain_last_linear(module: nn.Module, gain: float = 0.01) -> None: + """Re-init the *last* :class:`nn.Linear` in ``module`` with + orthogonal weights at ``gain`` and zero bias. + + Standard policy-gradient init from Mnih et al. and the PPO paper: + keeps the output near zero so the policy starts as a near-uniform + sample (loc = 0, scale = whatever the param extractor emits) and + avoids the "confidently wrong" behaviour from default Kaiming init + on a 256-wide last layer. + """ + last_linear = None + for sub in module.modules(): + if isinstance(sub, nn.Linear): + last_linear = sub + if last_linear is None: + raise RuntimeError("No nn.Linear found inside module to re-init.") + nn.init.orthogonal_(last_linear.weight, gain=gain) + if last_linear.bias is not None: + nn.init.zeros_(last_linear.bias) + + +def make_actor( + *, + obs_spec: TensorSpec, + action_spec: TensorSpec, + device: torch.device, + hidden: tuple[int, ...] = (256, 256), + activation: str | type[nn.Module] = "relu", + state_independent_scale: bool = False, + layer_norm: bool = False, + small_init_last_layer: bool = False, + scale_init: float = 1.0, +) -> ProbabilisticActor: + """TanhNormal actor. + + ``state_independent_scale=True`` matches the canonical PPO setup + (scale is a learned bias). ``False`` matches SAC (head outputs both + loc and scale). ``activation`` is the hidden-layer non-linearity; + string aliases (``"relu"``, ``"tanh"``, ``"elu"``, ``"gelu"``, + ``"silu"``) or a class are accepted. + """ + in_dim = _obs_dim(obs_spec) + action_dim = action_spec.shape[-1] + act_cls = _activation_class(activation) + norm_kwargs = ( + { + "norm_class": nn.LayerNorm, + "norm_kwargs": [{"normalized_shape": h} for h in hidden], + } + if layer_norm + else {} + ) + + if state_independent_scale: + body = MLP( + in_features=in_dim, + num_cells=list(hidden), + out_features=action_dim, + activation_class=act_cls, + device=device, + **norm_kwargs, + ) + mlp = nn.Sequential( + body, + AddStateIndependentNormalScale(action_dim, scale_lb=1e-4).to(device), + ) + else: + body = MLP( + in_features=in_dim, + num_cells=list(hidden), + out_features=2 * action_dim, + activation_class=act_cls, + device=device, + **norm_kwargs, + ) + mlp = nn.Sequential( + body, + NormalParamExtractor( + scale_mapping=f"biased_softplus_{scale_init:.4f}", + scale_lb=1e-4, + ).to(device), + ) + + if small_init_last_layer: + _small_gain_last_linear(body, gain=0.01) + + td_module = TensorDictModule( + mlp, in_keys=["observation"], out_keys=["loc", "scale"] + ) + # Pass shape-``[action_dim]`` bounds (not batched per env): vmapped + # specs make ``space.low/high`` shape ``[num_envs, action_dim]``, which + # breaks broadcasting inside ``TanhNormal.update`` under ``torch.compile`` + # (the loc batch dim differs from num_envs after replay sampling). + low_b = action_spec.space.low + high_b = action_spec.space.high + while low_b.ndim > 1: + low_b = low_b[0] + while high_b.ndim > 1: + high_b = high_b[0] + actor = ProbabilisticActor( + module=td_module, + in_keys=["loc", "scale"], + spec=action_spec, + distribution_class=TanhNormal, + distribution_kwargs={ + "low": low_b, + "high": high_b, + "tanh_loc": True, + }, + return_log_prob=True, + ) + return actor + + +def make_value_critic( + *, + obs_spec: TensorSpec, + device: torch.device, + hidden: tuple[int, ...] = (256, 256), + activation: str | type[nn.Module] = "tanh", +) -> ValueOperator: + in_dim = _obs_dim(obs_spec) + net = MLP( + in_features=in_dim, + num_cells=list(hidden), + out_features=1, + activation_class=_activation_class(activation), + device=device, + ) + return ValueOperator(net, in_keys=["observation"]) + + +def make_qvalue_critic( + *, + obs_spec: TensorSpec, + action_spec: TensorSpec, + device: torch.device, + hidden: tuple[int, ...] = (256, 256), + activation: str | type[nn.Module] = "relu", + layer_norm: bool = False, + small_init_last_layer: bool = False, +) -> ValueOperator: + in_dim = _obs_dim(obs_spec) + action_spec.shape[-1] + norm_kwargs = ( + { + "norm_class": nn.LayerNorm, + "norm_kwargs": [{"normalized_shape": h} for h in hidden], + } + if layer_norm + else {} + ) + net = MLP( + in_features=in_dim, + num_cells=list(hidden), + out_features=1, + activation_class=_activation_class(activation), + device=device, + **norm_kwargs, + ) + if small_init_last_layer: + _small_gain_last_linear(net, gain=0.01) + return ValueOperator(net, in_keys=["observation", "action"]) + + +# --------------------------------------------------------------------------- +# Eval metrics +# --------------------------------------------------------------------------- + + +def make_eval_metrics_fn( + categories: list[str], +) -> Callable[[TensorDictBase], dict[str, float]]: + """Return a metrics function that breaks down by ``categories``. + + ``categories`` must be the per-row category list from the same CSV + the eval primer uses, in the same order. The eval env is built with + ``num_envs == len(categories)`` so positions line up. + """ + cat_tensor_idx: dict[str, list[int]] = {} + for i, c in enumerate(categories): + cat_tensor_idx.setdefault(c, []).append(i) + + def _flatten(td: TensorDictBase, key: NestedKey) -> torch.Tensor: + # Eval rollouts come back with shape ``(B, T, ...)``; reduce over T. + return td.get(key) + + def fn(td: TensorDictBase) -> dict[str, float]: + # Episode return: collected by ``RewardSum`` under ``("next", + # "episode_reward")``. We take the *last* value per env. + ep_ret = _flatten(td, ("next", "episode_reward"))[..., -1, :] + # Final attitude error (radians) at the end of the rollout. + quat_err = _flatten(td, ("next", "quat_err")) # (B, T, 3) + final_err = quat_err[..., -1, :].norm(dim=-1) + # Manipulability per step: shape (B, T, 1). + manip = _flatten(td, ("next", "manipulability")).squeeze(-1) + min_manip = manip.min(dim=-1).values + mean_manip = manip.mean(dim=-1) + sum_inv_manip = (1.0 / manip.clamp_min(1e-6)).sum(dim=-1) + + out: dict[str, float] = { + "eval/return": ep_ret.mean().item(), + "eval/final_attitude_error_rad": final_err.mean().item(), + "eval/success_rate@0.10": (final_err <= 0.10).float().mean().item(), + "eval/success_rate@0.05": (final_err <= 0.05).float().mean().item(), + "eval/min_manipulability": min_manip.mean().item(), + "eval/mean_manipulability": mean_manip.mean().item(), + "eval/sum_inv_manipulability": sum_inv_manip.mean().item(), + } + for cat, idx_list in cat_tensor_idx.items(): + if not idx_list: + continue + idx = torch.tensor(idx_list, device=ep_ret.device) + out[f"eval/{cat}/return"] = ep_ret.index_select(0, idx).mean().item() + out[f"eval/{cat}/final_err"] = final_err.index_select(0, idx).mean().item() + out[f"eval/{cat}/success@0.10"] = ( + (final_err.index_select(0, idx) <= 0.10).float().mean().item() + ) + out[f"eval/{cat}/min_manipulability"] = ( + min_manip.index_select(0, idx).mean().item() + ) + return out + + return fn + + +# --------------------------------------------------------------------------- +# WandB key bootstrap +# --------------------------------------------------------------------------- + + +def setup_wandb_key(tokens_path: str | Path = "~/.tokens") -> None: + """Export the user's *personal* WandB key from ``~/.tokens``. + + The shared file mixes a Periodic key and a personal key on adjacent + lines; we look specifically for ``Personal: WANDB_API_KEY=...``. + Sets ``WANDB_API_KEY`` in ``os.environ``; raises if the personal key + is not found (per the user's hard rule: never the Periodic key). + + If ``WANDB_API_KEY`` is already set in the environment (e.g. a remote + run where ``~/.tokens`` is not present), this is a no-op. + """ + if os.environ.get("WANDB_API_KEY"): + return + path = Path(tokens_path).expanduser() + if not path.exists(): + raise FileNotFoundError( + f"No tokens file at {path}; cannot find personal WANDB key." + ) + personal_key: str | None = None + for line in path.read_text().splitlines(): + s = line.strip() + if s.startswith("Personal:") and "WANDB_API_KEY=" in s: + personal_key = s.split("WANDB_API_KEY=", 1)[1].strip() + break + if not personal_key: + raise RuntimeError( + "Could not find a 'Personal: WANDB_API_KEY=...' line in " + f"{path}. Refusing to fall back to a non-personal key." + ) + os.environ["WANDB_API_KEY"] = personal_key + + +# --------------------------------------------------------------------------- +# CLI: dump the test set CSV +# --------------------------------------------------------------------------- + + +def _cli_generate_test_set(argv: list[str]) -> int: + p = argparse.ArgumentParser(prog="examples.satellite._utils generate-test-set") + p.add_argument("--out", default=str(DEFAULT_TEST_SET_PATH)) + p.add_argument("--n", type=int, default=256) + p.add_argument("--seed", type=int, default=42) + p.add_argument("--num-cmgs", type=int, default=4) + args = p.parse_args(argv) + test_set = generate_test_set(n=args.n, seed=args.seed, num_cmgs=args.num_cmgs) + dump_test_set_csv(test_set, args.out) + cats = test_set["category"] + counts = {c: cats.count(c) for c in CATEGORIES} + print(f"Wrote {args.n} rows to {args.out}") + print("Per-category counts:", counts) + err = _attitude_error_norm(test_set["init_bus_quat"], test_set["target_quat"]) + print(f"Mean attitude error (rad): {err.mean().item():.3f}") + print(f"Max attitude error (rad): {err.max().item():.3f}") + return 0 + + +def main(argv: list[str] | None = None) -> int: + argv = list(sys.argv[1:] if argv is None else argv) + if not argv: + print( + "usage: python -m examples.satellite._utils generate-test-set [...]", + file=sys.stderr, + ) + return 2 + cmd, rest = argv[0], argv[1:] + if cmd == "generate-test-set": + return _cli_generate_test_set(rest) + print(f"unknown command: {cmd}", file=sys.stderr) + return 2 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/satellite/sac_per.py b/examples/satellite/sac_per.py new file mode 100644 index 00000000000..c862967db7e --- /dev/null +++ b/examples/satellite/sac_per.py @@ -0,0 +1,1036 @@ +"""SAC + Prioritized Replay Buffer training for the satellite task. + +Same env / eval / logging conventions as ``examples/satellite/ppo.py`` so the +two are directly comparable on the CSV-backed test set. +""" + +from __future__ import annotations + +import argparse +import sys +import time +from functools import partial +from pathlib import Path + +import torch + +from torchrl._utils import logger as torchrl_logger +from torchrl.collectors import Collector, Evaluator, MultiCollector +from torchrl.data import ( + LazyTensorStorage, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.data.postprocs import MultiStep +from torchrl.envs.utils import ExplorationType +from torchrl.objectives import SACLoss, SoftUpdate +from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.weight_update import MultiProcessWeightSyncScheme + +PACKAGE_DIR = Path(__file__).resolve().parent +if str(PACKAGE_DIR.parent.parent) not in sys.path: + sys.path.insert(0, str(PACKAGE_DIR.parent.parent)) + +from examples.satellite._utils import ( # noqa: E402 + DEFAULT_OBS_NORM_PATH, + DEFAULT_TEST_SET_PATH, + load_test_set_csv, + make_actor, + make_eval_env, + make_eval_metrics_fn, + make_qvalue_critic, + make_train_env, + pick_device, + setup_wandb_key, +) + + +# Module-level eval-process factories. The ``backend="process"`` +# Evaluator pickles these and re-constructs the env + policy inside a +# child process, so they must be importable top-level callables (not +# closures over local state). + + +def _train_env_factory( + *, + num_envs: int, + device_str: str, + max_steps: int, + min_random_horizon: int | None, + random_horizon_prob: float, + compile_step: bool, + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None, + use_obs_norm: bool, + num_cmgs: int, + action_scale: float, + singularity_weight: float, + singularity_clamp_min: float, + singularity_mode: str, + singularity_exp_k: float, + omega_weight: float, + ctrl_cost_weight: float, + frame_skip: int, + reward_scale: float = 1.0, + seed: int | None = None, +): + """Picklable training-env factory for ``aSyncDataCollector``.""" + env, _ = make_train_env( + num_envs=num_envs, + device=torch.device(device_str), + max_steps=max_steps, + min_random_horizon=min_random_horizon, + random_horizon_prob=random_horizon_prob, + compile_step=compile_step, + obs_norm_stats=obs_norm_stats, + use_obs_norm=use_obs_norm, + num_cmgs=num_cmgs, + action_scale=action_scale, + singularity_weight=singularity_weight, + singularity_clamp_min=singularity_clamp_min, + singularity_mode=singularity_mode, + singularity_exp_k=singularity_exp_k, + omega_weight=omega_weight, + ctrl_cost_weight=ctrl_cost_weight, + frame_skip=frame_skip, + reward_scale=reward_scale, + seed=seed, + ) + return env + + +def _eval_env_factory( + *, + device_str: str, + test_set_csv: str, + max_steps: int, + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None, + use_obs_norm: bool, + num_cmgs: int, + action_scale: float, + singularity_weight: float, + singularity_clamp_min: float, + singularity_mode: str, + singularity_exp_k: float, + omega_weight: float, + ctrl_cost_weight: float, + frame_skip: int, + compile_step: bool, + reward_scale: float = 1.0, +): + return make_eval_env( + device=torch.device(device_str), + test_set_csv=test_set_csv, + max_steps=max_steps, + obs_norm_stats=obs_norm_stats, + use_obs_norm=use_obs_norm, + num_cmgs=num_cmgs, + action_scale=action_scale, + singularity_weight=singularity_weight, + singularity_clamp_min=singularity_clamp_min, + singularity_mode=singularity_mode, + singularity_exp_k=singularity_exp_k, + omega_weight=omega_weight, + ctrl_cost_weight=ctrl_cost_weight, + frame_skip=frame_skip, + compile_step=compile_step, + reward_scale=reward_scale, + ) + + +def _eval_policy_factory( + env=None, + *, + obs_spec=None, + action_spec=None, + hidden: tuple[int, ...], + activation: str, + device_str: str, +): + """Build an actor identical to the trainer's. + + The ``MultiProcessWeightSyncScheme`` calls this factory **without** + args on the sender side (just to enumerate parameter shapes), and + again **with** an env on the receiver side. We accept either by + falling back to specs passed via :func:`functools.partial`. + """ + if env is not None: + obs_spec = env.observation_spec + action_spec = env.action_spec + if obs_spec is None or action_spec is None: + raise RuntimeError( + "_eval_policy_factory needs either env or pre-bound " + "(obs_spec, action_spec) via partial." + ) + return make_actor( + obs_spec=obs_spec, + action_spec=action_spec, + device=torch.device(device_str), + hidden=hidden, + activation=activation, + state_independent_scale=False, + ) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--num-envs", type=int, default=32_768) + p.add_argument("--num-cmgs", type=int, default=4) + p.add_argument("--max-steps", type=int, default=300) + p.add_argument( + "--frame-skip", + type=int, + default=50, + help=( + "Number of physics sub-steps per agent step. dt=0.001s, so " + "frame_skip=50 means each step is 50ms of sim time. Larger " + "frame skip = bigger per-step dynamics = stronger Q-action " + "signal." + ), + ) + p.add_argument( + "--min-random-horizon", + type=int, + default=None, + help=( + "If set, decorrelate vectorized env phases by sampling train " + "episode truncation horizons in [min_random_horizon, max_steps]." + ), + ) + p.add_argument( + "--random-horizon-prob", + type=float, + default=0.0, + help="Probability of resampling a shortened horizon after first reset.", + ) + p.add_argument( + "--frames-per-env", + type=int, + default=8, + help=( + "Env-steps per collection iteration. Higher = fewer " + "Python collector cycles per env transition; lower = more " + "frequent grad-step updates with fresher data." + ), + ) + p.add_argument("--total-iters", type=int, default=3000) + # Replay buffer + p.add_argument("--buffer-size", type=int, default=1_000_000) + p.add_argument("--batch-size", type=int, default=8_192) + p.add_argument("--prb-alpha", type=float, default=0.7) + p.add_argument("--prb-beta", type=float, default=0.5) + p.add_argument( + "--no-prb", + action="store_true", + help=( + "Use a uniform replay buffer instead of prioritized. " + "Disables priority updates (no td_error key required)." + ), + ) + p.add_argument( + "--lr-decay-end-frac", + type=float, + default=None, + help=( + "If set (e.g. 0.1), cosine-anneal actor/critic/alpha LR " + "from --lr to lr*frac over the full --total-iters span." + ), + ) + p.add_argument("--gradient-steps", type=int, default=8) + p.add_argument( + "--init-random-frames-per-env", + type=int, + default=4, + help="Random-action warm-up steps per env before SAC updates start.", + ) + # Optim / SAC + p.add_argument("--lr", type=float, default=3e-4) + p.add_argument( + "--critic-lr", + type=float, + default=None, + help="Optional critic learning-rate override. Defaults to --lr.", + ) + p.add_argument("--gamma", type=float, default=0.99) + p.add_argument("--alpha-init", type=float, default=1.0) + p.add_argument( + "--fixed-alpha", + action="store_true", + help="Keep SAC temperature fixed at --alpha-init.", + ) + p.add_argument("--min-alpha", type=float, default=None) + p.add_argument("--max-alpha", type=float, default=None) + p.add_argument("--target-update-polyak", type=float, default=0.995) + p.add_argument("--max-grad-norm", type=float, default=0.0) + p.add_argument( + "--action-scale", + type=float, + default=3.0, + help="Scaling from agent action [-1, 1] to commanded gimbal rate (rad/s).", + ) + p.add_argument( + "--singularity-weight", + type=float, + default=0.5, + help="Weight on -1/manip_norm in the reward (rotor-speed-invariant).", + ) + p.add_argument( + "--singularity-clamp-min", + type=float, + default=1e-6, + help=( + "Floor on manip_norm before division (only used when " + "--singularity-mode=inverse). Higher values bound the worst-case " + "spike: max penalty = singularity_weight / singularity_clamp_min." + ), + ) + p.add_argument( + "--singularity-mode", + type=str, + default="inverse", + choices=["inverse", "exp"], + help=( + "Singularity penalty form. 'inverse' (default): -w/manip_norm " + "(unbounded near singularity, controllable via --singularity-clamp-min). " + "'exp': -w*exp(-k*manip_norm), bounded at -w." + ), + ) + p.add_argument( + "--singularity-exp-k", + type=float, + default=5.0, + help="Curvature for --singularity-mode=exp. Larger = steeper falloff.", + ) + p.add_argument( + "--omega-weight", + type=float, + default=0.1, + help="Weight on -||bus_omega||^2 in the reward (slew-and-stop incentive).", + ) + p.add_argument( + "--ctrl-cost-weight", + type=float, + default=0.01, + help="Weight on -||action||^2. Set to 0 to remove the control penalty.", + ) + p.add_argument( + "--reward-scale", + type=float, + default=1.0, + help=( + "Multiplicative scaling on the per-step reward (applied as " + "a transform). Brings raw reward in [-3.5, 0] into a Q-friendly " + "range. 1.0=no change; 0.333 maps to roughly [-1, 0]." + ), + ) + p.add_argument( + "--hidden", + type=int, + nargs="+", + default=[256, 256], + help="Hidden-layer sizes for both actor and critic MLPs.", + ) + p.add_argument( + "--activation", + type=str, + default="relu", + choices=["relu", "tanh", "elu", "gelu", "silu"], + help="Activation for both actor and critic hidden layers.", + ) + p.add_argument( + "--layer-norm", + action="store_true", + help="Add LayerNorm in actor + critic MLP hidden layers.", + ) + p.add_argument( + "--small-init-last-layer", + action="store_true", + help=( + "Orthogonal-init the last Linear of the actor and critic " + "MLPs with gain=0.01 (zero bias). Initial actor mean is 0, " + "initial Q is ~0 -- avoids 'confidently wrong' default init." + ), + ) + p.add_argument( + "--scale-init", + type=float, + default=1.0, + help=( + "Initial value of the actor's TanhNormal scale at " + "zero-input (via biased_softplus). 1.0 saturates samples " + "near +/-1 (heavy exploration); 0.31 gives moderate " + "samples in [-0.3, 0.3]; 0.5 is a middle ground." + ), + ) + p.add_argument( + "--n-step", + type=int, + default=1, + help=( + "n in n-step returns. >1 wraps the collector with a " + "MultiStep postproc (gamma already taken from --gamma) so " + "Q learns from n-step rewards." + ), + ) + p.add_argument("--seed", type=int, default=0) + p.add_argument("--device", default=None) + p.add_argument( + "--eval-device", + default=None, + help=( + "Device for the eval-process env+policy. Defaults to --device. " + "Use a separate CUDA index (e.g. cuda:1) to run eval on a " + "different GPU than training." + ), + ) + p.add_argument( + "--buffer-device", + default=None, + help=( + "Device for the replay buffer storage. Defaults to --device. " + "Use 'cpu' on environments where TorchRL was not built with " + "CUDA support (no CUDA PRB extension)." + ), + ) + p.add_argument("--compile-env", action="store_true") + p.add_argument( + "--compile-eval-env", + action="store_true", + help="Pass compile_step=True to the eval env (separate from --compile-env).", + ) + p.add_argument( + "--compile-policy", + action="store_true", + help="torch.compile the policy forward pass during collection.", + ) + p.add_argument( + "--compile-loss", + action="store_true", + help=( + "torch.compile(loss_module) so the SAC objective is JIT-traced " + "and runs as a fused graph each gradient step." + ), + ) + p.add_argument( + "--async-collector", + action="store_true", + help=( + "Use aSyncDataCollector (collection runs in a separate process " + "and overlaps with gradient updates). Requires that the env " + "constructor be picklable; weights are synced once per outer " + "iter via MultiProcessWeightSyncScheme." + ), + ) + p.add_argument("--no-wandb", action="store_true") + p.add_argument("--wandb-project", default="torchrl-sat") + p.add_argument("--wandb-group", default="sac-per") + p.add_argument( + "--wandb-mode", + default="online", + choices=["online", "offline", "disabled"], + help=( + "wandb run mode. Use 'offline' when wandb.init() " + "handshake keeps timing out; sync later with " + "'wandb sync torchrl-sat/wandb/offline-run-*'." + ), + ) + p.add_argument("--test-set-csv", default=str(DEFAULT_TEST_SET_PATH)) + p.add_argument("--obs-norm-path", default=str(DEFAULT_OBS_NORM_PATH)) + p.add_argument( + "--no-obs-norm", + action="store_true", + help=( + "Skip the ObservationNorm transform; the policy sees raw " + "observations (already in physical units after the env's " + "sin/cos gimbal encoding)." + ), + ) + p.add_argument("--eval-every", type=int, default=10) + p.add_argument( + "--no-eval", + action="store_true", + help="Skip the periodic eval rollout entirely (focus on train metrics only).", + ) + p.add_argument("--obs-norm-warmup", type=int, default=1024) + return p.parse_args(argv) + + +def _ensure_obs_norm_stats( + *, + num_envs: int, + device: torch.device, + max_steps: int, + num_cmgs: int, + action_scale: float, + singularity_weight: float, + omega_weight: float, + ctrl_cost_weight: float, + frame_skip: int, + seed: int, + path: Path, + warmup: int, +) -> tuple[torch.Tensor, torch.Tensor]: + if path.exists(): + torchrl_logger.info(f"Loading ObservationNorm stats from {path}") + d = torch.load(path, map_location="cpu", weights_only=True) + loc, scale = d["loc"], d["scale"] + expected_dim = 6 + 3 * num_cmgs + if loc.shape[-1] == expected_dim and scale.shape[-1] == expected_dim: + return loc, scale + torchrl_logger.warning( + f"Ignoring stale ObservationNorm stats at {path}: expected " + f"last dim {expected_dim}, got loc={tuple(loc.shape)} " + f"scale={tuple(scale.shape)}." + ) + torchrl_logger.info( + f"No ObservationNorm stats at {path}; running {warmup} warm-up steps." + ) + env, obs_norm = make_train_env( + num_envs=min(num_envs, 1024), + device=device, + max_steps=max_steps, + min_random_horizon=None, + compile_step=False, + obs_norm_stats=None, + num_cmgs=num_cmgs, + action_scale=action_scale, + singularity_weight=singularity_weight, + omega_weight=omega_weight, + ctrl_cost_weight=ctrl_cost_weight, + frame_skip=frame_skip, + seed=seed, + ) + obs_norm.init_stats( + num_iter=warmup, reduce_dim=(0, 1), cat_dim=1, key="observation" + ) + loc = obs_norm.loc.detach().cpu() + scale = obs_norm.scale.detach().cpu() + path.parent.mkdir(parents=True, exist_ok=True) + torch.save({"loc": loc, "scale": scale}, path) + torchrl_logger.info(f"Wrote ObservationNorm stats to {path}") + env.close() + return loc, scale + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> int: + cfg = parse_args(argv) + device = pick_device(cfg.device) + eval_device = ( + torch.device(cfg.eval_device) if cfg.eval_device is not None else device + ) + torchrl_logger.info( + f"Device: {device} | eval_device: {eval_device} | " + f"num_envs={cfg.num_envs} | frames_per_env={cfg.frames_per_env} | " + f"total_iters={cfg.total_iters}" + ) + + if not cfg.no_wandb: + setup_wandb_key() + + obs_norm_stats: tuple[torch.Tensor, torch.Tensor] | None + if cfg.no_obs_norm: + torchrl_logger.info("ObservationNorm disabled (--no-obs-norm).") + obs_norm_stats = None + else: + loc, scale = _ensure_obs_norm_stats( + num_envs=cfg.num_envs, + device=device, + max_steps=cfg.max_steps, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + seed=cfg.seed, + path=Path(cfg.obs_norm_path), + warmup=cfg.obs_norm_warmup, + ) + obs_norm_stats = (loc, scale) + + train_env, _ = make_train_env( + num_envs=cfg.num_envs, + device=device, + max_steps=cfg.max_steps, + min_random_horizon=cfg.min_random_horizon, + random_horizon_prob=cfg.random_horizon_prob, + compile_step=cfg.compile_env, + obs_norm_stats=obs_norm_stats, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + reward_scale=cfg.reward_scale, + seed=cfg.seed, + ) + obs_spec = train_env.observation_spec + action_spec = train_env.action_spec + + hidden = tuple(cfg.hidden) + actor = make_actor( + obs_spec=obs_spec, + action_spec=action_spec, + device=device, + hidden=hidden, + activation=cfg.activation, + state_independent_scale=False, + layer_norm=cfg.layer_norm, + small_init_last_layer=cfg.small_init_last_layer, + scale_init=cfg.scale_init, + ) + qvalue = make_qvalue_critic( + obs_spec=obs_spec, + action_spec=action_spec, + device=device, + hidden=hidden, + activation=cfg.activation, + layer_norm=cfg.layer_norm, + small_init_last_layer=cfg.small_init_last_layer, + ) + + with torch.no_grad(): + td0 = train_env.reset() + actor(td0) + qvalue(actor(td0)) + + # ----- Loss + target net ----- + loss_module = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + loss_function="l2", + alpha_init=cfg.alpha_init, + min_alpha=cfg.min_alpha, + max_alpha=cfg.max_alpha, + fixed_alpha=cfg.fixed_alpha, + delay_actor=False, + delay_qvalue=True, + delay_value=True, + target_entropy="auto", + action_spec=action_spec, + ) + loss_module.make_value_estimator(gamma=cfg.gamma) + target_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) + + # Optional torch.compile of the loss objective. The SACLoss forward + # is a fixed graph over a single sample TD; compiling it fuses + # actor + twin-Q + alpha computation per gradient step. The + # uncompiled module is kept around for soft-update bookkeeping + # (target_updater walks loss_module.qvalue_network_params). + if cfg.compile_loss: + torchrl_logger.info("torch.compile(loss_module) enabled.") + loss_call = torch.compile(loss_module) + else: + loss_call = loss_module + + critic_lr = cfg.critic_lr if cfg.critic_lr is not None else cfg.lr + optim_actor = torch.optim.Adam( + loss_module.actor_network_params.flatten_keys().values(), + lr=cfg.lr, + ) + optim_critic = torch.optim.Adam( + loss_module.qvalue_network_params.flatten_keys().values(), + lr=critic_lr, + ) + optim_alpha = None + if not cfg.fixed_alpha: + optim_alpha = torch.optim.Adam([loss_module.log_alpha], lr=cfg.lr) + + # Optional cosine LR decay across the full training span. Decays + # to ``lr * lr_decay_end_frac`` by ``total_iters``. + schedulers: list[torch.optim.lr_scheduler.LRScheduler] = [] + if cfg.lr_decay_end_frac is not None: + eta_min_actor = cfg.lr * cfg.lr_decay_end_frac + eta_min_critic = critic_lr * cfg.lr_decay_end_frac + schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingLR( + optim_actor, + T_max=cfg.total_iters, + eta_min=eta_min_actor, + ) + ) + schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingLR( + optim_critic, + T_max=cfg.total_iters, + eta_min=eta_min_critic, + ) + ) + if optim_alpha is not None: + schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingLR( + optim_alpha, + T_max=cfg.total_iters, + eta_min=eta_min_actor, + ) + ) + torchrl_logger.info( + f"Cosine LR decay enabled: lr {cfg.lr:.4g} -> " + f"{eta_min_actor:.4g} over {cfg.total_iters} iters." + ) + + # ----- Replay buffer ----- + buffer_device = ( + torch.device(cfg.buffer_device) if cfg.buffer_device is not None else device + ) + storage = LazyTensorStorage(cfg.buffer_size, device=buffer_device) + if cfg.no_prb: + replay_buffer = TensorDictReplayBuffer( + storage=storage, + batch_size=cfg.batch_size, + prefetch=3, + ) + torchrl_logger.info("Using uniform TensorDictReplayBuffer (no PER).") + else: + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=cfg.prb_alpha, + beta=cfg.prb_beta, + storage=storage, + batch_size=cfg.batch_size, + priority_key="td_error", + prefetch=3, + ) + + # ----- Collector ----- + frames_per_batch = cfg.num_envs * cfg.frames_per_env + init_random_frames = cfg.num_envs * cfg.init_random_frames_per_env + total_frames = frames_per_batch * cfg.total_iters + postproc = ( + MultiStep(gamma=cfg.gamma, n_steps=cfg.n_step) if cfg.n_step > 1 else None + ) + if postproc is not None: + torchrl_logger.info( + f"Using {cfg.n_step}-step returns " + f"(MultiStep postproc with gamma={cfg.gamma})." + ) + if cfg.async_collector: + # Build a picklable env factory; the worker process re-creates + # the env there. The actor lives in the main process and gets + # synced to the worker once per outer iter via update_policy_weights_. + train_env_factory = partial( + _train_env_factory, + num_envs=cfg.num_envs, + device_str=str(device), + max_steps=cfg.max_steps, + min_random_horizon=cfg.min_random_horizon, + random_horizon_prob=cfg.random_horizon_prob, + compile_step=cfg.compile_env, + obs_norm_stats=obs_norm_stats, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + reward_scale=cfg.reward_scale, + seed=cfg.seed, + ) + # Free the in-process train env: the worker will build its own. + train_env.close() + torchrl_logger.info( + "Async collector enabled (MultiCollector sync=False, 1 worker). " + "Collection runs in a separate process and overlaps with grad updates." + ) + collector = MultiCollector( + create_env_fn=[train_env_factory], + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + init_random_frames=init_random_frames, + exploration_type=ExplorationType.RANDOM, + postproc=postproc, + update_at_each_batch=True, + sync=False, + ) + else: + collector = Collector( + train_env, + policy=actor, + frames_per_batch=frames_per_batch, + total_frames=total_frames, + device=device, + init_random_frames=init_random_frames, + exploration_type=ExplorationType.RANDOM, + compile_policy=cfg.compile_policy, + postproc=postproc, + ) + + # ----- WandB ----- + logger = None + if not cfg.no_wandb: + exp = generate_exp_name("SAC-PER", "satellite") + # ``--wandb-mode offline`` avoids the 90s wandb.init() handshake + # timeout that has been intermittently failing on this host. + # Sync offline runs with + # ``wandb sync torchrl-sat/wandb/offline-run-*``. + wandb_mode = cfg.wandb_mode + logger = get_logger( + "wandb", + logger_name="torchrl-sat", + experiment_name=exp, + wandb_kwargs={ + "project": cfg.wandb_project, + "group": cfg.wandb_group, + "config": vars(cfg), + "mode": wandb_mode, + }, + ) + + # ----- Eval (process-isolated, non-blocking) ----- + evaluator = None + if not cfg.no_eval: + _, _, cats = load_test_set_csv(cfg.test_set_csv) + metrics_fn = make_eval_metrics_fn(cats) + + def _on_eval_result(result): + # Logged from the evaluator's coordination thread. + torchrl_logger.info( + f"[eval] step={int(result.get('eval/step', -1))} " + f"return={float(result.get('eval/eval/return', float('nan'))):.3f} " + f"final_err={float(result.get('eval/eval/final_attitude_error_rad', float('nan'))):.3f} " + f"success@0.10={float(result.get('eval/eval/success_rate@0.10', float('nan'))):.3f}" + ) + if logger is not None: + # The Evaluator emits keys with the prefix it was + # configured with (default ``eval/``); we store them as-is. + flat = { + k: v.item() if hasattr(v, "item") else float(v) + for k, v in result.items() + } + step = int(flat.pop("eval/step", -1)) + logger.log_metrics(flat, step=max(0, step)) + + n_eval_envs = len(cats) + env_factory = partial( + _eval_env_factory, + device_str=str(eval_device), + test_set_csv=cfg.test_set_csv, + max_steps=cfg.max_steps, + obs_norm_stats=obs_norm_stats, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + compile_step=cfg.compile_eval_env, + reward_scale=cfg.reward_scale, + ) + policy_factory = partial( + _eval_policy_factory, + obs_spec=obs_spec, + action_spec=action_spec, + hidden=tuple(cfg.hidden), + activation=cfg.activation, + device_str=str(eval_device), + ) + evaluator = Evaluator( + env=env_factory, + policy_factory=policy_factory, + num_trajectories=n_eval_envs, + # ``max_steps`` is intentionally None: the eval env already + # has a :class:`StepCounter` set to ``cfg.max_steps``, and + # the process-backend collector raises if both sources try + # to enforce the same horizon. + max_steps=None, + exploration_type=ExplorationType.DETERMINISTIC, + metrics_fn=metrics_fn, + on_result=_on_eval_result, + backend="process", + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, + ) + torchrl_logger.info( + "Evaluator(backend='process') ready; first eval will spawn " + "the child process on the first trigger." + ) + else: + torchrl_logger.info("Eval disabled (--no-eval); training-only metrics.") + + def _check_finite(td, key: tuple[str, ...], context: str) -> None: + value = td.get(key) + if not torch.isfinite(value).all(): + raise RuntimeError(f"Non-finite tensor at {context}: {key}") + + # ----- Training loop ----- + pbar_t0 = time.perf_counter() + collected_frames = 0 + for it, batch in enumerate(collector): + _check_finite(batch, ("next", "reward"), f"collector iter {it}") + _check_finite(batch, ("next", "observation"), f"collector iter {it}") + # Flatten (num_envs, frames_per_env, ...) -> (N, ...) + flat = batch.reshape(-1) + replay_buffer.extend(flat) + collected_frames += frames_per_batch + + # Skip SAC updates while warming up the buffer. + if collected_frames < init_random_frames: + torchrl_logger.info(f"iter={it} (warmup) frames={collected_frames}") + continue + + # ----- SAC updates ----- + last_loss = None + for _ in range(cfg.gradient_steps): + sample = replay_buffer.sample().to(device) + losses = loss_call(sample) + + optim_actor.zero_grad(set_to_none=True) + optim_critic.zero_grad(set_to_none=True) + if optim_alpha is not None: + optim_alpha.zero_grad(set_to_none=True) + loss = losses["loss_actor"] + losses["loss_qvalue"] + losses["loss_alpha"] + if not torch.isfinite(loss): + loss_summary = { + key: value.detach().item() + for key, value in losses.items() + if value.numel() == 1 + } + raise RuntimeError(f"Non-finite SAC loss at iter {it}: {loss_summary}") + loss.backward() + # Per-component grad-norm logging. ``clip_grad_norm_`` + # returns the total norm BEFORE clipping, so we can log it + # directly. Calling it per parameter group also clips each + # group independently rather than treating actor + critic + # + alpha as one big vector. + grad_norm_actor = torch.nn.utils.clip_grad_norm_( + list(loss_module.actor_network_params.flatten_keys().values()), + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + grad_norm_critic = torch.nn.utils.clip_grad_norm_( + list(loss_module.qvalue_network_params.flatten_keys().values()), + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + if optim_alpha is not None: + grad_norm_alpha = torch.nn.utils.clip_grad_norm_( + [loss_module.log_alpha], + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + else: + grad_norm_alpha = torch.zeros((), device=device) + optim_actor.step() + optim_critic.step() + if optim_alpha is not None: + optim_alpha.step() + target_updater.step() + + # Push the per-sample TD error back as the priority signal + # (only meaningful for prioritized replay). + if not cfg.no_prb: + replay_buffer.update_tensordict_priority(sample) + last_loss = losses + + done_mask = batch["next", "done"] + ep_reward = batch["next", "episode_reward"][done_mask] + ep_length = batch["next", "step_count"][done_mask].to(torch.float32) + ep_reward_mean = ep_reward.mean().item() if ep_reward.numel() else None + ep_length_mean = ep_length.mean().item() if ep_length.numel() else None + # Per-step reward is the policy-quality signal: cumulative + # reward divided by episode length removes the artifact where + # ``episode_reward`` drifts more negative simply because + # ``RandomTruncationTransform`` lengthens episodes once the + # first uniformly-sampled horizons clear out. + ep_reward_per_step = ( + (ep_reward / ep_length.clamp_min(1.0)).mean().item() + if ep_reward.numel() + else None + ) + ep_reward_text = ( + f"{ep_reward_mean:.3f}" if ep_reward_mean is not None else "n/a" + ) + # ``batch_reward_per_step`` is the snapshot of "how is the + # policy doing right now": mean reward across the current + # 1024-sample collection batch. Decoupled from episode length + # and episode-completion rate, so it is the cleanest training + # signal. + batch_reward_per_step = batch.get(("next", "reward")).mean().item() + # Mean ||bus_omega||² this batch -- shows whether the policy + # is actually moving the satellite (>>0) or sitting still (~0). + batch_omega_sq = ( + (batch.get(("next", "bus_omega")) ** 2).sum(dim=-1).mean().item() + ) + # Mean ||quat_err|| this batch -- direct read of how far the + # bus is from the target on average right now. + batch_att_err = batch.get(("next", "quat_err")).norm(dim=-1).mean().item() + if logger is not None and last_loss is not None: + metrics = { + "train/loss_actor": last_loss["loss_actor"].item(), + "train/loss_qvalue": last_loss["loss_qvalue"].item(), + "train/loss_alpha": last_loss["loss_alpha"].item(), + "train/alpha": loss_module.log_alpha.detach().exp().item(), + "train/iter_per_sec": (it + 1) + / max(1e-6, time.perf_counter() - pbar_t0), + "train/buffer_size": len(replay_buffer), + "train/critic_lr": optim_critic.param_groups[0]["lr"], + "train/batch_reward_per_step": batch_reward_per_step, + "train/batch_omega_sq": batch_omega_sq, + "train/batch_attitude_error": batch_att_err, + "train/grad_norm_actor": grad_norm_actor.item(), + "train/grad_norm_critic": grad_norm_critic.item(), + "train/grad_norm_alpha": grad_norm_alpha.item(), + } + if ep_reward_mean is not None: + metrics["train/episode_reward"] = ep_reward_mean + metrics["train/episode_reward_per_step"] = ep_reward_per_step + metrics["train/episode_length"] = ep_length_mean + logger.log_metrics(metrics, step=collected_frames) + if last_loss is not None: + torchrl_logger.info( + f"iter={it} frames={collected_frames} " + f"r/step={batch_reward_per_step:.3f} " + f"|omega|²={batch_omega_sq:.3f} " + f"|q_err|={batch_att_err:.3f} " + f"ep_r={ep_reward_text} " + f"loss_q={last_loss['loss_qvalue'].item():.3f} " + f"loss_a={last_loss['loss_actor'].item():.3f} " + f"alpha={loss_module.log_alpha.detach().exp().item():.3f}" + ) + + # Step LR schedulers once per outer iteration (each iter is one + # collector batch; gradient_steps inner steps share the same lr). + for sched in schedulers: + sched.step() + + if ( + evaluator is not None + and (it + 1) % cfg.eval_every == 0 + and not evaluator.pending + ): + evaluator.trigger_eval(actor, step=collected_frames) + + if evaluator is not None: + # Drain any pending eval, then shut the worker process down. + if evaluator.pending: + try: + evaluator.wait(timeout=120.0) + except Exception as e: # noqa: BLE001 + torchrl_logger.warning(f"Final eval wait failed: {e}") + evaluator.shutdown() + collector.shutdown() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/satellite/sac_per_fully_async.py b/examples/satellite/sac_per_fully_async.py new file mode 100644 index 00000000000..4b313c3df2f --- /dev/null +++ b/examples/satellite/sac_per_fully_async.py @@ -0,0 +1,688 @@ +"""SAC + Prioritized Replay Buffer with fully-async collection. + +Like ``sac_per.py`` but the collector and trainer are fully decoupled: +the collector is constructed with ``replay_buffer=rb`` and started via +``collector.start()`` so it pushes new transitions to the buffer in the +background. The main loop only iterates over gradient steps, sampling +from the buffer and updating the policy/value/alpha networks. + +Defaults match the ``1rlrmzo4`` configuration (the strongest +synchronous run so far): 16k vmapped envs, 64 grad steps per "outer +chunk", 6400 outer chunks => 409,600 gradient updates total. UTD ratio +is no longer enforced -- collection and training proceed at their own +hardware-determined rates. +""" + +from __future__ import annotations + +import argparse +import sys +import time +from functools import partial +from pathlib import Path + +import torch + +from torchrl._utils import logger as torchrl_logger +from torchrl.collectors import Evaluator, MultiCollector +from torchrl.data import ( + LazyTensorStorage, + TensorDictPrioritizedReplayBuffer, + TensorDictReplayBuffer, +) +from torchrl.envs.utils import ExplorationType +from torchrl.objectives import SACLoss, SoftUpdate +from torchrl.record.loggers import generate_exp_name, get_logger +from torchrl.weight_update import MultiProcessWeightSyncScheme + +PACKAGE_DIR = Path(__file__).resolve().parent +if str(PACKAGE_DIR.parent.parent) not in sys.path: + sys.path.insert(0, str(PACKAGE_DIR.parent.parent)) + +from examples.satellite._utils import ( # noqa: E402 + DEFAULT_OBS_NORM_PATH, + DEFAULT_TEST_SET_PATH, + load_test_set_csv, + make_actor, + make_eval_metrics_fn, + make_qvalue_critic, + make_train_env, + pick_device, + setup_wandb_key, +) +from examples.satellite.sac_per import ( # noqa: E402 + _eval_env_factory, + _eval_policy_factory, + _train_env_factory, +) + + +# --------------------------------------------------------------------------- +# Config +# --------------------------------------------------------------------------- + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """CLI mirroring ``sac_per.py`` so launch commands are interchangeable. + + The async-only knobs are at the bottom. Unused-in-async flags + (``--init-random-frames-per-env``, ``--gradient-steps``, + ``--frames-per-env``) are still parsed so launch scripts don't break, + but their semantics shift -- see help text on each. + """ + p = argparse.ArgumentParser(description=__doc__) + # Env / collection sizing. + p.add_argument("--num-envs", type=int, default=16_384) + p.add_argument("--num-cmgs", type=int, default=4) + p.add_argument("--max-steps", type=int, default=300) + p.add_argument("--frame-skip", type=int, default=50) + p.add_argument("--min-random-horizon", type=int, default=100) + p.add_argument("--random-horizon-prob", type=float, default=0.02) + p.add_argument("--frames-per-env", type=int, default=1) + p.add_argument( + "--total-iters", + type=int, + default=6400, + help=( + "Number of outer chunks. Each chunk runs --gradient-steps " + "grad updates. Total grad updates = total_iters * " + "gradient_steps. Eval cadence and LR scheduler horizon both " + "use total_iters." + ), + ) + p.add_argument("--gradient-steps", type=int, default=64) + p.add_argument("--buffer-size", type=int, default=1_000_000) + p.add_argument( + "--init-buffer-size", + type=int, + default=16_384, + help=( + "Number of transitions the collector must produce before " + "the trainer starts taking gradient steps. Default = " + "num_envs (one outer-iter's worth)." + ), + ) + p.add_argument( + "--init-random-frames-per-env", + type=int, + default=0, + help="Parsed for CLI parity; not used by the async loop.", + ) + p.add_argument("--batch-size", type=int, default=4096) + p.add_argument("--prb-alpha", type=float, default=0.7) + p.add_argument("--prb-beta", type=float, default=0.5) + p.add_argument("--no-prb", action="store_true") + # SAC. + p.add_argument("--lr", type=float, default=9e-4) + p.add_argument("--critic-lr", type=float, default=9e-4) + p.add_argument("--gamma", type=float, default=0.99) + p.add_argument("--alpha-init", type=float, default=1.0) + p.add_argument("--fixed-alpha", action="store_true") + p.add_argument("--min-alpha", type=float, default=None) + p.add_argument("--max-alpha", type=float, default=None) + p.add_argument("--target-update-polyak", type=float, default=0.995) + p.add_argument("--max-grad-norm", type=float, default=1.0) + p.add_argument( + "--lr-decay-end-frac", + type=float, + default=0.1, + help="Cosine LR end value as a fraction of starting lr.", + ) + p.add_argument( + "--lr-warmup-iters", + type=int, + default=0, + help=( + "Linear warmup over this many outer iters from " + "lr * 1/start_div -> lr, then cosine decay over the remaining " + "(total_iters - lr_warmup_iters) iters." + ), + ) + p.add_argument( + "--lr-warmup-start-factor", + type=float, + default=1e-2, + help="Start factor for the warmup phase. lr_start = lr * start_factor.", + ) + # Reward / env physics. + p.add_argument("--action-scale", type=float, default=3.0) + p.add_argument("--singularity-weight", type=float, default=0.0) + p.add_argument("--singularity-clamp-min", type=float, default=1e-6) + p.add_argument( + "--singularity-mode", + type=str, + default="inverse", + choices=["inverse", "exp"], + ) + p.add_argument("--singularity-exp-k", type=float, default=5.0) + p.add_argument("--omega-weight", type=float, default=0.1) + p.add_argument("--ctrl-cost-weight", type=float, default=0.0) + p.add_argument("--reward-scale", type=float, default=0.333) + # Networks. + p.add_argument("--hidden", type=int, nargs="+", default=[256, 256, 256, 256]) + p.add_argument("--activation", type=str, default="tanh") + p.add_argument("--small-init-last-layer", action="store_true", default=True) + p.add_argument( + "--no-small-init-last-layer", dest="small_init_last_layer", action="store_false" + ) + p.add_argument("--scale-init", type=float, default=0.31) + p.add_argument("--layer-norm", action="store_true") + p.add_argument("--no-obs-norm", action="store_true", default=True) + p.add_argument("--obs-norm-path", type=str, default=str(DEFAULT_OBS_NORM_PATH)) + p.add_argument("--obs-norm-warmup", type=int, default=10_000) + # Devices / compile. + p.add_argument("--seed", type=int, default=0) + p.add_argument("--device", default=None) + p.add_argument("--eval-device", default=None) + p.add_argument( + "--buffer-device", + default="cpu", + help=( + "Device for the replay buffer storage. Must be 'cpu' for " + "shared-memory multiprocess use unless TorchRL was built " + "with CUDA support." + ), + ) + p.add_argument("--compile-env", action="store_true") + p.add_argument("--compile-eval-env", action="store_true") + p.add_argument("--compile-policy", action="store_true") + p.add_argument("--compile-loss", action="store_true") + # Eval. + p.add_argument( + "--eval-every", type=int, default=50, help="Eval every N outer chunks." + ) + p.add_argument("--no-eval", action="store_true") + p.add_argument("--test-set-csv", type=str, default=str(DEFAULT_TEST_SET_PATH)) + # WandB. + p.add_argument("--no-wandb", action="store_true") + p.add_argument("--wandb-project", default="torchrl-sat") + p.add_argument("--wandb-group", default="sac-per-async") + p.add_argument( + "--wandb-mode", + default="online", + choices=["online", "offline", "disabled"], + ) + return p.parse_args(argv) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> int: + cfg = parse_args(argv) + device = pick_device(cfg.device) + eval_device = ( + torch.device(cfg.eval_device) if cfg.eval_device is not None else device + ) + total_grad_steps = cfg.total_iters * cfg.gradient_steps + torchrl_logger.info( + f"Device: {device} | eval_device: {eval_device} | " + f"num_envs={cfg.num_envs} | total_iters={cfg.total_iters} | " + f"gradient_steps={cfg.gradient_steps} | " + f"total_grad_steps={total_grad_steps}" + ) + + # ----- Reference env (for spec extraction only) ----- + # The async collector builds its own env in a worker. We only need + # one in the main process so we can read obs/action specs to build + # the actor + critic. Closed immediately after. + spec_env, _ = make_train_env( + num_envs=cfg.num_envs, + device=device, + max_steps=cfg.max_steps, + min_random_horizon=cfg.min_random_horizon, + random_horizon_prob=cfg.random_horizon_prob, + compile_step=False, + obs_norm_stats=None, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + reward_scale=cfg.reward_scale, + seed=cfg.seed, + ) + obs_spec = spec_env.observation_spec + action_spec = spec_env.action_spec + spec_env.close() + obs_norm_stats = None # --no-obs-norm only path supported in async + + # ----- Networks + loss ----- + hidden = tuple(cfg.hidden) + actor = make_actor( + obs_spec=obs_spec, + action_spec=action_spec, + device=device, + hidden=hidden, + activation=cfg.activation, + state_independent_scale=False, + layer_norm=cfg.layer_norm, + small_init_last_layer=cfg.small_init_last_layer, + scale_init=cfg.scale_init, + ) + qvalue = make_qvalue_critic( + obs_spec=obs_spec, + action_spec=action_spec, + device=device, + hidden=hidden, + activation=cfg.activation, + layer_norm=cfg.layer_norm, + ) + loss_module = SACLoss( + actor_network=actor, + qvalue_network=qvalue, + num_qvalue_nets=2, + loss_function="l2", + alpha_init=cfg.alpha_init, + min_alpha=cfg.min_alpha, + max_alpha=cfg.max_alpha, + fixed_alpha=cfg.fixed_alpha, + delay_actor=False, + delay_qvalue=True, + delay_value=True, + target_entropy="auto", + action_spec=action_spec, + ) + loss_module.make_value_estimator(gamma=cfg.gamma) + target_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) + + if cfg.compile_loss: + torchrl_logger.info("torch.compile(loss_module) enabled.") + loss_call = torch.compile(loss_module) + else: + loss_call = loss_module + + critic_lr = cfg.critic_lr if cfg.critic_lr is not None else cfg.lr + optim_actor = torch.optim.Adam( + loss_module.actor_network_params.flatten_keys().values(), + lr=cfg.lr, + ) + optim_critic = torch.optim.Adam( + loss_module.qvalue_network_params.flatten_keys().values(), + lr=critic_lr, + ) + optim_alpha = None + if not cfg.fixed_alpha: + optim_alpha = torch.optim.Adam([loss_module.log_alpha], lr=cfg.lr) + + schedulers: list[torch.optim.lr_scheduler.LRScheduler] = [] + + def _build_lr_scheduler(optim, eta_min): + """Linear warmup -> cosine decay (or just cosine if warmup=0).""" + decay_iters = max(1, cfg.total_iters - cfg.lr_warmup_iters) + cosine = torch.optim.lr_scheduler.CosineAnnealingLR( + optim, + T_max=decay_iters, + eta_min=eta_min, + ) + if cfg.lr_warmup_iters > 0: + warmup = torch.optim.lr_scheduler.LinearLR( + optim, + start_factor=cfg.lr_warmup_start_factor, + end_factor=1.0, + total_iters=cfg.lr_warmup_iters, + ) + return torch.optim.lr_scheduler.SequentialLR( + optim, + schedulers=[warmup, cosine], + milestones=[cfg.lr_warmup_iters], + ) + return cosine + + if cfg.lr_decay_end_frac is not None: + eta_min_actor = cfg.lr * cfg.lr_decay_end_frac + eta_min_critic = critic_lr * cfg.lr_decay_end_frac + schedulers.append(_build_lr_scheduler(optim_actor, eta_min_actor)) + schedulers.append(_build_lr_scheduler(optim_critic, eta_min_critic)) + if optim_alpha is not None: + schedulers.append(_build_lr_scheduler(optim_alpha, eta_min_actor)) + if cfg.lr_warmup_iters > 0: + torchrl_logger.info( + f"LR schedule: linear warmup {cfg.lr * cfg.lr_warmup_start_factor:.4g} " + f"-> {cfg.lr:.4g} over {cfg.lr_warmup_iters} iters, then " + f"cosine decay -> {eta_min_actor:.4g} over the remaining " + f"{cfg.total_iters - cfg.lr_warmup_iters} iters." + ) + else: + torchrl_logger.info( + f"Cosine LR decay: lr {cfg.lr:.4g} -> {eta_min_actor:.4g} " + f"over {cfg.total_iters} outer chunks." + ) + + # ----- Replay buffer (shared across processes) ----- + buffer_device = torch.device(cfg.buffer_device) + storage = LazyTensorStorage(cfg.buffer_size, device=buffer_device) + # Note: ``prefetch`` is incompatible with ``shared=True`` (the + # multiprocess SyncManager wrappers can't pickle the prefetch + # thread state). + if cfg.no_prb: + replay_buffer = TensorDictReplayBuffer( + storage=storage, + batch_size=cfg.batch_size, + shared=True, + ) + torchrl_logger.info("Using uniform TensorDictReplayBuffer (no PER), shared.") + else: + # ``sync=False`` (added in PR pytorch/rl#3714): writer procs use a + # shareable RandomSampler, while the learner owns a local + # PrioritizedSampler that catches up from shared write_count + # before sampling. Required for the fully-async collector flow. + replay_buffer = TensorDictPrioritizedReplayBuffer( + alpha=cfg.prb_alpha, + beta=cfg.prb_beta, + storage=storage, + batch_size=cfg.batch_size, + priority_key="td_error", + shared=True, + sync=False, + ) + torchrl_logger.info( + f"Using TensorDictPrioritizedReplayBuffer " + f"(alpha={cfg.prb_alpha}, beta={cfg.prb_beta}), shared, sync=False." + ) + + # ----- Async collector ----- + train_env_factory = partial( + _train_env_factory, + num_envs=cfg.num_envs, + device_str=str(device), + max_steps=cfg.max_steps, + min_random_horizon=cfg.min_random_horizon, + random_horizon_prob=cfg.random_horizon_prob, + compile_step=cfg.compile_env, + obs_norm_stats=obs_norm_stats, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + reward_scale=cfg.reward_scale, + seed=cfg.seed, + ) + frames_per_batch = cfg.num_envs * cfg.frames_per_env + collector = MultiCollector( + create_env_fn=[train_env_factory], + policy=actor, + replay_buffer=replay_buffer, + frames_per_batch=frames_per_batch, + total_frames=-1, + device=device, + exploration_type=ExplorationType.RANDOM, + update_at_each_batch=True, + sync=False, + ) + torchrl_logger.info( + "MultiCollector(sync=False, replay_buffer=rb) ready -- starting " + "async collection now." + ) + + # ----- WandB ----- + logger = None + if not cfg.no_wandb: + setup_wandb_key() + exp = generate_exp_name("SAC-PER-async", "satellite") + logger = get_logger( + "wandb", + logger_name="torchrl-sat", + experiment_name=exp, + wandb_kwargs={ + "project": cfg.wandb_project, + "group": cfg.wandb_group, + "mode": cfg.wandb_mode, + "config": vars(cfg), + }, + ) + + # ----- Eval setup (process backend, identical to sac_per.py) ----- + metrics_fn = None + evaluator = None + if not cfg.no_eval: + _, _, cats = load_test_set_csv(cfg.test_set_csv) + metrics_fn = make_eval_metrics_fn(cats) + + def _on_eval_result(result): + torchrl_logger.info( + f"[eval] step={int(result.get('eval/step', -1))} " + f"return={float(result.get('eval/eval/return', float('nan'))):.3f} " + f"final_err={float(result.get('eval/eval/final_attitude_error_rad', float('nan'))):.3f} " + f"success@0.10={float(result.get('eval/eval/success_rate@0.10', float('nan'))):.3f}" + ) + if logger is not None: + flat = { + k: v.item() if hasattr(v, "item") else float(v) + for k, v in result.items() + } + step = int(flat.pop("eval/step", -1)) + logger.log_metrics(flat, step=max(0, step)) + + env_factory = partial( + _eval_env_factory, + device_str=str(eval_device), + test_set_csv=cfg.test_set_csv, + max_steps=cfg.max_steps, + obs_norm_stats=obs_norm_stats, + use_obs_norm=not cfg.no_obs_norm, + num_cmgs=cfg.num_cmgs, + action_scale=cfg.action_scale, + singularity_weight=cfg.singularity_weight, + singularity_clamp_min=cfg.singularity_clamp_min, + singularity_mode=cfg.singularity_mode, + singularity_exp_k=cfg.singularity_exp_k, + omega_weight=cfg.omega_weight, + ctrl_cost_weight=cfg.ctrl_cost_weight, + frame_skip=cfg.frame_skip, + compile_step=cfg.compile_eval_env, + reward_scale=cfg.reward_scale, + ) + policy_factory = partial( + _eval_policy_factory, + obs_spec=obs_spec, + action_spec=action_spec, + hidden=tuple(cfg.hidden), + activation=cfg.activation, + device_str=str(eval_device), + ) + evaluator = Evaluator( + env=env_factory, + policy_factory=policy_factory, + num_trajectories=len(cats), + max_steps=None, + exploration_type=ExplorationType.DETERMINISTIC, + metrics_fn=metrics_fn, + on_result=_on_eval_result, + backend="process", + weight_sync_schemes={"policy": MultiProcessWeightSyncScheme()}, + ) + + # ----- Start async collection ----- + collector.start() + + # Wait for the buffer to warm up enough to start sampling. + while len(replay_buffer) < cfg.init_buffer_size: + torchrl_logger.info( + f"warmup: buffer size = {len(replay_buffer)} " + f"< {cfg.init_buffer_size}; sleeping 1s" + ) + time.sleep(1.0) + torchrl_logger.info( + f"warmup done: buffer size = {len(replay_buffer)}; " + f"starting gradient updates." + ) + + # ----- Training loop ----- + pbar_t0 = time.perf_counter() + grad_step = 0 + last_log_t = pbar_t0 + last_loss = None + + for outer_iter in range(cfg.total_iters): + for _ in range(cfg.gradient_steps): + sample = replay_buffer.sample().to(device) + losses = loss_call(sample) + + optim_actor.zero_grad(set_to_none=True) + optim_critic.zero_grad(set_to_none=True) + if optim_alpha is not None: + optim_alpha.zero_grad(set_to_none=True) + loss = losses["loss_actor"] + losses["loss_qvalue"] + losses["loss_alpha"] + if not torch.isfinite(loss): + loss_summary = { + key: value.detach().item() + for key, value in losses.items() + if value.numel() == 1 + } + raise RuntimeError( + f"Non-finite SAC loss at grad_step {grad_step}: " f"{loss_summary}" + ) + loss.backward() + grad_norm_actor = torch.nn.utils.clip_grad_norm_( + list(loss_module.actor_network_params.flatten_keys().values()), + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + grad_norm_critic = torch.nn.utils.clip_grad_norm_( + list(loss_module.qvalue_network_params.flatten_keys().values()), + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + if optim_alpha is not None: + grad_norm_alpha = torch.nn.utils.clip_grad_norm_( + [loss_module.log_alpha], + cfg.max_grad_norm if cfg.max_grad_norm > 0 else float("inf"), + ) + else: + grad_norm_alpha = torch.zeros((), device=device) + optim_actor.step() + optim_critic.step() + if optim_alpha is not None: + optim_alpha.step() + target_updater.step() + + if not cfg.no_prb: + replay_buffer.update_tensordict_priority(sample) + last_loss = losses + grad_step += 1 + + for sched in schedulers: + sched.step() + + # Per-outer-iter logging. + now = time.perf_counter() + log_dt = now - last_log_t + last_log_t = now + # Reproduce the same step semantic as ``sac_per.py``: "frames" + # in wandb step indexing tracks collector frames; use the + # collector's running write counter as the canonical step. + try: + collected_frames = int(replay_buffer.write_count) + except (AttributeError, TypeError): + collected_frames = len(replay_buffer) + torchrl_logger.info( + f"iter={outer_iter} grad_step={grad_step} " + f"buffer={len(replay_buffer)} " + f"frames={collected_frames} " + f"loss_q={last_loss['loss_qvalue'].item():.3f} " + f"loss_a={last_loss['loss_actor'].item():.3f} " + f"alpha={loss_module.log_alpha.detach().exp().item():.3f} " + f"dt={log_dt:.2f}s" + ) + # Diagnostic batch metrics from the most recent sampled minibatch. + # Mirrors what ``sac_per.py`` and ``ppo_buffer.py`` log so all + # three scripts share a comparable training-quality dashboard. + with torch.no_grad(): + batch_reward_per_step = sample.get(("next", "reward")).mean().item() + qerr_t = sample.get(("next", "quat_err"), default=None) + batch_attitude_error = ( + qerr_t.norm(dim=-1).mean().item() + if qerr_t is not None + else float("nan") + ) + omega_t = sample.get(("next", "bus_omega"), default=None) + batch_omega_sq = ( + (omega_t**2).sum(dim=-1).mean().item() + if omega_t is not None + else float("nan") + ) + done_mask = sample.get(("next", "done")) + ep_r_t = sample.get(("next", "episode_reward"), default=None) + episode_reward = ( + ep_r_t[done_mask].mean().item() + if ep_r_t is not None and done_mask.any() + else None + ) + + if logger is not None and last_loss is not None: + metrics = { + "train/loss_actor": last_loss["loss_actor"].item(), + "train/loss_qvalue": last_loss["loss_qvalue"].item(), + "train/loss_alpha": last_loss["loss_alpha"].item(), + "train/alpha": loss_module.log_alpha.detach().exp().item(), + "train/n_updates": grad_step, + "train/buffer_size": len(replay_buffer), + # ``train/lr`` is the canonical key shared with PPO; keep + # ``train/critic_lr`` as a back-compat alias. + "train/lr": optim_critic.param_groups[0]["lr"], + "train/critic_lr": optim_critic.param_groups[0]["lr"], + "train/grad_norm_actor": grad_norm_actor.item(), + "train/grad_norm_critic": grad_norm_critic.item(), + "train/grad_norm_alpha": grad_norm_alpha.item() + if isinstance(grad_norm_alpha, torch.Tensor) + else grad_norm_alpha, + "train/iter_per_sec": (outer_iter + 1) / max(1e-6, now - pbar_t0), + "train/iter_dt_sec": log_dt, + "train/batch_reward_per_step": batch_reward_per_step, + "train/batch_attitude_error": batch_attitude_error, + "train/batch_omega_sq": batch_omega_sq, + } + if episode_reward is not None: + metrics["train/episode_reward"] = episode_reward + logger.log_metrics(metrics, step=collected_frames) + + # Eval trigger -- skip if the previous eval is still running + # (the async Evaluator's default ``busy_policy='error'`` raises + # when overlapping triggers come in faster than eval can drain). + if ( + evaluator is not None + and (outer_iter + 1) % cfg.eval_every == 0 + and not evaluator.pending + ): + evaluator.trigger_eval(actor, step=collected_frames) + + # ----- Shutdown ----- + torchrl_logger.info( + f"Training complete: {grad_step} gradient updates. " "Shutting down collector." + ) + try: + collector.async_shutdown(timeout=30.0) + except Exception as e: + torchrl_logger.warning(f"collector.async_shutdown failed: {e}") + if evaluator is not None: + if evaluator.pending: + try: + evaluator.wait(timeout=120.0) + except Exception as e: + torchrl_logger.warning(f"Final eval wait failed: {e}") + try: + evaluator.shutdown() + except Exception as e: + torchrl_logger.warning(f"evaluator.shutdown failed: {e}") + if logger is not None: + try: + logger.experiment.finish() + except Exception: + pass + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/satellite/test_set.csv b/examples/satellite/test_set.csv new file mode 100644 index 00000000000..373ecf23c85 --- /dev/null +++ b/examples/satellite/test_set.csv @@ -0,0 +1,257 @@ +init_w,init_x,init_y,init_z,target_w,target_x,target_y,target_z,category +0.5765609145,0.4450169802,0.2695076466,-0.6300024986,-0.5368856788,-0.6730548739,-0.3261828721,-0.3903275728,uniform +0.3176580369,-0.5780549049,-0.0201656241,-0.7513582706,0.0651217476,0.1860047728,-0.0625796914,-0.9783890843,uniform +-0.3234239817,0.7089627385,-0.1687686443,-0.6035610437,-0.1529447436,0.4152177870,-0.6843644381,0.5795235038,uniform +-0.5127438307,-0.3940810859,-0.5415958166,0.5370919704,0.1497044712,-0.2420543730,0.5007444024,0.8174675703,uniform +0.9233906269,-0.0897333920,-0.2796611190,0.2471584976,-0.0022736422,0.8412088752,0.0776815489,-0.5350963473,uniform +-0.3323571086,0.4727236331,0.3510629237,0.7367672324,-0.2714111507,0.0396070890,0.1895352155,0.9427850842,uniform +0.5468608141,0.5542563796,0.2609911859,0.5706370473,-0.0289293360,-0.3641286194,-0.8233078718,-0.4344394207,uniform +-0.2500519454,0.0450818054,-0.2715901732,0.9282674193,0.0383138470,0.1700676084,0.4199368060,0.8906526566,uniform +-0.5812217593,-0.3657043576,-0.0937586203,0.7208682299,-0.7732198834,0.2298453748,-0.4484505355,0.3849599063,uniform +0.3228825927,-0.4298479259,0.3095580041,-0.7843158841,-0.5010041595,-0.0155826882,-0.2029817551,0.8411601782,uniform +-0.7299830317,0.4666219950,-0.4123267829,-0.2817359865,0.7153947353,-0.4074093103,0.4865961373,0.2923220694,uniform +-0.4535493553,0.7556306124,-0.4394895434,-0.1736787558,0.5638076067,0.5907313824,0.5283543468,0.2323770076,uniform +-0.7338157296,-0.5284956098,0.0626545250,0.4222335219,-0.2882999182,0.2370217144,-0.6338878870,-0.6774141192,uniform +-0.2884652913,0.7042509913,-0.4811822176,-0.4350655377,-0.8536746502,0.2494996637,-0.0721620396,0.4514223337,uniform +-0.9000198841,0.0230922550,-0.0407132693,0.4333281815,-0.9108331800,-0.0621723793,0.4074559808,0.0222982075,uniform +-0.0377033800,0.7110687494,-0.4566248953,0.5333415866,0.5480587482,0.1034929231,-0.8299785852,0.0075108213,uniform +0.5104403496,0.3024962544,0.7834537029,0.1847889870,-0.4603649378,0.2088720798,-0.8627701998,-0.0080102123,uniform +0.2033989727,-0.1157865226,-0.6187928915,0.7498785257,-0.0653332621,-0.3993984759,0.9139218926,0.0309697241,uniform +-0.2462537736,0.7490810752,0.0809760690,0.6096552610,0.4619418979,0.7501518726,0.2821698785,0.3798182607,uniform +0.2314069569,-0.2582576275,-0.8879501820,-0.3021559417,0.3294924498,0.9191653728,0.2138279974,-0.0291108228,uniform +0.0072374791,-0.2255806178,-0.8928058743,-0.3898188174,0.6198113561,-0.6464206576,-0.3451029658,-0.2808526456,uniform +0.3923373222,0.3838746250,0.8350352049,0.0377884731,-0.3955938816,-0.1630397886,-0.9018449187,-0.0599942580,uniform +0.4958071113,-0.3209520280,-0.6994168758,0.4024686217,0.8591461778,0.1079570651,0.3981393576,0.3028168380,uniform +-0.7235037088,-0.3477301002,0.5606834888,0.2031257302,-0.0208346378,0.0666650236,0.8810570836,0.4678249657,uniform +-0.9383618832,0.1824720055,0.2933705449,0.0107116848,0.2670309246,-0.4721468687,0.5513368249,-0.6338765621,uniform +0.4430379272,0.4032751918,0.7377059460,-0.3112499416,0.1534914821,0.4216271043,-0.7804291844,-0.4354322255,uniform +-0.2075705081,0.8433781862,0.4534837306,0.1999502629,0.0756342784,-0.8942091465,-0.0563761704,-0.4375969470,uniform +0.1983736306,0.9533935785,-0.0009816211,-0.2273491621,-0.4411619008,-0.0318920128,0.6950677633,-0.5667802095,uniform +-0.8834578991,-0.0620501526,-0.3632916510,0.2892593443,0.2186851948,-0.3618485928,0.5324716568,-0.7332914472,uniform +0.7389573455,0.0927560627,-0.3959124386,0.5372074246,0.7745558619,0.6184306741,-0.0077543482,0.1324634850,uniform +-0.0156172011,0.2964956462,0.1631654203,0.9408631325,0.1818922609,0.3391383886,-0.8290458918,-0.4056885839,uniform +0.7738668919,0.2402082384,0.5088164210,0.2907505929,0.5564295650,0.3399969637,-0.7507383823,0.1057364047,uniform +0.6788300276,0.3556843400,-0.5049170256,-0.3971615732,0.7380903363,0.1520929188,0.6555600166,-0.0482846648,uniform +-0.0732010230,0.8800240159,0.3523198664,0.3099516332,0.5490892529,-0.6877480745,-0.4723846614,0.0485420339,uniform +0.4902859032,0.0079736188,-0.7756062746,0.3974809051,0.2964138091,0.7945515513,-0.0412766859,0.5283207893,uniform +-0.2599313855,0.7154278159,-0.4755158424,0.4410025775,0.7076929212,-0.5052364469,-0.4278694391,-0.2466467470,uniform +-0.3788765371,0.3733390868,0.6307119727,0.5650424361,-0.5212318897,0.2463437170,0.0577136278,-0.8150468469,uniform +0.3305187225,-0.2580269575,-0.8987855315,0.1279218346,-0.1421530396,-0.3747986853,-0.2447292358,0.8828510642,uniform +0.1777378768,0.4931247830,-0.2373335361,0.8178691268,0.7444463968,0.6640471220,0.0110040624,-0.0687009916,uniform +-0.8549933434,-0.1439434588,-0.4121190906,0.2800439000,-0.6538252831,-0.2597763240,-0.1454942077,0.6956006289,uniform +-0.9769384861,0.1174948439,0.0127784414,-0.1778282672,-0.0793506429,0.0604090691,0.7788802385,0.6191928983,uniform +0.2112462521,-0.5382511616,0.1287517846,-0.8056572676,0.3385745883,0.6288666129,-0.1996344626,0.6708503366,uniform +-0.6507762671,0.5495287776,0.5235046744,0.0212417617,-0.0773573071,-0.6538837552,-0.7104691863,0.2483654320,uniform +-0.5584033132,0.2731661201,0.2799576521,0.7315666676,-0.4909583330,0.7098066807,-0.4072924554,-0.2987427413,uniform +0.0425674729,0.1433137804,-0.9574918747,-0.2466948628,0.9230594635,-0.1028727442,-0.2903441489,0.2303880453,uniform +-0.4183198810,-0.7141310573,-0.5099012256,-0.2345764041,0.7956089377,-0.5081574321,-0.3253849745,-0.0539171956,uniform +-0.2074877620,-0.6004481316,-0.7705816627,0.0511351787,-0.3486448824,-0.8806648850,-0.1172536835,0.2985426486,uniform +0.6665403843,-0.3621613383,0.4589110613,0.4625621736,0.6351379156,-0.2135300934,-0.5194701552,0.5302408934,uniform +0.7905637622,-0.4105891287,0.4288114011,-0.1501544416,0.0180287696,-0.7767347693,-0.5548743010,-0.2974434793,uniform +-0.6219360828,0.1894684434,0.2533435225,0.7163199782,0.9764693379,-0.1447325349,0.1549520046,-0.0393710770,uniform +0.2169947326,0.8711921573,-0.2148064971,-0.3844420314,0.3962334394,-0.6637709141,0.4140706360,0.4805753827,uniform +0.1052517593,0.6074793935,0.7426767945,0.2613849640,-0.1651657820,-0.3218509257,-0.2250332534,0.9047056437,uniform +0.6968679428,-0.4218562245,-0.3468739092,0.4648557603,-0.6726810932,-0.3868969679,0.2420839965,0.5824142098,large_err +0.5647940636,-0.0318775363,0.4450093508,0.6942319870,-0.4443878829,0.5519148707,0.2542763650,0.6582195163,large_err +-0.0644768551,0.7432582974,0.4753297865,-0.4663383663,-0.2682366669,-0.2185077965,0.3207903802,0.8817012310,large_err +-0.7788249850,0.2740024924,-0.5639241338,0.0185429957,-0.7508643866,-0.4195079505,0.2688220441,0.4335325658,large_err +-0.2865368426,-0.6554612517,-0.4286468029,0.5518416166,0.5906928182,0.0118015651,-0.6601318121,-0.4638628364,large_err +0.0404954851,-0.4843718708,-0.6624060869,0.5700545311,-0.3194694817,-0.0783349648,-0.4211389422,-0.8452484012,large_err +0.8474854231,-0.5283589959,0.0510383770,-0.0005982571,-0.2519455552,-0.2391362041,0.9260072112,0.1478107423,large_err +-0.7681078911,0.2215996534,0.3008430302,-0.5199974179,-0.3658339083,-0.7401066422,-0.1369227022,-0.5474120378,large_err +0.6827676892,0.4706366062,0.4935742617,-0.2621333301,-0.0588072129,0.7222244143,-0.6850427985,-0.0751661286,large_err +0.8226979375,0.4318535030,-0.2031614482,-0.3088626564,0.0232535582,0.0761324763,0.5616278052,0.8235516548,large_err +-0.2414807230,-0.2939476371,-0.7827830911,-0.4924757183,-0.3598566949,-0.6018050909,-0.2963801622,0.6484540701,large_err +-0.3343604803,-0.0541786700,-0.4068666995,-0.8483673930,0.1237304509,-0.2341439426,-0.6175463796,0.7406104207,large_err +0.8486857414,-0.2383830398,0.4671629667,0.0682999417,0.2799378037,-0.3420477211,-0.8505390286,-0.2849939466,large_err +-0.2953739762,-0.2540234327,-0.5032595396,0.7713339925,0.4416933656,0.6535217762,-0.0337161086,0.6137422323,large_err +-0.3269045651,-0.8705685139,0.2406763285,0.2780625224,0.8151741624,-0.2093605697,-0.4827665985,-0.2420652658,large_err +0.6381177902,0.4345389307,-0.3479897380,0.5318691134,0.4990907907,0.2190208882,-0.5639416575,-0.6204095483,large_err +-0.8999763131,0.1273362190,-0.4027706981,0.1077213883,0.0143957650,-0.9176358581,0.3195224404,-0.2358867824,large_err +0.7788752913,0.1094965339,-0.3740244806,-0.4913955033,-0.8256134987,0.0149692502,-0.2068684548,-0.5247320533,large_err +-0.6629487872,0.0345392153,-0.7090710998,-0.2377477288,0.5739606023,0.6820913553,-0.0778645128,-0.4463829398,large_err +0.6349475980,-0.6105579734,0.3620378375,0.3049410284,-0.2974488139,-0.3962547183,0.2737978399,-0.8243429065,large_err +0.1106180996,-0.8288899064,-0.4563334584,-0.3040804267,0.1065907031,0.4724985659,-0.7830964923,-0.3900555670,large_err +-0.7014446855,-0.0506028645,-0.5513593554,0.4487956762,0.2756977081,0.6350089312,-0.1090783626,0.7133416533,large_err +-0.3220693171,-0.4081122875,-0.0867877081,0.8498138785,-0.8489464521,0.3648563027,-0.2889395356,-0.2503671050,large_err +0.6348542571,-0.3963636756,0.3539052308,0.5608984232,-0.3937693834,-0.8429794908,-0.0566757470,-0.3621037602,large_err +0.5707576871,0.8048466444,0.1132776812,0.1167291403,0.7493536472,-0.3034982681,0.5047899485,-0.3025640249,large_err +0.3186406791,-0.3780684173,0.4858687222,-0.7207384109,-0.8480805755,0.0441955663,-0.4898403585,-0.1971358806,large_err +0.3344392776,0.4589494169,0.8230267167,0.0119519150,-0.5250146985,0.0825933889,0.3918833137,-0.7509762645,large_err +0.5198995471,-0.5891577005,-0.1701085269,0.5946937799,0.5095557570,0.0118257981,0.8586930633,-0.0534713790,large_err +-0.5029318333,-0.7018501163,0.1225001067,-0.4893462062,-0.5929372311,0.4619339108,-0.0966047868,-0.6524645686,large_err +-0.3077037632,-0.8249747157,0.4735643566,0.0217238087,0.6980845332,0.0780912116,0.7079037428,-0.0738386512,large_err +-0.1964939833,-0.1599262059,-0.7585188746,-0.6003856063,0.6182367206,0.7572979331,0.1917170435,0.0867625996,large_err +-0.5418920517,0.1866806298,0.4801559746,-0.6640433669,-0.5653663874,0.3372915089,-0.5599923134,0.5029948950,large_err +0.3615823388,-0.8709657192,0.0922627524,0.3196316957,0.4260075092,-0.1918314844,0.1507914960,-0.8711947203,large_err +0.3205989897,0.0803235844,0.8896522522,0.3150925934,-0.6301438212,-0.4130046368,0.0705481395,0.6537346244,large_err +-0.5863404274,0.7278684378,-0.2880268991,0.2084538192,-0.5894134045,0.3330538869,0.5799854994,-0.4530825317,large_err +-0.1630355567,-0.8583679795,0.4794000089,0.0824586153,0.1638925374,0.1914934665,0.9049834609,-0.3427453935,large_err +-0.3461409509,0.6728027463,0.6530655026,-0.0320660546,-0.9124129415,-0.0843395516,0.2295368910,-0.3281804025,large_err +0.2845056057,0.2431261241,0.5041337013,0.7783286572,-0.3191023469,-0.3096354604,0.6199600101,-0.6464897990,large_err +-0.3911590576,0.4900604188,-0.6631744504,-0.4086990952,0.4833739102,-0.3301065862,-0.7586250901,0.2861245275,large_err +-0.5269294381,-0.1890888214,0.3950764835,0.7283579707,0.2498888522,0.9547296166,-0.0464533046,0.1545612514,large_err +0.0283443984,0.9326224923,0.1498247534,0.3270539939,-0.1439840496,-0.1216090247,-0.5781298876,0.7938801050,large_err +0.0773261189,-0.4712623358,-0.7361610532,0.4795824885,0.1033695489,0.7430937290,-0.5602025390,-0.3511404395,large_err +-0.1780701727,-0.2964829504,-0.0987749621,-0.9330767989,-0.8356988430,0.1909192055,0.1445923597,0.4942169189,large_err +-0.6826238632,-0.0661941692,-0.7277449965,-0.0054944376,0.1538330466,-0.7887527347,0.1076927111,0.5853263736,large_err +-0.8234325647,-0.2096153349,0.1543100625,-0.5041910410,0.1768541634,-0.9127289653,-0.3629643917,-0.0624930859,large_err +0.6717652678,-0.2161907107,0.0985926390,0.7016212940,-0.6108610630,0.1770737469,-0.5515127778,0.5397474766,large_err +0.0531594083,0.2597351670,0.0959360823,0.9594309926,-0.0690760016,0.0607221127,-0.9396282434,0.3296061158,large_err +-0.0655037388,-0.2523321509,0.8664602637,-0.4257749617,-0.7366220951,-0.0863018110,-0.3159785271,-0.5916903615,large_err +-0.3437717855,0.8306443095,0.3696675599,0.2349400371,-0.5250837207,-0.6317701340,-0.2276587486,0.5228050351,large_err +0.1559921205,0.2160713077,0.7377265096,0.6202735901,-0.6491004825,-0.3617776632,0.5592410564,0.3674708605,large_err +-0.0659323558,0.7779129148,-0.2729792297,-0.5621269941,0.3327767551,-0.5462961793,-0.6788402796,0.3605498374,large_err +0.1253060251,0.8297061920,-0.5169183016,0.1693561673,0.7808852196,-0.1130248681,0.0037157251,0.6143531799,near_singular +-0.1366664022,0.2688074112,-0.8672754169,-0.3961037397,-0.1728432328,-0.2392090559,0.0336129554,0.9548687935,near_singular +0.8941589594,0.1693999767,-0.4139359891,-0.0209829137,0.2932556272,-0.8428772092,-0.1479943991,-0.4262120128,near_singular +0.3642675281,0.5725809336,-0.5593920350,0.4759631455,0.5889151692,-0.7659914494,0.1136486009,-0.2313438207,near_singular +-0.1108813733,-0.1790249795,-0.2987135351,-0.9308198690,-0.1260232180,-0.1523701102,0.7771361470,0.5974620581,near_singular +-0.1696088761,0.2510158718,0.3578359187,-0.8832764626,0.3751700521,-0.3994451761,0.8296582699,-0.1065751761,near_singular +0.4391006827,0.2870866656,0.8504198790,-0.0394694284,-0.2320598811,0.5062256455,0.6243227124,-0.5478183627,near_singular +0.5880202055,0.6891258955,0.0680706799,-0.4179762006,0.1581024081,-0.7647993565,0.4184598923,0.4636559188,near_singular +0.6624516845,-0.1891448796,0.0844115093,0.7199004292,-0.2000017911,0.5707195401,-0.7936362028,-0.0664835498,near_singular +-0.7033249736,-0.5415078998,0.0384537913,-0.4589383900,-0.2854847014,-0.6072996259,0.6781337261,0.2997002602,near_singular +0.5104073882,0.3843522072,0.1934930980,-0.7445254326,-0.3206688464,-0.2772272527,-0.9001556635,-0.1001816243,near_singular +-0.6751686931,-0.5495695472,0.1774043739,0.4589642584,0.3139085472,0.0207795594,0.8288320303,0.4626738131,near_singular +0.5546681881,0.2564232647,-0.2431124896,-0.7533169985,-0.2563979030,-0.8709082007,-0.3690562248,-0.1989385337,near_singular +0.5237284303,0.8209043145,0.0992799103,-0.2048612982,0.8441168666,-0.4646601975,0.2337268144,0.1301131397,near_singular +0.4829994142,0.7770447731,-0.3977569044,-0.0685742348,0.6784280539,-0.4787372947,-0.0444278717,-0.5554926991,near_singular +-0.1121124029,-0.9544025660,0.0944874138,-0.2600357831,0.8156655431,0.4858420789,-0.2880380154,-0.1252249032,near_singular +0.2816623151,-0.4159805477,-0.3856071234,-0.7739080787,0.2996439338,-0.1199385226,-0.8850911260,0.3353237212,near_singular +-0.0810814872,0.4424323142,0.6969001889,0.5585781336,0.5240030289,0.3189566135,-0.0047590593,-0.7897245884,near_singular +-0.1084988043,0.1602014154,0.0659902170,0.9788814187,-0.0269851610,0.8844300508,0.4140766561,0.2135316432,near_singular +0.7838524580,-0.5375248194,0.1522083580,0.2710629106,-0.5388365388,-0.7398316264,0.3073480129,-0.2604639828,near_singular +0.2086893469,-0.1474526972,0.0718902573,-0.9641256332,-0.0298555717,0.0181757417,0.8828220963,-0.4684052765,near_singular +0.0896132886,0.9583822489,-0.2537099123,-0.0954161137,0.6637288332,-0.6919069290,0.2397125512,-0.1525344551,near_singular +-0.1756517887,0.8275718689,0.5303179622,0.0550815053,0.2403610349,0.2860612273,-0.0369084440,0.9268404841,near_singular +0.6733121872,0.4742806554,-0.4695495069,0.3181695938,-0.1221178249,-0.5229095817,-0.4372692704,-0.7214210629,near_singular +-0.1928677410,0.5834074616,-0.3438722789,-0.7100630999,0.6243361831,0.1618155241,0.7395491600,-0.1925803274,near_singular +-0.1742137372,0.5668209195,0.3972574770,0.7003928423,0.1379095465,0.6965302825,0.1092914790,0.6956161857,near_singular +0.1819849461,0.9725120068,0.1188717261,-0.0834946856,-0.4444175959,0.8623691797,0.0267449655,-0.2410337776,near_singular +-0.3732139170,0.2091331482,-0.8935489058,-0.1361806244,0.5111750960,0.3777841032,-0.6889425516,-0.3483351171,near_singular +-0.2439089566,-0.9682663083,-0.0301145352,-0.0454084016,0.1919811666,-0.1072731763,0.2340465784,0.9470258951,near_singular +0.5488935113,0.3867770731,0.0299569536,-0.7404201031,0.4788212776,-0.5018733740,-0.1407979876,-0.7064200044,near_singular +0.2355781496,0.5668969154,0.7873584032,0.0565481782,0.9881789088,0.0962109268,0.0834159404,0.0853686631,near_singular +0.0856738165,0.1742558926,-0.8797746897,-0.4339255393,0.5338418484,0.3591652215,-0.0628290549,0.7629323006,near_singular +0.0711834729,0.0400444269,0.0684986487,0.9943024516,-0.0927231163,-0.8206062913,-0.2041304260,-0.5256791115,near_singular +0.4977307618,0.2424022257,0.5407144427,-0.6333506703,0.6927196383,-0.7065383196,-0.0649947152,-0.1293014288,near_singular +0.9362745881,0.0214701220,-0.1240693554,0.3279263377,-0.8981961012,-0.2149042785,-0.1699608266,-0.3437630534,near_singular +-0.1737618446,-0.9015474319,-0.3934307396,0.0472369939,-0.4057266414,-0.7855293751,-0.4105836749,-0.2230482101,near_singular +-0.2042062730,-0.3039425611,-0.5461176634,0.7534415722,0.0847639963,0.2827377617,0.0935226232,-0.9508564472,near_singular +-0.0025485710,-0.1501984894,0.9047443867,0.3985868394,-0.8289795518,-0.2973423898,-0.4641213119,-0.0947197601,near_singular +0.7509312630,-0.0417818092,0.6438624263,-0.1407043785,0.9503879547,0.1151386648,0.0380866975,-0.2864528596,near_singular +0.2647968829,-0.1714508086,-0.8896347284,-0.3302080035,0.5752843022,0.1447261274,-0.7542284727,0.2814991772,near_singular +-0.2746165991,0.9239203334,0.2293483764,-0.1354853511,-0.1633588970,0.0380790569,-0.2851632535,0.9436873198,near_singular +-0.1257797629,0.0407170877,0.5040342808,0.8535051346,0.0051471307,0.0529473349,0.9955478907,-0.0778105780,near_singular +-0.4891841114,0.6178389788,0.5898881555,-0.1760850102,-0.8532334566,-0.0060885609,-0.4594660997,0.2466708869,near_singular +-0.8064907789,0.4981707335,0.2741786242,0.1619402170,-0.3851086795,-0.6406468153,-0.6223835349,0.2321673334,near_singular +-0.0310608465,-0.3322150707,0.3253640234,0.8847635984,-0.0556532703,-0.6409766078,0.4514475763,-0.6182610989,near_singular +0.0035461958,-0.2574644685,0.3540952206,0.8990639448,0.2434497923,0.3769963980,0.3163479567,-0.8357810378,near_singular +-0.4110143781,0.7692613006,-0.0257770196,0.4885075986,-0.8193839788,-0.1607086062,0.2871046662,0.4694181979,near_singular +-0.6542116404,-0.1848922968,-0.6433438659,-0.3520379364,-0.0054482385,-0.4455439448,-0.8720747828,0.2023520470,near_singular +-0.2703700066,-0.7149355412,0.6446377039,0.0144701218,-0.6063002944,-0.1848825365,0.4140280187,0.6532987356,near_singular +-0.1457846463,0.6840152144,-0.5759202242,-0.4233037531,-0.4379869998,-0.1357305348,-0.3177236915,0.8299375176,near_singular +0.1192364618,0.6212294698,0.7369441390,0.2382644862,0.7974092960,-0.0453348719,-0.2352293134,0.5538504124,near_singular +-0.3387504220,0.1961085498,-0.3881158829,0.8343594074,-0.0144764083,0.5221565366,-0.0379524939,0.8518818021,precision +0.0061861072,0.5483211875,-0.7466587424,-0.3765718043,-0.3653986454,-0.3934561908,0.7621444464,0.3616794050,precision +0.8086666465,-0.1054298282,-0.5518583059,0.1743421555,-0.7080537081,-0.3530876935,0.5810059309,-0.1908434182,precision +-0.1012888104,0.3090709746,0.6746972203,0.6625701785,-0.2170276642,-0.0943711177,-0.8476867080,-0.4747845232,precision +-0.4765225649,0.0760065019,0.1424238533,-0.8642132878,0.2183807045,0.2435441464,0.3074236810,-0.8935809135,precision +0.9351258278,-0.2019770890,-0.0089607667,-0.2909721732,-0.6247051358,-0.1624485403,0.2227261961,0.7305799723,precision +0.5254984498,-0.6422654390,-0.4714173675,0.2985166311,0.3464554548,-0.4699550569,-0.7206335664,0.3738958836,precision +0.2000606656,0.7168529034,0.1509593129,0.6506218910,0.5730660558,0.7427632213,0.2521032691,0.2373646498,precision +-0.2075038850,0.9446325898,-0.2028354406,0.1531965733,0.2091992050,-0.9225215316,0.0018484895,-0.3243243396,precision +-0.0001165057,-0.2673927844,-0.9603832960,-0.0785169676,-0.4424977005,0.2785128355,-0.7900624871,-0.3200432658,precision +-0.8545114994,-0.0575328432,0.2422058284,-0.4558908343,-0.5898409486,-0.4843844473,-0.1165734455,-0.6355077028,precision +-0.8137446642,0.4482749701,0.3484625816,0.1242693886,-0.9221937656,0.0160517599,0.1251604855,0.3655622900,precision +-0.2117313147,-0.2038156986,0.9534536600,-0.0674921945,0.0039148810,0.1430569887,-0.9770224690,-0.1579444855,precision +0.2590807378,0.3009169400,0.7297839522,0.5565443635,-0.0023295670,-0.1410570741,0.9774150252,0.1573448330,precision +-0.0865643471,0.5323529243,-0.7549852729,0.3729668260,-0.1830106378,0.2353199571,-0.8360537887,0.4605928063,precision +0.0588358119,0.9210214019,0.0885456055,0.3747235239,-0.4765682220,-0.8616161942,-0.0838553831,-0.1531949639,precision +0.1019992903,0.1048168167,0.4459411800,0.8830323219,-0.4569721520,0.2710000277,-0.1195931658,-0.8387090564,precision +0.1668585837,0.0567871444,-0.7527346015,-0.6342901587,0.0535958111,-0.5182295442,-0.8398047686,-0.1526224464,precision +-0.7567061186,-0.5406005979,-0.1422332525,-0.3389934003,-0.4721406996,-0.3244054615,-0.1141489744,-0.8116737604,precision +0.9304831028,-0.1642141640,-0.3200774789,-0.0691761672,-0.9201987386,0.2049579769,0.0367969051,-0.3314700127,precision +-0.1104842573,0.7643997073,0.3573668897,0.5251429677,-0.0430133753,-0.5624576807,-0.5720103383,-0.5954791307,precision +0.6352842450,-0.5222283006,0.5309429169,0.2044290006,0.2825617492,-0.3317571282,0.8813337684,0.1826112717,precision +-0.4464294314,-0.0519787744,-0.8045563102,0.3881856501,-0.1718042046,0.2475963831,-0.7779501081,0.5513373613,precision +0.5040262938,0.1312221587,-0.2024377137,-0.8293113112,0.4445734322,0.2072319388,0.1001443118,-0.8656676412,precision +0.8236836791,-0.4482718110,0.2220860273,0.2669745088,-0.9842389822,0.1192486361,0.0462683812,-0.1221171319,precision +0.4781199992,-0.1203269884,0.5609267354,-0.6650443077,0.3694638610,0.0548892319,0.8345638514,-0.4049529731,precision +0.4789435267,0.3560050726,-0.7835884690,-0.1728082746,-0.7879298329,0.2728313208,0.5100920796,0.2110351026,precision +-0.1967675239,0.9486905336,-0.1915451735,-0.1567773521,-0.5687038302,0.6731565595,-0.4519563019,-0.1384616941,precision +0.2278582901,-0.4713589251,-0.2389827520,0.8177949786,-0.3226634562,0.0869105905,-0.3295232952,-0.8830341697,precision +0.4334940016,0.1581609398,0.8805400133,0.1082465649,0.2906801999,-0.1034581661,0.8188303709,-0.4840642214,precision +0.6312527061,0.1358741969,-0.6055377722,-0.4651690125,-0.1011749730,-0.3938803971,0.7093100548,0.5757612586,precision +-0.0028808906,0.7238988280,-0.2989606559,0.6217594743,-0.0421455204,-0.2407351434,0.7540364265,-0.6096716523,precision +-0.0765471160,0.2638218701,0.9275199771,-0.2534662485,-0.2926414013,-0.2683823109,-0.6497194767,0.6482257843,precision +-0.3931856453,0.0235186554,0.8791167140,0.2683387697,-0.6289770603,0.3247342110,0.6379321814,-0.3032788038,precision +0.7736812234,-0.3149532080,0.0562075041,-0.5468661785,-0.8068801165,0.5796535015,0.0981577709,-0.0575437844,precision +-0.1348114163,-0.2783644497,-0.6518107653,-0.6924462318,0.3434258997,0.0813325420,-0.6222872138,-0.6987146735,precision +-0.7449692488,-0.5154733658,-0.3904960155,-0.1637711078,-0.6803236604,-0.3993207514,0.0371668786,-0.6134504080,precision +-0.3964410126,0.6145939827,0.6098064184,-0.3053603470,-0.4283410311,0.7589061260,0.1631444544,-0.4625681639,precision +0.5057138801,-0.5278573632,0.4546603262,-0.5088262558,-0.6126950383,0.3575898707,-0.1839443147,0.6803665757,precision +0.3298825622,0.7220147252,-0.5810136199,-0.1797090918,-0.0641000047,0.3072360754,-0.8996073008,-0.3036508560,precision +-0.0024017729,-0.9813342094,-0.0720761195,0.1782765239,-0.1696422398,-0.6183058023,-0.2699713111,0.7183557153,precision +0.2426537275,-0.4099556208,0.8738018274,0.0976013467,0.2833406925,-0.7742570043,0.5584419370,0.0915790573,precision +0.7091106176,-0.1389959604,-0.6388345957,-0.2640692294,-0.1886707097,0.4507237971,0.8608942032,0.1418187022,precision +0.7565515041,-0.5028947592,0.3302069008,-0.2563007772,-0.2249078751,0.8756054640,-0.4197920263,-0.0806615055,precision +0.6430086493,0.3834329844,0.6621972919,0.0318390355,0.3478351533,0.3338105679,0.8070112467,-0.3410485983,precision +0.6666319370,-0.1288073957,0.2884028554,-0.6751551032,0.3270417750,0.1496325284,0.8271850348,-0.4317623973,precision +0.1821030527,-0.8627601266,0.0846384317,-0.4640256763,0.3876724541,-0.5085322261,-0.5397988558,-0.5474688411,precision +0.3454990983,0.3853314519,-0.0864977837,0.8512745500,-0.1774160862,-0.4080548882,0.6594276428,-0.6059455276,precision +0.2611908913,0.7615833282,0.5852715969,0.0960587338,-0.2828935385,0.9200071096,0.2654040158,-0.0558461882,precision +-0.0950481817,0.5277545452,-0.1941010207,-0.8214412928,-0.5175274014,0.2140915096,-0.5775576830,-0.5939338207,precision +0.1978553087,-0.4690381289,0.7498471737,0.4225941300,-0.5903618336,0.0744636655,-0.4189461470,-0.6858659983,precision +0.3657827079,-0.2488426864,-0.2150893360,0.8706416488,0.7093685269,-0.4719077945,0.4146018922,0.3196945786,off_axis +0.4159088433,0.0531812124,0.3074172139,0.8542168140,0.7944093347,-0.3731456995,-0.3730162680,-0.3008903861,off_axis +0.2733111978,-0.5549582243,-0.3140222430,-0.7202168703,0.7661744356,0.3077702224,-0.4844837487,-0.2890151143,off_axis +-0.3021493256,-0.5505224466,0.7310578823,-0.2668055296,0.3920069337,-0.2425784767,0.3564361334,-0.8126743436,off_axis +-0.4896101058,0.1203272045,-0.5849639177,0.6353114843,0.7944104075,0.4089204073,0.3625744879,0.2650206983,off_axis +0.7433882952,-0.2631540596,0.4912641644,-0.3698422313,0.8555234671,-0.2088967860,0.2807210386,-0.3816247880,off_axis +-0.6671281457,0.0607365966,-0.0744004026,-0.7387257814,0.7026898861,-0.4762889743,-0.3032164872,0.4329382777,off_axis +-0.8088250756,0.2112846971,-0.3736040592,-0.4019713104,0.5117912292,0.3968138397,-0.5701189041,-0.5055422783,off_axis +-0.1588096917,0.0995824486,0.0242953729,0.9819737673,0.8170246482,0.1138916090,0.4806018472,0.2975254655,off_axis +-0.2230105996,-0.9074729681,-0.2968116403,-0.1966266781,0.7978664637,-0.2755243480,-0.5075240135,-0.1729589850,off_axis +0.4394114316,-0.8058360815,-0.0985207781,-0.3844988048,0.3591170907,-0.2865315676,-0.8235929608,0.3326098323,off_axis +-0.9192651510,-0.0658734813,-0.3874958754,-0.0214294605,0.3816950917,-0.3386258185,0.3451678157,-0.7877186537,off_axis +-0.2867800891,0.2882508337,0.5508955717,-0.7288228869,0.6647337079,-0.3399763405,0.3517611325,0.5646320581,off_axis +-0.1663702726,-0.6047533751,-0.6746099591,-0.3892242610,0.8099514842,-0.5144972205,0.2411636859,-0.1452971995,off_axis +0.3999395370,-0.4207392037,-0.3364665210,0.7414965034,0.6664898396,-0.5405593514,-0.3584864140,0.3675246537,off_axis +-0.5369109511,0.4567735493,-0.2613247633,0.6593890786,0.5709999204,0.2178629488,0.3670171499,0.7012796998,off_axis +-0.7394815087,-0.0535243936,0.2054878920,0.6388089657,0.5000947118,0.1761332154,0.8143514991,-0.2360382825,off_axis +0.9428015947,0.2453801036,0.0339312926,-0.2230750769,0.1937250197,0.2838715315,0.3767279983,-0.8602114320,off_axis +-0.3838754892,0.8893271089,-0.1739889532,-0.1773830354,0.7297095060,-0.3015848398,0.2784036398,0.5468656421,off_axis +-0.1023444459,0.2333052307,0.5595436692,-0.7886730433,0.2510482371,0.2103316039,-0.7591030002,-0.5625814795,off_axis +0.4347971380,0.5475036502,-0.6120938063,0.3695029914,0.6983520985,0.3028165698,-0.5343798995,-0.3674841821,off_axis +0.0909157023,-0.9392173290,0.3300426602,-0.0260199439,0.5307925344,-0.2404746413,0.7628666162,0.2801176012,off_axis +0.5312511325,0.6927239299,-0.4730537534,-0.1188524589,0.7119163871,-0.2896908820,-0.4067001641,-0.4938110709,off_axis +0.2749144137,0.1929744482,0.8609386086,-0.3820572793,0.6203848720,0.4311186969,0.3181777894,-0.5727322698,off_axis +0.8216028214,0.3095712960,-0.4731952250,-0.0722529143,0.3257389069,0.8094465733,0.3215044141,0.3678659797,off_axis +0.1291329116,-0.0122732678,-0.0426687859,-0.9906328321,0.2616447210,-0.5891177058,0.3318631053,-0.6887301803,off_axis +0.7220728397,-0.2821916640,-0.4986415803,-0.3877309859,0.2592317760,-0.3889842629,-0.8150291443,0.3423707783,off_axis +-0.3759568930,0.8232302666,0.4149942398,-0.0934249833,0.8074967265,0.3993062377,-0.2308284044,-0.3677253425,off_axis +-0.2121598274,0.6678100228,-0.7052567005,0.1078466177,0.4960836172,-0.2944883406,-0.7994435430,-0.1675341576,off_axis +-0.5031794310,-0.2785966992,0.8014529943,-0.1639127284,0.7113929391,0.3759950697,0.5169498920,0.2920798063,off_axis +-0.4167300165,-0.1451184601,-0.8972877860,0.0122970343,0.7631041408,0.2663827240,0.5640166402,-0.1691083461,off_axis +0.1281247586,0.5461177826,-0.8100746870,0.1706410199,0.2133316994,0.4151563942,0.8616775870,-0.1991147399,off_axis +-0.7934812307,0.4571343362,0.2585727870,-0.3074994087,0.4510550201,-0.4859864414,0.7075994015,0.2442737222,off_axis +0.4091331363,0.1467105150,0.8010534644,-0.4115816951,0.3407166898,0.1663883775,-0.5634717941,0.7339800000,off_axis +-0.6521788836,-0.4443181455,0.1921496838,0.5833717585,0.7919062972,-0.2953449190,0.3436406255,0.4093494117,off_axis +0.8902254105,0.4228579402,-0.1467277706,-0.0846218690,0.5740323663,0.4612068534,-0.6539881229,0.1734202504,off_axis +0.4922516346,0.3192203939,0.3947248757,0.7070918679,0.8476251364,-0.4120882154,0.2525842190,-0.2188975811,off_axis +0.2578463554,-0.3152136207,0.4778327048,0.7783518434,0.6808111668,0.4082299471,-0.4819263816,-0.3709332943,off_axis +-0.3795522153,-0.7652170658,-0.4534837902,-0.2544316351,0.1926669776,-0.4178189635,0.7856814265,0.4135353565,off_axis +-0.4667710364,-0.7060041428,0.4394766390,0.3009040654,0.2020235062,-0.5356283784,-0.7028787136,-0.4221970737,off_axis +-0.7670376897,-0.3402480483,0.0931896642,-0.5359105468,0.7744576335,0.1510931551,0.5468286872,0.2799370289,off_axis +0.2500271499,0.7545146346,0.3742992580,0.4775920808,0.8551702499,0.3042804897,0.2261524200,-0.3534861803,off_axis +0.1317458302,0.2196974754,0.5193811655,-0.8152418733,0.5471379757,0.7196413279,0.2488981038,0.3475717306,off_axis +-0.5644424558,-0.6052215099,-0.4227528274,0.3693126142,0.2299744636,0.5456702113,0.3445337415,0.7284588814,off_axis +0.5378409624,0.7938348651,-0.2385083586,0.1538411379,0.2486838251,0.4699462950,-0.8190599680,-0.2155171484,off_axis +0.4145342410,0.7678257823,0.2041255087,-0.4437765181,0.5715116262,-0.3534813225,-0.4034751058,-0.6209938526,off_axis +-0.8968217969,-0.3675900698,-0.2066753209,0.1336928159,0.2871094644,0.3646900952,-0.3494209945,-0.8139252663,off_axis +0.6001796722,0.2096502036,0.2418205738,0.7330442667,0.2623543143,0.8639948368,-0.3755383193,-0.2089357078,off_axis +-0.3170050383,-0.2117804736,0.9190400839,-0.1001102999,0.8405498266,0.2295046896,-0.4485359192,0.1990455240,off_axis +0.6126387119,0.1922095418,-0.0267061684,0.7661697865,0.6033971906,0.4131008089,0.4429579377,-0.5186983347,off_axis +0.0675699264,-0.6571859121,0.1710746586,-0.7309407592,0.4776850045,0.4219897985,0.3810240924,-0.6697478890,off_axis diff --git a/test/libs/test_mujoco.py b/test/libs/test_mujoco.py new file mode 100644 index 00000000000..b2e86a5f8c2 --- /dev/null +++ b/test/libs/test_mujoco.py @@ -0,0 +1,712 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for the MuJoCo custom envs (humanoid / ant / walker / hopper / +satellite) across the three physics backends.""" + +from __future__ import annotations + +import pytest +import torch + +from tensordict import TensorDict +from torchrl.envs import ( + AntEnv, + HopperEnv, + HumanoidEnv, + MujocoEnv, + ParallelEnv, + SatelliteEnv, + SerialEnv, + Walker2dEnv, +) +from torchrl.envs.custom.mujoco._backends import ( + _has_jax, + _has_mjx, + _has_mujoco, + _has_mujoco_torch, +) +from torchrl.envs.custom.mujoco._math import ( + cmg_jacobian, + pyramid_4cmg_geometry, + quat_conj, + quat_log, + quat_mul, + random_unit_quat, +) +from torchrl.envs.utils import check_env_specs + +_AVAILABLE_BACKENDS: list[str] = [] +if _has_mujoco_torch: + _AVAILABLE_BACKENDS.append("mujoco-torch") +if _has_mjx and _has_jax: + _AVAILABLE_BACKENDS.append("mjx") +if _has_mujoco: + _AVAILABLE_BACKENDS.append("mujoco") + +_VMAP_BACKENDS = [b for b in _AVAILABLE_BACKENDS if b in ("mujoco-torch", "mjx")] +_LOCOMOTION_ENVS = [HumanoidEnv, AntEnv, Walker2dEnv, HopperEnv] + + +@pytest.mark.skipif( + not _AVAILABLE_BACKENDS, + reason="No MuJoCo backend installed (mujoco-torch / mjx / mujoco).", +) +class TestMujoco: + # ------------------------------------------------------------------ + # Spec / rollout coverage across all available backends. + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + @pytest.mark.parametrize("cls", _LOCOMOTION_ENVS) + def test_locomotion_env_specs(self, cls, backend): + if backend == "mujoco": + # Single-env semantics for C-bindings backend. + env = cls(num_envs=1, seed=0, backend=backend) + else: + env = cls(num_envs=2, seed=0, backend=backend) + check_env_specs(env) + assert env.observation_spec["observation"].shape[0] == env.batch_size[0] + assert env.action_spec.shape[0] == env.batch_size[0] + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + @pytest.mark.parametrize("cls", _LOCOMOTION_ENVS) + def test_locomotion_rollout(self, cls, backend): + n = 1 if backend == "mujoco" else 2 + env = cls(num_envs=n, seed=0, backend=backend) + td = env.rollout(5) + reward = td.get(("next", "reward")) + assert reward.shape[-1] == 1 + assert torch.isfinite(reward).all() + + # ------------------------------------------------------------------ + # Satellite: spec, dim sanity, finite singularity reward. + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + @pytest.mark.parametrize("num_cmgs", [4, 6]) + def test_satellite_specs(self, num_cmgs, backend): + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv(num_cmgs=num_cmgs, num_envs=n, seed=0, backend=backend) + check_env_specs(env) + # action_spec dim = N_GIMBALS, not nu (rotors are held constant). + assert env.action_spec.shape == torch.Size([n, num_cmgs]) + # The observation is exposed as named sub-keys so a + # CatTensors transform can pack the dynamics-relevant ones into + # a single policy input while keeping ``manipulability`` + # available for logging. + obs_spec = env.observation_spec + assert obs_spec["quat_err"].shape == torch.Size([n, 3]) + assert obs_spec["bus_omega"].shape == torch.Size([n, 3]) + assert obs_spec["gimbal_angles"].shape == torch.Size([n, 2 * num_cmgs]) + assert obs_spec["gimbal_rates"].shape == torch.Size([n, num_cmgs]) + assert obs_spec["manipulability"].shape == torch.Size([n, 1]) + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + @pytest.mark.parametrize("num_cmgs", [4, 6]) + def test_satellite_reward_finite(self, num_cmgs, backend): + """Singularity term must never explode: ``+eps`` in ``manipulability`` + guards ``1/sqrt(det(JJ^T))`` against rank-deficient configurations. + """ + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv(num_cmgs=num_cmgs, num_envs=n, seed=0, backend=backend) + td = env.rollout(50) + assert torch.isfinite(td.get(("next", "reward"))).all() + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_satellite_reward_guard_reports_nonfinite_component(self, backend): + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend) + env.reset() + state = env._state_td() + action = torch.zeros(env.action_spec.shape, dtype=env.dtype, device=env.device) + action[0, 0] = torch.finfo(env.dtype).max + with pytest.raises(RuntimeError, match="reward/control_cost"): + env._compute_reward(state, action, state) + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_satellite_action_changes_gimbal_state(self, backend): + n = 1 if backend == "mujoco" else 2 + env_zero = SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend) + env_one = SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend) + td_zero = env_zero.reset() + td_one = env_one.reset() + zero_action = torch.zeros( + env_zero.action_spec.shape, dtype=env_zero.dtype, device=env_zero.device + ) + one_action = torch.ones( + env_one.action_spec.shape, dtype=env_one.dtype, device=env_one.device + ) + + td_zero = env_zero.step(td_zero.set("action", zero_action))["next"] + td_one = env_one.step(td_one.set("action", one_action))["next"] + + assert not torch.allclose(td_zero["gimbal_angles"], td_one["gimbal_angles"]) + + def test_quat_log_uses_short_arc(self): + q = random_unit_quat((1024,), generator=torch.Generator().manual_seed(0)) + log_q = quat_log(q) + log_neg_q = quat_log(-q) + assert torch.allclose(log_q, log_neg_q, atol=1e-5, rtol=1e-5) + assert log_q.norm(dim=-1).max() <= torch.pi + 1e-5 + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_satellite_gimbal_observation_is_periodic(self, backend): + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend) + td = env.rollout(1000) + gimbal_obs = td["next", "gimbal_angles"] + assert torch.isfinite(gimbal_obs).all() + assert (gimbal_obs.abs() <= 1.0 + 1e-6).all() + + # ------------------------------------------------------------------ + # Satellite physics-correctness: specific (state, action) -> (next + # state, reward) transitions verified against analytical + # predictions. These tests catch bugs that pure spec / finite-value + # tests miss (e.g. a wrong gimbal index in the obs builder, an + # inverted torque sign, an off-by-one in the reset override). + # ------------------------------------------------------------------ + + @staticmethod + def _make_sat(backend: str, n: int = 2, **kwargs) -> SatelliteEnv: + if backend == "mujoco": + n = 1 + return SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend, **kwargs) + + @staticmethod + def _bus_quat(env: SatelliteEnv) -> torch.Tensor: + return env._backend.qpos[..., 3:7].to(env.dtype).clone() + + @staticmethod + def _bus_omega(env: SatelliteEnv) -> torch.Tensor: + return env._backend.qvel[..., 3:6].to(env.dtype).clone() + + @staticmethod + def _step_with_action( + env: SatelliteEnv, action: torch.Tensor, n_steps: int + ) -> None: + """Drive the env for ``n_steps`` substeps with a fixed ``action``.""" + td = env.reset() if not getattr(env, "_was_reset", False) else None + if td is None: + # Re-read current state into a fresh td (without resetting). + td = TensorDict( + {"action": action}, batch_size=env.batch_size, device=env.device + ) + else: + td.set("action", action) + for _ in range(n_steps): + td = env.step(td) + td = td["next"].select(*env.observation_spec.keys()) + td.set("action", action) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_zero_action_preserves_orientation(self, backend): + """Zero gimbal command + symmetric pyramid CMG cluster (sum of + rotor moments == 0 at theta=0) means the bus has zero net + torque applied to it. Roll out 200 steps with zero action and + confirm the bus quaternion drifts < 1 deg from its initial + attitude. + """ + env = self._make_sat(backend, n=2) + # Use a non-trivial init_bus_quat so we'd notice if the env + # were silently re-initialising every step. + init_q = torch.tensor( + [[0.7071, 0.0, 0.7071, 0.0], [0.7071, 0.7071, 0.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + env.reset(TensorDict({"init_bus_quat": init_q}, batch_size=env.batch_size)) + bus0 = self._bus_quat(env) + + zero_action = torch.zeros( + env.action_spec.shape, dtype=env.dtype, device=env.device + ) + td = TensorDict({"action": zero_action}, batch_size=env.batch_size) + for _ in range(200): + td = env.step(td) + td = td["next"].select(*env.observation_spec.keys()) + td.set("action", zero_action) + + bus1 = self._bus_quat(env) + # Compare via shortest-arc angle: cos(angle/2) = ||. + cos_half = (bus0 * bus1).sum(dim=-1).abs().clamp(-1.0, 1.0) + angle_deg = (2.0 * torch.acos(cos_half)).rad2deg() + assert angle_deg.max().item() < 1.0, ( + f"Bus drifted {angle_deg.tolist()} deg under zero action; " + "expected < 1 deg from rotor-induced numerical noise alone." + ) + # Bus angular velocity should also stay near zero. + omega = self._bus_omega(env).norm(dim=-1) + assert omega.max().item() < 0.05, ( + f"Bus omega = {omega.tolist()} rad/s under zero action; " + "the satellite should be inertially still." + ) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_init_bus_quat_is_honored(self, backend): + """``reset({"init_bus_quat": q})`` must place ``qpos[..., 3:7]`` + at ``q`` (post-normalization) and propagate to the + ``quat_err`` observation according to ``q_err = q^-1 * target``. + """ + env = self._make_sat(backend, n=2) + init_q = torch.tensor( + [[0.7071, 0.0, 0.0, 0.7071], [0.6, 0.0, 0.8, 0.0]], + dtype=env.dtype, + device=env.device, + ) + target_q = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + td = env.reset( + TensorDict( + {"init_bus_quat": init_q, "target_quat": target_q}, + batch_size=env.batch_size, + ) + ) + # Backend stored the (normalized) init quat verbatim. + init_q_norm = init_q / init_q.norm(dim=-1, keepdim=True) + torch.testing.assert_close( + self._bus_quat(env), init_q_norm, rtol=1e-4, atol=1e-4 + ) + # quat_err observation = quat_log(init^-1 * target). + target_q_norm = target_q / target_q.norm(dim=-1, keepdim=True) + expected_qerr = quat_log(quat_mul(quat_conj(init_q_norm), target_q_norm)) + torch.testing.assert_close(td["quat_err"], expected_qerr, rtol=1e-4, atol=1e-4) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_quat_err_is_zero_at_target(self, backend): + """Setting ``init_bus_quat == target_quat`` makes the + observation ``quat_err`` start at zero (within reset noise).""" + env = self._make_sat(backend, n=2) + q = torch.tensor( + [[0.5, 0.5, 0.5, 0.5], [1.0, 0.0, 0.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + td = env.reset( + TensorDict( + {"init_bus_quat": q, "target_quat": q}, + batch_size=env.batch_size, + ) + ) + # Reset noise is RESET_NOISE_SCALE = 1e-3 on qpos, so quat_err + # should be small (a few mrad) but not exactly zero. + assert ( + td["quat_err"].abs().max().item() < 5e-2 + ), f"quat_err = {td['quat_err']} when init == target; expected near zero." + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_180deg_target_gives_pi_attitude_error(self, backend): + """A 180-deg rotation about an axis is the maximum SO(3) + distance. ``||quat_log(q_err)||`` should equal ``pi`` (within + reset noise).""" + env = self._make_sat(backend, n=2) + identity = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + # 180 deg about +x and 180 deg about +y respectively. + target = torch.tensor( + [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + td = env.reset( + TensorDict( + {"init_bus_quat": identity, "target_quat": target}, + batch_size=env.batch_size, + ) + ) + att_err = td["quat_err"].norm(dim=-1) + assert torch.allclose( + att_err, torch.full_like(att_err, torch.pi), atol=1e-2 + ), f"||quat_err|| = {att_err.tolist()}, expected ~pi" + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_gimbal_action_torques_bus(self, backend): + """Driving CMG #1 alone with action=+1 produces a bus torque + whose direction is **opposite** to the column of + :func:`cmg_jacobian` for that CMG. + + Why opposite: ``cmg_jacobian`` returns ``h * (g_i x r_i)`` per + unit gimbal rate. By Newton's third law that's the torque on + the *rotor* (whose angular momentum is rotating with the + gimbal), and the *bus* sees the reaction torque + ``-h * (g_i x r_i)``. The manipulability metric used in the + reward (``sqrt(det(J J^T))``) is sign-invariant so the env + reward is unaffected, but the body-frame slewing direction + flips. This test pins the sign convention so a future + refactor can't silently invert it. + """ + env = self._make_sat(backend, n=2, action_scale=3.0) + identity = torch.tensor( + [[1.0, 0.0, 0.0, 0.0]] * env.num_envs, + dtype=env.dtype, + device=env.device, + ) + env.reset(TensorDict({"init_bus_quat": identity}, batch_size=env.batch_size)) + + action = torch.zeros(env.action_spec.shape, dtype=env.dtype, device=env.device) + action[..., 0] = 1.0 # only CMG 1 + td = TensorDict({"action": action}, batch_size=env.batch_size) + for _ in range(20): + td = env.step(td) + td = td["next"].select(*env.observation_spec.keys()) + td.set("action", action) + + omega = self._bus_omega(env) + # Predicted torque on the BUS = -h * (g_1 x r_1(0)). + g, r0 = pyramid_4cmg_geometry(device=env.device, dtype=env.dtype) + jac = cmg_jacobian( + torch.zeros(1, 4, device=env.device, dtype=env.dtype), + g, + r0, + float(env.ROTOR_SPEED), + ).squeeze(0) + bus_torque_dir = -jac[:, 0] # reaction on the bus + # Bus omega should align with the (sign of) bus_torque_dir on + # the axes where the predicted torque is large; the y-axis + # contribution is structurally zero for CMG 1 in the pyramid. + omega_signs = torch.sign(omega) + torque_signs = torch.sign(bus_torque_dir).unsqueeze(0).expand_as(omega) + big_axes = bus_torque_dir.abs() > 0.1 + match = omega_signs[..., big_axes] == torque_signs[..., big_axes] + assert match.all(), ( + f"Bus omega = {omega.tolist()} does not match expected reaction " + f"sign pattern -cmg_jacobian[:, 0] = {bus_torque_dir.tolist()}." + ) + # Magnitude must actually grow (not just numerical noise). + assert omega.norm(dim=-1).min().item() > 0.05, ( + f"|bus_omega| = {omega.norm(dim=-1).tolist()} rad/s after 20 " + "steps of saturated CMG-1 command; expected the bus to slew." + ) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_reward_at_zero_error_is_baseline(self, backend): + """With ``init_bus_quat == target_quat`` and zero action, the + reward should equal the singularity baseline: + + r ~= -singularity_weight / (manip / rotor_h^3) + + which for the nominal pyramid with ``rotor_h = 100`` is + approximately ``-0.5 / 1.0 = -0.5`` per step (control cost is + zero, attitude error is at the reset-noise floor). + """ + env = self._make_sat(backend, n=2, action_scale=3.0, singularity_weight=0.5) + q = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.5, 0.5, 0.5, 0.5]], + dtype=env.dtype, + device=env.device, + ) + td = env.reset( + TensorDict( + {"init_bus_quat": q, "target_quat": q}, batch_size=env.batch_size + ) + ) + zero_action = torch.zeros( + env.action_spec.shape, dtype=env.dtype, device=env.device + ) + td.set("action", zero_action) + td = env.step(td) + reward = td["next", "reward"].squeeze(-1) + # Allow generous tolerance: reset noise + 1 step of dynamics. + assert (reward > -0.7).all() and (reward < -0.3).all(), ( + f"Reward at zero attitude error = {reward.tolist()}; " + "expected ~-0.5 (singularity baseline only)." + ) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_reward_at_180deg_is_pi_plus_baseline(self, backend): + """At a 180-deg attitude error with zero action, the reward + should equal ``-pi - singularity_weight/manip_norm`` per step, + i.e. about ``-3.64`` for the default weights. + """ + env = self._make_sat(backend, n=2, action_scale=3.0, singularity_weight=0.5) + identity = torch.tensor( + [[1.0, 0.0, 0.0, 0.0]] * env.num_envs, + dtype=env.dtype, + device=env.device, + ) + target = torch.tensor( + [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]], + dtype=env.dtype, + device=env.device, + ) + td = env.reset( + TensorDict( + {"init_bus_quat": identity, "target_quat": target}, + batch_size=env.batch_size, + ) + ) + zero_action = torch.zeros( + env.action_spec.shape, dtype=env.dtype, device=env.device + ) + td.set("action", zero_action) + td = env.step(td) + reward = td["next", "reward"].squeeze(-1) + expected = -torch.pi - 0.5 + # Bus has barely moved in 1 step, so attitude error stays near pi. + assert (reward > expected - 0.2).all() and ( + reward < expected + 0.2 + ).all(), ( + f"Reward at 180-deg error = {reward.tolist()}; expected ~{expected:.2f}." + ) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_observation_matches_state(self, backend): + """Observation channels are read directly off ``qpos`` / + ``qvel`` -- not synthesised. After one non-trivial step: + + * ``bus_omega == qvel[..., 3:6]`` + * ``gimbal_rates == qvel[..., gimbal_rate_idx]`` + * ``gimbal_angles == [sin(qpos_gimbals), cos(qpos_gimbals)]`` + """ + env = self._make_sat(backend, n=2, action_scale=3.0) + env.reset() + action = torch.full( + env.action_spec.shape, 0.5, dtype=env.dtype, device=env.device + ) + td = TensorDict({"action": action}, batch_size=env.batch_size) + for _ in range(10): + td = env.step(td) + td = td["next"].select(*env.observation_spec.keys()) + td.set("action", action) + + qpos = env._backend.qpos.to(env.dtype) + qvel = env._backend.qvel.to(env.dtype) + gimbal_idx = [7 + 2 * i for i in range(env.N_GIMBALS)] + gimbal_rate_idx = [6 + 2 * i for i in range(env.N_GIMBALS)] + + torch.testing.assert_close( + td["bus_omega"], qvel[..., 3:6], rtol=1e-5, atol=1e-5 + ) + torch.testing.assert_close( + td["gimbal_rates"], + qvel[..., gimbal_rate_idx], + rtol=1e-5, + atol=1e-5, + ) + # gimbal_angles is concat([sin, cos]) over the gimbal qpos. + gimbals = qpos[..., gimbal_idx] + expected = torch.cat([gimbals.sin(), gimbals.cos()], dim=-1) + torch.testing.assert_close(td["gimbal_angles"], expected, rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_satellite_reset_is_reproducible(self, backend): + """Same ``(init_bus_quat, target_quat)`` and same action sequence + must produce byte-identical bus quaternion trajectories. This is + the determinism guarantee that the eval pipeline relies on + (the :class:`TestSetPrimer` replays the same starts every + iteration -- if the env is non-deterministic, eval comparisons + between iterations are meaningless). + """ + n = 2 + init_q = torch.tensor( + [[0.5, 0.5, 0.5, 0.5], [0.7071, 0.7071, 0.0, 0.0]], + ) + target_q = torch.tensor( + [[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]], + ) + + def _trajectory(seed: int) -> torch.Tensor: + env = SatelliteEnv( + num_cmgs=4, + num_envs=n, + seed=seed, + backend=backend, + action_scale=3.0, + ) + env.reset( + TensorDict( + { + "init_bus_quat": init_q.to(env.dtype).to(env.device), + "target_quat": target_q.to(env.dtype).to(env.device), + }, + batch_size=env.batch_size, + ) + ) + action = torch.full( + env.action_spec.shape, 0.3, dtype=env.dtype, device=env.device + ) + traj = [] + td = TensorDict({"action": action}, batch_size=env.batch_size) + for _ in range(20): + td = env.step(td) + traj.append(self._bus_quat(env)) + td = td["next"].select(*env.observation_spec.keys()) + td.set("action", action) + return torch.stack(traj, dim=0) + + # Same (init, target, seed, action) -> identical trajectory. + torch.testing.assert_close(_trajectory(0), _trajectory(0)) + + # ------------------------------------------------------------------ + # Backend dispatch: vmap backends reject num_workers / parallel; + # the C-bindings backend composes via ParallelEnv / SerialEnv. + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_vmap_backend_rejects_num_workers(self, backend): + with pytest.raises(ValueError, match="num_envs"): + HopperEnv(backend=backend, num_workers=2, seed=0) + + @pytest.mark.parametrize("backend", _VMAP_BACKENDS) + def test_vmap_backend_rejects_parallel(self, backend): + with pytest.raises(ValueError, match="parallel"): + HopperEnv(backend=backend, parallel=True, seed=0) + + @pytest.mark.skipif(not _has_mujoco, reason="mujoco not installed") + def test_mujoco_backend_num_envs_aliases_num_workers(self): + """For the C-bindings backend, ``num_envs`` and ``num_workers`` + are aliases; both produce a :class:`ParallelEnv` of N copies.""" + env_a = HopperEnv(backend="mujoco", num_envs=2, seed=0) + env_b = HopperEnv(backend="mujoco", num_workers=2, seed=0) + # Lazy ParallelEnvs -- don't start workers, just shape-check. + assert isinstance(env_a, ParallelEnv) + assert isinstance(env_b, ParallelEnv) + assert env_a.batch_size == env_b.batch_size + + @pytest.mark.skipif(not _has_mujoco, reason="mujoco not installed") + def test_mujoco_backend_rejects_both_envs_and_workers(self): + with pytest.raises(ValueError, match="aliases"): + HopperEnv(backend="mujoco", num_envs=2, num_workers=2, seed=0) + + @pytest.mark.skipif(not _has_mujoco, reason="mujoco not installed") + def test_mujoco_backend_serial_dispatch(self): + env = HopperEnv(backend="mujoco", num_envs=2, parallel=False, seed=0) + assert isinstance(env, SerialEnv) + td = env.rollout(3) + assert torch.isfinite(td.get(("next", "reward"))).all() + env.close() + + @pytest.mark.skipif(not _has_mujoco, reason="mujoco not installed") + def test_mujoco_backend_parallel_rollout(self): + env = HopperEnv(backend="mujoco", num_envs=2, seed=0) + assert isinstance(env, ParallelEnv) + td = env.rollout(3) + assert torch.isfinite(td.get(("next", "reward"))).all() + env.close() + + @pytest.mark.skipif(not _has_mujoco, reason="mujoco not installed") + def test_mujoco_backend_single_env_passthrough(self): + """``backend='mujoco'`` with N=1 returns a bare ``HopperEnv``, + not a ``ParallelEnv`` wrapper.""" + env = HopperEnv(backend="mujoco", num_envs=1, seed=0) + assert isinstance(env, HopperEnv) + assert env.batch_size == torch.Size([1]) + + # ------------------------------------------------------------------ + # Compile / unknown-backend / custom XML. + # ------------------------------------------------------------------ + + @pytest.mark.skipif(not _has_mujoco_torch, reason="mujoco-torch not installed") + def test_torch_backend_compile_smoke(self): + """``compile_step=True`` must not raise on the default backend.""" + env = HopperEnv(num_envs=2, seed=0, compile_step=True) + td = env.rollout(3) + assert torch.isfinite(td.get(("next", "reward"))).all() + + def test_unknown_backend_raises(self): + with pytest.raises(ValueError, match="unknown backend"): + HopperEnv(num_envs=1, seed=0, backend="not-a-backend") + + # ------------------------------------------------------------------ + # Rendering / from_pixels. + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_from_pixels_spec_and_rollout(self, backend): + """``from_pixels=True`` adds a ``pixels`` key with a ``uint8`` spec + of shape ``(num_envs, H, W, 3)`` and values in ``[0, 255]``. + + Uses :class:`SatelliteEnv` because locomotion envs terminate + early under random actions, masking rollout shape assertions. + """ + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv( + num_cmgs=4, + num_envs=n, + seed=0, + backend=backend, + from_pixels=True, + render_width=32, + render_height=32, + ) + check_env_specs(env) + assert env.observation_spec["pixels"].shape == torch.Size([n, 32, 32, 3]) + assert env.observation_spec["pixels"].dtype == torch.uint8 + + td = env.rollout(2) + pixels = td.get(("next", "pixels")) + assert pixels.shape == torch.Size([n, 2, 32, 32, 3]) + assert pixels.dtype == torch.uint8 + assert (pixels >= 0).all() and (pixels <= 255).all() + # Real RGB content -- not all-zero, not saturated. + assert pixels.float().std().item() > 0 + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_pixel_only_drops_observation_key(self, backend): + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv( + num_cmgs=4, + num_envs=n, + seed=0, + backend=backend, + from_pixels=True, + pixel_only=True, + render_width=32, + render_height=32, + ) + check_env_specs(env) + keys = set(env.observation_spec.keys()) + assert keys == {"pixels"}, f"pixel_only must drop 'observation', got {keys}" + + def test_pixel_only_without_from_pixels_raises(self): + with pytest.raises(ValueError, match="pixel_only"): + HopperEnv(num_envs=1, seed=0, pixel_only=True) + + @pytest.mark.parametrize("backend", _AVAILABLE_BACKENDS) + def test_render_method(self, backend): + """``env.render()`` returns the same shape as the pixel obs.""" + n = 1 if backend == "mujoco" else 2 + env = SatelliteEnv(num_cmgs=4, num_envs=n, seed=0, backend=backend) + env.reset() + rgb = env.render(width=24, height=24) + assert rgb.shape == torch.Size([n, 24, 24, 3]) + assert rgb.dtype == torch.uint8 + + def test_xml_path_kwarg_overrides_class_attr(self, tmp_path): + """Custom ``xml_path=`` overrides the class-level :attr:`XML_PATH`.""" + backend = _AVAILABLE_BACKENDS[0] + # A trivial single-hinge model -- pure-XML, no external mesh deps. + xml = ( + "" + "" + "" + "" + "" + "" + "" + ) + path = tmp_path / "tiny.xml" + path.write_text(xml) + + class TinyEnv(MujocoEnv): + FRAME_SKIP = 2 + + def _compute_reward(self, state, action, next_state): + return torch.zeros(self.num_envs, 1, device=self.device) + + def _compute_done(self, state, next_state): + return torch.zeros( + self.num_envs, 1, dtype=torch.bool, device=self.device + ) + + env = TinyEnv(xml_path=str(path), backend=backend, num_envs=1, seed=0) + check_env_specs(env) + td = env.rollout(3) + assert td.shape == torch.Size([1, 3]) diff --git a/test/test_configs.py b/test/test_configs.py index 97ce8d31ed8..ab074ad7b9a 100644 --- a/test/test_configs.py +++ b/test/test_configs.py @@ -1631,7 +1631,7 @@ def test_ppo_trainer_config_optional_fields(self): optimizer=optimizer_config, logger=None, # Optional field save_trainer_file="/tmp/test.pt", - replay_buffer=replay_buffer_config + replay_buffer=replay_buffer_config, # All optional fields are omitted to test defaults ) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 52af517af69..baf2d0c763b 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -7,11 +7,17 @@ from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict from .custom import ( + AntEnv, ChessEnv, FinancialRegimeEnv, + HopperEnv, + HumanoidEnv, LLMHashingEnv, + MujocoEnv, PendulumEnv, + SatelliteEnv, TicTacToeEnv, + Walker2dEnv, ) from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv @@ -143,6 +149,7 @@ __all__ = [ "ActionDiscretizer", "ActionMask", + "AntEnv", "VecNormV2", "IsaacLabWrapper", "AutoResetEnv", @@ -191,6 +198,8 @@ "GenesisWrapper", "HabitatEnv", "Hash", + "HopperEnv", + "HumanoidEnv", "InitTracker", "ImaginedEnv", "IsaacGymEnv", @@ -206,6 +215,7 @@ "MeltingpotEnv", "MeltingpotWrapper", "ModelBasedEnvBase", + "MujocoEnv", "MultiAction", "MultiStepTransform", "MultiThreadedEnv", @@ -236,6 +246,7 @@ "RewardSum", "RoboHiveEnv", "SMACv2Env", + "SatelliteEnv", "SMACv2Wrapper", "SelectTransform", "SerialEnv", @@ -264,6 +275,7 @@ "VecNorm", "VmasEnv", "VmasWrapper", + "Walker2dEnv", "check_env_specs", "check_marl_grouping", "default_info_dict_reader", diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index a710312f78d..3a73c9a2437 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -5,14 +5,21 @@ from .chess import ChessEnv from .llm import LLMHashingEnv +from .mujoco import AntEnv, HopperEnv, HumanoidEnv, MujocoEnv, SatelliteEnv, Walker2dEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv from .trading import FinancialRegimeEnv __all__ = [ + "AntEnv", "ChessEnv", "FinancialRegimeEnv", + "HopperEnv", + "HumanoidEnv", "LLMHashingEnv", + "MujocoEnv", "PendulumEnv", + "SatelliteEnv", "TicTacToeEnv", + "Walker2dEnv", ] diff --git a/torchrl/envs/custom/mujoco/__init__.py b/torchrl/envs/custom/mujoco/__init__.py new file mode 100644 index 00000000000..6a4131022f4 --- /dev/null +++ b/torchrl/envs/custom/mujoco/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""MuJoCo-backed custom envs with selectable physics backend. + +The base class :class:`MujocoEnv` accepts an XML asset (path or URL) +and dispatches the simulation to one of three engines: + +* ``mujoco-torch`` (default) -- native torch, batched, ``torch.compile``-friendly. +* ``mjx`` -- JAX-vectorized via :func:`jax.vmap` + :func:`jax.jit`, + bridged to torch through DLPack. +* ``mujoco`` -- official C-bindings, batched by Python loop. + +Subclasses describe the *task*: reward, termination, optional +observation map. The locomotion subclasses (:class:`HumanoidEnv`, +:class:`AntEnv`, :class:`Walker2dEnv`, :class:`HopperEnv`) mirror the +Gymnasium ``-v4`` reward / termination spec. :class:`SatelliteEnv` +implements an attitude-control task with 4- or 6-CMG clusters and a +manipulability-based singularity penalty. +""" + +from torchrl.envs.custom.mujoco.ant import AntEnv +from torchrl.envs.custom.mujoco.base import MujocoEnv +from torchrl.envs.custom.mujoco.hopper import HopperEnv +from torchrl.envs.custom.mujoco.humanoid import HumanoidEnv +from torchrl.envs.custom.mujoco.satellite import SatelliteEnv +from torchrl.envs.custom.mujoco.walker import Walker2dEnv + +__all__ = [ + "AntEnv", + "HopperEnv", + "HumanoidEnv", + "MujocoEnv", + "SatelliteEnv", + "Walker2dEnv", +] diff --git a/torchrl/envs/custom/mujoco/_backends.py b/torchrl/envs/custom/mujoco/_backends.py new file mode 100644 index 00000000000..399240d00d4 --- /dev/null +++ b/torchrl/envs/custom/mujoco/_backends.py @@ -0,0 +1,692 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Physics-engine adapters for the MuJoCo custom envs. + +Three backends share a common contract (:class:`_PhysicsBackend`): + +* ``mujoco-torch`` -- native torch, batched via :func:`torch.vmap`. +* ``mjx`` -- JAX-vectorized via :func:`jax.vmap` + :func:`jax.jit`, + bridged to torch through DLPack. +* ``mujoco`` -- official C-bindings, batched by Python loop. + +Each backend owns the model and the per-env simulation state. The env +calls :meth:`reset`, :meth:`step`, and reads ``qpos``/``qvel``/``time`` +to compute observations and rewards. +""" + +from __future__ import annotations + +import abc +import importlib.util +import urllib.request +from pathlib import Path +from typing import Literal + +import torch + +_has_mujoco_torch = importlib.util.find_spec("mujoco_torch") is not None +_has_mujoco = importlib.util.find_spec("mujoco") is not None +_has_jax = importlib.util.find_spec("jax") is not None +_has_mjx = _has_mujoco and importlib.util.find_spec("mujoco.mjx") is not None + + +BackendName = Literal["mujoco-torch", "mjx", "mujoco"] + + +def resolve_xml_string(path_or_url: str | Path) -> str: + """Read XML from a local path or http(s) URL. + + URLs let users point at remote XML assets without vendoring them. + """ + p = str(path_or_url) + if p.startswith(("http://", "https://")): + with urllib.request.urlopen(p) as resp: + return resp.read().decode("utf-8") + return Path(p).read_text() + + +class _PhysicsBackend(abc.ABC): + """Common contract across the three engines. + + After construction, attributes ``nq, nv, nu, timestep, qpos0, + actuator_lo, actuator_hi`` are populated from the model. ``qpos0`` is + the initial ``qpos`` produced by ``mj_forward`` on a fresh ``MjData``. + """ + + nq: int + nv: int + nu: int + timestep: float + qpos0: torch.Tensor + qvel0: torch.Tensor + actuator_lo: torch.Tensor + actuator_hi: torch.Tensor + actuator_ctrllimited: torch.Tensor + + def __init__( + self, xml_string: str, *, num_envs: int, device: torch.device | None + ) -> None: + self.num_envs = num_envs + self.device = ( + torch.device(device) if device is not None else torch.device("cpu") + ) + self._init_model(xml_string) + + @abc.abstractmethod + def _init_model(self, xml_string: str) -> None: + """Parse XML and prepare the batched data state. + + Populates ``nq, nv, nu, timestep, qpos0, qvel0, actuator_lo, + actuator_hi`` from the parsed model. + """ + + @abc.abstractmethod + def reset(self, qpos: torch.Tensor, qvel: torch.Tensor) -> None: + """Set the full batched state to the given ``qpos, qvel``. + + Shapes: ``(num_envs, nq)`` and ``(num_envs, nv)``. + """ + + @abc.abstractmethod + def reset_mask( + self, mask: torch.Tensor, qpos: torch.Tensor, qvel: torch.Tensor + ) -> None: + """Reset the subset of envs where ``mask`` is True. + + ``mask`` shape ``(num_envs,)`` bool. ``qpos, qvel`` cover only the + masked envs: ``(mask.sum(), nq)`` and ``(mask.sum(), nv)``. + """ + + @abc.abstractmethod + def step(self, ctrl: torch.Tensor, frame_skip: int) -> None: + """Advance state by ``frame_skip`` physics substeps.""" + + @property + @abc.abstractmethod + def qpos(self) -> torch.Tensor: + """Current ``qpos``, shape ``(num_envs, nq)`` on ``self.device``.""" + + @property + @abc.abstractmethod + def qvel(self) -> torch.Tensor: + """Current ``qvel``, shape ``(num_envs, nv)`` on ``self.device``.""" + + @property + @abc.abstractmethod + def time(self) -> torch.Tensor: + """Current simulation time per env, shape ``(num_envs,)``.""" + + def render( + self, + *, + camera_id: int = 0, + width: int = 64, + height: int = 64, + background: tuple[float, float, float] | None = None, + ) -> torch.Tensor: + """Render every env to RGB pixels. + + Returns a ``(num_envs, height, width, 3)`` ``uint8`` tensor on + ``self.device``. Default raises ``NotImplementedError`` -- + backends override. + """ + raise NotImplementedError( + f"{type(self).__name__} does not implement rendering." + ) + + +# ---------------------------------------------------------------------- +# mujoco-torch backend +# ---------------------------------------------------------------------- + + +class _TorchBackend(_PhysicsBackend): + """Native-torch backend powered by ``mujoco-torch``. + + Uses :func:`torch.vmap` to batch and (optionally) :func:`torch.compile` + for the per-step physics. State lives in a ``mujoco_torch.Data`` object + with mutable ``qpos`` / ``qvel`` / ``ctrl`` torch tensors. + """ + + def __init__( + self, + xml_string: str, + *, + num_envs: int, + device: torch.device | None, + compile_step: bool = False, + compile_kwargs: dict | None = None, + ) -> None: + if not _has_mujoco_torch: + raise ImportError( + "backend='mujoco-torch' requires the `mujoco-torch` package. " + "Install with `pip install mujoco-torch`." + ) + self._compile_step = compile_step + self._compile_kwargs = compile_kwargs or {} + super().__init__(xml_string, num_envs=num_envs, device=device) + + def _init_model(self, xml_string: str) -> None: + import mujoco + import mujoco_torch + + m_mj = mujoco.MjModel.from_xml_string(xml_string) + d_mj = mujoco.MjData(m_mj) + mujoco.mj_forward(m_mj, d_mj) + + mx = mujoco_torch.device_put(m_mj) + dx0 = mujoco_torch.device_put(d_mj) + if self.device != torch.device("cpu"): + mx = mx.to(self.device) + dx0 = dx0.to(self.device) + # One step so all derived dtypes match what vmap(step) produces. + dx0 = mujoco_torch.step(mx, dx0) + + self._mx = mx + self._dx0 = dx0 + self._sim_dtype = dx0.qpos.dtype + self._ctrl_dtype = dx0.ctrl.dtype + + self.nq = int(m_mj.nq) + self.nv = int(m_mj.nv) + self.nu = int(m_mj.nu) + self.timestep = float(m_mj.opt.timestep) + self.qpos0 = dx0.qpos.detach().clone() + self.qvel0 = dx0.qvel.detach().clone() + ar = torch.as_tensor(m_mj.actuator_ctrlrange, device=self.device) + self.actuator_lo = ar[:, 0].to(self._ctrl_dtype) + self.actuator_hi = ar[:, 1].to(self._ctrl_dtype) + self.actuator_ctrllimited = torch.as_tensor( + m_mj.actuator_ctrllimited, device=self.device, dtype=torch.bool + ) + + # Build the batched state and the (compiled) physics step. + self._dx = self._dx0.expand(self.num_envs).clone() + self._build_step_fn() + + def _build_step_fn(self) -> None: + import mujoco_torch + + mx = self._mx + single = self.num_envs == 1 + + if single: + + def _one_step(d): + return mujoco_torch.step(mx, d) + + base = _one_step + else: + _vmap_step = torch.vmap(lambda d: mujoco_torch.step(mx, d)) + + def _vmap_one(d): + return _vmap_step(d) + + base = _vmap_one + + # When ``compile_step`` is set, compile only the per-step fn and + # keep the frame_skip loop in Python. Compiling the unrolled loop + # blows up the graph (50x more nodes), which sends inductor's + # fusion analysis into multi-hour territory on CUDA backends. + if self._compile_step: + base_compiled = torch.compile(base, **self._compile_kwargs) + + def _multi_step(d, frame_skip: int): + for _ in range(frame_skip): + d = base_compiled(d) + return d + + else: + + def _multi_step(d, frame_skip: int): + for _ in range(frame_skip): + d = base(d) + return d + + self._physics_step = _multi_step + self._single_env = single + + def reset(self, qpos: torch.Tensor, qvel: torch.Tensor) -> None: + # Refresh from the warm reference so all derived data is sane, + # then overwrite qpos / qvel. + self._dx = self._dx0.expand(self.num_envs).clone() + self._dx.qpos.copy_(qpos.to(self._sim_dtype)) + self._dx.qvel.copy_(qvel.to(self._sim_dtype)) + + def reset_mask( + self, mask: torch.Tensor, qpos: torch.Tensor, qvel: torch.Tensor + ) -> None: + n = int(mask.sum().item()) + if n == 0: + return + fresh = self._dx0.expand(n).clone() + fresh.qpos.copy_(qpos.to(self._sim_dtype)) + fresh.qvel.copy_(qvel.to(self._sim_dtype)) + # Per-leaf masked assignment: tensorclass.__setitem__ iterates + # over leaves and chokes on 0-dim entries (e.g. ``nefc``, + # ``ncon`` -- scalars shared across all envs and not subject to + # masking). Walk the underlying tensordict and copy leaf-by-leaf, + # skipping anything that doesn't look batched along the env dim. + cur_td = self._dx._tensordict + new_td = fresh._tensordict + for k, dst in cur_td.items(): + src = new_td.get(k, default=None) + if src is None: + continue + if dst.ndim == 0 or dst.shape[0] != mask.shape[0]: + # 0-dim or non-env-batched leaf; the simulator treats + # these as shared, so we leave them alone. + continue + dst[mask] = src + + def step(self, ctrl: torch.Tensor, frame_skip: int) -> None: + ctrl = ctrl.to(self._ctrl_dtype) + # Clamp only actuators that MuJoCo marks as ctrl-limited. Unlimited + # actuators report a default ctrlrange of [0, 0], which is not an + # active range and must not zero out controls. + clamped = torch.minimum(torch.maximum(ctrl, self.actuator_lo), self.actuator_hi) + ctrl = torch.where(self.actuator_ctrllimited, clamped, ctrl) + self._dx.update_(ctrl=ctrl) + if self._single_env: + stepped = self._physics_step(self._dx[0], frame_skip) + self._dx = stepped.unsqueeze(0) + else: + self._dx = self._physics_step(self._dx, frame_skip) + + @property + def qpos(self) -> torch.Tensor: + return self._dx.qpos + + @property + def qvel(self) -> torch.Tensor: + return self._dx.qvel + + @property + def time(self) -> torch.Tensor: + # mujoco-torch's Data exposes a per-env time scalar. + t = self._dx.time + if t.ndim == 0: + t = t.expand(self.num_envs) + return t + + def render( + self, + *, + camera_id: int = 0, + width: int = 64, + height: int = 64, + background: tuple[float, float, float] | None = None, + ) -> torch.Tensor: + import mujoco_torch + + # Cache precomputed render data lazily (depends on the model only). + if not hasattr(self, "_render_precomp"): + self._render_precomp = mujoco_torch.precompute_render_data(self._mx) + frames = [] + for i in range(self.num_envs): + rgb, _, _ = mujoco_torch.render( + self._mx, + self._dx[i], + camera_id=camera_id, + width=width, + height=height, + precomp=self._render_precomp, + background=background, + ) + frames.append((rgb * 255).clamp(0, 255).to(torch.uint8)) + return torch.stack(frames) + + +# ---------------------------------------------------------------------- +# mujoco C-bindings backend +# ---------------------------------------------------------------------- + + +class _MujocoBackend(_PhysicsBackend): + """Reference backend: official ``mujoco`` C-bindings, **single env**. + + The C-bindings can't vmap, so we don't fake batching here. To run + multiple environments in parallel with this backend, compose with + :class:`~torchrl.envs.ParallelEnv` (multiprocess) or + :class:`~torchrl.envs.SerialEnv` (in-process loop). The + :class:`~torchrl.envs.custom.mujoco.MujocoEnv` metaclass does this + dispatch automatically when ``num_workers > 1`` or ``num_envs > 1`` + is requested with ``backend='mujoco'``. + """ + + def __init__( + self, + xml_string: str, + *, + num_envs: int, + device: torch.device | None, + ) -> None: + if not _has_mujoco: + raise ImportError( + "backend='mujoco' requires the `mujoco` package. " + "Install with `pip install mujoco`." + ) + if num_envs != 1: + raise ValueError( + "backend='mujoco' is single-env. To batch, wrap with " + "torchrl.envs.ParallelEnv or torchrl.envs.SerialEnv " + "(MujocoEnv does this automatically when num_workers>1 " + "or num_envs>1 is passed)." + ) + super().__init__(xml_string, num_envs=num_envs, device=device) + + def _init_model(self, xml_string: str) -> None: + import mujoco + + m_mj = mujoco.MjModel.from_xml_string(xml_string) + d_mj = mujoco.MjData(m_mj) + mujoco.mj_forward(m_mj, d_mj) + + self._m = m_mj + self._d = d_mj + + self.nq = int(m_mj.nq) + self.nv = int(m_mj.nv) + self.nu = int(m_mj.nu) + self.timestep = float(m_mj.opt.timestep) + self.qpos0 = torch.as_tensor(d_mj.qpos.copy(), device=self.device).to( + torch.float32 + ) + self.qvel0 = torch.as_tensor(d_mj.qvel.copy(), device=self.device).to( + torch.float32 + ) + ar = torch.as_tensor( + m_mj.actuator_ctrlrange, device=self.device, dtype=torch.float32 + ) + self.actuator_lo = ar[:, 0] + self.actuator_hi = ar[:, 1] + self.actuator_ctrllimited = torch.as_tensor( + m_mj.actuator_ctrllimited, device=self.device, dtype=torch.bool + ) + + def reset(self, qpos: torch.Tensor, qvel: torch.Tensor) -> None: + import mujoco + + self._d.qpos[:] = qpos.detach().cpu().double().numpy()[0] + self._d.qvel[:] = qvel.detach().cpu().double().numpy()[0] + self._d.time = 0.0 + mujoco.mj_forward(self._m, self._d) + + def reset_mask( + self, mask: torch.Tensor, qpos: torch.Tensor, qvel: torch.Tensor + ) -> None: + # Single-env: mask is shape (1,). Either reset or no-op. + if bool(mask.any()): + self.reset(qpos, qvel) + + def step(self, ctrl: torch.Tensor, frame_skip: int) -> None: + import mujoco + + clamped = torch.minimum(torch.maximum(ctrl, self.actuator_lo), self.actuator_hi) + clamped = torch.where(self.actuator_ctrllimited, clamped, ctrl) + self._d.ctrl[:] = clamped.detach().cpu().double().numpy()[0] + for _ in range(frame_skip): + mujoco.mj_step(self._m, self._d) + + @property + def qpos(self) -> torch.Tensor: + return torch.as_tensor( + self._d.qpos.copy(), device=self.device, dtype=torch.float32 + ).unsqueeze(0) + + @property + def qvel(self) -> torch.Tensor: + return torch.as_tensor( + self._d.qvel.copy(), device=self.device, dtype=torch.float32 + ).unsqueeze(0) + + @property + def time(self) -> torch.Tensor: + return torch.tensor([self._d.time], device=self.device, dtype=torch.float32) + + def render( + self, + *, + camera_id: int = 0, + width: int = 64, + height: int = 64, + background: tuple[float, float, float] | None = None, + ) -> torch.Tensor: + import mujoco + import numpy as np + + if ( + not hasattr(self, "_renderer") + or self._renderer.height != height + or self._renderer.width != width + ): + self._renderer = mujoco.Renderer(self._m, height=height, width=width) + self._renderer.update_scene(self._d, camera=camera_id) + rgb = self._renderer.render() # (H, W, 3) uint8 numpy + if background is not None: + # Tint background pixels (those at the far plane). Approximation: + # uses the geom-id buffer is overkill -- skip when not requested. + pass + return torch.as_tensor(np.ascontiguousarray(rgb), device=self.device).unsqueeze( + 0 + ) + + +# ---------------------------------------------------------------------- +# MJX (JAX) backend +# ---------------------------------------------------------------------- + + +class _MJXBackend(_PhysicsBackend): + """JAX-vectorized backend via ``mujoco.mjx``. + + Mirrors the pattern in :mod:`torchrl.envs.libs.brax`: build an + ``mjx.Model``, ``vmap+jit`` the step function, bridge JAX arrays to + torch tensors via DLPack on each call. + """ + + def __init__( + self, + xml_string: str, + *, + num_envs: int, + device: torch.device | None, + ) -> None: + if not (_has_mjx and _has_jax): + raise ImportError( + "backend='mjx' requires `mujoco>=3.0` (with mjx) and `jax`. " + "Install with `pip install mujoco-mjx jax`." + ) + super().__init__(xml_string, num_envs=num_envs, device=device) + + def _init_model(self, xml_string: str) -> None: + import jax + import mujoco + from mujoco import mjx + + m_mj = mujoco.MjModel.from_xml_string(xml_string) + mx = mjx.put_model(m_mj) + dx0_single = mjx.make_data(mx) + dx0_single = mjx.forward(mx, dx0_single) + + self._mjx = mjx + self._jax = jax + self._m_mj = m_mj # kept for rendering via mujoco.Renderer + self._mx = mx + self._dx0_single = dx0_single + # All JAX arrays must live on this device or jit will fail with + # "Received incompatible devices for jitted computation". The + # model's data state lives wherever `mjx.put_model` placed it + # (JAX default device), and any tensors we splice in (qpos / + # qvel / time) must be put_d to match. + self._jax_device = dx0_single.qpos.device + + self.nq = int(m_mj.nq) + self.nv = int(m_mj.nv) + self.nu = int(m_mj.nu) + self.timestep = float(m_mj.opt.timestep) + self.qpos0 = self._jax_to_torch(dx0_single.qpos) + self.qvel0 = self._jax_to_torch(dx0_single.qvel) + ar = torch.as_tensor( + m_mj.actuator_ctrlrange, device=self.device, dtype=torch.float32 + ) + self.actuator_lo = ar[:, 0] + self.actuator_hi = ar[:, 1] + self.actuator_ctrllimited = torch.as_tensor( + m_mj.actuator_ctrllimited, device=self.device, dtype=torch.bool + ) + + self._vmap_step = jax.jit(jax.vmap(lambda d: mjx.step(mx, d))) + self._vmap_forward = jax.jit(jax.vmap(lambda d: mjx.forward(mx, d))) + + # Initial batched state. + self._dx = jax.vmap(lambda _: dx0_single)(jax.numpy.arange(self.num_envs)) + + def _jax_to_torch(self, x) -> torch.Tensor: + from torchrl.envs.libs.jax_utils import _ndarray_to_tensor + + return _ndarray_to_tensor(x).to(self.device).to(torch.float32) + + def _torch_to_jax(self, x: torch.Tensor): + from torchrl.envs.libs.jax_utils import _tensor_to_ndarray + + arr = _tensor_to_ndarray(x.contiguous()) + # Force the array onto the model's JAX device. Torch tensors on + # CPU would otherwise yield a CPU JAX array even when the model + # lives on GPU, causing a mixed-device pytree at jit time. + return self._jax.device_put(arr, self._jax_device) + + def _broadcast_dx0(self): + """Build a batched copy of ``_dx0_single`` on ``self._jax_device``.""" + jax = self._jax + idx = jax.device_put(jax.numpy.arange(self.num_envs), self._jax_device) + return jax.vmap(lambda _: self._dx0_single)(idx) + + def reset(self, qpos: torch.Tensor, qvel: torch.Tensor) -> None: + qpos_j = self._torch_to_jax(qpos) + qvel_j = self._torch_to_jax(qvel) + # Broadcast _dx0_single (with time=0) over the batch, splice in + # the new qpos/qvel, then re-run mjx.forward so derived quantities + # are consistent. Time stays at zero from the reference state. + dx = self._broadcast_dx0() + dx = dx.replace(qpos=qpos_j, qvel=qvel_j) + self._dx = self._vmap_forward(dx) + + def reset_mask( + self, mask: torch.Tensor, qpos: torch.Tensor, qvel: torch.Tensor + ) -> None: + n = int(mask.sum().item()) + if n == 0: + return + # Splice the reset envs into the current full state on the torch + # side, then push back to JAX. Partial reset isn't on the hot path. + full_qpos = self.qpos.clone() + full_qvel = self.qvel.clone() + full_qpos[mask] = qpos.to(full_qpos.dtype) + full_qvel[mask] = qvel.to(full_qvel.dtype) + time = self.time.clone() + time[mask] = 0.0 + + qpos_j = self._torch_to_jax(full_qpos) + qvel_j = self._torch_to_jax(full_qvel) + time_j = self._torch_to_jax(time) + dx = self._broadcast_dx0() + dx = dx.replace(qpos=qpos_j, qvel=qvel_j, time=time_j) + self._dx = self._vmap_forward(dx) + + def step(self, ctrl: torch.Tensor, frame_skip: int) -> None: + clamped = torch.minimum(torch.maximum(ctrl, self.actuator_lo), self.actuator_hi) + ctrl = torch.where(self.actuator_ctrllimited, clamped, ctrl) + ctrl_j = self._torch_to_jax(ctrl) + self._dx = self._dx.replace(ctrl=ctrl_j) + for _ in range(frame_skip): + self._dx = self._vmap_step(self._dx) + + @property + def qpos(self) -> torch.Tensor: + return self._jax_to_torch(self._dx.qpos) + + @property + def qvel(self) -> torch.Tensor: + return self._jax_to_torch(self._dx.qvel) + + @property + def time(self) -> torch.Tensor: + return self._jax_to_torch(self._dx.time) + + def render( + self, + *, + camera_id: int = 0, + width: int = 64, + height: int = 64, + background: tuple[float, float, float] | None = None, + ) -> torch.Tensor: + """Render via mujoco's CPU renderer after copying mjx state to MjData. + + Slow path (one MjData per env, sequential render). Acceptable for + eval / video; not for high-throughput pixel-based training. + """ + import mujoco + import numpy as np + + if ( + not hasattr(self, "_renderer") + or self._renderer.height != height + or self._renderer.width != width + ): + self._renderer = mujoco.Renderer(self._m_mj, height=height, width=width) + if not hasattr(self, "_render_d"): + self._render_d = mujoco.MjData(self._m_mj) + + qpos = self.qpos.detach().cpu().double().numpy() + qvel = self.qvel.detach().cpu().double().numpy() + frames = [] + for i in range(self.num_envs): + self._render_d.qpos[:] = qpos[i] + self._render_d.qvel[:] = qvel[i] + mujoco.mj_forward(self._m_mj, self._render_d) + self._renderer.update_scene(self._render_d, camera=camera_id) + frames.append(self._renderer.render().copy()) + rgb = np.stack(frames, axis=0) + return torch.as_tensor(np.ascontiguousarray(rgb), device=self.device) + + +# ---------------------------------------------------------------------- +# Dispatch +# ---------------------------------------------------------------------- + + +def make_backend( + name: BackendName, + xml_string: str, + *, + num_envs: int, + device: torch.device | None, + compile_step: bool = False, + compile_kwargs: dict | None = None, +) -> _PhysicsBackend: + """Instantiate the requested backend. + + Raises ``ImportError`` with an actionable message when the underlying + package is missing. ``compile_step`` is honored only by the + ``mujoco-torch`` backend (mjx already JITs; the C-bindings backend is + a Python loop). + """ + if name == "mujoco-torch": + return _TorchBackend( + xml_string, + num_envs=num_envs, + device=device, + compile_step=compile_step, + compile_kwargs=compile_kwargs, + ) + if name == "mjx": + return _MJXBackend(xml_string, num_envs=num_envs, device=device) + if name == "mujoco": + return _MujocoBackend(xml_string, num_envs=num_envs, device=device) + raise ValueError( + f"unknown backend {name!r}; expected one of 'mujoco-torch', 'mjx', 'mujoco'" + ) diff --git a/torchrl/envs/custom/mujoco/_math.py b/torchrl/envs/custom/mujoco/_math.py new file mode 100644 index 00000000000..66905aca9dc --- /dev/null +++ b/torchrl/envs/custom/mujoco/_math.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Pure-torch math helpers for the MuJoCo custom envs. + +Quaternions follow the MuJoCo convention ``(w, x, y, z)``. CMG geometry +helpers support the satellite env's manipulability-based singularity penalty. +""" + +from __future__ import annotations + +import math + +import torch + + +def quat_mul(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: + """Hamilton product of two unit quaternions, ``(..., 4)`` -> ``(..., 4)``.""" + w1, x1, y1, z1 = q1.unbind(-1) + w2, x2, y2, z2 = q2.unbind(-1) + return torch.stack( + [ + w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2, + w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2, + w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2, + w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2, + ], + dim=-1, + ) + + +def quat_conj(q: torch.Tensor) -> torch.Tensor: + """Conjugate of a unit quaternion.""" + w, x, y, z = q.unbind(-1) + return torch.stack([w, -x, -y, -z], dim=-1) + + +def quat_log(q: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + """Logarithm map of a unit quaternion to a 3-vector axis-angle. + + For ``q = (cos(a/2), sin(a/2) n)`` with axis ``n`` and angle ``a``, + ``quat_log(q) = a * n``. Range: ``[0, pi]`` in magnitude. + """ + q = q / q.norm(dim=-1, keepdim=True).clamp_min(eps) + # q and -q encode the same SO(3) rotation. Use the representative + # with non-negative scalar part so the log map is the shortest arc. + sign = torch.where(q[..., 0:1] < 0, -1.0, 1.0) + q = q * sign + w = q[..., 0:1].clamp(-1.0, 1.0) + v = q[..., 1:] + v_norm = v.norm(dim=-1, keepdim=True).clamp_min(eps) + angle = 2.0 * torch.atan2(v_norm, w) + return v / v_norm * angle + + +def random_unit_quat( + shape: tuple[int, ...], + *, + generator: torch.Generator | None = None, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Uniform random unit quaternion on ``SO(3)``.""" + q = torch.randn(*shape, 4, generator=generator, device=device, dtype=dtype) + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +def _rodrigues_rotate( + g: torch.Tensor, r0: torch.Tensor, theta: torch.Tensor +) -> torch.Tensor: + """Rotate ``r0`` around unit axis ``g`` by angle ``theta``. + + ``g``, ``r0``: ``(3, N)``. ``theta``: ``(..., N)``. Output: ``(..., 3, N)``. + """ + cos_t = torch.cos(theta).unsqueeze(-2) + sin_t = torch.sin(theta).unsqueeze(-2) + g_dot_r = (g * r0).sum(dim=0) + g_cross_r = torch.linalg.cross(g, r0, dim=0) + return cos_t * r0 + sin_t * g_cross_r + (1.0 - cos_t) * g_dot_r * g + + +def cmg_jacobian( + gimbal_angles: torch.Tensor, + gimbal_axes: torch.Tensor, + rotor_axes_ref: torch.Tensor, + h: float, +) -> torch.Tensor: + """CMG output-torque Jacobian over gimbal rates. + + For each CMG with gimbal axis ``g_i`` and rotor axis ``r_i(theta_i)``, + the rate-of-change of the rotor's angular momentum per unit gimbal + rate is ``h * (g_i x r_i(theta_i))``. Stacked into a ``(..., 3, N)`` + matrix. The torque applied to the *bus* is the Newton's-third-law + reaction, ``-h * (g_i x r_i)``; this function returns the + rotor-frame quantity because the manipulability metric + ``sqrt(det(J J^T))`` -- the only consumer in this module -- is + sign-invariant. Callers that care about the body-frame slewing + direction must negate the result. + + Args: + gimbal_angles: ``(..., N)`` current gimbal angles in radians. + gimbal_axes: ``(3, N)`` fixed gimbal axes in body frame, unit norm. + rotor_axes_ref: ``(3, N)`` rotor axes at ``theta=0``, unit norm, + perpendicular to the corresponding gimbal axis. + h: scalar rotor angular momentum magnitude. + + Returns: + ``(..., 3, N)`` Jacobian whose ``i``-th column is the + rotor-momentum-rate per unit ``i``-th gimbal rate. Negate to + get the bus torque per unit gimbal rate. + """ + r = _rodrigues_rotate(gimbal_axes, rotor_axes_ref, gimbal_angles) + g_expanded = gimbal_axes.expand_as(r) + return h * torch.linalg.cross(g_expanded, r, dim=-2) + + +def manipulability(jac: torch.Tensor, eps: float = 1e-8) -> torch.Tensor: + """``sqrt(det(J J^T) + eps)`` -- proxy for distance from singularity. + + ``jac`` shape: ``(..., 3, N)`` with ``N >= 3``. Output: ``(...,)``. + """ + jjt = jac @ jac.transpose(-1, -2) + det = torch.linalg.det(jjt).clamp_min(0.0) + return torch.sqrt(det + eps) + + +def pyramid_4cmg_geometry( + skew_deg: float = 54.7356, + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """Standard 4-CMG pyramid: gimbal axes tilted by ``skew_deg`` from +z. + + Produces ``(gimbal_axes, rotor_axes_ref)``, both ``(3, 4)``. + The default skew angle ``arctan(sqrt(2)) ~ 54.74 deg`` gives a + spherical momentum envelope (the textbook configuration). + """ + beta = math.radians(skew_deg) + cb, sb = math.cos(beta), math.sin(beta) + g = torch.tensor( + [ + [sb, 0.0, -sb, 0.0], + [0.0, sb, 0.0, -sb], + [cb, cb, cb, cb], + ], + device=device, + dtype=dtype, + ) + # Reference rotor axes: in-plane perpendicular to each gimbal axis. + # For gimbal i tilted from +z, pick the axis lying in the body xy-plane + # rotated 90 deg around the gimbal axis from +z. + r0 = torch.tensor( + [ + [0.0, 1.0, 0.0, -1.0], + [-1.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0], + ], + device=device, + dtype=dtype, + ) + return g, r0 + + +def orthogonal_6cmg_geometry( + *, + device: torch.device | None = None, + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor]: + """6-CMG redundant cluster with gimbal axes along ``+/-x, +/-y, +/-z``. + + Reference rotor axes lie in the plane perpendicular to each gimbal, + chosen so the cluster is full-rank at ``theta = 0``. + """ + g = torch.tensor( + [ + [1.0, -1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, -1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, -1.0], + ], + device=device, + dtype=dtype, + ) + r0 = torch.tensor( + [ + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0], + [1.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 0.0, 0.0], + ], + device=device, + dtype=dtype, + ) + return g, r0 diff --git a/torchrl/envs/custom/mujoco/ant.py b/torchrl/envs/custom/mujoco/ant.py new file mode 100644 index 00000000000..2846ebba343 --- /dev/null +++ b/torchrl/envs/custom/mujoco/ant.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Ant-v4 quadruped locomotion env.""" + +from __future__ import annotations + +import re + +import torch +from tensordict import TensorDictBase +from torchrl.envs.custom.mujoco.base import MujocoEnv + + +class AntEnv(MujocoEnv): + """Quadruped locomotion (15-DoF, 8 actuators). + + The bundled ``ant.xml`` is a fixed-base ant; we patch it at load + time to insert a free joint on the torso and set ``timestep=0.01``, + matching Gymnasium ``Ant-v4`` semantics. + + Args: see :class:`~torchrl.envs.custom.mujoco.MujocoEnv`. + + Example: + >>> from torchrl.envs import AntEnv # doctest: +SKIP + >>> env = AntEnv(num_envs=8) # doctest: +SKIP + >>> td = env.rollout(50) # doctest: +SKIP + """ + + XML_PATH = "ant.xml" + FRAME_SKIP = 5 + RESET_NOISE_SCALE = 0.1 + SKIP_QPOS = 2 + HEALTHY_Z_LOW = 0.2 + HEALTHY_Z_HIGH = 1.0 + HEALTHY_REWARD = 1.0 + CTRL_COST_WEIGHT = 0.5 + + @classmethod + def _patch_xml(cls, xml: str) -> str: + xml = super()._patch_xml(xml) + xml = re.sub( + r'(]*>)', + r"\1\n ", + xml, + count=1, + ) + xml = re.sub( + r"(\s*)", + r'\1