diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index 08c832a4cc3..bf2da4fea0e 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -50,4 +50,5 @@ Documentation Sections objectives_policy objectives_actorcritic objectives_offline + objectives_multiagent objectives_other diff --git a/docs/source/reference/objectives_common.rst b/docs/source/reference/objectives_common.rst index 33ebaa44fce..3bdf05f174f 100644 --- a/docs/source/reference/objectives_common.rst +++ b/docs/source/reference/objectives_common.rst @@ -28,6 +28,7 @@ Value Estimators TD1Estimator TDLambdaEstimator GAE + MultiAgentGAE .. currentmodule:: torchrl.objectives diff --git a/docs/source/reference/objectives_multiagent.rst b/docs/source/reference/objectives_multiagent.rst new file mode 100644 index 00000000000..79703fc65f6 --- /dev/null +++ b/docs/source/reference/objectives_multiagent.rst @@ -0,0 +1,58 @@ +.. currentmodule:: torchrl.objectives.multiagent + +Multi-Agent Objectives +====================== + +Loss modules for multi-agent reinforcement learning algorithms. These losses +follow the torchrl multi-agent tensordict convention (per-agent tensors +nested under group keys such as ``("agents", "observation")``; see +:class:`~torchrl.envs.libs.vmas.VmasEnv` and +:class:`~torchrl.envs.libs.pettingzoo.PettingZooEnv`). + +MAPPO and IPPO +-------------- + +:class:`MAPPOLoss` implements Multi-Agent PPO (Yu et al. 2022) — a +decentralised actor paired with a *centralised critic* that conditions on the +joint observation / state. :class:`IPPOLoss` is the independent-learner +counterpart from de Witt et al. 2020: each agent has its own local critic and +there is no centralised information at training time. + +Both are thin specialisations of :class:`~torchrl.objectives.ClipPPOLoss` +that: + +- default the value estimator to + :class:`~torchrl.objectives.value.MultiAgentGAE`, which broadcasts + team-shared rewards / done flags across the agent dimension before + computing returns; +- default ``normalize_advantage_exclude_dims`` to ``(-2,)`` so the agent dim + is excluded from advantage standardisation; +- optionally accept a :class:`~torchrl.modules.ValueNorm` subclass — either + :class:`~torchrl.modules.PopArtValueNorm` (EMA, recommended for drifting + reward scales) or :class:`~torchrl.modules.RunningValueNorm` (exact + Welford running stats, recommended for stationary scales) — to stabilise + the critic loss. The MAPPO paper credits this trick for its strong SMAC + results. + +See ``sota-implementations/multiagent/mappo_ippo.py`` for a hydra-configured +recipe and ``examples/multiagent/mappo_vmas.py`` for a minimal one. + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + MAPPOLoss + IPPOLoss + +QMixer +------ + +:class:`QMixerLoss` mixes local per-agent Q values into a global team Q +value via a learnable mixing network, and trains them jointly with a DQN +update on the global value (Rashid et al. 2018). + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + QMixerLoss diff --git a/examples/multiagent/mappo_vmas.py b/examples/multiagent/mappo_vmas.py new file mode 100644 index 00000000000..a2fa4cc3ea5 --- /dev/null +++ b/examples/multiagent/mappo_vmas.py @@ -0,0 +1,221 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Minimal MAPPO / IPPO recipe on VMAS using the new +:class:`~torchrl.objectives.multiagent.MAPPOLoss` / +:class:`~torchrl.objectives.multiagent.IPPOLoss` classes. + +For the full, hydra-configured, wandb-logged version see +``sota-implementations/multiagent/mappo_ippo.py``. This file is intentionally +short: it's there to show that the new loss classes collapse the boilerplate +that previously required ``ClipPPOLoss`` + manual ``set_keys(done=..., +terminated=...)`` + manual ``make_value_estimator(GAE, ...)`` into a single +construction call. + +Usage:: + + python examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000 + python examples/multiagent/mappo_vmas.py --algo ippo --frames 200_000 + +The two should reach similar reward on the easy navigation scenario; MAPPO +typically pulls ahead on harder coordination tasks (Yu et al. 2022). +""" +from __future__ import annotations + +import argparse +import time + +import torch +from tensordict.nn import TensorDictModule +from tensordict.nn.distributions import NormalParamExtractor +from torch import nn + +from torchrl.collectors import SyncDataCollector +from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer +from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement +from torchrl.envs import RewardSum, TransformedEnv +from torchrl.modules import ( + MultiAgentMLP, + PopArtValueNorm, + ProbabilisticActor, + TanhNormal, +) +from torchrl.objectives import IPPOLoss, MAPPOLoss + + +def make_actor(env, *, share_params: bool = True) -> ProbabilisticActor: + obs_dim = env.observation_spec["agents", "observation"].shape[-1] + action_dim = env.action_spec.shape[-1] + backbone = nn.Sequential( + MultiAgentMLP( + n_agent_inputs=obs_dim, + n_agent_outputs=2 * action_dim, + n_agents=env.n_agents, + centralized=False, + share_params=share_params, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + NormalParamExtractor(), + ) + module = TensorDictModule( + backbone, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], + ) + return ProbabilisticActor( + module=module, + in_keys=[("agents", "loc"), ("agents", "scale")], + out_keys=[env.action_key], + distribution_class=TanhNormal, + distribution_kwargs={ + "low": env.full_action_spec_unbatched[("agents", "action")].space.low, + "high": env.full_action_spec_unbatched[("agents", "action")].space.high, + }, + return_log_prob=True, + ) + + +def make_critic( + env, *, centralized: bool, share_params: bool = True +) -> TensorDictModule: + obs_dim = env.observation_spec["agents", "observation"].shape[-1] + return TensorDictModule( + MultiAgentMLP( + n_agent_inputs=obs_dim, + n_agent_outputs=1, + n_agents=env.n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + num_cells=256, + activation_class=nn.Tanh, + ), + in_keys=[("agents", "observation")], + out_keys=[("agents", "state_value")], + ) + + +def main(args: argparse.Namespace) -> None: + try: + from torchrl.envs.libs.vmas import VmasEnv + except ImportError as exc: + raise SystemExit( + "This example requires VMAS. Install it with `pip install vmas`." + ) from exc + + device = "cuda" if torch.cuda.is_available() else "cpu" + torch.manual_seed(args.seed) + + n_envs = max(1, args.frames_per_batch // args.max_steps) + env = TransformedEnv( + VmasEnv( + scenario=args.scenario, + num_envs=n_envs, + continuous_actions=True, + max_steps=args.max_steps, + device=device, + seed=args.seed, + ), + RewardSum( + in_keys=[("next", "agents", "reward")], + out_keys=[("agents", "episode_reward")], + ) + if False + else RewardSum(in_keys=["reward"], out_keys=["episode_reward"]), + ) + + actor = make_actor(env) + centralised = args.algo == "mappo" + critic = make_critic(env, centralized=centralised) + + LossCls = MAPPOLoss if args.algo == "mappo" else IPPOLoss + value_norm = ( + PopArtValueNorm(shape=1, device=device) if args.algo == "mappo" else None + ) + loss_module = LossCls( + actor_network=actor, + critic_network=critic, + value_norm=value_norm, + clip_epsilon=0.2, + entropy_coeff=0.01, + ) + loss_module.set_keys( + value=("agents", "state_value"), + action=env.action_key, + reward=env.reward_key, + ) + + collector = SyncDataCollector( + env, + actor, + device=device, + storing_device=device, + frames_per_batch=args.frames_per_batch, + total_frames=args.frames, + ) + + replay_buffer = TensorDictReplayBuffer( + storage=LazyTensorStorage(args.frames_per_batch, device=device), + sampler=SamplerWithoutReplacement(), + batch_size=args.minibatch_size, + ) + + optim = torch.optim.Adam(loss_module.parameters(), lr=args.lr) + + total_frames = 0 + start = time.time() + for it, td in enumerate(collector): + with torch.no_grad(): + loss_module.value_estimator( + td, + params=loss_module.critic_network_params, + target_params=loss_module.target_critic_network_params, + ) + replay_buffer.extend(td.reshape(-1)) + total_frames += td.numel() + + for _ in range(args.epochs): + for _ in range(args.frames_per_batch // args.minibatch_size): + subdata = replay_buffer.sample() + losses = loss_module(subdata) + loss = ( + losses["loss_objective"] + + losses["loss_critic"] + + losses["loss_entropy"] + ) + optim.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0) + optim.step() + + collector.update_policy_weights_() + ep_reward = td.get(("next", "episode_reward")).mean().item() + print( + f"[{args.algo}] iter={it:03d} frames={total_frames:>7d} " + f"reward={ep_reward:+.3f} elapsed={time.time() - start:5.1f}s" + ) + + collector.shutdown() + if not env.is_closed: + env.close() + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--algo", choices=("mappo", "ippo"), default="mappo") + p.add_argument("--scenario", default="navigation") + p.add_argument("--frames", type=int, default=200_000) + p.add_argument("--frames_per_batch", type=int, default=6_000) + p.add_argument("--minibatch_size", type=int, default=400) + p.add_argument("--max_steps", type=int, default=100) + p.add_argument("--epochs", type=int, default=4) + p.add_argument("--lr", type=float, default=3e-4) + p.add_argument("--seed", type=int, default=0) + return p.parse_args() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/test/objectives/test_cql.py b/test/objectives/test_cql.py index b850518fae8..c5a3e01a4dd 100644 --- a/test/objectives/test_cql.py +++ b/test/objectives/test_cql.py @@ -182,7 +182,7 @@ def test_cql( delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -361,7 +361,7 @@ def test_cql_deactivate_vmap( deactivate_vmap=False, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -389,7 +389,7 @@ def test_cql_deactivate_vmap( deactivate_vmap=True, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -845,7 +845,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/test/objectives/test_ddpg.py b/test/objectives/test_ddpg.py index 2974179f250..d363cf289d3 100644 --- a/test/objectives/test_ddpg.py +++ b/test/objectives/test_ddpg.py @@ -248,7 +248,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): delay_actor=delay_actor, delay_value=delay_value, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1024,7 +1024,7 @@ def test_td3( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1136,7 +1136,7 @@ def test_td3_deactivate_vmap( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -1166,7 +1166,7 @@ def test_td3_deactivate_vmap( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -1937,7 +1937,7 @@ def test_td3bc( delay_actor=delay_actor, delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/test/objectives/test_dqn.py b/test/objectives/test_dqn.py index 7dcbac2dea0..9b2bf247ceb 100644 --- a/test/objectives/test_dqn.py +++ b/test/objectives/test_dqn.py @@ -258,7 +258,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): delay_value=delay_value, double_dqn=double_dqn, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -922,7 +922,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/test/objectives/test_dreamer.py b/test/objectives/test_dreamer.py index 4fe8cdb3ec6..027d9bf77cc 100644 --- a/test/objectives/test_dreamer.py +++ b/test/objectives/test_dreamer.py @@ -406,7 +406,7 @@ def test_dreamer_actor(self, device, imagination_horizon, discount_loss, td_est) imagination_horizon=imagination_horizon, discount_loss=discount_loss, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_module.make_value_estimator(td_est) return diff --git a/test/objectives/test_iql.py b/test/objectives/test_iql.py index 2a40ed9c010..658eae26920 100644 --- a/test/objectives/test_iql.py +++ b/test/objectives/test_iql.py @@ -310,7 +310,7 @@ def test_iql( expectile=expectile, loss_function="l2", ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -439,7 +439,7 @@ def test_iql_deactivate_vmap( loss_function="l2", deactivate_vmap=False, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -463,7 +463,7 @@ def test_iql_deactivate_vmap( loss_function="l2", deactivate_vmap=True, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -1206,7 +1206,7 @@ def test_discrete_iql( loss_function="l2", action_space="one-hot", ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return diff --git a/test/objectives/test_mappo.py b/test/objectives/test_mappo.py new file mode 100644 index 00000000000..7be435716de --- /dev/null +++ b/test/objectives/test_mappo.py @@ -0,0 +1,478 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Tests for MAPPOLoss, IPPOLoss, MultiAgentGAE, and the ValueNorm family. + +These tests use synthetic tensordicts so they don't depend on any external +MARL env (VMAS / PettingZoo). They follow the layout pattern from +``test/test_cost.py::TestQMixer`` — per-agent observations under +``("agents", "observation")`` and team-shared reward / done at the root. +""" +from __future__ import annotations + +import pytest +import torch +from tensordict import TensorDict +from tensordict.nn import TensorDictModule + +from torchrl.modules import ( + MultiAgentMLP, + PopArtValueNorm, + ProbabilisticActor, + RunningValueNorm, + ValueNorm, +) +from torchrl.modules.distributions import NormalParamExtractor, TanhNormal +from torchrl.objectives import IPPOLoss, MAPPOLoss +from torchrl.objectives.utils import ValueEstimators +from torchrl.objectives.value import GAE, MultiAgentGAE + + +# -------------------------------------------------------------------------- +# helpers +# -------------------------------------------------------------------------- + + +def _make_actor(n_agents=3, obs_dim=6, action_dim=2, share_params=True): + backbone = torch.nn.Sequential( + MultiAgentMLP( + n_agent_inputs=obs_dim, + n_agent_outputs=2 * action_dim, + n_agents=n_agents, + centralized=False, + share_params=share_params, + ), + NormalParamExtractor(), + ) + module = TensorDictModule( + backbone, + in_keys=[("agents", "observation")], + out_keys=[("agents", "loc"), ("agents", "scale")], + ) + return ProbabilisticActor( + module=module, + in_keys={"loc": ("agents", "loc"), "scale": ("agents", "scale")}, + out_keys=[("agents", "action")], + distribution_class=TanhNormal, + return_log_prob=True, + ) + + +def _make_critic(n_agents=3, obs_dim=6, centralized=True, share_params=True): + return TensorDictModule( + MultiAgentMLP( + n_agent_inputs=obs_dim, + n_agent_outputs=1, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + ), + in_keys=[("agents", "observation")], + out_keys=[("agents", "state_value")], + ) + + +def _make_data( + B=2, T=10, n_agents=3, obs_dim=6, action_dim=2, per_agent_reward=False, device="cpu" +): + torch.manual_seed(0) + obs = torch.randn(B, T, n_agents, obs_dim, device=device) + next_obs = torch.randn(B, T, n_agents, obs_dim, device=device) + if per_agent_reward: + reward_shape = (B, T, n_agents, 1) + else: + reward_shape = (B, T, 1) + reward = torch.randn(*reward_shape, device=device) + done = torch.zeros(B, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(B, T, 1, dtype=torch.bool, device=device) + return TensorDict( + { + "agents": TensorDict({"observation": obs}, [B, T, n_agents]), + "next": TensorDict( + { + "agents": TensorDict({"observation": next_obs}, [B, T, n_agents]), + "reward": reward, + "done": done, + "terminated": terminated, + }, + [B, T], + ), + }, + [B, T], + device=device, + ) + + +def _attach_action_and_logprob(td: TensorDict, actor: ProbabilisticActor, loss): + with torch.no_grad(): + sampled = actor(td.clone()) + td[("agents", "action")] = sampled[("agents", "action")] + td[loss.tensor_keys.sample_log_prob] = sampled[loss.tensor_keys.sample_log_prob] + return td + + +# -------------------------------------------------------------------------- +# MultiAgentGAE +# -------------------------------------------------------------------------- + + +class TestMultiAgentGAE: + def test_team_reward_broadcast(self): + n_agents = 3 + critic = _make_critic(n_agents=n_agents) + gae = MultiAgentGAE(gamma=0.99, lmbda=0.95, value_network=critic) + gae.set_keys(value=("agents", "state_value")) + td = _make_data(n_agents=n_agents, per_agent_reward=False) + gae(td) + # advantage matches the critic's per-agent value shape + assert td[gae.tensor_keys.advantage].shape[-2] == n_agents + assert td[gae.tensor_keys.advantage].shape[-1] == 1 + assert ( + td[gae.tensor_keys.value_target].shape + == td[gae.tensor_keys.advantage].shape + ) + + def test_per_agent_reward_passthrough(self): + n_agents = 3 + critic = _make_critic(n_agents=n_agents, centralized=False) + gae = MultiAgentGAE(gamma=0.99, lmbda=0.95, value_network=critic) + gae.set_keys(value=("agents", "state_value")) + td = _make_data(n_agents=n_agents, per_agent_reward=True) + gae(td) + assert td[gae.tensor_keys.advantage].shape[-2] == n_agents + + def test_broadcast_error_on_bad_shape(self): + gae = MultiAgentGAE(gamma=0.99, lmbda=0.95, value_network=None, agent_dim=-2) + # value has ndim=4 (B, T, n_agents, 1); a 2-D reward is neither + # per-agent nor team-shared, so the helper must reject it. + bad_tensor = torch.zeros(4, 5) + target = torch.zeros(4, 5, 3, 1) + with pytest.raises(ValueError, match="expected the reward/done/terminated"): + gae._broadcast_to_agents(bad_tensor, target, agent_dim=-2) + + def test_value_estimator_enum_registered(self): + # MAGAE is wired up in default_value_kwargs and the enum. + from torchrl.objectives.utils import default_value_kwargs + + kw = default_value_kwargs(ValueEstimators.MAGAE) + assert "gamma" in kw and "lmbda" in kw + + +# -------------------------------------------------------------------------- +# Value-estimator registry +# -------------------------------------------------------------------------- + + +class TestValueEstimatorRegistry: + def test_all_builtins_registered(self): + """Every ValueEstimators enum member must have a registry entry.""" + from torchrl.objectives.utils import ( + _VALUE_ESTIMATOR_REGISTRY, + get_value_estimator_entry, + ) + + for member in ValueEstimators: + assert member in _VALUE_ESTIMATOR_REGISTRY, f"missing: {member}" + entry = get_value_estimator_entry(member) + assert entry.cls is not None + assert "gamma" in entry.default_kwargs + + def test_string_alias_resolves(self): + from torchrl.objectives.utils import get_value_estimator_entry + + assert ( + get_value_estimator_entry("gae").cls + is get_value_estimator_entry(ValueEstimators.GAE).cls + ) + assert ( + get_value_estimator_entry("magae").cls + is get_value_estimator_entry(ValueEstimators.MAGAE).cls + ) + + def test_unknown_alias_raises(self): + from torchrl.objectives.utils import get_value_estimator_entry + + with pytest.raises(KeyError, match="Unknown value estimator alias"): + get_value_estimator_entry("not_a_real_estimator") + + def test_unknown_type_raises(self): + from torchrl.objectives.utils import get_value_estimator_entry + + with pytest.raises(TypeError, match="must be a ValueEstimators"): + get_value_estimator_entry(42) + + def test_register_and_dispatch_custom_estimator(self): + """Adding a new estimator must not require touching any loss file.""" + from enum import Enum + + from torchrl.objectives.utils import ( + _VALUE_ESTIMATOR_REGISTRY, + register_value_estimator, + ) + from torchrl.objectives.value.advantages import GAE + + # We have to extend the enum at runtime for this test. Python's Enum + # forbids appending, so we monkey-patch the registry directly with a + # sentinel key — that's the path a third-party custom enum would take. + class _Custom(Enum): + FAKE = "fake" + + # Pretend we registered against a "real" entry by abusing _Custom. + @register_value_estimator( + _Custom.FAKE, default_kwargs={"gamma": 0.99, "lmbda": 0.5} + ) + class _MyGAE(GAE): + pass + + try: + entry = _VALUE_ESTIMATOR_REGISTRY[_Custom.FAKE] + assert entry.cls is _MyGAE + assert entry.default_kwargs == {"gamma": 0.99, "lmbda": 0.5} + finally: + _VALUE_ESTIMATOR_REGISTRY.pop(_Custom.FAKE, None) + + def test_default_value_kwargs_reads_registry(self): + """Back-compat shim must agree with the registry.""" + from torchrl.objectives.utils import ( + _VALUE_ESTIMATOR_REGISTRY, + default_value_kwargs, + ) + + for member, entry in _VALUE_ESTIMATOR_REGISTRY.items(): + assert default_value_kwargs(member) == entry.default_kwargs + + +# -------------------------------------------------------------------------- +# ValueNorm — abstract base + the two concrete implementations +# -------------------------------------------------------------------------- + + +class TestValueNormBase: + def test_abstract_base_is_not_instantiable(self): + with pytest.raises(TypeError): + ValueNorm(shape=1) # type: ignore[abstract] + + +class TestPopArtValueNorm: + def test_running_stats_converge(self): + torch.manual_seed(0) + vn = PopArtValueNorm(shape=1) + x = torch.randn(4096, 1) * 5.0 + 2.0 + for _ in range(200): + vn.update(x) + mean, var = vn._running_stats() + assert abs(mean.item() - 2.0) < 0.2 + assert abs(var.sqrt().item() - 5.0) < 0.5 + + def test_denormalize_inverts_normalize(self): + torch.manual_seed(0) + vn = PopArtValueNorm(shape=1) + x = torch.randn(512, 1) * 3.0 + 1.0 + for _ in range(50): + vn.update(x) + y = torch.randn(64, 1) * 3.0 + 1.0 + recovered = vn.denormalize(vn.normalize(y)) + torch.testing.assert_close(recovered, y, rtol=1e-4, atol=1e-4) + + def test_bad_shape_raises(self): + vn = PopArtValueNorm(shape=1) + with pytest.raises(ValueError, match="trailing shape"): + vn.update(torch.randn(4, 8)) # trailing 8 != 1 + + +class TestRunningValueNorm: + def test_running_stats_converge(self): + """Exact running stats should be very tight even after few updates.""" + torch.manual_seed(0) + vn = RunningValueNorm(shape=1) + x = torch.randn(4096, 1) * 5.0 + 2.0 + for _ in range(20): + vn.update(x) + assert abs(vn.mean.item() - 2.0) < 0.1 + assert abs(vn._var().sqrt().item() - 5.0) < 0.1 + + def test_denormalize_inverts_normalize(self): + torch.manual_seed(0) + vn = RunningValueNorm(shape=1) + for _ in range(10): + vn.update(torch.randn(256, 1) * 3.0 + 1.0) + y = torch.randn(64, 1) * 3.0 + 1.0 + recovered = vn.denormalize(vn.normalize(y)) + torch.testing.assert_close(recovered, y, rtol=1e-4, atol=1e-4) + + def test_no_decay(self): + """RunningValueNorm should not be biased by sample order (no EMA).""" + torch.manual_seed(0) + vn = RunningValueNorm(shape=1) + # Feed two batches with very different scales; running stats should + # land at the true combined mean rather than getting dominated by + # whichever batch came last (which is what an EMA would do). + a = torch.full((1000, 1), 1.0) + b = torch.full((1000, 1), 5.0) + vn.update(a) + vn.update(b) + # Combined mean of 1000 ones + 1000 fives = 3.0 exactly. + assert abs(vn.mean.item() - 3.0) < 1e-4 + + +# -------------------------------------------------------------------------- +# MAPPOLoss +# -------------------------------------------------------------------------- + + +class TestMAPPOLoss: + def test_forward_shapes_and_backward(self): + actor = _make_actor() + critic = _make_critic(centralized=True) + loss_mod = MAPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + + td = _make_data() + _attach_action_and_logprob(td, actor, loss_mod) + out = loss_mod(td) + + for k in ("loss_objective", "loss_entropy", "loss_critic"): + assert k in out, f"missing key {k}" + assert out[k].shape == torch.Size( + [] + ), f"{k} should be scalar, got {out[k].shape}" + + # Gradients reach both actor and critic. + total = out["loss_objective"] + out["loss_entropy"] + out["loss_critic"] + total.backward() + actor_grads = [ + p.grad + for p in loss_mod.actor_network_params.values(True, True) + if isinstance(p, torch.nn.Parameter) and p.grad is not None + ] + critic_grads = [ + p.grad + for p in loss_mod.critic_network_params.values(True, True) + if isinstance(p, torch.nn.Parameter) and p.grad is not None + ] + assert len(actor_grads) > 0, "actor received no grads" + assert len(critic_grads) > 0, "critic received no grads" + + def test_centralized_critic_uses_full_team_obs(self): + """Perturbing one agent's obs must change every agent's value.""" + actor = _make_actor() + critic = _make_critic(centralized=True) + loss_mod = MAPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + + td = _make_data() + with torch.no_grad(): + v_before = critic(td.clone())[("agents", "state_value")] + td2 = td.clone() + td2["agents", "observation"][..., 0, :] += 5.0 # perturb agent 0 + v_after = critic(td2)[("agents", "state_value")] + + # Other agents' values must change (centralised critic saw the perturb). + diff = (v_before - v_after).abs().mean(dim=(0, 1, 3)) + assert diff[1] > 1e-6, "Centralised critic ignored cross-agent obs change" + assert diff[2] > 1e-6, "Centralised critic ignored cross-agent obs change" + + @pytest.mark.parametrize("share_params", [True, False]) + def test_share_params_modes(self, share_params): + actor = _make_actor(share_params=share_params) + critic = _make_critic(centralized=True, share_params=share_params) + loss_mod = MAPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + td = _make_data() + _attach_action_and_logprob(td, actor, loss_mod) + out = loss_mod(td) + assert out["loss_objective"].shape == torch.Size([]) + + def test_value_norm_round_trip(self): + """With PopArtValueNorm, critic loss should remain bounded across many updates.""" + torch.manual_seed(0) + actor = _make_actor() + critic = _make_critic(centralized=True) + vn = PopArtValueNorm(shape=1) + loss_mod = MAPPOLoss(actor, critic, value_norm=vn) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + + critic_losses = [] + for step in range(8): + td = _make_data(B=2, T=10) + # inflate reward scale over time to stress normalisation + td["next", "reward"] *= (step + 1) * 10.0 + _attach_action_and_logprob(td, actor, loss_mod) + out = loss_mod(td) + critic_losses.append(out["loss_critic"].item()) + + # Without ValueNorm a 10x reward inflation would blow critic loss up + # quadratically; with ValueNorm it should stay roughly bounded. + assert max(critic_losses) < 10.0, f"critic loss exploded: {critic_losses}" + + def test_default_value_estimator_is_magae(self): + actor = _make_actor() + critic = _make_critic(centralized=True) + loss_mod = MAPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + loss_mod.make_value_estimator() + assert loss_mod.value_type == ValueEstimators.MAGAE + assert isinstance(loss_mod._value_estimator, MultiAgentGAE) + + def test_make_value_estimator_falls_through_for_non_magae(self): + """Selecting a non-MAGAE estimator goes through the parent class. + + The downstream call may still fail at runtime if the user feeds + team-shared reward/done tensors (that is the whole reason + :class:`MultiAgentGAE` exists), but we want to confirm the dispatch + actually selects the requested estimator class. + """ + actor = _make_actor() + critic = _make_critic(centralized=False) + loss_mod = MAPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + loss_mod.make_value_estimator(ValueEstimators.GAE) + assert isinstance(loss_mod._value_estimator, GAE) + assert not isinstance(loss_mod._value_estimator, MultiAgentGAE) + + +# -------------------------------------------------------------------------- +# IPPOLoss +# -------------------------------------------------------------------------- + + +class TestIPPOLoss: + def test_forward_shapes_and_backward(self): + actor = _make_actor() + critic = _make_critic(centralized=False) + loss_mod = IPPOLoss(actor, critic) + loss_mod.set_keys(value=("agents", "state_value"), action=("agents", "action")) + + td = _make_data() + _attach_action_and_logprob(td, actor, loss_mod) + out = loss_mod(td) + + for k in ("loss_objective", "loss_entropy", "loss_critic"): + assert out[k].shape == torch.Size([]) + + (out["loss_objective"] + out["loss_critic"]).backward() + + def test_decentralized_critic_ignores_other_agents(self): + """IPPO critic must depend only on the agent's own observation.""" + critic = _make_critic(centralized=False) + td = _make_data() + with torch.no_grad(): + v_before = critic(td.clone())[("agents", "state_value")] + td2 = td.clone() + td2["agents", "observation"][..., 0, :] += 5.0 # perturb agent 0 + v_after = critic(td2)[("agents", "state_value")] + + diff = (v_before - v_after).abs().mean(dim=(0, 1, 3)) + # Agent 0's value changed. + assert diff[0] > 1e-6 + # Agents 1, 2 must be unaffected. + assert diff[1] < 1e-6 + assert diff[2] < 1e-6 + + +if __name__ == "__main__": + import argparse + + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/objectives/test_sac.py b/test/objectives/test_sac.py index 84281021351..ad7b98dd7b1 100644 --- a/test/objectives/test_sac.py +++ b/test/objectives/test_sac.py @@ -436,7 +436,7 @@ def test_sac( **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -611,7 +611,7 @@ def test_sac_deactivate_vmap( **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -636,7 +636,7 @@ def test_sac_deactivate_vmap( **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -1671,7 +1671,7 @@ def test_discrete_sac( action_space="one-hot", **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -1790,7 +1790,7 @@ def test_discrete_sac_deactivate_vmap( deactivate_vmap=False, **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -1818,7 +1818,7 @@ def test_discrete_sac_deactivate_vmap( deactivate_vmap=True, **kwargs, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -2466,7 +2466,7 @@ def test_crossq( loss_function="l2", ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -2567,7 +2567,7 @@ def test_crossq_deactivate_vmap( deactivate_vmap=False, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_vmap.make_value_estimator(td_est) return @@ -2590,7 +2590,7 @@ def test_crossq_deactivate_vmap( deactivate_vmap=True, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_no_vmap.make_value_estimator(td_est) return @@ -3300,7 +3300,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3687,7 +3687,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn.make_value_estimator(td_est) return @@ -3704,7 +3704,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): loss_function="l2", delay_qvalue=delay_qvalue, ) - if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): loss_fn_deprec.make_value_estimator(td_est) return diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 55cb92b793c..296b6e36a3b 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -100,6 +100,7 @@ set_exploration_modules_spec_from_env, ) from .utils import get_env_transforms_from_module, get_primers_from_module +from .value_norm import PopArtValueNorm, RunningValueNorm, ValueNorm from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip from .mcts import ( # usort:skip EXP3Score, @@ -175,6 +176,7 @@ "Ordinal", "OrnsteinUhlenbeckProcessModule", "OrnsteinUhlenbeckProcessWrapper", + "PopArtValueNorm", "ProbabilisticActor", "PUCTScore", "QMixer", @@ -186,6 +188,7 @@ "RSSMPrior", "RSSMRollout", "ReparamGradientStrategy", + "RunningValueNorm", "SafeModule", "SafeProbabilisticModule", "SafeProbabilisticTensorDictSequential", @@ -199,6 +202,7 @@ "UCB1TunedScore", "UCBScore", "VDNMixer", + "ValueNorm", "ValueOperator", "VmapModule", "WorldModelWrapper", diff --git a/torchrl/modules/value_norm.py b/torchrl/modules/value_norm.py new file mode 100644 index 00000000000..fad5ede6c4e --- /dev/null +++ b/torchrl/modules/value_norm.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Value normalisation for actor-critic algorithms. + +Defines the abstract :class:`ValueNorm` interface and two concrete +implementations: + +- :class:`PopArtValueNorm` — exponential-moving-average mean / mean-of-squares + with debiasing (van Hasselt et al., *Multi-task Deep RL with PopArt*, + AAAI 2019, https://arxiv.org/abs/1809.04474). Used by MAPPO + (Yu et al. 2022) to stabilise the critic loss when reward scales drift. +- :class:`RunningValueNorm` — exact Welford running mean / variance with no + decay. Cheaper and more stable when value targets are stationary; tends to + be the better default for shorter / non-curriculum runs. + +Plug any subclass into :class:`~torchrl.objectives.multiagent.MAPPOLoss` (or +your own actor-critic loss) via ``value_norm=...``. +""" +from __future__ import annotations + +from abc import ABCMeta, abstractmethod + +import torch +from torch import nn + + +class ValueNorm(nn.Module, metaclass=ABCMeta): + """Abstract base class for value normalisers. + + A *value normaliser* keeps a running estimate of the location and scale of + the value target seen during training. Critics use it to: + + - **normalize** the regression target before computing MSE, keeping the + critic loss on a fixed scale across episodes / reward inflations; + - **denormalize** the critic's output back to the real reward scale when + forming bootstrapped value estimates inside GAE / TD. + + Subclasses must implement :meth:`update`, :meth:`normalize`, and + :meth:`denormalize`. The convention is that all three operate on tensors + whose trailing dims match :attr:`shape` (the per-element value shape, + usually ``(1,)``). + """ + + shape: tuple[int, ...] + + def __init__( + self, + *, + shape: int | tuple[int, ...] = 1, + epsilon: float = 1e-5, + device: torch.device | None = None, + ) -> None: + super().__init__() + if isinstance(shape, int): + shape = (shape,) + self.shape = tuple(shape) + self.epsilon = epsilon + self._device = device + + # ------------------------------------------------------------------ API + + @abstractmethod + def update(self, value_target: torch.Tensor) -> None: + """Fold a batch of value targets into the running stats.""" + + @abstractmethod + def normalize(self, value_target: torch.Tensor) -> torch.Tensor: + """Standardise ``value_target`` using the current running stats.""" + + @abstractmethod + def denormalize(self, normalised_value: torch.Tensor) -> torch.Tensor: + """Inverse of :meth:`normalize` — recover real-scale values.""" + + # ------------------------------------------------------- shared helpers + + def _check_trailing_shape(self, value_target: torch.Tensor) -> tuple[int, ...]: + if value_target.shape[-len(self.shape) :] != self.shape: + raise ValueError( + f"{type(self).__name__} was initialised with shape={self.shape} " + f"but got a value_target with trailing shape " + f"{tuple(value_target.shape[-len(self.shape) :])}." + ) + return tuple(range(value_target.ndim - len(self.shape))) + + +class PopArtValueNorm(ValueNorm): + """PopArt-style EMA value normaliser. + + Maintains exponentially-weighted running estimates of the value-target + mean and mean-of-squares, with debiasing (so the early-training estimates + are unbiased even before the EMA has had time to wash out the zero + initialisation). Equivalent to the value-normaliser used by the reference + MAPPO implementation. + + Keyword Args: + shape: per-element shape of the value tensor (everything except the + leading batch / time / agent dims that get reduced). Defaults to + ``1``. + beta: exponential decay for the running stats. Higher = slower + adaptation. Defaults to ``0.99999`` (the MAPPO default). + epsilon: numerical stabiliser added to the running variance and used + as a floor for the debiasing term. Defaults to ``1e-5``. + device: device for the running-stats buffers. + + Example: + >>> vn = PopArtValueNorm(shape=1) + >>> target = torch.randn(64, 1) * 5.0 + 2.0 # mean 2, std 5 + >>> for _ in range(100): + ... vn.update(target) + >>> normed = vn.normalize(target) # ~ N(0, 1) + >>> recovered = vn.denormalize(normed) # back to real scale + """ + + def __init__( + self, + *, + shape: int | tuple[int, ...] = 1, + beta: float = 0.99999, + epsilon: float = 1e-5, + device: torch.device | None = None, + ) -> None: + super().__init__(shape=shape, epsilon=epsilon, device=device) + self.beta = beta + # Both running buffers start at zero. The debiasing term tracks + # \sum_{s<=t} beta^{t-s}, which also starts at zero; dividing the + # zero-init buffers by the (clamped) debias gives an unbiased EMA. + self.register_buffer("running_mean", torch.zeros(self.shape, device=device)) + self.register_buffer("running_mean_sq", torch.zeros(self.shape, device=device)) + self.register_buffer("debiasing_term", torch.zeros((), device=device)) + + def _running_stats(self) -> tuple[torch.Tensor, torch.Tensor]: + debias = self.debiasing_term.clamp(min=self.epsilon) + mean = self.running_mean / debias + mean_sq = self.running_mean_sq / debias + var = (mean_sq - mean.pow(2)).clamp(min=self.epsilon) + return mean, var + + @torch.no_grad() + def update(self, value_target: torch.Tensor) -> None: + value_target = value_target.detach() + reduce_dims = self._check_trailing_shape(value_target) + if reduce_dims: + batch_mean = value_target.mean(dim=reduce_dims) + batch_mean_sq = value_target.pow(2).mean(dim=reduce_dims) + else: + batch_mean = value_target + batch_mean_sq = value_target.pow(2) + + self.running_mean.mul_(self.beta).add_(batch_mean, alpha=1.0 - self.beta) + self.running_mean_sq.mul_(self.beta).add_(batch_mean_sq, alpha=1.0 - self.beta) + self.debiasing_term.mul_(self.beta).add_(1.0 - self.beta) + + def normalize(self, value_target: torch.Tensor) -> torch.Tensor: + mean, var = self._running_stats() + return (value_target - mean) / var.sqrt() + + def denormalize(self, normalised_value: torch.Tensor) -> torch.Tensor: + mean, var = self._running_stats() + return normalised_value * var.sqrt() + mean + + +class RunningValueNorm(ValueNorm): + """Exact running mean / variance (Welford's online algorithm). + + Unlike :class:`PopArtValueNorm`, this normaliser does not decay older + samples — it accumulates the true sample mean and variance over every + target it has ever seen. Useful when value targets are roughly stationary + (no curriculum, no reward-shaping schedule), where the EMA's adaptivity + is unnecessary and the exact running stats give a slightly tighter + estimate. + + Keyword Args: + shape: per-element shape of the value tensor. Defaults to ``1``. + epsilon: numerical stabiliser added to the running variance. + Defaults to ``1e-5``. + device: device for the running-stats buffers. + + Example: + >>> vn = RunningValueNorm(shape=1) + >>> for _ in range(10): + ... vn.update(torch.randn(64, 1) * 3.0 + 1.0) + >>> normed = vn.normalize(torch.randn(8, 1)) + """ + + def __init__( + self, + *, + shape: int | tuple[int, ...] = 1, + epsilon: float = 1e-5, + device: torch.device | None = None, + ) -> None: + super().__init__(shape=shape, epsilon=epsilon, device=device) + self.register_buffer("mean", torch.zeros(self.shape, device=device)) + # m2 stores the running sum of squared deviations from the mean + # (Welford's M2). var = m2 / max(count - 1, 1). + self.register_buffer("m2", torch.zeros(self.shape, device=device)) + self.register_buffer("count", torch.zeros((), device=device)) + + @torch.no_grad() + def update(self, value_target: torch.Tensor) -> None: + value_target = value_target.detach() + reduce_dims = self._check_trailing_shape(value_target) + if reduce_dims: + batch_count = float( + torch.tensor([value_target.shape[d] for d in reduce_dims]).prod() + ) + batch_mean = value_target.mean(dim=reduce_dims) + batch_var = value_target.var(dim=reduce_dims, unbiased=False) + else: + batch_count = 1.0 + batch_mean = value_target + batch_var = torch.zeros_like(value_target) + + # Chan et al. parallel variance update. + delta = batch_mean - self.mean + total = self.count + batch_count + new_mean = self.mean + delta * (batch_count / total) + new_m2 = ( + self.m2 + + batch_var * batch_count + + delta.pow(2) * (self.count * batch_count / total) + ) + self.mean.copy_(new_mean) + self.m2.copy_(new_m2) + self.count.fill_(total) + + def _var(self) -> torch.Tensor: + denom = self.count.clamp(min=1.0) + return (self.m2 / denom).clamp(min=self.epsilon) + + def normalize(self, value_target: torch.Tensor) -> torch.Tensor: + return (value_target - self.mean) / self._var().sqrt() + + def denormalize(self, normalised_value: torch.Tensor) -> torch.Tensor: + return normalised_value * self._var().sqrt() + self.mean diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index fe620506c42..09d44bbf626 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -30,7 +30,7 @@ ) from torchrl.objectives.gail import GAILLoss from torchrl.objectives.iql import DiscreteIQLLoss, IQLLoss -from torchrl.objectives.multiagent import QMixerLoss +from torchrl.objectives.multiagent import IPPOLoss, MAPPOLoss, QMixerLoss from torchrl.objectives.pilco import ExponentialQuadraticCost from torchrl.objectives.ppo import ClipPPOLoss, KLPENPPOLoss, PPOLoss from torchrl.objectives.redq import REDQLoss @@ -75,9 +75,11 @@ "ExponentialQuadraticCost", "GAILLoss", "HardUpdate", + "IPPOLoss", "IQLLoss", "KLPENPPOLoss", "LossModule", + "MAPPOLoss", "OnlineDTLoss", "PPOLoss", "QMixerLoss", diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index eba56d98756..dcf65eea89b 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -from copy import deepcopy from dataclasses import dataclass import torch @@ -34,18 +33,11 @@ _GAMMA_LMBDA_DEPREC_ERROR, _get_default_device, _reduce, - default_value_kwargs, + build_value_estimator, distance_loss, ValueEstimators, ) -from torchrl.objectives.value import ( - GAE, - TD0Estimator, - TD1Estimator, - TDLambdaEstimator, - ValueEstimatorBase, - VTrace, -) +from torchrl.objectives.value import ValueEstimatorBase class A2CLoss(LossModule): @@ -613,40 +605,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams return LossModule.make_value_estimator(self, value_type, **hyperparams) self.value_type = value_type - hp = dict(default_value_kwargs(value_type)) - hp.update(hyperparams) - - device = _get_default_device(self) - hp["device"] = device - + hp = dict(hyperparams) + hp["device"] = _get_default_device(self) if hasattr(self, "gamma"): - hp["gamma"] = self.gamma - if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic_network, **hp) - elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.VTrace: - # VTrace currently does not support functional call on the actor - if self.functional: - actor_with_params = deepcopy(self.actor_network) - self.actor_network_params.to_module(actor_with_params) - else: - actor_with_params = self.actor_network - self._value_estimator = VTrace( - value_network=self.critic_network, actor_network=actor_with_params, **hp - ) - else: - raise NotImplementedError(f"Unknown value type {value_type}") + hp.setdefault("gamma", self.gamma) + # Registry-driven dispatch — adding a new estimator only needs + # @register_value_estimator on the class, not edits here. + self._value_estimator = build_value_estimator(self, value_type, **hp) tensor_keys = { "advantage": self.tensor_keys.advantage, diff --git a/torchrl/objectives/multiagent/__init__.py b/torchrl/objectives/multiagent/__init__.py index cec01e0ca0c..6e25271b326 100644 --- a/torchrl/objectives/multiagent/__init__.py +++ b/torchrl/objectives/multiagent/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .mappo import IPPOLoss, MAPPOLoss from .qmixer import QMixerLoss -__all__ = ["QMixerLoss"] +__all__ = ["IPPOLoss", "MAPPOLoss", "QMixerLoss"] diff --git a/torchrl/objectives/multiagent/mappo.py b/torchrl/objectives/multiagent/mappo.py new file mode 100644 index 00000000000..00925659dc2 --- /dev/null +++ b/torchrl/objectives/multiagent/mappo.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Multi-agent PPO objectives. + +Implements :class:`MAPPOLoss` (centralised critic) and :class:`IPPOLoss` +(independent / decentralised critic). + +References: + - Yu, C. et al. *The Surprising Effectiveness of PPO in Cooperative + Multi-Agent Games.* NeurIPS 2022. https://arxiv.org/abs/2103.01955 + - de Witt, C. S. et al. *Is Independent Learning All You Need in the + StarCraft Multi-Agent Challenge?* 2020. https://arxiv.org/abs/2011.09533 +""" +from __future__ import annotations + +import contextlib +from typing import Any + +from tensordict import TensorDictBase +from tensordict.nn import TensorDictModule +from tensordict.nn.probabilistic import ProbabilisticTensorDictSequential + +from torchrl.modules.value_norm import ValueNorm +from torchrl.objectives.ppo import ClipPPOLoss +from torchrl.objectives.utils import distance_loss, ValueEstimators + + +class _MultiAgentPPOMixin: + """Shared plumbing for :class:`MAPPOLoss` and :class:`IPPOLoss`. + + Two pieces: + + 1. Default the value estimator to :class:`MultiAgentGAE` so per-agent value + outputs broadcast cleanly against team-shared reward / done signals. + Dispatch goes through + :func:`~torchrl.objectives.utils.build_value_estimator`, so we no + longer need a ``make_value_estimator`` override here — the registry + resolves ``ValueEstimators.MAGAE`` to :class:`MultiAgentGAE` + automatically. + 2. Wrap the parent's :meth:`loss_critic` so that, when a + :class:`~torchrl.modules.ValueNorm` is attached, the running value + target stats are updated and both target and prediction are + normalised before the MSE. This stabilises critic-loss magnitude + when reward scales drift during training (Yu et al. 2022, Table 13). + """ + + default_value_estimator = ValueEstimators.MAGAE + + def loss_critic(self, tensordict: TensorDictBase): + # Delegate to ClipPPOLoss; if no value_norm is attached this is a + # no-op wrapper. Otherwise we normalise target and prediction so the + # MSE lives on a fixed scale across training. + value_norm: ValueNorm | None = getattr(self, "value_norm", None) + if value_norm is None: + return super().loss_critic(tensordict) + + target_return = tensordict.get(self.tensor_keys.value_target, None) + if target_return is None: + raise KeyError( + f"the key {self.tensor_keys.value_target} was not found in the " + "input tensordict. Make sure the value estimator ran before " + "computing the loss." + ) + + # Forward the critic ourselves so we can normalise its output. + with self.critic_network_params.to_module( + self.critic_network + ) if self.functional else contextlib.nullcontext(): + state_value_td = self.critic_network(tensordict) + state_value = state_value_td.get(self.tensor_keys.value) + if state_value is None: + raise KeyError( + f"the key {self.tensor_keys.value} was not found in the critic " + "output tensordict." + ) + + value_norm.update(target_return) + normalised_target = value_norm.normalize(target_return.detach()) + normalised_pred = value_norm.normalize(state_value) + loss_value = distance_loss( + normalised_target, normalised_pred, loss_function=self.loss_critic_type + ) + + self._clear_weakrefs( + tensordict, + "actor_network_params", + "critic_network_params", + "target_actor_network_params", + "target_critic_network_params", + ) + if self._has_critic: + return self.critic_coef * loss_value, None, None + return loss_value, None, None + + +class MAPPOLoss(_MultiAgentPPOMixin, ClipPPOLoss): + """Multi-Agent PPO loss with a centralised critic (Yu et al. 2022). + + MAPPO trains a *decentralised actor* (each agent's policy conditions only + on its local observation) together with a *centralised critic* (single + value function that conditions on the full team state or concatenated + observations). The decentralised actor lets policies run independently at + execution time, while the centralised critic reduces variance during + training by giving every agent the same value baseline derived from full + state information. + + This class is a thin specialisation of :class:`ClipPPOLoss`. The + differences: + + - The default value estimator is :class:`~torchrl.objectives.value.MultiAgentGAE`, + which broadcasts team-shared rewards / done flags along the agent + dimension before computing returns. + - ``normalize_advantage_exclude_dims`` defaults to ``(-2,)`` so the agent + dim is excluded when standardising advantages. + - An optional :class:`~torchrl.modules.ValueNorm` can be supplied via + ``value_norm=PopArtValueNorm(shape=1)`` to stabilise the critic loss; + the MAPPO paper reports this is load-bearing on SMAC (their Table 13). + :class:`~torchrl.modules.RunningValueNorm` is a no-decay alternative + for stationary reward scales. + + Args: + actor_network (ProbabilisticTensorDictSequential): per-agent policy + operator. Conventionally built with + :class:`~torchrl.modules.MultiAgentMLP` using + ``centralized=False, share_params=True`` for cooperative + homogeneous teams. + critic_network (TensorDictModule): centralised value operator. Build + this with :class:`~torchrl.modules.MultiAgentMLP` and + ``centralized=True, share_params=True``, or with any module that + consumes a global ``"state"`` key and returns + ``("agents", "state_value")`` of shape ``[*B, n_agents, 1]``. + + Keyword Args: + value_norm (ValueNorm, optional): if supplied, the critic target and + prediction are normalised by this running normaliser before the + MSE / smooth-L1 distance. Defaults to ``None`` (no value norm). + clip_epsilon (float): PPO ratio clip. Defaults to ``0.2``. + entropy_coeff (float): entropy bonus weight. Defaults to ``0.01`` + (MAPPO default). + critic_coef (float, optional): critic loss weight. Defaults to ``1.0``. + normalize_advantage (bool): whether to standardise the advantage. + Defaults to ``True`` (MAPPO default; differs from base + :class:`ClipPPOLoss` which defaults to ``False``). + normalize_advantage_exclude_dims (tuple of int): dimensions to + exclude from advantage standardisation. Defaults to ``(-2,)`` + (the agent dim). + **kwargs: forwarded to :class:`ClipPPOLoss`. + + The expected tensordict layout follows the torchrl multi-agent convention + (see :class:`~torchrl.envs.libs.vmas.VmasEnv`, + :class:`~torchrl.envs.libs.pettingzoo.PettingZooEnv`): + + - ``("agents", "observation")``: ``[*B, T, n_agents, obs_dim]`` + - ``("agents", "action")``: ``[*B, T, n_agents, action_dim]`` + - Optional ``"state"`` at the root for centralised critics + - Team-shared ``("next", "reward")``, ``("next", "done")``, + ``("next", "terminated")`` of shape ``[*B, T, 1]`` (or per-agent under + ``("next", "agents", "reward")`` for competitive settings). + + Example: + >>> import torch + >>> from tensordict import TensorDict + >>> from tensordict.nn import TensorDictModule + >>> from torchrl.modules import ( + ... MultiAgentMLP, PopArtValueNorm, ProbabilisticActor, + ... ) + >>> from torchrl.modules.distributions import NormalParamExtractor, TanhNormal + >>> from torchrl.objectives.multiagent import MAPPOLoss + >>> n_agents, obs_dim, action_dim, state_dim = 3, 6, 2, 12 + >>> # Decentralised actor + >>> actor_net = torch.nn.Sequential( + ... MultiAgentMLP( + ... n_agent_inputs=obs_dim, n_agent_outputs=2 * action_dim, + ... n_agents=n_agents, centralized=False, share_params=True, + ... ), + ... NormalParamExtractor(), + ... ) + >>> actor_module = TensorDictModule( + ... actor_net, + ... in_keys=[("agents", "observation")], + ... out_keys=[("agents", "loc"), ("agents", "scale")], + ... ) + >>> actor = ProbabilisticActor( + ... module=actor_module, + ... in_keys=[("agents", "loc"), ("agents", "scale")], + ... out_keys=[("agents", "action")], + ... distribution_class=TanhNormal, + ... ) + >>> # Centralised critic — same agent-dim layout as the actor, with + >>> # centralized=True so each agent's effective input is the full + >>> # team's observation concatenated. + >>> critic = TensorDictModule( + ... MultiAgentMLP( + ... n_agent_inputs=obs_dim, n_agent_outputs=1, + ... n_agents=n_agents, centralized=True, share_params=True, + ... ), + ... in_keys=[("agents", "observation")], + ... out_keys=[("agents", "state_value")], + ... ) + >>> loss = MAPPOLoss(actor, critic, value_norm=PopArtValueNorm(shape=1)) + >>> loss.set_keys(value=("agents", "state_value"), action=("agents", "action")) + """ + + actor_network: TensorDictModule + critic_network: TensorDictModule + + def __init__( + self, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, + *, + value_norm: ValueNorm | None = None, + entropy_coeff: float | dict[str, float] = 0.01, + normalize_advantage: bool = True, + normalize_advantage_exclude_dims: tuple[int, ...] = (-2,), + **kwargs: Any, + ) -> None: + super().__init__( + actor_network, + critic_network, + entropy_coeff=entropy_coeff, + normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, + **kwargs, + ) + # Registered as a submodule so it moves with .to(device) and shows up + # in state_dict(). None is a valid value — we still keep the + # attribute for the loss_critic override to query. + self.value_norm = value_norm + if value_norm is not None: + self.add_module("_value_norm_module", value_norm) + + +class IPPOLoss(_MultiAgentPPOMixin, ClipPPOLoss): + """Independent PPO loss (de Witt et al. 2020). + + IPPO is the decentralised counterpart of MAPPO: each agent has its *own* + value function that conditions only on its local observation. There is no + centralised critic and no global state required. Surprisingly competitive + with MAPPO on many SMAC scenarios (the de Witt et al. paper is titled + *Is Independent Learning All You Need...*). + + Structurally this loss is identical to :class:`MAPPOLoss`; the difference + lives entirely in the critic the user passes in. We expose it as a + separate class so the API is self-documenting: when you import + ``IPPOLoss`` it is unambiguous which algorithm you are running, and the + docstring spells out the critic-construction recipe. + + Args: + actor_network (ProbabilisticTensorDictSequential): per-agent policy. + Build with ``MultiAgentMLP(centralized=False, share_params=True)``. + critic_network (TensorDictModule): per-agent value operator. Build + with ``MultiAgentMLP(centralized=False, share_params=True)`` so + each agent values its own observation. + + Keyword Args: + value_norm (ValueNorm, optional): rarely used with IPPO; defaults to + ``None``. + entropy_coeff (float): defaults to ``0.01``. + normalize_advantage (bool): defaults to ``True``. + normalize_advantage_exclude_dims (tuple of int): defaults to ``(-2,)``. + **kwargs: forwarded to :class:`ClipPPOLoss`. + """ + + actor_network: TensorDictModule + critic_network: TensorDictModule + + def __init__( + self, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, + *, + value_norm: ValueNorm | None = None, + entropy_coeff: float | dict[str, float] = 0.01, + normalize_advantage: bool = True, + normalize_advantage_exclude_dims: tuple[int, ...] = (-2,), + **kwargs: Any, + ) -> None: + super().__init__( + actor_network, + critic_network, + entropy_coeff=entropy_coeff, + normalize_advantage=normalize_advantage, + normalize_advantage_exclude_dims=normalize_advantage_exclude_dims, + **kwargs, + ) + self.value_norm = value_norm + if value_norm is not None: + self.add_module("_value_norm_module", value_norm) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 88ed2f739d2..28fd2b5cd89 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -7,7 +7,6 @@ import contextlib import warnings from collections.abc import Mapping -from copy import deepcopy from dataclasses import dataclass import torch @@ -39,18 +38,11 @@ _maybe_get_or_select, _reduce, _sum_td_features, - default_value_kwargs, + build_value_estimator, distance_loss, ValueEstimators, ) -from torchrl.objectives.value import ( - GAE, - TD0Estimator, - TD1Estimator, - TDLambdaEstimator, - ValueEstimatorBase, - VTrace, -) +from torchrl.objectives.value import ValueEstimatorBase class PPOLoss(LossModule): @@ -926,36 +918,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams return LossModule.make_value_estimator(self, value_type, **hyperparams) self.value_type = value_type - hp = dict(default_value_kwargs(value_type)) + hp = dict(hyperparams) if hasattr(self, "gamma"): - hp["gamma"] = self.gamma - hp.update(hyperparams) - if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic_network, **hp) - elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.VTrace: - # VTrace currently does not support functional call on the actor - if self.functional: - actor_with_params = deepcopy(self.actor_network) - self.actor_network_params.to_module(actor_with_params) - else: - actor_with_params = self.actor_network - self._value_estimator = VTrace( - value_network=self.critic_network, actor_network=actor_with_params, **hp - ) - else: - raise NotImplementedError(f"Unknown value type {value_type}") + hp.setdefault("gamma", self.gamma) + # Dispatch entirely through the registry; new estimators can be + # added with @register_value_estimator without touching this file. + self._value_estimator = build_value_estimator(self, value_type, **hp) tensor_keys = { "advantage": self.tensor_keys.advantage, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 39639632f66..8e75f34f320 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,7 +5,6 @@ from __future__ import annotations import contextlib -from copy import deepcopy from dataclasses import dataclass import torch @@ -24,18 +23,11 @@ _clip_value_loss, _GAMMA_LMBDA_DEPREC_ERROR, _reduce, - default_value_kwargs, + build_value_estimator, distance_loss, ValueEstimators, ) -from torchrl.objectives.value import ( - GAE, - TD0Estimator, - TD1Estimator, - TDLambdaEstimator, - ValueEstimatorBase, - VTrace, -) +from torchrl.objectives.value import ValueEstimatorBase class ReinforceLoss(LossModule): @@ -488,36 +480,12 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams return LossModule.make_value_estimator(self, value_type, **hyperparams) self.value_type = value_type - hp = dict(default_value_kwargs(value_type)) + hp = dict(hyperparams) if hasattr(self, "gamma"): - hp["gamma"] = self.gamma - hp.update(hyperparams) - if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic_network, **hp) - elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator( - value_network=self.critic_network, **hp - ) - elif value_type == ValueEstimators.VTrace: - # VTrace currently does not support functional call on the actor - if self.functional: - actor_with_params = deepcopy(self.actor_network) - self.actor_network_params.to_module(actor_with_params) - else: - actor_with_params = self.actor_network - self._value_estimator = VTrace( - value_network=self.critic_network, actor_network=actor_with_params, **hp - ) - else: - raise NotImplementedError(f"Unknown value type {value_type}") + hp.setdefault("gamma", self.gamma) + # Registry-driven dispatch — adding a new estimator only needs + # @register_value_estimator on the class, not edits here. + self._value_estimator = build_value_estimator(self, value_type, **hp) tensor_keys = { "advantage": self.tensor_keys.advantage, diff --git a/torchrl/objectives/utils.py b/torchrl/objectives/utils.py index b97ef5ea6a9..8bc8ac1908f 100644 --- a/torchrl/objectives/utils.py +++ b/torchrl/objectives/utils.py @@ -60,33 +60,139 @@ class ValueEstimators(Enum): TD1 = "TD(1) (infinity-step return)" TDLambda = "TD(lambda)" GAE = "Generalized advantage estimate" + MAGAE = "Multi-agent generalized advantage estimate" VTrace = "V-trace" +# --------------------------------------------------------------------------- +# Value-estimator registry +# --------------------------------------------------------------------------- +# +# Historically, every loss that wanted to pick between TD0 / GAE / V-Trace / +# etc. shipped its own ``make_value_estimator`` body with a hard-coded +# ``if/elif`` chain that knew the class names, the default kwargs, and any +# per-estimator construction quirks (e.g. V-Trace needs the actor). Adding a +# new estimator therefore meant touching ~15 loss files. +# +# The registry below decouples those three things: +# - which class implements a given ``ValueEstimators`` enum entry +# - what default hyper-parameters that class expects +# - how to wire the estimator against a particular ``LossModule`` +# +# Estimators self-register via the :func:`register_value_estimator` decorator +# at class-definition time. Loss modules can then build the right estimator +# with a single call to :func:`build_value_estimator`, regardless of how many +# concrete estimator classes exist. +# +# The registry accepts either an enum value or its lowercase string alias +# (e.g. ``"gae"``), which is convenient for config-driven setups. + + +class _ValueEstimatorRegistryEntry: + """One row of the value-estimator registry.""" + + __slots__ = ("cls", "default_kwargs") + + def __init__(self, cls: type, default_kwargs: dict) -> None: + self.cls = cls + self.default_kwargs = dict(default_kwargs) + + +_VALUE_ESTIMATOR_REGISTRY: dict[ValueEstimators, _ValueEstimatorRegistryEntry] = {} + + +def register_value_estimator( + value_type: ValueEstimators, *, default_kwargs: dict | None = None +): + """Decorator: register an estimator class against a :class:`ValueEstimators` entry. + + Args: + value_type: the enum entry this class implements. + default_kwargs: hyperparameter defaults applied when a loss calls + ``make_value_estimator(value_type)`` without overriding them. + + Example: + >>> @register_value_estimator( + ... ValueEstimators.GAE, + ... default_kwargs={"gamma": 0.99, "lmbda": 0.95, "differentiable": True}, + ... ) + ... class GAE(ValueEstimatorBase): + ... ... + """ + + def _decorator(cls): + _VALUE_ESTIMATOR_REGISTRY[value_type] = _ValueEstimatorRegistryEntry( + cls, default_kwargs or {} + ) + return cls + + return _decorator + + +def _coerce_value_type(value_type) -> ValueEstimators: + """Allow string aliases like ``"gae"`` alongside the enum values.""" + if isinstance(value_type, ValueEstimators): + return value_type + if isinstance(value_type, str): + # Accept both the enum *member* name ("GAE") and a lowercase alias + # ("gae") for ergonomics with hydra / yaml configs. + key = value_type.lower() + for member in ValueEstimators: + if member.name.lower() == key: + return member + raise KeyError( + f"Unknown value estimator alias {value_type!r}. " + f"Known aliases: {[m.name.lower() for m in ValueEstimators]}." + ) + raise TypeError( + f"value_type must be a ValueEstimators enum value or a string alias, " + f"got {type(value_type).__name__}." + ) + + +def get_value_estimator_entry(value_type) -> _ValueEstimatorRegistryEntry: + """Look up the registry entry for ``value_type`` (enum or string alias).""" + coerced = _coerce_value_type(value_type) + try: + return _VALUE_ESTIMATOR_REGISTRY[coerced] + except KeyError as exc: + raise NotImplementedError( + f"No value estimator registered for {coerced!r}. " + "Register one with @register_value_estimator(...) at class definition time." + ) from exc + + +def build_value_estimator(loss_module, value_type, **hyperparams): + """Construct a value estimator for ``loss_module`` using the registry. + + Resolves the class via :func:`get_value_estimator_entry`, merges the + registry defaults with the caller's ``hyperparams``, then delegates the + final wiring to ``cls.for_loss(loss_module, **merged)``. Estimator + subclasses with construction quirks (V-Trace needs the actor network) + override ``for_loss`` rather than every loss owning the quirk. + """ + entry = get_value_estimator_entry(value_type) + merged = {**entry.default_kwargs, **hyperparams} + return entry.cls.for_loss(loss_module, **merged) + + def default_value_kwargs(value_type: ValueEstimators): """Default value function keyword argument generator. + Now reads from :data:`_VALUE_ESTIMATOR_REGISTRY` so any + :func:`register_value_estimator`-decorated class is picked up + automatically. Retained as a top-level function for back-compat with + callers that don't want to touch the registry directly. + Args: value_type (Enum.value): the value function type, from the :class:`~torchrl.objectives.utils.ValueEstimators` class. Examples: >>> kwargs = default_value_kwargs(ValueEstimators.TDLambda) - {"gamma": 0.99, "lmbda": 0.95} - + {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} """ - if value_type == ValueEstimators.TD1: - return {"gamma": 0.99, "differentiable": True} - elif value_type == ValueEstimators.TD0: - return {"gamma": 0.99, "differentiable": True} - elif value_type == ValueEstimators.GAE: - return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} - elif value_type == ValueEstimators.TDLambda: - return {"gamma": 0.99, "lmbda": 0.95, "differentiable": True} - elif value_type == ValueEstimators.VTrace: - return {"gamma": 0.99, "differentiable": True} - else: - raise NotImplementedError(f"Unknown value type {value_type}.") + return dict(get_value_estimator_entry(value_type).default_kwargs) class _context_manager: diff --git a/torchrl/objectives/value/__init__.py b/torchrl/objectives/value/__init__.py index 4c8a29d6da3..e8f56be08f1 100644 --- a/torchrl/objectives/value/__init__.py +++ b/torchrl/objectives/value/__init__.py @@ -5,6 +5,7 @@ from .advantages import ( GAE, + MultiAgentGAE, TD0Estimate, TD0Estimator, TD1Estimate, @@ -17,6 +18,7 @@ __all__ = [ "GAE", + "MultiAgentGAE", "TD0Estimate", "TD0Estimator", "TD1Estimate", diff --git a/torchrl/objectives/value/advantages.py b/torchrl/objectives/value/advantages.py index 0201ed6da25..8315b6c4a3a 100644 --- a/torchrl/objectives/value/advantages.py +++ b/torchrl/objectives/value/advantages.py @@ -9,6 +9,7 @@ import warnings from collections.abc import Callable from contextlib import nullcontext +from copy import deepcopy from dataclasses import asdict, dataclass from functools import wraps @@ -34,6 +35,8 @@ _pseudo_vmap, _vmap_func, hold_out_net, + register_value_estimator, + ValueEstimators, ) from torchrl.objectives.value.functional import ( generalized_advantage_estimate, @@ -105,6 +108,24 @@ class ValueEstimatorBase(TensorDictModuleBase): """ + @classmethod + def for_loss(cls, loss_module, **hyperparams): + """Construct an instance configured against ``loss_module``. + + Used by the value-estimator registry + (:func:`~torchrl.objectives.utils.build_value_estimator`) to keep + per-estimator wiring quirks out of every loss class. The default + implementation picks up ``loss_module.critic_network`` if present, + falling back to ``loss_module.value_network``, and forwards the + remaining ``hyperparams`` to the constructor. Subclasses with + additional dependencies (e.g. :class:`VTrace` needing the actor) + override this method. + """ + value_network = getattr(loss_module, "critic_network", None) + if value_network is None: + value_network = getattr(loss_module, "value_network", None) + return cls(value_network=value_network, **hyperparams) + @dataclass class _AcceptedKeys: """Maintains default values for all configurable tensordict keys. @@ -619,6 +640,10 @@ def _call_value_net(data_in: TensorDictBase) -> torch.Tensor: return value, value_ +@register_value_estimator( + ValueEstimators.TD0, + default_kwargs={"gamma": 0.99, "differentiable": True}, +) class TD0Estimator(ValueEstimatorBase): """Temporal Difference (TD(0)) estimate of advantage function. @@ -841,6 +866,10 @@ def value_estimate( return value_target +@register_value_estimator( + ValueEstimators.TD1, + default_kwargs={"gamma": 0.99, "differentiable": True}, +) class TD1Estimator(ValueEstimatorBase): r""":math:`\infty`-Temporal Difference (TD(1)) estimate of advantage function. @@ -1071,6 +1100,10 @@ def value_estimate( return value_target +@register_value_estimator( + ValueEstimators.TDLambda, + default_kwargs={"gamma": 0.99, "lmbda": 0.95, "differentiable": True}, +) class TDLambdaEstimator(ValueEstimatorBase): r"""TD(:math:`\lambda`) estimate of advantage function. @@ -1335,6 +1368,10 @@ def value_estimate( return val +@register_value_estimator( + ValueEstimators.GAE, + default_kwargs={"gamma": 0.99, "lmbda": 0.95, "differentiable": True}, +) class GAE(ValueEstimatorBase): """A class wrapper around the generalized advantage estimate functional. @@ -1614,10 +1651,18 @@ def forward( terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) time_dim = self._get_time_dim(time_dim, tensordict) + # Subclass extension hook: lets subclasses reshape / broadcast the + # reward and done signals to match the value tensor before the + # advantage recursion is run. Default: identity. + reward, done, terminated = self._prepare_signals( + reward, done, terminated, value + ) + if self.auto_reset_env: truncated = tensordict.get(("next", "truncated")) + truncated = self._broadcast_optional(truncated, value) if truncated.any(): - reward += gamma * value * truncated + reward = reward + gamma * value * truncated if self.vectorized: adv, value_target = vec_generalized_advantage_estimate( @@ -1643,16 +1688,49 @@ def forward( ) if self.average_gae: - loc = adv.mean() - scale = adv.std().clamp_min(1e-4) - adv = adv - loc - adv = adv / scale + adv = self._normalize_advantage(adv) tensordict.set(self.tensor_keys.advantage, adv) tensordict.set(self.tensor_keys.value_target, value_target) return tensordict + # -- extension hooks ----------------------------------------------------- + + def _prepare_signals( + self, + reward: Tensor, + done: Tensor, + terminated: Tensor, + value: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + """Hook to reshape reward / done / terminated before the recursion. + + Default implementation is identity. :class:`MultiAgentGAE` overrides + this to broadcast team-shared signals across the agent dim. + """ + return reward, done, terminated + + def _broadcast_optional(self, tensor: Tensor, value: Tensor) -> Tensor: + """Optional broadcast for the truncated signal used in auto_reset_env. + + Default: return ``tensor`` unchanged. Subclasses that broadcast + rewards / done flags should typically override this with the same + broadcasting policy. + """ + return tensor + + def _normalize_advantage(self, adv: Tensor) -> Tensor: + """Standardise the advantage tensor. + + Default standardises globally (single mean/std over the whole tensor). + :class:`MultiAgentGAE` overrides this to leave the agent dim + independent. + """ + loc = adv.mean() + scale = adv.std().clamp_min(1e-4) + return (adv - loc) / scale + def value_estimate( self, tensordict, @@ -1708,6 +1786,9 @@ def value_estimate( next_value = tensordict.get(("next", self.tensor_keys.value)) done = tensordict.get(("next", self.tensor_keys.done)) terminated = tensordict.get(("next", self.tensor_keys.terminated), default=done) + reward, done, terminated = self._prepare_signals( + reward, done, terminated, value + ) _, value_target = vec_generalized_advantage_estimate( gamma, lmbda, @@ -1721,6 +1802,104 @@ def value_estimate( return value_target +@register_value_estimator( + ValueEstimators.MAGAE, + default_kwargs={"gamma": 0.99, "lmbda": 0.95, "differentiable": True}, +) +class MultiAgentGAE(GAE): + """Multi-agent Generalized Advantage Estimator. + + Drop-in replacement for :class:`GAE` when the value network produces per-agent + state values (shape ``[*B, T, n_agents, 1]``) but the reward / done / + terminated signals are shared across agents at the team level + (shape ``[*B, T, 1]``) — the standard cooperative-MARL layout in torchrl + (see e.g. ``torchrl/envs/libs/vmas.py`` and + ``torchrl/envs/libs/pettingzoo.py``). + + The estimator detects whether the reward/done/terminated tensors are missing + the agent dimension relative to the value tensor, and broadcasts them along + that dimension before running the standard vectorised GAE recursion. If the + reward is already per-agent (e.g. a competitive setting), it is passed + through unchanged. + + The output ``"advantage"`` and ``"value_target"`` entries match the shape + of the value tensor (``[*B, T, n_agents, 1]``), which is what + :class:`~torchrl.objectives.multiagent.MAPPOLoss` expects. + + Keyword Args: + agent_dim (int, optional): the dimension that holds the agent index in + the value tensor. Negative dimensions are taken modulo + ``value.ndim``. Defaults to ``-2`` (penultimate), matching the + convention used by :class:`~torchrl.modules.MultiAgentMLP`. + + Other args/kwargs are forwarded to :class:`GAE`. + """ + + def __init__(self, *, agent_dim: int = -2, **kwargs): + super().__init__(**kwargs) + self.agent_dim = agent_dim + + @staticmethod + def _broadcast_to_agents( + tensor: torch.Tensor, target: torch.Tensor, agent_dim: int + ) -> torch.Tensor: + """Expand ``tensor`` along ``agent_dim`` to match ``target``'s shape. + + If ``tensor`` already has the same number of dims as ``target`` we + assume it is per-agent and return it unchanged. Otherwise we unsqueeze + at ``agent_dim`` and expand. + """ + if tensor.ndim == target.ndim: + return tensor + if tensor.ndim != target.ndim - 1: + raise ValueError( + f"MultiAgentGAE expected the reward/done/terminated tensor to " + f"have either the same number of dims as the value tensor " + f"(per-agent) or one fewer (team-shared). Got " + f"tensor.shape={tuple(tensor.shape)}, " + f"value.shape={tuple(target.shape)}." + ) + dim = agent_dim if agent_dim >= 0 else target.ndim + agent_dim + n_agents = target.shape[dim] + unsqueezed = tensor.unsqueeze(dim) + expand_shape = list(unsqueezed.shape) + expand_shape[dim] = n_agents + return unsqueezed.expand(expand_shape) + + # -- GAE extension hooks ------------------------------------------------- + + def _prepare_signals( + self, + reward: Tensor, + done: Tensor, + terminated: Tensor, + value: Tensor, + ) -> tuple[Tensor, Tensor, Tensor]: + return ( + self._broadcast_to_agents(reward, value, self.agent_dim), + self._broadcast_to_agents(done, value, self.agent_dim), + self._broadcast_to_agents(terminated, value, self.agent_dim), + ) + + def _broadcast_optional(self, tensor: Tensor, value: Tensor) -> Tensor: + # Used by GAE for the auto_reset_env ``truncated`` tensor — same + # broadcasting policy as the other team signals. + return self._broadcast_to_agents(tensor, value, self.agent_dim) + + def _normalize_advantage(self, adv: Tensor) -> Tensor: + # Per-agent standardisation: normalise over batch + time but keep the + # agent dim independent so high-variance agents are not flattened. + agent_dim = self.agent_dim if self.agent_dim >= 0 else adv.ndim + self.agent_dim + reduce_dims = [d for d in range(adv.ndim) if d != agent_dim] + loc = adv.mean(dim=reduce_dims, keepdim=True) + scale = adv.std(dim=reduce_dims, keepdim=True).clamp_min(1e-4) + return (adv - loc) / scale + + +@register_value_estimator( + ValueEstimators.VTrace, + default_kwargs={"gamma": 0.99, "differentiable": True}, +) class VTrace(ValueEstimatorBase): """A class wrapper around V-Trace estimate functional. @@ -1785,6 +1964,27 @@ class VTrace(ValueEstimatorBase): """ + @classmethod + def for_loss(cls, loss_module, **hyperparams): + """V-Trace needs both the critic *and* the actor. + + When the loss is functional, the actor stored on the loss module is + a stateless template — we deep-copy it and bake the current params + in, since V-Trace doesn't support a functional actor call. + """ + value_network = getattr(loss_module, "critic_network", None) + if value_network is None: + value_network = getattr(loss_module, "value_network", None) + actor_network = loss_module.actor_network + if getattr(loss_module, "functional", False): + actor_network = deepcopy(actor_network) + loss_module.actor_network_params.to_module(actor_network) + return cls( + value_network=value_network, + actor_network=actor_network, + **hyperparams, + ) + def __init__( self, *,