Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ Documentation Sections
objectives_policy
objectives_actorcritic
objectives_offline
objectives_multiagent
objectives_other
1 change: 1 addition & 0 deletions docs/source/reference/objectives_common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Value Estimators
TD1Estimator
TDLambdaEstimator
GAE
MultiAgentGAE

.. currentmodule:: torchrl.objectives

Expand Down
58 changes: 58 additions & 0 deletions docs/source/reference/objectives_multiagent.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
.. currentmodule:: torchrl.objectives.multiagent

Multi-Agent Objectives
======================

Loss modules for multi-agent reinforcement learning algorithms. These losses
follow the torchrl multi-agent tensordict convention (per-agent tensors
nested under group keys such as ``("agents", "observation")``; see
:class:`~torchrl.envs.libs.vmas.VmasEnv` and
:class:`~torchrl.envs.libs.pettingzoo.PettingZooEnv`).

MAPPO and IPPO
--------------

:class:`MAPPOLoss` implements Multi-Agent PPO (Yu et al. 2022) — a
decentralised actor paired with a *centralised critic* that conditions on the
joint observation / state. :class:`IPPOLoss` is the independent-learner
counterpart from de Witt et al. 2020: each agent has its own local critic and
there is no centralised information at training time.

Both are thin specialisations of :class:`~torchrl.objectives.ClipPPOLoss`
that:

- default the value estimator to
:class:`~torchrl.objectives.value.MultiAgentGAE`, which broadcasts
team-shared rewards / done flags across the agent dimension before
computing returns;
- default ``normalize_advantage_exclude_dims`` to ``(-2,)`` so the agent dim
is excluded from advantage standardisation;
- optionally accept a :class:`~torchrl.modules.ValueNorm` subclass — either
:class:`~torchrl.modules.PopArtValueNorm` (EMA, recommended for drifting
reward scales) or :class:`~torchrl.modules.RunningValueNorm` (exact
Welford running stats, recommended for stationary scales) — to stabilise
the critic loss. The MAPPO paper credits this trick for its strong SMAC
results.

See ``sota-implementations/multiagent/mappo_ippo.py`` for a hydra-configured
recipe and ``examples/multiagent/mappo_vmas.py`` for a minimal one.

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

MAPPOLoss
IPPOLoss

QMixer
------

:class:`QMixerLoss` mixes local per-agent Q values into a global team Q
value via a learnable mixing network, and trains them jointly with a DQN
update on the global value (Rashid et al. 2018).

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

QMixerLoss
221 changes: 221 additions & 0 deletions examples/multiagent/mappo_vmas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Minimal MAPPO / IPPO recipe on VMAS using the new
:class:`~torchrl.objectives.multiagent.MAPPOLoss` /
:class:`~torchrl.objectives.multiagent.IPPOLoss` classes.

For the full, hydra-configured, wandb-logged version see
``sota-implementations/multiagent/mappo_ippo.py``. This file is intentionally
short: it's there to show that the new loss classes collapse the boilerplate
that previously required ``ClipPPOLoss`` + manual ``set_keys(done=...,
terminated=...)`` + manual ``make_value_estimator(GAE, ...)`` into a single
construction call.

Usage::

python examples/multiagent/mappo_vmas.py --algo mappo --frames 200_000
python examples/multiagent/mappo_vmas.py --algo ippo --frames 200_000

The two should reach similar reward on the easy navigation scenario; MAPPO
typically pulls ahead on harder coordination tasks (Yu et al. 2022).
"""
from __future__ import annotations

import argparse
import time

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data import LazyTensorStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.envs import RewardSum, TransformedEnv
from torchrl.modules import (
MultiAgentMLP,
PopArtValueNorm,
ProbabilisticActor,
TanhNormal,
)
from torchrl.objectives import IPPOLoss, MAPPOLoss


def make_actor(env, *, share_params: bool = True) -> ProbabilisticActor:
obs_dim = env.observation_spec["agents", "observation"].shape[-1]
action_dim = env.action_spec.shape[-1]
backbone = nn.Sequential(
MultiAgentMLP(
n_agent_inputs=obs_dim,
n_agent_outputs=2 * action_dim,
n_agents=env.n_agents,
centralized=False,
share_params=share_params,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
NormalParamExtractor(),
)
module = TensorDictModule(
backbone,
in_keys=[("agents", "observation")],
out_keys=[("agents", "loc"), ("agents", "scale")],
)
return ProbabilisticActor(
module=module,
in_keys=[("agents", "loc"), ("agents", "scale")],
out_keys=[env.action_key],
distribution_class=TanhNormal,
distribution_kwargs={
"low": env.full_action_spec_unbatched[("agents", "action")].space.low,
"high": env.full_action_spec_unbatched[("agents", "action")].space.high,
},
return_log_prob=True,
)


def make_critic(
env, *, centralized: bool, share_params: bool = True
) -> TensorDictModule:
obs_dim = env.observation_spec["agents", "observation"].shape[-1]
return TensorDictModule(
MultiAgentMLP(
n_agent_inputs=obs_dim,
n_agent_outputs=1,
n_agents=env.n_agents,
centralized=centralized,
share_params=share_params,
depth=2,
num_cells=256,
activation_class=nn.Tanh,
),
in_keys=[("agents", "observation")],
out_keys=[("agents", "state_value")],
)


def main(args: argparse.Namespace) -> None:
try:
from torchrl.envs.libs.vmas import VmasEnv
except ImportError as exc:
raise SystemExit(
"This example requires VMAS. Install it with `pip install vmas`."
) from exc

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(args.seed)

n_envs = max(1, args.frames_per_batch // args.max_steps)
env = TransformedEnv(
VmasEnv(
scenario=args.scenario,
num_envs=n_envs,
continuous_actions=True,
max_steps=args.max_steps,
device=device,
seed=args.seed,
),
RewardSum(
in_keys=[("next", "agents", "reward")],
out_keys=[("agents", "episode_reward")],
)
if False
else RewardSum(in_keys=["reward"], out_keys=["episode_reward"]),
)

actor = make_actor(env)
centralised = args.algo == "mappo"
critic = make_critic(env, centralized=centralised)

LossCls = MAPPOLoss if args.algo == "mappo" else IPPOLoss
value_norm = (
PopArtValueNorm(shape=1, device=device) if args.algo == "mappo" else None
)
loss_module = LossCls(
actor_network=actor,
critic_network=critic,
value_norm=value_norm,
clip_epsilon=0.2,
entropy_coeff=0.01,
)
loss_module.set_keys(
value=("agents", "state_value"),
action=env.action_key,
reward=env.reward_key,
)

collector = SyncDataCollector(
env,
actor,
device=device,
storing_device=device,
frames_per_batch=args.frames_per_batch,
total_frames=args.frames,
)

replay_buffer = TensorDictReplayBuffer(
storage=LazyTensorStorage(args.frames_per_batch, device=device),
sampler=SamplerWithoutReplacement(),
batch_size=args.minibatch_size,
)

optim = torch.optim.Adam(loss_module.parameters(), lr=args.lr)

total_frames = 0
start = time.time()
for it, td in enumerate(collector):
with torch.no_grad():
loss_module.value_estimator(
td,
params=loss_module.critic_network_params,
target_params=loss_module.target_critic_network_params,
)
replay_buffer.extend(td.reshape(-1))
total_frames += td.numel()

for _ in range(args.epochs):
for _ in range(args.frames_per_batch // args.minibatch_size):
subdata = replay_buffer.sample()
losses = loss_module(subdata)
loss = (
losses["loss_objective"]
+ losses["loss_critic"]
+ losses["loss_entropy"]
)
optim.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0)
optim.step()

collector.update_policy_weights_()
ep_reward = td.get(("next", "episode_reward")).mean().item()
print(
f"[{args.algo}] iter={it:03d} frames={total_frames:>7d} "
f"reward={ep_reward:+.3f} elapsed={time.time() - start:5.1f}s"
)

collector.shutdown()
if not env.is_closed:
env.close()


def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--algo", choices=("mappo", "ippo"), default="mappo")
p.add_argument("--scenario", default="navigation")
p.add_argument("--frames", type=int, default=200_000)
p.add_argument("--frames_per_batch", type=int, default=6_000)
p.add_argument("--minibatch_size", type=int, default=400)
p.add_argument("--max_steps", type=int, default=100)
p.add_argument("--epochs", type=int, default=4)
p.add_argument("--lr", type=float, default=3e-4)
p.add_argument("--seed", type=int, default=0)
return p.parse_args()


if __name__ == "__main__":
main(parse_args())
8 changes: 4 additions & 4 deletions test/objectives/test_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_cql(
delay_qvalue=delay_qvalue,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_cql_deactivate_vmap(
deactivate_vmap=False,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn_vmap.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -389,7 +389,7 @@ def test_cql_deactivate_vmap(
deactivate_vmap=True,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn_no_vmap.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -845,7 +845,7 @@ def test_dcql(self, delay_value, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DiscreteCQLLoss(actor, loss_function="l2", delay_value=delay_value)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down
10 changes: 5 additions & 5 deletions test/objectives/test_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
delay_actor=delay_actor,
delay_value=delay_value,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def test_td3(
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -1136,7 +1136,7 @@ def test_td3_deactivate_vmap(
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn_vmap.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -1166,7 +1166,7 @@ def test_td3_deactivate_vmap(
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn_no_vmap.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -1937,7 +1937,7 @@ def test_td3bc(
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down
4 changes: 2 additions & 2 deletions test/objectives/test_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
delay_value=delay_value,
double_dqn=double_dqn,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down Expand Up @@ -922,7 +922,7 @@ def test_qmixer(self, delay_value, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = QMixerLoss(actor, mixer, loss_function="l2", delay_value=delay_value)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
if td_est in (ValueEstimators.GAE, ValueEstimators.MAGAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
loss_fn.make_value_estimator(td_est)
return
Expand Down
Loading
Loading