From 598dd3a6cf7d44aad55c7733e841ef132ea9d3d8 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 8 Oct 2025 16:39:00 -0400 Subject: [PATCH 1/9] feat: add qsm and idem --- examples/online/config/mujoco/algo/idem.yaml | 23 ++ examples/online/config/mujoco/algo/qsm.yaml | 22 ++ examples/online/main_mujoco_offpolicy.py | 3 + flowrl/agent/online/__init__.py | 2 + flowrl/agent/online/idem.py | 100 ++++++++ flowrl/agent/online/qsm.py | 253 +++++++++++++++++++ flowrl/config/online/mujoco/__init__.py | 4 + flowrl/config/online/mujoco/algo/idem.py | 32 +++ flowrl/config/online/mujoco/algo/qsm.py | 31 +++ 9 files changed, 470 insertions(+) create mode 100644 examples/online/config/mujoco/algo/idem.yaml create mode 100644 examples/online/config/mujoco/algo/qsm.yaml create mode 100644 flowrl/agent/online/idem.py create mode 100644 flowrl/agent/online/qsm.py create mode 100644 flowrl/config/online/mujoco/algo/idem.py create mode 100644 flowrl/config/online/mujoco/algo/qsm.py diff --git a/examples/online/config/mujoco/algo/idem.yaml b/examples/online/config/mujoco/algo/idem.yaml new file mode 100644 index 0000000..f4a9b36 --- /dev/null +++ b/examples/online/config/mujoco/algo/idem.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: idem + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/qsm.yaml b/examples/online/config/mujoco/algo/qsm.yaml new file mode 100644 index 0000000..5815408 --- /dev/null +++ b/examples/online/config/mujoco/algo/qsm.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/main_mujoco_offpolicy.py b/examples/online/main_mujoco_offpolicy.py index f408493..8e85f9e 100644 --- a/examples/online/main_mujoco_offpolicy.py +++ b/examples/online/main_mujoco_offpolicy.py @@ -11,6 +11,7 @@ from tqdm import tqdm from flowrl.agent.online import * +from flowrl.agent.online.idem import IDEMAgent from flowrl.config.online.mujoco import Config from flowrl.dataset.buffer.state import ReplayBuffer from flowrl.types import * @@ -25,6 +26,8 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, + "qsm": QSMAgent, + "idem": IDEMAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 9041320..d8068bd 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -2,6 +2,7 @@ from .ctrl.ctrl import CtrlTD3Agent from .dpmd import DPMDAgent from .ppo import PPOAgent +from .qsm import QSMAgent from .sac import SACAgent from .sdac import SDACAgent from .td3 import TD3Agent @@ -16,4 +17,5 @@ "DPMDAgent", "PPOAgent", "CtrlTD3Agent", + "QSMAgent", ] diff --git a/flowrl/agent/online/idem.py b/flowrl/agent/online/idem.py new file mode 100644 index 0000000..b24c906 --- /dev/null +++ b/flowrl/agent/online/idem.py @@ -0,0 +1,100 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.qsm import QSMAgent, jit_update_qsm_critic +from flowrl.config.online.mujoco.algo.idem import IDEMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + +jit_update_idem_critic = jit_update_qsm_critic + +@partial(jax.jit, static_argnames=("num_reverse_samples", "temp",)) +def jit_update_idem_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + num_reverse_samples: int, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + obs_repeat = batch.obs[jnp.newaxis, ...].repeat(num_reverse_samples, axis=0) + + rng, tnormal_rng, clipped_rng = jax.random.split(rng, 3) + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + lower_bound = - 1.0 / alpha2 * at - alpha1 / alpha2 + upper_bound = - 1.0 / alpha2 * at + alpha1 / alpha2 + tnormal_noise = jax.random.truncated_normal(tnormal_rng, lower_bound, upper_bound, (num_reverse_samples, *at.shape)) + normal_noise = jax.random.normal(clipped_rng, (num_reverse_samples, *at.shape)) + normal_noise_clipped = jnp.clip(normal_noise, lower_bound, upper_bound) + eps_reverse = jnp.where(jnp.isnan(tnormal_noise), normal_noise_clipped, tnormal_noise) + a0_hat = 1 / alpha1 * at + alpha2 / alpha1 * eps_reverse + + q_value_and_grad_fn = jax.vmap( + jax.vmap( + jax.value_and_grad(lambda a, s: critic_target(s, a).min(axis=0).mean()), + ) + ) + q_value, q_grad = q_value_and_grad_fn(a0_hat, obs_repeat) + q_grad = q_grad / temp + weight = jax.nn.softmax(q_value / temp, axis=0) + eps_estimation = - (alpha2 / alpha1) * jnp.sum(weight[:, :, jnp.newaxis] * q_grad, axis=0) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class IDEMAgent(QSMAgent): + """ + Iterative Denoising Energy Matching (iDEM) Agent. + """ + name = "IDEMAgent" + model_names = ["actor", "critic", "critic_target"] + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_idem_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_idem_actor( + self.rng, + self.actor, + self.critic_target, + batch, + num_reverse_samples=self.cfg.num_reverse_samples, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py new file mode 100644 index 0000000..3e66f88 --- /dev/null +++ b/flowrl/agent/online/qsm.py @@ -0,0 +1,253 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.config.online.mujoco.algo.qsm import QSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = critic(obs_repeat, actions) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver", "ema")) +def jit_update_qsm_critic( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + critic_target: Model, + batch: Batch, + discount: float, + solver: str, + ema: float, +) -> Tuple[PRNGKey, Model, Model, Metric]: + # update critic + rng, next_xT_rng = jax.random.split(rng) + next_xT = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample(rng, next_xT, batch.next_obs, training=False, solver=solver) + q_target = critic_target(batch.next_obs, next_action) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q = critic.apply( + {"params": critic_params}, + batch.obs, + batch.action, + training=True, + rngs={"dropout": dropout_rng}, + ) + critic_loss = ((q - q_target[jnp.newaxis, :])**2).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + + new_critic_target = ema_update(new_critic, critic_target, ema) + return rng, new_critic, new_critic_target, critic_metrics + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_qsm_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class QSMAgent(BaseAgent): + """ + Q Score Matching (QSM) agent and beyond. + """ + name = "QSMAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, actor_rng, critic_rng = jax.random.split(self.rng, 3) + + # define the actor + time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + + self.actor = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + # CHECK: is this really necessary, since we are not using the target actor for policy evaluation? + self.actor_target = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + ) + + # define the critic + critic_def = EnsembleCritic( + hidden_dims=cfg.critic_hidden_dims, + activation=jax.nn.relu, + layer_norm=False, + dropout=None, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_qsm_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_qsm_actor( + self.rng, + self.actor, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + return action, {} diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index e775d25..6928331 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -3,6 +3,8 @@ from .algo.base import BaseAlgoConfig from .algo.ctrl_td3 import CtrlTD3Config from .algo.dpmd import DPMDConfig +from .algo.idem import IDEMConfig +from .algo.qsm import QSMConfig from .algo.sac import SACConfig from .algo.sdac import SDACConfig from .algo.td3 import TD3Config @@ -24,6 +26,8 @@ "td7": TD7Config, "dpmd": DPMDConfig, "ctrl": CtrlTD3Config, + "qsm": QSMConfig, + "idem": IDEMConfig } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/idem.py b/flowrl/config/online/mujoco/algo/idem.py new file mode 100644 index 0000000..c46e7b6 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/idem.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class IDEMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class IDEMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_lr: float + discount: float + num_samples: int + num_reverse_samples: int + ema: float + temp: float + diffusion: IDEMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/qsm.py b/flowrl/config/online/mujoco/algo/qsm.py new file mode 100644 index 0000000..158b821 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/qsm.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class QSMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class QSMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_lr: float + discount: float + num_samples: int + ema: float + temp: float + diffusion: QSMDiffusionConfig From 7b35efb4270d837025a556c732a59e870d921d2c Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 8 Oct 2025 16:45:28 -0400 Subject: [PATCH 2/9] update: some logging --- flowrl/agent/online/idem.py | 5 +++++ flowrl/agent/online/qsm.py | 1 + 2 files changed, 6 insertions(+) diff --git a/flowrl/agent/online/idem.py b/flowrl/agent/online/idem.py index b24c906..4fdfe41 100644 --- a/flowrl/agent/online/idem.py +++ b/flowrl/agent/online/idem.py @@ -64,6 +64,11 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra loss = ((eps_pred - eps_estimation) ** 2).mean() return loss, { "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/weights": weight.mean(), + "misc/weight_std": weight.std(0).mean(), + "misc/weight_max": weight.max(0).mean(), + "misc/weight_min": weight.min(0).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py index 3e66f88..6f20fa5 100644 --- a/flowrl/agent/online/qsm.py +++ b/flowrl/agent/online/qsm.py @@ -110,6 +110,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra loss = ((eps_pred - eps_estimation) ** 2).mean() return loss, { "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) From 9ba90e2ddd1da7056eced292d8411f86860d37ce Mon Sep 17 00:00:00 2001 From: Edward Chen Date: Thu, 9 Oct 2025 19:44:37 -0400 Subject: [PATCH 3/9] qsm scripts --- examples/online/config/dmc/algo/qsm.yaml | 22 ++++++++ examples/online/main_dmc_offpolicy.py | 1 + flowrl/agent/online/__init__.py | 2 + scripts/dmc/qsm.sh | 68 ++++++++++++++++++++++++ 4 files changed, 93 insertions(+) create mode 100644 examples/online/config/dmc/algo/qsm.yaml create mode 100644 scripts/dmc/qsm.sh diff --git a/examples/online/config/dmc/algo/qsm.yaml b/examples/online/config/dmc/algo/qsm.yaml new file mode 100644 index 0000000..1a3fdd1 --- /dev/null +++ b/examples/online/config/dmc/algo/qsm.yaml @@ -0,0 +1,22 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512, 512] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index 0f8e59f..df1d54b 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -27,6 +27,7 @@ "sdac": SDACAgent, "dpmd": DPMDAgent, "ctrl_td3": CtrlTD3Agent, + "qsm": QSMAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index d8068bd..949cbd2 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,6 +1,7 @@ from ..base import BaseAgent from .ctrl.ctrl import CtrlTD3Agent from .dpmd import DPMDAgent +from .idem import IDEMAgent from .ppo import PPOAgent from .qsm import QSMAgent from .sac import SACAgent @@ -18,4 +19,5 @@ "PPOAgent", "CtrlTD3Agent", "QSMAgent", + "IDEMAgent" ] diff --git a/scripts/dmc/qsm.sh b/scripts/dmc/qsm.sh new file mode 100644 index 0000000..fea0999 --- /dev/null +++ b/scripts/dmc/qsm.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi From 4d2a765029c0fceed9498c52aa75c8030533fc98 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 29 Oct 2025 12:43:17 -0400 Subject: [PATCH 4/9] update --- examples/online/config/mujoco/algo/alac.yaml | 24 ++ examples/online/config/mujoco/algo/qsm.yaml | 6 +- examples/online/config/mujoco/config.yaml | 1 - examples/online/main_mujoco_offpolicy.py | 4 +- flowrl/agent/online/__init__.py | 5 +- flowrl/agent/online/alac/alac.py | 339 +++++++++++++++++++ flowrl/agent/online/alac/network.py | 87 +++++ flowrl/agent/online/qsm.py | 20 +- flowrl/config/online/mujoco/__init__.py | 4 +- flowrl/config/online/mujoco/algo/alac.py | 32 ++ flowrl/flow/langevin_dynamics.py | 293 ++++++++++++++++ scripts/mujoco/alac.sh | 60 ++++ scripts/mujoco/qsm.sh | 53 +++ 13 files changed, 904 insertions(+), 24 deletions(-) create mode 100644 examples/online/config/mujoco/algo/alac.yaml create mode 100644 flowrl/agent/online/alac/alac.py create mode 100644 flowrl/agent/online/alac/network.py create mode 100644 flowrl/config/online/mujoco/algo/alac.py create mode 100644 flowrl/flow/langevin_dynamics.py create mode 100644 scripts/mujoco/alac.sh create mode 100644 scripts/mujoco/qsm.sh diff --git a/examples/online/config/mujoco/algo/alac.yaml b/examples/online/config/mujoco/algo/alac.yaml new file mode 100644 index 0000000..94fe9e3 --- /dev/null +++ b/examples/online/config/mujoco/algo/alac.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +algo: + name: alac + discount: 0.99 + num_samples: 10 + ema: 0.005 + ld: + resnet: false + activation: relu + ensemble_size: 2 + time_dim: 64 + hidden_dims: [512, 512] + cond_hidden_dims: [128, 128] + steps: 20 + step_size: 0.05 + noise_scale: 1.0 + noise_schedule: "none" + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + epsilon: 0.001 + lr: 0.0003 + clip_grad_norm: null diff --git a/examples/online/config/mujoco/algo/qsm.yaml b/examples/online/config/mujoco/algo/qsm.yaml index 5815408..a6d14e3 100644 --- a/examples/online/config/mujoco/algo/qsm.yaml +++ b/examples/online/config/mujoco/algo/qsm.yaml @@ -2,15 +2,15 @@ algo: name: qsm - critic_hidden_dims: [256, 256] + critic_hidden_dims: [512, 512] critic_lr: 0.0003 discount: 0.99 num_samples: 10 ema: 0.005 - temp: 0.2 + temp: 0.1 diffusion: time_dim: 64 - mlp_hidden_dims: [256, 256] + mlp_hidden_dims: [512, 512] lr: 0.0003 end_lr: null lr_decay_steps: null diff --git a/examples/online/config/mujoco/config.yaml b/examples/online/config/mujoco/config.yaml index f662225..7658820 100644 --- a/examples/online/config/mujoco/config.yaml +++ b/examples/online/config/mujoco/config.yaml @@ -27,7 +27,6 @@ random_frames: 5_000 eval_frames: 10_000 log_frames: 1_000 lap_reset_frames: 250 -eval_episodes: 10 log: dir: logs tag: debug diff --git a/examples/online/main_mujoco_offpolicy.py b/examples/online/main_mujoco_offpolicy.py index 8e85f9e..4432994 100644 --- a/examples/online/main_mujoco_offpolicy.py +++ b/examples/online/main_mujoco_offpolicy.py @@ -6,12 +6,11 @@ import jax import numpy as np import omegaconf -import wandb from omegaconf import OmegaConf from tqdm import tqdm +import wandb from flowrl.agent.online import * -from flowrl.agent.online.idem import IDEMAgent from flowrl.config.online.mujoco import Config from flowrl.dataset.buffer.state import ReplayBuffer from flowrl.types import * @@ -28,6 +27,7 @@ "dpmd": DPMDAgent, "qsm": QSMAgent, "idem": IDEMAgent, + "alac": ALACAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 949cbd2..915b658 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,5 +1,6 @@ from ..base import BaseAgent from .ctrl.ctrl import CtrlTD3Agent +from .alac.alac import ALACAgent from .dpmd import DPMDAgent from .idem import IDEMAgent from .ppo import PPOAgent @@ -19,5 +20,7 @@ "PPOAgent", "CtrlTD3Agent", "QSMAgent", - "IDEMAgent" + "IDEMAgent", + "ALACAgent", + "CtrlTD3Agent", ] diff --git a/flowrl/agent/online/alac/alac.py b/flowrl/agent/online/alac/alac.py new file mode 100644 index 0000000..45bbe3c --- /dev/null +++ b/flowrl/agent/online/alac/alac.py @@ -0,0 +1,339 @@ +from functools import partial +from typing import Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.alac.network import EnsembleEnergyNet +from flowrl.config.online.mujoco.algo.alac import ALACConfig +from flowrl.flow.langevin_dynamics import AnnealedLangevinDynamics +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples")) +def jit_sample_actions( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + ld: AnnealedLangevinDynamics, + obs, + training: bool, + num_samples: int +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, x_init_rng = jax.random.split(rng) + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + x_init = jax.random.normal(x_init_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, x_init, obs_repeat, training=training) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = ld(actions, t=jnp.zeros((B, num_samples, 1), dtype=jnp.float32), condition=obs_repeat) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "ema")) +def jit_update_ld( + rng: PRNGKey, + ld: AnnealedLangevinDynamics, + ld_target: AnnealedLangevinDynamics, + actor: AnnealedLangevinDynamics, + batch: Batch, + discount: float, + ema: float, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, AnnealedLangevinDynamics, Metric]: + B, A = batch.action.shape[0], batch.action.shape[1] + feed_t = jnp.zeros((B, 1), dtype=jnp.float32) + + rng, next_xT_rng = jax.random.split(rng) + # next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], ld.x_dim)) + next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + # rng, next_action, history = ld_target.sample( + # rng, + # next_action_init, + # batch.next_obs, + # training=False, + # ) + rng, next_action, history = actor.sample( + rng, + next_action_init, + batch.next_obs, + training=False, + ) + q_target = ld_target(next_action, feed_t, batch.next_obs, training=False) + # q_target = ld_target(batch.next_obs, next_action, training=False) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def ld_loss_fn(params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = ld.apply( + {"params": params}, + batch.action, + t=feed_t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + # q_pred = ld.apply( + # {"params": params}, + # batch.obs, + # batch.action, + # training=True, + # rngs={"dropout": dropout_rng}, + # ) + ld_loss = ((q_pred - q_target[jnp.newaxis, :])**2).mean() + return ld_loss, { + "loss/ld_loss": ld_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + # "misc/q_grad_l1": jnp.abs(history[1]).mean(), + } + + new_ld, ld_metrics = ld.apply_gradient(ld_loss_fn) + new_ld_target = ema_update(new_ld, ld_target, ema) + + # record energy + # num_checkpoints = 5 + # stepsize_checkpoint = ld.steps // num_checkpoints + # energy_history = history[2][jnp.arange(0, ld.steps, stepsize_checkpoint)] + # energy_history = energy_history.mean(axis=[-2, -1]) + # ld_metrics.update({ + # f"info/energy_step{i}": energy for i, energy in enumerate(energy_history) + # }) + + return rng, new_ld, new_ld_target, ld_metrics + + +@partial(jax.jit, static_argnames=()) +def jit_update_actor( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + critic_target: AnnealedLangevinDynamics, + batch: Batch, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, Metric]: + x0 = batch.action + rng, xt, t, eps = actor.add_noise(rng, x0) + # rng, t_rng, noise_rng = jax.random.split(rng, 3) + # t = jax.random.uniform(t_rng, (*x0.shape[:-1], 1), dtype=jnp.float32, minval=actor.t_diffusion[0], maxval=actor.t_diffusion[1]) + # eps = jax.random.normal(noise_rng, x0.shape, dtype=jnp.float32) + alpha, sigma = actor.noise_schedule_func(t) + xt = alpha * x0 + sigma * eps + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(a, None, condition=s).min(axis=0).mean())) + # q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(xt, batch.obs) + q_grad = alpha * q_grad - sigma * xt + eps_estimation = sigma * q_grad / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + xt, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class ALACAgent(BaseAgent): + """ + Annealed Langevin Dynamics Actor-Critic (ALAC) agent. + """ + name = "ALACAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ALACConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, ld_rng = jax.random.split(self.rng, 2) + + # define the critic + # from flowrl.module.critic import EnsembleCritic + # from flowrl.module.model import Model + # critic_def = EnsembleCritic( + # hidden_dims=cfg.ld.hidden_dims, + # activation=jax.nn.relu, + # layer_norm=False, + # dropout=None, + # ensemble_size=2, + # ) + # self.ld = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + # self.ld_target = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # ) + + mlp_impl = ResidualMLP if cfg.ld.resnet else MLP + activation = {"mish": mish, "relu": jax.nn.relu}[cfg.ld.activation] + energy_def = EnsembleEnergyNet( + mlp_impl=mlp_impl, + hidden_dims=cfg.ld.hidden_dims, + output_dim=1, + activation=activation, + layer_norm=False, + dropout=None, + ensemble_size=cfg.ld.ensemble_size, + # time_embedding=partial(LearnableFourierEmbedding, output_dim=cfg.ld.time_dim), + time_embedding=None, + # cond_embedding=partial(MLP, hidden_dims=cfg.ld.cond_hidden_dims, activation=activation), + cond_embedding=None, + ) + self.ld = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + self.ld_target = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + ) + + # DEBUG define the actor + from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone + self.rng, actor_rng = jax.random.split(self.rng, 2) + time_embedding = partial[LearnableFourierEmbedding](LearnableFourierEmbedding, output_dim=cfg.ld.time_dim) + cond_embedding = partial(MLP, hidden_dims=[128, 128], activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.ld.hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + # self.actor = ContinuousDDPM.create( + # network=backbone_def, + # rng=actor_rng, + # inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + # x_dim=self.act_dim, + # steps=cfg.ld.steps, + # noise_schedule="cosine", + # noise_schedule_params={}, + # clip_sampler=cfg.ld.clip_sampler, + # x_min=cfg.ld.x_min, + # x_max=cfg.ld.x_max, + # t_schedule_n=1.0, + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + self.actor = AnnealedLangevinDynamics.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=True, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.ld, self.ld_target, ld_metrics = jit_update_ld( + self.rng, + self.ld, + self.ld_target, + self.actor, + batch, + self.cfg.discount, + self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.ld_target, + batch, + ) + + self._n_training_steps += 1 + return {**ld_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + # self.ld, + self.actor, + self.ld, + obs, + training=False, + num_samples=num_samples, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} diff --git a/flowrl/agent/online/alac/network.py b/flowrl/agent/online/alac/network.py new file mode 100644 index 0000000..6f9e05b --- /dev/null +++ b/flowrl/agent/online/alac/network.py @@ -0,0 +1,87 @@ +import flax.linen as nn +import jax.numpy as jnp + +from flowrl.functional.activation import mish +from flowrl.module.mlp import MLP +from flowrl.types import * + + +class EnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + if condition is not None: + if self.cond_embedding is not None: + condition = self.cond_embedding()(condition, training=training) + else: + condition = condition + x = jnp.concatenate([x, condition], axis=-1) + if self.time_embedding is not None: + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([x, t_ff], axis=-1) + x = self.mlp_impl( + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + )(x, training) + return x + + +class EnsembleEnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + ensemble_size: int = 2 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + vmap_energy_net = nn.vmap( + EnergyNet, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + x = vmap_energy_net( + mlp_impl=self.mlp_impl, + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + cond_embedding=self.cond_embedding, + time_embedding=self.time_embedding, + )(x, t, condition, training) + return x diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py index 6f20fa5..9e013da 100644 --- a/flowrl/agent/online/qsm.py +++ b/flowrl/agent/online/qsm.py @@ -55,7 +55,6 @@ def jit_update_qsm_critic( solver: str, ema: float, ) -> Tuple[PRNGKey, Model, Model, Metric]: - # update critic rng, next_xT_rng = jax.random.split(rng) next_xT = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) rng, next_action, _ = actor.sample(rng, next_xT, batch.next_obs, training=False, solver=solver) @@ -75,6 +74,7 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar "loss/critic_loss": critic_loss, "misc/q_mean": q.mean(), "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_action).mean(), } new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) @@ -96,7 +96,7 @@ def jit_update_qsm_actor( q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) q_grad = q_grad_fn(at, batch.obs) - eps_estimation = - alpha2 * q_grad / temp + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: eps_pred = actor.apply( @@ -171,20 +171,6 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): t_schedule_n=1.0, optimizer=optax.adam(learning_rate=actor_lr), ) - # CHECK: is this really necessary, since we are not using the target actor for policy evaluation? - self.actor_target = ContinuousDDPM.create( - network=backbone_def, - rng=actor_rng, - inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), - x_dim=self.act_dim, - steps=cfg.diffusion.steps, - noise_schedule="cosine", - noise_schedule_params={}, - clip_sampler=cfg.diffusion.clip_sampler, - x_min=cfg.diffusion.x_min, - x_max=cfg.diffusion.x_max, - t_schedule_n=1.0, - ) # define the critic critic_def = EnsembleCritic( @@ -251,4 +237,6 @@ def sample_actions( num_samples=num_samples, solver=self.cfg.diffusion.solver, ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) return action, {} diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index 6928331..c4d559f 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -1,5 +1,6 @@ from hydra.core.config_store import ConfigStore +from .algo.alac import ALACConfig from .algo.base import BaseAlgoConfig from .algo.ctrl_td3 import CtrlTD3Config from .algo.dpmd import DPMDConfig @@ -27,7 +28,8 @@ "dpmd": DPMDConfig, "ctrl": CtrlTD3Config, "qsm": QSMConfig, - "idem": IDEMConfig + "idem": IDEMConfig, + "alac": ALACConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/alac.py b/flowrl/config/online/mujoco/algo/alac.py new file mode 100644 index 0000000..b07cc24 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/alac.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class ALACLangevinDynamicsConfig: + resnet: bool + activation: str + ensemble_size: int + time_dim: int + hidden_dims: List[int] + cond_hidden_dims: List[int] + steps: int + step_size: float + noise_scale: float + noise_schedule: str + clip_sampler: bool + x_min: float + x_max: float + epsilon: float + lr: float + clip_grad_norm: float | None + +@dataclass +class ALACConfig(BaseAlgoConfig): + name: str + discount: float + ema: float + num_samples: int + ld: ALACLangevinDynamicsConfig diff --git a/flowrl/flow/langevin_dynamics.py b/flowrl/flow/langevin_dynamics.py new file mode 100644 index 0000000..81fe14b --- /dev/null +++ b/flowrl/flow/langevin_dynamics.py @@ -0,0 +1,293 @@ +from functools import partial +from typing import Callable, Optional, Sequence, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.struct import PyTreeNode, dataclass, field +from flax.training.train_state import TrainState + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule, linear_noise_schedule +from flowrl.module.model import Model +from flowrl.types import * + +# ======= Langevin Dynamics Sampling ======= + +@dataclass +class LangevinDynamics(Model): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool = True, + steps: int = 100, + step_size: float = 0.01, + noise_scale: float = 1.0, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + optimizer: Optional[optax.GradientTransformation] = None, + clip_grad_norm: float = None + ) -> 'LangevinDynamics': + ret = super().create(network, rng, inputs, optimizer, clip_grad_norm) + + return ret.replace( + x_dim=x_dim, + grad_prediction=grad_prediction, + steps=steps, + step_size=step_size, + noise_scale=noise_scale, + clip_sampler=clip_sampler, + x_min=x_min, + x_max=x_max, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + i: int, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = i * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros_like((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + step_size = step_size or self.step_size + noise_scale = noise_scale or self.noise_scale + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, noise_rng, dropout_rng_ = jax.random.split(rng_, 3) + + grad, energy = self.compute_grad(xt, i, condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + + xt_1 = xt + step_size * grad + if self.clip_sampler: + xt_1 = jnp.clip(xt_1, self.x_min, self.x_max) + noise = jax.random.normal(noise_rng, xt_1.shape, dtype=jnp.float32) + xt_1 += (i>1) * jnp.sqrt(2 * step_size * noise_scale) * noise + + return (rng_, xt_1), (xt, grad, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history + + +@dataclass +class AnnealedLangevinDynamics(LangevinDynamics): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + t_schedule_n: float = field(pytree_node=False, default=None) + t_diffusion: Tuple[float, float] = field(pytree_node=False, default=None) + noise_schedule_func: Callable = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool, + steps: int, + step_size: float, + noise_scale: float, + noise_schedule: str, + noise_schedule_params: Optional[Dict]=None, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + t_schedule_n: float=1.0, + epsilon: float=0.001, + optimizer: Optional[optax.GradientTransformation]=None, + clip_grad_norm: float=None + ) -> 'AnnealedLangevinDynamics': + ret = super().create( + network, + rng, + inputs, + x_dim, + grad_prediction, + steps, + step_size, + noise_scale, + clip_sampler, + x_min, + x_max, + optimizer, + clip_grad_norm, + ) + + if noise_schedule_params is None: + noise_schedule_params = {} + if noise_schedule == "cosine": + t_diffusion = [epsilon, 0.9946] + else: + t_diffusion = [epsilon, 1.0] + if noise_schedule == "linear": + noise_schedule_func = partial(linear_noise_schedule, **noise_schedule_params) + elif noise_schedule == "cosine": + noise_schedule_func = partial(cosine_noise_schedule, **noise_schedule_params) + elif noise_schedule == "none": + noise_schedule_func = lambda t: (jnp.ones_like(t), jnp.zeros_like(t)) + else: + raise NotImplementedError(f"Unsupported noise schedule: {noise_schedule}") + + return ret.replace( + t_schedule_n=t_schedule_n, + t_diffusion=t_diffusion, + noise_schedule_func=noise_schedule_func, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + t: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = t * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + # alpha, sigma = self.noise_schedule_func(t) + # grad = alpha * grad - sigma * x + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + def add_noise(self, rng: PRNGKey, x: jnp.ndarray) -> Tuple[PRNGKey, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + rng, t_rng, noise_rng = jax.random.split(rng, 3) + t = jax.random.uniform(t_rng, (*x.shape[:-1], 1), dtype=jnp.float32, minval=self.t_diffusion[0], maxval=self.t_diffusion[1]) + alpha, sigma = self.noise_schedule_func(t) + eps = jax.random.normal(noise_rng, x.shape, dtype=jnp.float32) + xt = alpha * x + sigma * eps + return rng, xt, t, eps + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + # step_size = step_size or self.step_size + # noise_scale = noise_scale or self.noise_scale + t_schedule_n = 1.0 + from flowrl.flow.continuous_ddpm import quad_t_schedule + ts = quad_t_schedule(steps, n=t_schedule_n, tmin=self.t_diffusion[0], tmax=self.t_diffusion[1]) + alpha_hats = self.noise_schedule_func(ts)[0] ** 2 + alphas = alpha_hats[1:] / alpha_hats[:-1] + alphas = jnp.concat([jnp.ones((1, )), alphas], axis=0) + betas = 1 - alphas + alpha1, alpha2 = self.noise_schedule_func(ts) + + t_proto = jnp.ones((*x_init.shape[:-1], 1), dtype=jnp.int32) + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, dropout_rng_, key_ = jax.random.split(rng_, 3) + input_t = t_proto * ts[i] + + q_grad, energy = self.compute_grad(xt, ts[i], condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + eps_theta = q_grad + + x0_hat = (xt - jnp.sqrt(1 - alpha_hats[i]) * eps_theta) / jnp.sqrt(alpha_hats[i]) + x0_hat = jnp.clip(x0_hat, self.x_min, self.x_max) if self.clip_sampler else x0_hat + + mean_coef1 = jnp.sqrt(alpha_hats[i-1]) * betas[i] / (1 - alpha_hats[i]) + mean_coef2 = jnp.sqrt(alphas[i]) * (1 - alpha_hats[i-1]) / (1 - alpha_hats[i]) + xt_1 = mean_coef1 * x0_hat + mean_coef2 * xt + xt_1 += (i>1) * jnp.sqrt(betas[i]) * jax.random.normal(key_, xt_1.shape) + + return (rng_, xt_1), (xt, eps_theta, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history diff --git a/scripts/mujoco/alac.sh b/scripts/mujoco/alac.sh new file mode 100644 index 0000000..2232a3a --- /dev/null +++ b/scripts/mujoco/alac.sh @@ -0,0 +1,60 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + # "Hopper-v5" + # "HumanoidStandup-v5" + "Humanoid-v5" + # "InvertedDoublePendulum-v5" + # "InvertedPendulum-v5" + # "Pusher-v5" + # "Reacher-v5" + # "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=alac" + # "algo.ld.step_size=0.1" + # "algo.ld.noise_scale=0.01" + # "algo.ld.steps=50" + # "log.tag=noise_none-stepsize0.1-noise0.01-steps50-no_last_noise" + "algo.ld.activation=relu" + "algo.ld.steps=20" + "algo.ld.noise_schedule=cosine" + "log.tag=use_ld_but_actually_diffusion-decay_q-temp0.5" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/mujoco/qsm.sh b/scripts/mujoco/qsm.sh new file mode 100644 index 0000000..04023c8 --- /dev/null +++ b/scripts/mujoco/qsm.sh @@ -0,0 +1,53 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + "Hopper-v5" + "HumanoidStandup-v5" + "Humanoid-v5" + "InvertedDoublePendulum-v5" + "InvertedPendulum-v5" + "Pusher-v5" + "Reacher-v5" + "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi From 330da0f17826a6187d0a0fa61c9989d54012150e Mon Sep 17 00:00:00 2001 From: typoverflow Date: Thu, 30 Oct 2025 22:34:58 -0400 Subject: [PATCH 5/9] update --- examples/online/config/dmc/algo/ctrl_qsm.yaml | 49 +++ examples/online/config/dmc/algo/qsm.yaml | 3 +- examples/online/config/mujoco/algo/qsm.yaml | 1 + examples/online/main_dmc_offpolicy.py | 1 + flowrl/agent/online/__init__.py | 3 +- flowrl/agent/online/ctrl/__init__.py | 7 + flowrl/agent/online/ctrl/ctrl_qsm.py | 287 ++++++++++++++++++ .../online/ctrl/{ctrl.py => ctrl_td3.py} | 2 +- flowrl/agent/online/qsm.py | 8 +- flowrl/config/online/mujoco/__init__.py | 7 +- .../online/mujoco/algo/ctrl/__init__.py | 7 + .../online/mujoco/algo/ctrl/ctrl_qsm.py | 40 +++ .../online/mujoco/algo/{ => ctrl}/ctrl_td3.py | 2 +- flowrl/config/online/mujoco/algo/qsm.py | 1 + 14 files changed, 409 insertions(+), 9 deletions(-) create mode 100644 examples/online/config/dmc/algo/ctrl_qsm.yaml create mode 100644 flowrl/agent/online/ctrl/__init__.py create mode 100644 flowrl/agent/online/ctrl/ctrl_qsm.py rename flowrl/agent/online/ctrl/{ctrl.py => ctrl_td3.py} (99%) create mode 100644 flowrl/config/online/mujoco/algo/ctrl/__init__.py create mode 100644 flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py rename flowrl/config/online/mujoco/algo/{ => ctrl}/ctrl_td3.py (96%) diff --git a/examples/online/config/dmc/algo/ctrl_qsm.yaml b/examples/online/config/dmc/algo/ctrl_qsm.yaml new file mode 100644 index 0000000..af85061 --- /dev/null +++ b/examples/online/config/dmc/algo/ctrl_qsm.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +algo: + name: ctrl_qsm + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + # critic_hidden_dims: [512, 512, 512] # not used + critic_activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + critic_lr: 0.0003 + clip_grad_norm: null + + # below are params specific to ctrl_td3 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ctrl_coef: 1.0 + reward_coef: 1.0 + back_critic_grad: false + critic_coef: 1.0 + + num_noises: 25 + linear: false + ranking: true + + num_samples: 10 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/qsm.yaml b/examples/online/config/dmc/algo/qsm.yaml index 1a3fdd1..12b9624 100644 --- a/examples/online/config/dmc/algo/qsm.yaml +++ b/examples/online/config/dmc/algo/qsm.yaml @@ -3,11 +3,12 @@ algo: name: qsm critic_hidden_dims: [512, 512, 512] + critic_activation: elu critic_lr: 0.0003 discount: 0.99 num_samples: 10 ema: 0.005 - temp: 0.2 + temp: 0.1 diffusion: time_dim: 64 mlp_hidden_dims: [512, 512, 512] diff --git a/examples/online/config/mujoco/algo/qsm.yaml b/examples/online/config/mujoco/algo/qsm.yaml index a6d14e3..c646df7 100644 --- a/examples/online/config/mujoco/algo/qsm.yaml +++ b/examples/online/config/mujoco/algo/qsm.yaml @@ -3,6 +3,7 @@ algo: name: qsm critic_hidden_dims: [512, 512] + critic_activation: relu critic_lr: 0.0003 discount: 0.99 num_samples: 10 diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index df1d54b..8f51bb9 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -28,6 +28,7 @@ "dpmd": DPMDAgent, "ctrl_td3": CtrlTD3Agent, "qsm": QSMAgent, + "ctrl_qsm": CtrlQSMAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 915b658..db396ef 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,6 +1,6 @@ from ..base import BaseAgent -from .ctrl.ctrl import CtrlTD3Agent from .alac.alac import ALACAgent +from .ctrl import * from .dpmd import DPMDAgent from .idem import IDEMAgent from .ppo import PPOAgent @@ -23,4 +23,5 @@ "IDEMAgent", "ALACAgent", "CtrlTD3Agent", + "CtrlQSMAgent", ] diff --git a/flowrl/agent/online/ctrl/__init__.py b/flowrl/agent/online/ctrl/__init__.py new file mode 100644 index 0000000..602a563 --- /dev/null +++ b/flowrl/agent/online/ctrl/__init__.py @@ -0,0 +1,7 @@ +from .ctrl_qsm import CtrlQSMAgent +from .ctrl_td3 import CtrlTD3Agent + +__all__ = [ + "CtrlTD3Agent", + "CtrlQSMAgent", +] diff --git a/flowrl/agent/online/ctrl/ctrl_qsm.py b/flowrl/agent/online/ctrl/ctrl_qsm.py new file mode 100644 index 0000000..2bf388b --- /dev/null +++ b/flowrl/agent/online/ctrl/ctrl_qsm.py @@ -0,0 +1,287 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce +from flowrl.agent.online.qsm import QSMAgent +from flowrl.config.online.mujoco.algo.ctrl.ctrl_qsm import CtrlQSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM +from flowrl.functional.ema import ema_update +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + feature = nce_target(obs_repeat, actions, method="forward_phi") + qs = critic(feature) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample( + rng, + next_xT, + batch.next_obs, + training=False, + solver=solver, + ) + next_feature = nce_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + feature = nce_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + +@partial(jax.jit, static_argnames=("temp")) +def update_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: + feature = nce_target(obs, action, method="forward_phi") + q = critic_target(feature) + return q.min(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs) + q_grad = alpha1 * q_grad - alpha2 * at + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class CtrlQSMAgent(QSMAgent): + """ + CTRL with Q Score Matching (QSM) agent. + """ + + name = "CtrlQSMAgent" + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlQSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ctrl_coef = cfg.ctrl_coef + self.critic_coef = cfg.critic_coef + + self.linear = cfg.linear + self.ranking = cfg.ranking + self.feature_dim = cfg.feature_dim + self.num_noises = cfg.num_noises + self.reward_coef = cfg.reward_coef + self.rff_dim = cfg.rff_dim + self.actor_update_freq = cfg.actor_update_freq + self.target_update_freq = cfg.target_update_freq + + + # sanity checks for the hyper-parameters + assert not self.linear, "linear mode is not supported yet" + + # networks + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.nce_target, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.nce_target, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/ctrl/ctrl.py b/flowrl/agent/online/ctrl/ctrl_td3.py similarity index 99% rename from flowrl/agent/online/ctrl/ctrl.py rename to flowrl/agent/online/ctrl/ctrl_td3.py index 186c4d9..f379b46 100644 --- a/flowrl/agent/online/ctrl/ctrl.py +++ b/flowrl/agent/online/ctrl/ctrl_td3.py @@ -7,7 +7,7 @@ from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce from flowrl.agent.online.td3 import TD3Agent -from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config +from flowrl.config.online.mujoco.algo.ctrl.ctrl_td3 import CtrlTD3Config from flowrl.functional.ema import ema_update from flowrl.module.actor import SquashedDeterministicActor from flowrl.module.mlp import MLP diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py index 9e013da..0a4e0f6 100644 --- a/flowrl/agent/online/qsm.py +++ b/flowrl/agent/online/qsm.py @@ -111,6 +111,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra return loss, { "loss/actor_loss": loss, "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) @@ -156,7 +157,6 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): else: actor_lr = cfg.diffusion.lr - self.actor = ContinuousDDPM.create( network=backbone_def, rng=actor_rng, @@ -173,9 +173,13 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): ) # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] critic_def = EnsembleCritic( hidden_dims=cfg.critic_hidden_dims, - activation=jax.nn.relu, + activation=critic_activation, layer_norm=False, dropout=None, ensemble_size=2, diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index c4d559f..801cfc1 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -2,7 +2,7 @@ from .algo.alac import ALACConfig from .algo.base import BaseAlgoConfig -from .algo.ctrl_td3 import CtrlTD3Config +from .algo.ctrl import * from .algo.dpmd import DPMDConfig from .algo.idem import IDEMConfig from .algo.qsm import QSMConfig @@ -26,10 +26,11 @@ "td3": TD3Config, "td7": TD7Config, "dpmd": DPMDConfig, - "ctrl": CtrlTD3Config, "qsm": QSMConfig, - "idem": IDEMConfig, "alac": ALACConfig, + "idem": IDEMConfig, + "ctrl_td3": CtrlTD3Config, + "ctrl_qsm": CtrlQSMConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/ctrl/__init__.py b/flowrl/config/online/mujoco/algo/ctrl/__init__.py new file mode 100644 index 0000000..7ed1456 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrl/__init__.py @@ -0,0 +1,7 @@ +from .ctrl_qsm import CtrlQSMConfig +from .ctrl_td3 import CtrlTD3Config + +__all__ = [ + "CtrlTD3Config", + "CtrlQSMConfig", +] diff --git a/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py new file mode 100644 index 0000000..318e72b --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py @@ -0,0 +1,40 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig +from ..qsm import QSMDiffusionConfig + + +@dataclass +class CtrlQSMConfig(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + # critic_hidden_dims: List[int] + critic_activation: str # not used + critic_ensemble_size: int + layer_norm: bool + critic_lr: float + clip_grad_norm: float | None + + feature_dim: int + feature_lr: float + feature_ema: float + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ctrl_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float + + num_noises: int + linear: bool + ranking: bool + + num_samples: int + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/ctrl_td3.py b/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py similarity index 96% rename from flowrl/config/online/mujoco/algo/ctrl_td3.py rename to flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py index 667496a..374d820 100644 --- a/flowrl/config/online/mujoco/algo/ctrl_td3.py +++ b/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import List -from .base import BaseAlgoConfig +from ..base import BaseAlgoConfig @dataclass diff --git a/flowrl/config/online/mujoco/algo/qsm.py b/flowrl/config/online/mujoco/algo/qsm.py index 158b821..8c02f0f 100644 --- a/flowrl/config/online/mujoco/algo/qsm.py +++ b/flowrl/config/online/mujoco/algo/qsm.py @@ -23,6 +23,7 @@ class QSMDiffusionConfig: class QSMConfig(BaseAlgoConfig): name: str critic_hidden_dims: List[int] + critic_activation: str critic_lr: float discount: float num_samples: int From df0525095542bbe46c27817daf439370a3444430 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Tue, 4 Nov 2025 16:03:31 -0500 Subject: [PATCH 6/9] stash ctrl_qsm --- flowrl/agent/online/ctrl/ctrl_qsm.py | 2 +- .../online/mujoco/algo/ctrl/ctrl_qsm.py | 1 + scripts/dmc/ctrl_qsm.sh | 68 +++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 scripts/dmc/ctrl_qsm.sh diff --git a/flowrl/agent/online/ctrl/ctrl_qsm.py b/flowrl/agent/online/ctrl/ctrl_qsm.py index 2bf388b..2960ae1 100644 --- a/flowrl/agent/online/ctrl/ctrl_qsm.py +++ b/flowrl/agent/online/ctrl/ctrl_qsm.py @@ -107,7 +107,6 @@ def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: return q.min(axis=0).mean() q_grad_fn = jax.vmap(jax.grad(get_q_value)) q_grad = q_grad_fn(at, batch.obs) - q_grad = alpha1 * q_grad - alpha2 * at eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: @@ -123,6 +122,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra return loss, { "loss/actor_loss": loss, "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) diff --git a/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py index 318e72b..402aa92 100644 --- a/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py +++ b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py @@ -37,4 +37,5 @@ class CtrlQSMConfig(BaseAlgoConfig): ranking: bool num_samples: int + temp: float diffusion: QSMDiffusionConfig diff --git a/scripts/dmc/ctrl_qsm.sh b/scripts/dmc/ctrl_qsm.sh new file mode 100644 index 0000000..34729d6 --- /dev/null +++ b/scripts/dmc/ctrl_qsm.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=ctrl_qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi From f148d68e9f067b8de0646aa50c001384c5a5d0f7 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 5 Nov 2025 15:27:04 -0500 Subject: [PATCH 7/9] update: add aca --- examples/online/config/dmc/algo/aca.yaml | 41 +++ examples/online/main_dmc_offpolicy.py | 1 + flowrl/agent/online/__init__.py | 2 + flowrl/agent/online/unirep/__init__.py | 5 + flowrl/agent/online/unirep/aca.py | 338 ++++++++++++++++++ flowrl/agent/online/unirep/network.py | 208 +++++++++++ flowrl/config/online/mujoco/__init__.py | 2 + .../online/mujoco/algo/unirep/__init__.py | 5 + .../config/online/mujoco/algo/unirep/aca.py | 49 +++ scripts/dmc/aca.sh | 68 ++++ 10 files changed, 719 insertions(+) create mode 100644 examples/online/config/dmc/algo/aca.yaml create mode 100644 flowrl/agent/online/unirep/__init__.py create mode 100644 flowrl/agent/online/unirep/aca.py create mode 100644 flowrl/agent/online/unirep/network.py create mode 100644 flowrl/config/online/mujoco/algo/unirep/__init__.py create mode 100644 flowrl/config/online/mujoco/algo/unirep/aca.py create mode 100644 scripts/dmc/aca.sh diff --git a/examples/online/config/dmc/algo/aca.yaml b/examples/online/config/dmc/algo/aca.yaml new file mode 100644 index 0000000..d40491a --- /dev/null +++ b/examples/online/config/dmc/algo/aca.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +algo: + name: aca + target_update_freq: 1 + feature_dim: 512 + rff_dim: 1024 + critic_hidden_dims: [512, 512] + reward_hidden_dims: [512, 512] + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + ctrl_coef: 1.0 + reward_coef: 1.0 + critic_coef: 1.0 + critic_activation: elu # not used + back_critic_grad: false + feature_lr: 0.0001 + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + feature_ema: 0.005 + clip_grad_norm: null + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + num_noises: 25 + linear: false + ranking: true + +norm_obs: true diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index 8f51bb9..5d8c0fb 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -29,6 +29,7 @@ "ctrl_td3": CtrlTD3Agent, "qsm": QSMAgent, "ctrl_qsm": CtrlQSMAgent, + "aca": ACAAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index db396ef..4de6c7a 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -9,6 +9,7 @@ from .sdac import SDACAgent from .td3 import TD3Agent from .td7.td7 import TD7Agent +from .unirep import * __all__ = [ "BaseAgent", @@ -24,4 +25,5 @@ "ALACAgent", "CtrlTD3Agent", "CtrlQSMAgent", + "ACAAgent", ] diff --git a/flowrl/agent/online/unirep/__init__.py b/flowrl/agent/online/unirep/__init__.py new file mode 100644 index 0000000..e789c67 --- /dev/null +++ b/flowrl/agent/online/unirep/__init__.py @@ -0,0 +1,5 @@ +from .aca import ACAAgent + +__all__ = [ + "ACAAgent", +] diff --git a/flowrl/agent/online/unirep/aca.py b/flowrl/agent/online/unirep/aca.py new file mode 100644 index 0000000..e5451ee --- /dev/null +++ b/flowrl/agent/online/unirep/aca.py @@ -0,0 +1,338 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.unirep.network import FactorizedNCE, update_factorized_nce +from flowrl.config.online.mujoco.algo.unirep.aca import ACAConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + nce_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + t0 = jnp.ones((obs_repeat.shape[0], num_samples, 1)) + f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + qs = critic(f0) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver", "critic_coef")) +def jit_update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + # q0 target + t0 = jnp.ones((batch.obs.shape[0], 1)) + rng, next_aT_rng = jax.random.split(rng) + next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) + next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") + q0_target = critic_target(next_f0) + q0_target = batch.reward + discount * (1 - batch.terminal) * q0_target.min(axis=0) + + # qt target + a0 = batch.action + f0 = nce_target(batch.obs, a0, t0, method="forward_phi") + qt_target = critic_target(f0) + + # features + rng, at, t, eps = actor.add_noise(rng, a0) + ft = nce_target(batch.obs, at, t, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q0_pred = critic.apply( + {"params": critic_params}, + f0, + training=True, + rngs={"dropout": dropout_rng}, + ) + qt_pred = critic.apply( + {"params": critic_params}, + ft, + training=True, + rngs={"dropout": dropout_rng}, + ) + critic_loss = ( + ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() + + ((qt_pred - qt_target[jnp.newaxis, :])**2).mean() + ) + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q0_mean": q0_pred.mean(), + "misc/qt_mean": qt_pred.mean(), + "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_a0).mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, critic_metrics + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha, sigma = actor.noise_schedule_func(t) + + def get_q_value(at: jnp.ndarray, obs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: + ft = nce_target(obs, at, t, method="forward_phi") + q = critic_target(ft) + return q.mean(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs, t) + eps_estimation = - sigma * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class ACAAgent(BaseAgent): + """ + ACA (Actor-Critic with Actor) agent. + """ + name = "ACAAgent" + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.feature_dim = cfg.feature_dim + self.ranking = cfg.ranking + self.linear = cfg.linear + self.reward_coef = cfg.reward_coef + self.critic_coef = cfg.critic_coef + + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + + # define the nce + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + # define the actor + time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + self.actor = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] + critic_def = EnsembleCritic( + hidden_dims=cfg.critic_hidden_dims, + activation=critic_activation, + layer_norm=True, + dropout=None, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + self.rng, self.critic, critic_metrics = jit_update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.nce_target, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.cfg.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.nce_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/network.py b/flowrl/agent/online/unirep/network.py new file mode 100644 index 0000000..a4bc468 --- /dev/null +++ b/flowrl/agent/online/unirep/network.py @@ -0,0 +1,208 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import l2_normalize, mish +from flowrl.module.critic import Critic +from flowrl.module.mlp import ResidualMLP +from flowrl.module.model import Model +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class FactorizedNCE(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int = 0 + num_noises: int = 0 + ranking: bool = False + + def setup(self): + self.mlp_t = nn.Sequential( + [LearnableFourierEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + ) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + # self.reward = RffReward( + # self.feature_dim, + # self.reward_hidden_dims, + # rff_dim=self.rff_dim, + # ) + self.reward = Critic( + hidden_dims=self.reward_hidden_dims, + activation=nn.elu, + layer_norm=True, + dropout=None, + ) + if self.num_noises > 0: + self.use_noise_perturbation = True + self.noise_schedule_fn = cosine_noise_schedule + else: + self.use_noise_perturbation = False + self.N = max(self.num_noises, 1) + if not self.ranking: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + else: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + + def forward_phi(self, s, at, t): + x = jnp.concat([s, at], axis=-1) + if t is not None: + t_ff = self.mlp_t(t) + x = jnp.concat([x, t_ff], axis=-1) + x = self.mlp_phi(x) + x = l2_normalize(x, group_size=None) + return x + + def forward_mu(self, sp): + sp = self.mlp_mu(sp) + return sp + + def forward_reward(self, x: jnp.ndarray): # for z_phi + return self.reward(x) + + def forward_logits( + self, + rng: PRNGKey, + s: jnp.ndarray, + a: jnp.ndarray, + sp: jnp.ndarray, + z_mu: jnp.ndarray | None=None + ): + B, D = sp.shape + rng, t_rng, eps_rng = jax.random.split(rng, 3) + if z_mu is None: + z_mu = self.forward_mu(sp) + if self.use_noise_perturbation: + s = jnp.broadcast_to(s, (self.N, B, s.shape[-1])) + a0 = jnp.broadcast_to(a, (self.N, B, a.shape[-1])) + t = jax.random.uniform(t_rng, (self.N,), dtype=jnp.float32) # check removing min val and max val is valid + t = jnp.repeat(t, B).reshape(self.N, B, 1) + eps = jax.random.normal(eps_rng, a0.shape) + alpha, sigma = self.noise_schedule_fn(t) + at = alpha * a0 + sigma * eps + else: + s = jnp.expand_dims(s, 0) + at = jnp.expand_dims(a, 0) + t = None + z_phi = self.forward_phi(s, at, t) + z_mu = jnp.broadcast_to(z_mu, (self.N, B, self.feature_dim)) + logits = jax.lax.batch_matmul(z_phi, jnp.swapaxes(z_mu, -1, -2)) + logits = logits / jnp.exp(self.normalizer[:, None, None]) + rewards = self.forward_reward(z_phi) + return logits, rewards + + def forward_normalizer(self): + return self.normalizer + + def __call__( + self, + rng: PRNGKey, + s, + a, + sp, + ): + logits, rewards = self.forward_logits(rng, s, a, sp) + _ = self.forward_normalizer() + + return logits, rewards + + +@partial(jax.jit, static_argnames=("ranking", "reward_coef")) +def update_factorized_nce( + rng: PRNGKey, + nce: Model, + batch: Batch, + ranking: bool, + reward_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + B = batch.obs.shape[0] + rng, logits_rng = jax.random.split(rng) + if ranking: + labels = jnp.arange(B) + else: + labels = jnp.eye(B) + + def loss_fn(nce_params: Param, dropout_rng: PRNGKey): + z_mu = nce.apply({"params": nce_params}, batch.next_obs, method="forward_mu") + logits, rewards = nce.apply( + {"params": nce_params}, + logits_rng, + batch.obs, + batch.action, + batch.next_obs, + z_mu, + method="forward_logits", + ) + + if ranking: + model_loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.broadcast_to(labels, (logits.shape[0], B)) + ).mean(axis=-1) + else: + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + eff_logits = logits + normalizer[:, None, None] - jnp.log(B) + model_loss = optax.sigmoid_binary_cross_entropy(eff_logits, labels).mean([-2, -1]) + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + rewards_target = jnp.broadcast_to(batch.reward, rewards.shape) + reward_loss = jnp.mean((rewards - rewards_target) ** 2) + + nce_loss = model_loss.mean() + reward_coef * reward_loss + 0.000 * (logits**2).mean() + + pos_logits = logits[ + jnp.arange(logits.shape[0])[..., jnp.newaxis], + jnp.arange(logits.shape[1]), + jnp.arange(logits.shape[2])[jnp.newaxis, ...].repeat(logits.shape[0], axis=0) + ] + pos_logits_per_noise = pos_logits.mean(axis=-1) + neg_logits = (logits.sum(axis=-1) - pos_logits) / (logits.shape[-1] - 1) + neg_logits_per_noise = neg_logits.mean(axis=-1) + metrics = { + "loss/nce_loss": nce_loss, + "loss/model_loss": model_loss.mean(), + "loss/reward_loss": reward_loss, + "misc/obs_mean": batch.obs.mean(), + "misc/obs_std": batch.obs.std(axis=0).mean(), + } + checkpoints = list(range(0, logits.shape[0], logits.shape[0]//5)) + [logits.shape[0]-1] + metrics.update({ + f"misc/positive_logits_{i}": pos_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/negative_logits_{i}": neg_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/logits_gap_{i}": (pos_logits_per_noise[i] - neg_logits_per_noise[i]).mean() for i in checkpoints + }) + metrics.update({ + f"misc/normalizer_{i}": jnp.exp(normalizer[i]) for i in checkpoints + }) + return nce_loss, metrics + + new_nce, metrics = nce.apply_gradient(loss_fn) + return rng, new_nce, metrics diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index 801cfc1..8a72576 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -10,6 +10,7 @@ from .algo.sdac import SDACConfig from .algo.td3 import TD3Config from .algo.td7 import TD7Config +from .algo.unirep import * from .config import Config, LogConfig _DEF_SUFFIX = "_cfg_def" @@ -31,6 +32,7 @@ "idem": IDEMConfig, "ctrl_td3": CtrlTD3Config, "ctrl_qsm": CtrlQSMConfig, + "aca": ACAConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/unirep/__init__.py b/flowrl/config/online/mujoco/algo/unirep/__init__.py new file mode 100644 index 0000000..dafaab0 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/__init__.py @@ -0,0 +1,5 @@ +from .aca import ACAConfig + +__all__ = [ + "ACAConfig", +] diff --git a/flowrl/config/online/mujoco/algo/unirep/aca.py b/flowrl/config/online/mujoco/algo/unirep/aca.py new file mode 100644 index 0000000..3d68d4a --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/aca.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig + + +@dataclass +class ACADiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class ACAConfig(BaseAlgoConfig): + name: str + target_update_freq: int + feature_dim: int + rff_dim: int + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + ctrl_coef: float + reward_coef: float + critic_coef: float + critic_activation: str + back_critic_grad: bool + feature_lr: float + critic_lr: float + discount: float + num_samples: int + ema: float + feature_ema: float + clip_grad_norm: float | None + temp: float + diffusion: ACADiffusionConfig + + num_noises: int + linear: bool + ranking: bool diff --git a/scripts/dmc/aca.sh b/scripts/dmc/aca.sh new file mode 100644 index 0000000..c4ac7c0 --- /dev/null +++ b/scripts/dmc/aca.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=aca" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi From 40ec550fb29b19504ba701c504b36f828b93226c Mon Sep 17 00:00:00 2001 From: typoverflow Date: Thu, 13 Nov 2025 20:04:21 -0500 Subject: [PATCH 8/9] separate q network works --- examples/online/config/dmc/algo/aca.yaml | 2 +- examples/toy2d/config/algo/aca.yaml | 40 +++ examples/toy2d/config/algo/qsm.yaml | 23 ++ examples/toy2d/config/algo/sdac.yaml | 23 ++ examples/toy2d/main_toy2d.py | 22 +- examples/toy2d/utils.py | 48 ++- flowrl/agent/online/unirep/aca.py | 419 +++++++++++++++++++---- flowrl/agent/online/unirep/network.py | 131 ++++++- flowrl/dataset/toy2d.py | 28 +- scripts/dmc/aca.sh | 56 +-- 10 files changed, 684 insertions(+), 108 deletions(-) create mode 100644 examples/toy2d/config/algo/aca.yaml create mode 100644 examples/toy2d/config/algo/qsm.yaml create mode 100644 examples/toy2d/config/algo/sdac.yaml diff --git a/examples/online/config/dmc/algo/aca.yaml b/examples/online/config/dmc/algo/aca.yaml index d40491a..bc5d564 100644 --- a/examples/online/config/dmc/algo/aca.yaml +++ b/examples/online/config/dmc/algo/aca.yaml @@ -38,4 +38,4 @@ algo: linear: false ranking: true -norm_obs: true +# norm_obs: true diff --git a/examples/toy2d/config/algo/aca.yaml b/examples/toy2d/config/algo/aca.yaml new file mode 100644 index 0000000..4484f3c --- /dev/null +++ b/examples/toy2d/config/algo/aca.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +algo: + name: aca + target_update_freq: 1 + feature_dim: 512 + rff_dim: 1024 + critic_hidden_dims: [512, 512] + reward_hidden_dims: [512, 512] + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + ctrl_coef: 1.0 + reward_coef: 1.0 + critic_coef: 1.0 + critic_activation: elu # not used + back_critic_grad: false + feature_lr: 0.0001 + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + feature_ema: 0.005 + clip_grad_norm: null + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -5.0 + x_max: 5.0 + # solver: ddpm + solver: ddpm + num_noises: 25 + linear: false + ranking: true diff --git a/examples/toy2d/config/algo/qsm.yaml b/examples/toy2d/config/algo/qsm.yaml new file mode 100644 index 0000000..816a84b --- /dev/null +++ b/examples/toy2d/config/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -5.0 + x_max: 5.0 + solver: ddpm diff --git a/examples/toy2d/config/algo/sdac.yaml b/examples/toy2d/config/algo/sdac.yaml new file mode 100644 index 0000000..3505de9 --- /dev/null +++ b/examples/toy2d/config/algo/sdac.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: sdac + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: false + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/toy2d/main_toy2d.py b/examples/toy2d/main_toy2d.py index f2f3881..3863d8c 100644 --- a/examples/toy2d/main_toy2d.py +++ b/examples/toy2d/main_toy2d.py @@ -3,10 +3,10 @@ import hydra import omegaconf -import wandb from omegaconf import OmegaConf from tqdm import trange +import wandb from examples.toy2d.utils import compute_metrics, plot_data, plot_energy, plot_sample from flowrl.agent.offline import * from flowrl.agent.online import * @@ -19,6 +19,9 @@ SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = { "bdpo": BDPOAgent, "dac": DACAgent, + "qsm": QSMAgent, + "sdac": SDACAgent, + "aca": ACAAgent, } class Trainer(): @@ -30,14 +33,15 @@ def __init__(self, cfg: Config): log_dir="/".join([cfg.log.dir, cfg.algo.name, cfg.log.tag, cfg.task]), name="seed"+str(cfg.seed), logger_config={ - "TensorboardLogger": {"activate": True}, - "WandbLogger": { - "activate": True, - "config": OmegaConf.to_container(cfg), - "settings": wandb.Settings(_disable_stats=True), - "project": cfg.log.project, - "entity": cfg.log.entity - } if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False}, + "CsvLogger": {"activate": True}, + # "TensorboardLogger": {"activate": True}, + # "WandbLogger": { + # "activate": True, + # "config": OmegaConf.to_container(cfg), + # "settings": wandb.Settings(_disable_stats=True), + # "project": cfg.log.project, + # "entity": cfg.log.entity + # } if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False}, } ) self.ckpt_save_dir = os.path.join(self.logger.log_dir, "ckpt") diff --git a/examples/toy2d/utils.py b/examples/toy2d/utils.py index 6b38344..8445f0c 100644 --- a/examples/toy2d/utils.py +++ b/examples/toy2d/utils.py @@ -9,6 +9,7 @@ from flowrl.agent.base import BaseAgent from flowrl.agent.offline import * +from flowrl.agent.online import * from flowrl.dataset.toy2d import Toy2dDataset, inf_train_gen SAMPLE_GRAPH_SIZE = 2000 @@ -108,6 +109,9 @@ def plot_energy(out_dir, task: str, agent: BaseAgent): vmin = e.min() vmax = e.max() + def default_plot(): + pass + def bdpo_plot(): tt = [0, 1, 3, 5, 10, 20, 30, 40, 50] plt.figure(figsize=(30, 3.0)) @@ -168,12 +172,54 @@ def dac_plot(): plt.close() tqdm.write(f"Saved value plot to {saveto}") + def aca_plot(): + tt = [0, 1, 3, 5, 10, 20] + plt.figure(figsize=(20, 3.0)) + axes = [] + for i, t in enumerate(tt): + plt.subplot(1, len(tt), i+1) + if t == 0: + model = agent.critic_target + c = model(zero, id_matrix).mean(axis=0).reshape(90, 90) + else: + model = agent.value_target + t_input = np.ones((90*90, 1)) * t + c = model(zero, id_matrix, t_input).mean(axis=0).reshape(90, 90) + plt.gca().set_aspect("equal", adjustable="box") + plt.xlim(0, 89) + plt.ylim(0, 89) + if i == 0: + mappable = plt.imshow( + c, origin="lower", vmin=vmin, vmax=vmax, + cmap="viridis", rasterized=True + ) + plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4]) + else: + plt.imshow( + c, origin="lower", vmin=vmin, vmax=vmax, + cmap="viridis", rasterized=True + ) + plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[None, None, None, None, None]) + + axes.append(plt.gca()) + plt.xticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4]) + plt.title(f't={t}') + plt.tight_layout() + cbar = plt.gcf().colorbar(mappable, ax=axes, fraction=0.1, pad=0.02, aspect=12) + plt.gcf().axes[-1].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f')) + saveto = os.path.join(out_dir, "qt_space.png") + plt.savefig(saveto, dpi=300) + plt.close() + tqdm.write(f"Saved value plot to {saveto}") + if isinstance(agent, BDPOAgent): bdpo_plot() elif isinstance(agent, DACAgent): dac_plot() + elif isinstance(agent, ACAAgent): + aca_plot() else: - raise NotImplementedError(f"Plotting for {type(agent)} is not implemented") + default_plot() diff --git a/flowrl/agent/online/unirep/aca.py b/flowrl/agent/online/unirep/aca.py index e5451ee..e23bdb6 100644 --- a/flowrl/agent/online/unirep/aca.py +++ b/flowrl/agent/online/unirep/aca.py @@ -9,19 +9,20 @@ from flowrl.agent.online.unirep.network import FactorizedNCE, update_factorized_nce from flowrl.config.online.mujoco.algo.unirep.aca import ACAConfig from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone -from flowrl.functional.activation import mish +from flowrl.flow.ddpm import DDPM, DDPMBackbone +from flowrl.functional.activation import atanh, l2_normalize, mish, tanh from flowrl.functional.ema import ema_update -from flowrl.module.critic import EnsembleCritic +from flowrl.module.critic import EnsembleCritic, EnsembleCriticT from flowrl.module.mlp import MLP from flowrl.module.model import Model -from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.module.time_embedding import LearnableFourierEmbedding, PositionalEmbedding from flowrl.types import Batch, Metric, Param, PRNGKey @partial(jax.jit, static_argnames=("training", "num_samples", "solver")) def jit_sample_actions( rng: PRNGKey, - actor: ContinuousDDPM, + actor: Model, critic: Model, nce_target: Model, obs: jnp.ndarray, @@ -40,20 +41,38 @@ def jit_sample_actions( if num_samples == 1: actions = actions[:, 0] else: - t0 = jnp.ones((obs_repeat.shape[0], num_samples, 1)) - f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") - qs = critic(f0) + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + qs = critic(obs_repeat, actions) qs = qs.min(axis=0).reshape(B, num_samples) best_idx = qs.argmax(axis=-1) actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] return rng, actions + +@partial(jax.jit, static_argnames=("deterministic", "exploration_noise")) +def jit_td3_sample_action( + rng: PRNGKey, + actor: Model, + obs: jnp.ndarray, + deterministic: bool, + exploration_noise: float, +) -> jnp.ndarray: + action = actor(obs, training=False) + if not deterministic: + action = action + exploration_noise * jax.random.normal(rng, action.shape) + action = jnp.clip(action, -1.0, 1.0) + return action + @partial(jax.jit, static_argnames=("discount", "solver", "critic_coef")) def jit_update_critic( rng: PRNGKey, critic: Model, critic_target: Model, - actor: ContinuousDDPM, + value: Model, + value_target: Model, + actor: Model, + backup: Model, nce_target: Model, batch: Batch, discount: float, @@ -61,71 +80,216 @@ def jit_update_critic( critic_coef: float, ) -> Tuple[PRNGKey, Model, Metric]: # q0 target + B = batch.obs.shape[0] t0 = jnp.ones((batch.obs.shape[0], 1)) rng, next_aT_rng = jax.random.split(rng) next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) - rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) - next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") - q0_target = critic_target(next_f0) + # rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) + next_a0 = backup(batch.next_obs) + # next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") + # q0_target = critic_target(next_f0) + q0_target = critic_target(batch.next_obs, next_a0) q0_target = batch.reward + discount * (1 - batch.terminal) * q0_target.min(axis=0) # qt target a0 = batch.action - f0 = nce_target(batch.obs, a0, t0, method="forward_phi") - qt_target = critic_target(f0) + # f0 = nce_target(batch.obs, a0, t0, method="forward_phi") + qt_target = critic_target(batch.obs, a0).mean(axis=0) + # qt_target = critic(batch.obs, a0).mean(axis=0) # features - rng, at, t, eps = actor.add_noise(rng, a0) - ft = nce_target(batch.obs, at, t, method="forward_phi") + # rng, at, t, eps = actor.add_noise(rng, a0) + # weight_t = actor.alpha_hats[t] / (1-actor.alpha_hats[t]) + weight_t = 1.0 + # ft = nce_target(batch.obs, at, t, method="forward_phi") + # rng, t_rng, noise_rng = jax.random.split(rng, 3) + # t = jax.random.randint(t_rng, (*a0.shape[:-1], 1), 0, actor.steps+1) + # t = jnp.ones((*a0.shape[:-1], 1), dtype=jnp.int32) + # eps = jax.random.normal(noise_rng, a0.shape) + # at = jnp.sqrt(actor.alpha_hats[t]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t]) * eps def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: q0_pred = critic.apply( {"params": critic_params}, - f0, - training=True, - rngs={"dropout": dropout_rng}, - ) - qt_pred = critic.apply( - {"params": critic_params}, - ft, - training=True, + batch.obs, + a0, + # t0, + # training=True, rngs={"dropout": dropout_rng}, ) + # qt_pred = value.apply( + # {"params": critic_params}, + # batch.obs, + # at, + # t, + # # training=True, + # rngs={"dropout": dropout_rng}, + # ) critic_loss = ( - ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() + - ((qt_pred - qt_target[jnp.newaxis, :])**2).mean() + ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() + # + ((qt_pred - qt_target[jnp.newaxis, :])**2).mean() ) return critic_loss, { "loss/critic_loss": critic_loss, "misc/q0_mean": q0_pred.mean(), - "misc/qt_mean": qt_pred.mean(), + # "misc/qt_mean": qt_pred.mean(), "misc/reward": batch.reward.mean(), "misc/next_action_l1": jnp.abs(next_a0).mean(), } new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) - return rng, new_critic, critic_metrics + + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + value_loss = 0 + for t in range(0, actor.steps+1): + t_input = jnp.ones((B, 1)) * t + noise_rng, dropout_rng = jax.random.split(dropout_rng) + eps = jax.random.normal(noise_rng, a0.shape) + at = jnp.sqrt(actor.alpha_hats[t]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t]) * eps + qt_pred = value.apply( + {"params": value_params}, + batch.obs, + at, + t_input, + training=True, + rngs={"dropout": dropout_rng}, + ) + value_loss += ((qt_pred - q0_target[:])**2).mean() + return value_loss, { + "loss/value_loss": value_loss, + "misc/qt_mean": qt_pred.mean(), + } + new_value, value_metrics = value.apply_gradient(value_loss_fn) + + # t_zero = jnp.zeros((B, 1)) + # v0_target = value_target(batch.next_obs, next_a0, t_zero) + # v0_target = v0_target.min(axis=0) + # v0_target = batch.reward + discount * (1-batch.terminal) * v0_target + + # def td_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + # v0_pred = value.apply( + # {"params": value_params}, + # batch.obs, + # batch.action, + # t_zero, + # training=True, + # rngs={"dropout": dropout_rng}, + # ) + # td_loss = ((v0_pred - v0_target[jnp.newaxis, :])**2).mean() + # return td_loss, { + # "loss/td_loss": td_loss, + # } + # # value, td_metrics = value.apply_gradient(td_loss_fn) + # td_metrics = {} + + # rng, rng1, rng2, rng3 = jax.random.split(rng, 4) + # t = jax.random.randint(rng1, (B, 1), 0, actor.steps) + # eps = jax.random.normal(rng2, batch.action.shape) + # at = jnp.sqrt(actor.alpha_hats[t]) * batch.action + jnp.sqrt(1 - actor.alpha_hats[t]) * eps + # vt_target = value(batch.obs, batch.action, t_zero).mean(axis=0) + # def nontd_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + # vt_pred = value.apply( + # {"params": value_params}, + # batch.obs, + # at, + # t, + # training=True, + # rngs={"dropout": dropout_rng}, + # ) + # nontd_loss = ((vt_pred - q0_target[jnp.newaxis, :])**2).mean() + # return nontd_loss, { + # "loss/nontd_loss": nontd_loss, + # } + # new_value, nontd_metrics = value.apply_gradient(nontd_loss_fn) + + # return rng, new_critic, new_value, { + # **critic_metrics, + # # **value_metrics, + # **td_metrics, + # **nontd_metrics, + # } + + return rng, new_critic, new_value, { + **critic_metrics, + **value_metrics, + } + +def jit_compute_metrics( + rng: PRNGKey, + actor: Model, + critic: Model, + value: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = batch.action[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + q = critic(obs, at) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_mean/critic": all_critic.mean(), + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + def get_value(at, obs, t): + q = value(obs, at, t) + return q.mean() + + for t in [0] + list(range(1, actor.steps+1, actor.steps//5)): + t_input = jnp.ones((B, num_actions, 1)) * t + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_input.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_mean/value_{t}": all_value.mean(), + f"q_std/value_{t}": all_value.std(axis=1).mean(), + f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + @partial(jax.jit, static_argnames=("temp",)) def jit_update_actor( rng: PRNGKey, - actor: ContinuousDDPM, + actor: Model, + backup: Model, nce_target: Model, critic_target: Model, + value_target: Model, batch: Batch, temp: float, -) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: +) -> Tuple[PRNGKey, Model, Metric]: a0 = batch.action rng, at, t, eps = actor.add_noise(rng, a0) - alpha, sigma = actor.noise_schedule_func(t) - + # alpha, sigma = actor.noise_schedule_func(t) + sigma = jnp.sqrt(1 - actor.alpha_hats[t]) def get_q_value(at: jnp.ndarray, obs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: - ft = nce_target(obs, at, t, method="forward_phi") - q = critic_target(ft) + # ft = nce_target(obs, at, t, method="forward_phi") + q = value_target(obs, at, t) + # q = critic_target(obs, at) return q.mean(axis=0).mean() q_grad_fn = jax.vmap(jax.grad(get_q_value)) q_grad = q_grad_fn(at, batch.obs, t) eps_estimation = - sigma * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + # eps_estimation = -sigma * l2_normalize(q_grad) / temp def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: eps_pred = actor.apply( @@ -142,7 +306,19 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) - return rng, new_actor, actor_metrics + + def backup_loss_fn(backup_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + new_action = backup.apply( + {"params": backup_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + q = critic_target(batch.obs, new_action) + backup_loss = - q.mean() + return backup_loss, {} + new_backup, backup_metrics = backup.apply_gradient(backup_loss_fn) + return rng, new_actor, new_backup, actor_metrics class ACAAgent(BaseAgent): @@ -150,7 +326,7 @@ class ACAAgent(BaseAgent): ACA (Actor-Critic with Actor) agent. """ name = "ACAAgent" - model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + model_names = ["nce", "nce_target", "actor", "critic", "critic_target", "value", "value_target"] def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): super().__init__(obs_dim, act_dim, cfg, seed) @@ -162,7 +338,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): self.reward_coef = cfg.reward_coef self.critic_coef = cfg.critic_coef - self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng, value_rng = jax.random.split(self.rng, 6) # define the nce nce_def = FactorizedNCE( @@ -200,7 +376,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): ) # define the actor - time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) noise_predictor = partial( MLP, @@ -210,7 +386,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): layer_norm=False, dropout=None, ) - backbone_def = ContinuousDDPMBackbone( + backbone_def = DDPMBackbone( noise_predictor=noise_predictor, time_embedding=time_embedding, cond_embedding=cond_embedding, @@ -226,7 +402,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): else: actor_lr = cfg.diffusion.lr - self.actor = ContinuousDDPM.create( + self.actor = DDPM.create( network=backbone_def, rng=actor_rng, inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), @@ -234,10 +410,28 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): steps=cfg.diffusion.steps, noise_schedule="cosine", noise_schedule_params={}, + approx_postvar=False, clip_sampler=cfg.diffusion.clip_sampler, x_min=cfg.diffusion.x_min, x_max=cfg.diffusion.x_max, - t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the backup actor + from flowrl.module.actor import SquashedDeterministicActor + backup_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=[512,512,512], + layer_norm=True, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + self.backup = Model.create( + backup_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), optimizer=optax.adam(learning_rate=actor_lr), ) @@ -246,44 +440,116 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): "relu": jax.nn.relu, "elu": jax.nn.elu, }[cfg.critic_activation] + # critic_def = nn.Sequential([ + # nn.LayerNorm(), + # EnsembleCritic( + # hidden_dims=cfg.critic_hidden_dims, + # activation=critic_activation, + # layer_norm=True, + # dropout=None, + # ensemble_size=2, + # ) + # ]) + # self.critic = Model.create( + # critic_def, + # critic_rng, + # inputs=(jnp.ones((1, self.feature_dim))), + # optimizer=optax.adam(learning_rate=cfg.critic_lr), + # ) + # self.critic_target = Model.create( + # critic_def, + # critic_rng, + # inputs=(jnp.ones((1, self.feature_dim))), + # ) critic_def = EnsembleCritic( - hidden_dims=cfg.critic_hidden_dims, + # time_embedding=time_embedding, + hidden_dims=[512, 512, 512], activation=critic_activation, - layer_norm=True, - dropout=None, ensemble_size=2, + layer_norm=False, ) self.critic = Model.create( critic_def, critic_rng, - inputs=(jnp.ones((1, self.feature_dim))), + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), optimizer=optax.adam(learning_rate=cfg.critic_lr), ) self.critic_target = Model.create( critic_def, critic_rng, - inputs=(jnp.ones((1, self.feature_dim))), + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + # value_def = EnsembleCriticT( + # time_embedding=time_embedding, + # hidden_dims=[512, 512, 512], + # activation=critic_activation, + # ensemble_size=2, + # layer_norm=True, + # ) + from flowrl.agent.online.unirep.network import ( + EnsembleACACritic, + EnsembleResidualCritic, + ResidualCritic, + SeparateCritic, ) + # value_def = EnsembleACACritic( + # time_dim=16, + # hidden_dims=[256,256,256], + # activation=jax.nn.mish, + # ensemble_size=2, + # ) + # value_def = EnsembleResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512, 512], + # activation=jax.nn.mish, + # ) + value_def = SeparateCritic( + hidden_dims=[512, 512, 512], + activation=jax.nn.mish, + ensemble_size=cfg.diffusion.steps+1, + ) + self.value = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.value_target = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + ) # define tracking variables self._n_training_steps = 0 def train_step(self, batch: Batch, step: int) -> Metric: metrics = {} + # batch = Batch( + # obs=batch.obs, + # action=atanh(batch.action, scale=5.0), + # next_obs=batch.next_obs, + # reward=batch.reward, + # terminal=batch.terminal, + # next_action=atanh(batch.next_action, scale=5.0), + # ) - self.rng, self.nce, nce_metrics = update_factorized_nce( - self.rng, - self.nce, - batch, - self.ranking, - self.reward_coef, - ) - metrics.update(nce_metrics) - self.rng, self.critic, critic_metrics = jit_update_critic( + # self.rng, self.nce, nce_metrics = update_factorized_nce( + # self.rng, + # self.nce, + # batch, + # self.ranking, + # self.reward_coef, + # ) + # metrics.update(nce_metrics) + self.rng, self.critic, self.value, critic_metrics = jit_update_critic( self.rng, self.critic, self.critic_target, + self.value, + self.value_target, self.actor, + self.backup, self.nce_target, batch, discount=self.cfg.discount, @@ -291,11 +557,13 @@ def train_step(self, batch: Batch, step: int) -> Metric: critic_coef=self.critic_coef, ) metrics.update(critic_metrics) - self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, self.actor, self.backup, actor_metrics = jit_update_actor( self.rng, self.actor, + self.backup, self.nce_target, self.critic_target, + self.value_target, batch, temp=self.cfg.temp, ) @@ -304,6 +572,15 @@ def train_step(self, batch: Batch, step: int) -> Metric: if self._n_training_steps % self.cfg.target_update_freq == 0: self.sync_target() + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.actor, + self.critic, + self.value, + batch, + ) + self._n_training_steps += 1 return metrics @@ -317,22 +594,28 @@ def sample_actions( # if not, sample 1 action if deterministic: num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.nce_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) else: - num_samples = 1 - self.rng, action = jit_sample_actions( - self.rng, - self.actor, - self.critic, - self.nce_target, - obs, - training=False, - num_samples=num_samples, - solver=self.cfg.diffusion.solver, - ) - if not deterministic: - action = action + 0.1 * jax.random.normal(self.rng, action.shape) + self.rng, action_rng = jax.random.split(self.rng) + action = jit_td3_sample_action( + self.rng, + self.backup, + obs, + deterministic, + exploration_noise=0.2, + ) return action, {} def sync_target(self): self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.value_target = ema_update(self.value, self.value_target, self.cfg.ema) self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/network.py b/flowrl/agent/online/unirep/network.py index a4bc468..e1dc431 100644 --- a/flowrl/agent/online/unirep/network.py +++ b/flowrl/agent/online/unirep/network.py @@ -9,7 +9,7 @@ from flowrl.flow.ddpm import get_noise_schedule from flowrl.functional.activation import l2_normalize, mish from flowrl.module.critic import Critic -from flowrl.module.mlp import ResidualMLP +from flowrl.module.mlp import MLP, ResidualMLP from flowrl.module.model import Model from flowrl.module.rff import RffReward from flowrl.module.time_embedding import LearnableFourierEmbedding @@ -17,6 +17,135 @@ from flowrl.types import Sequence +class ResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + + @nn.compact + def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, t: jnp.ndarray, training: bool=False): + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([item for item in [obs, action, t_ff] if item is not None], axis=-1) + x = ResidualMLP( + hidden_dims=self.hidden_dims, + output_dim=1, + multiplier=1, + activation=self.activation, + layer_norm=True, + )(x, training) + return x + +class EnsembleResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + ensemble_size: int = 2 + + @nn.compact + def __call__( + self, + obs: Optional[jnp.ndarray] = None, + action: Optional[jnp.ndarray] = None, + t: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + vmap_critic = nn.vmap( + ResidualCritic, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + x = vmap_critic( + time_embedding=self.time_embedding, + hidden_dims=self.hidden_dims, + activation=self.activation, + )(obs, action, t, training) + return x + +from flowrl.module.time_embedding import PositionalEmbedding + + +class ACACritic(nn.Module): + time_dim: int + hidden_dims: Sequence[int] + activation: Callable + + @nn.compact + def __call__(self, obs, action, t, training=False): + t_ff = PositionalEmbedding(self.time_dim)(t) + t_ff = nn.Dense(2*self.time_dim)(t_ff) + t_ff = self.activation(t_ff) + t_ff = nn.Dense(self.time_dim)(t_ff) + x = jnp.concatenate([obs, action, t_ff], axis=-1) + return MLP( + hidden_dims=self.hidden_dims, + output_dim=1, + activation=self.activation, + layer_norm=False + )(x, training) + +class EnsembleACACritic(nn.Module): + time_dim: int + hidden_dims: Sequence[int] + activation: Callable + ensemble_size: int + + @nn.compact + def __call__(self, obs, action, t, training=False): + vmap_critic = nn.vmap( + ACACritic, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + return vmap_critic( + time_dim=self.time_dim, + hidden_dims=self.hidden_dims, + activation=self.activation, + )(obs, action, t, training) + + +class SeparateCritic(nn.Module): + hidden_dims: Sequence[int] + activation: Callable + ensemble_size: int + + @nn.compact + def __call__(self, obs, action, t, training=False): + vmap_critic = nn.vmap( + MLP, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=-1, + axis_size=self.ensemble_size + ) + x = jnp.concatenate([obs, action], axis=-1) + out = vmap_critic( + hidden_dims=self.hidden_dims, + output_dim=1, + activation=self.activation, + layer_norm=False + )(x, training) + out = out.reshape(*out.shape[:-2], -1) + # Using jnp.take_along_axis for batched index selection (broadcasting as needed) + # out: (E, B, T, 1), need to select on axis=2 using t_indices[b] for each batch + # We assume batch dim is axis=1, time axis=2 + out = jnp.take_along_axis( + out, + t.astype(jnp.int32), + axis=-1 + ) + return out + class FactorizedNCE(nn.Module): obs_dim: int action_dim: int diff --git a/flowrl/dataset/toy2d.py b/flowrl/dataset/toy2d.py index 649ccd4..ba93b76 100644 --- a/flowrl/dataset/toy2d.py +++ b/flowrl/dataset/toy2d.py @@ -4,6 +4,7 @@ import numpy as np import sklearn import sklearn.datasets +from scipy.stats import multivariate_normal from sklearn.utils import shuffle as util_shuffle from flowrl.types import Batch @@ -142,13 +143,38 @@ def inf_train_gen(data, batch_size=200): x = np.random.rand(batch_size) * 5 - 2.5 y = np.sin(x) * 2.5 return np.stack((x, y), 1) + + elif data == "8gaussiansmix": + scale = 3.5 + centers = [ + (0, 1), + (-1. / np.sqrt(2), 1. / np.sqrt(2)), + (-1, 0), + (-1. / np.sqrt(2), -1. / np.sqrt(2)), + (0, -1), + (1. / np.sqrt(2), -1. / np.sqrt(2)), + (1, 0), + (1. / np.sqrt(2), 1. / np.sqrt(2)), + ] + weights = [8, 7, 6, 5, 4, 3, 2, 1] + + + centers = [(scale * x, scale * y) for x, y in centers] + cov = 1.5**2 + x = np.random.rand(batch_size, 2) * 10 - 5 + energy = [ + multivariate_normal.pdf(x, center, cov) + for center in centers + ] + energy = sum([weights[i] * energy[i] for i in range(8)]) + return x, energy[:, None] else: assert False class Toy2dDataset(object): def __init__(self, task: str, data_size: int=1000000, scan: bool=True): - assert task in ["swissroll", "8gaussians", "moons", "rings", "checkerboard", "2spirals"] + assert task in ["swissroll", "8gaussians", "moons", "rings", "checkerboard", "2spirals", "8gaussiansmix"] self.task = task self.data_size = data_size self.scan = scan diff --git a/scripts/dmc/aca.sh b/scripts/dmc/aca.sh index c4ac7c0..39f1737 100644 --- a/scripts/dmc/aca.sh +++ b/scripts/dmc/aca.sh @@ -1,44 +1,46 @@ # Specify which GPUs to use GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use -SEEDS=(0 1 2 3) -NUM_EACH_GPU=3 +SEEDS=(0 1) +NUM_EACH_GPU=2 PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) TASKS=( - "acrobot-swingup" - "ball_in_cup-catch" - "cartpole-balance" - "cartpole-balance_sparse" - "cartpole-swingup" - "cartpole-swingup_sparse" + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" "cheetah-run" - "dog-run" + # "dog-run" "dog-stand" - "dog-trot" - "dog-walk" - "finger-spin" - "finger-turn_easy" - "finger-turn_hard" - "fish-swim" - "hopper-hop" - "hopper-stand" - "humanoid-run" - "humanoid-stand" - "humanoid-walk" - "pendulum-swingup" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" "quadruped-run" - "quadruped-walk" - "reacher-easy" - "reacher-hard" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" "walker-run" - "walker-stand" - "walker-walk" + # "walker-stand" + # "walker-walk" ) SHARED_ARGS=( "algo=aca" - "log.tag=default" + "algo.temp=0.05" + # "algo.critic_activation=relu" + "log.tag=backup-temp0.05-sepcritic" "log.project=flow-rl" "log.entity=lambda-rl" ) From 348c8d1e5f8d985797abd80487fb3d22d9ac3b10 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Tue, 25 Nov 2025 17:41:56 -0500 Subject: [PATCH 9/9] update --- examples/online/config/dmc/algo/aca.yaml | 4 +- .../algo/{ctrl_td3.yaml => ctrlsr_td3.yaml} | 2 +- .../online/config/dmc/algo/diffsr_aca.yaml | 51 ++ .../config/dmc/algo/diffsr_aca_sep.yaml | 51 ++ .../online/config/dmc/algo/diffsr_qsm.yaml | 47 ++ .../online/config/dmc/algo/diffsr_td3.yaml | 37 ++ examples/online/config/dmc/algo/sdac.yaml | 24 + examples/online/config/mujoco/algo/sdac.yaml | 3 +- examples/online/main_dmc_offpolicy.py | 6 +- flowrl/agent/online/__init__.py | 10 +- flowrl/agent/online/ctrl/__init__.py | 7 - flowrl/agent/online/ctrlsr/__init__.py | 6 + .../agent/online/{ctrl => ctrlsr}/ctrl_qsm.py | 0 .../ctrl_td3.py => ctrlsr/ctrlsr_td3.py} | 17 +- .../agent/online/{ctrl => ctrlsr}/network.py | 106 ++-- flowrl/agent/online/diffsr/__init__.py | 7 + flowrl/agent/online/diffsr/diffsr_qsm.py | 342 ++++++++++++ flowrl/agent/online/diffsr/diffsr_td3.py | 237 +++++++++ flowrl/agent/online/diffsr/network.py | 139 +++++ flowrl/agent/online/sdac.py | 21 +- flowrl/agent/online/unirep/__init__.py | 2 + flowrl/agent/online/unirep/aca.py | 299 +++++------ .../online/unirep/diffsr/diffsr_aca_sep.py | 495 ++++++++++++++++++ .../agent/online/unirep/diffsr/diffsr_td3.py | 476 +++++++++++++++++ flowrl/agent/online/unirep/diffsr/network.py | 161 ++++++ flowrl/agent/online/unirep/network.py | 94 ++-- flowrl/config/online/mujoco/__init__.py | 8 +- .../online/mujoco/algo/ctrl/__init__.py | 7 - .../online/mujoco/algo/ctrlsr/__init__.py | 5 + .../mujoco/algo/{ctrl => ctrlsr}/ctrl_qsm.py | 0 .../ctrl_td3.py => ctrlsr/ctrlsr_td3.py} | 2 +- .../online/mujoco/algo/diffsr/__init__.py | 7 + .../online/mujoco/algo/diffsr/diffsr_qsm.py | 39 ++ .../online/mujoco/algo/diffsr/diffsr_td3.py | 38 ++ flowrl/config/online/mujoco/algo/sdac.py | 1 + flowrl/module/actor.py | 25 +- flowrl/module/critic.py | 21 +- flowrl/module/initialization.py | 15 +- flowrl/module/mlp.py | 33 +- flowrl/module/rff.py | 21 +- flowrl/types.py | 3 +- scripts/dmc/aca.sh | 8 +- scripts/dmc/{ctrl_td3.sh => ctrlsr_td3.sh} | 48 +- scripts/dmc/diffsr_aca.sh | 68 +++ scripts/dmc/diffsr_aca_sep.sh | 68 +++ scripts/dmc/diffsr_td3.sh | 70 +++ scripts/dmc/sdac.sh | 70 +++ 47 files changed, 2803 insertions(+), 398 deletions(-) rename examples/online/config/dmc/algo/{ctrl_td3.yaml => ctrlsr_td3.yaml} (97%) create mode 100644 examples/online/config/dmc/algo/diffsr_aca.yaml create mode 100644 examples/online/config/dmc/algo/diffsr_aca_sep.yaml create mode 100644 examples/online/config/dmc/algo/diffsr_qsm.yaml create mode 100644 examples/online/config/dmc/algo/diffsr_td3.yaml create mode 100644 examples/online/config/dmc/algo/sdac.yaml delete mode 100644 flowrl/agent/online/ctrl/__init__.py create mode 100644 flowrl/agent/online/ctrlsr/__init__.py rename flowrl/agent/online/{ctrl => ctrlsr}/ctrl_qsm.py (100%) rename flowrl/agent/online/{ctrl/ctrl_td3.py => ctrlsr/ctrlsr_td3.py} (92%) rename flowrl/agent/online/{ctrl => ctrlsr}/network.py (64%) create mode 100644 flowrl/agent/online/diffsr/__init__.py create mode 100644 flowrl/agent/online/diffsr/diffsr_qsm.py create mode 100644 flowrl/agent/online/diffsr/diffsr_td3.py create mode 100644 flowrl/agent/online/diffsr/network.py create mode 100644 flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py create mode 100644 flowrl/agent/online/unirep/diffsr/diffsr_td3.py create mode 100644 flowrl/agent/online/unirep/diffsr/network.py delete mode 100644 flowrl/config/online/mujoco/algo/ctrl/__init__.py create mode 100644 flowrl/config/online/mujoco/algo/ctrlsr/__init__.py rename flowrl/config/online/mujoco/algo/{ctrl => ctrlsr}/ctrl_qsm.py (100%) rename flowrl/config/online/mujoco/algo/{ctrl/ctrl_td3.py => ctrlsr/ctrlsr_td3.py} (95%) create mode 100644 flowrl/config/online/mujoco/algo/diffsr/__init__.py create mode 100644 flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py create mode 100644 flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py rename scripts/dmc/{ctrl_td3.sh => ctrlsr_td3.sh} (65%) create mode 100644 scripts/dmc/diffsr_aca.sh create mode 100644 scripts/dmc/diffsr_aca_sep.sh create mode 100644 scripts/dmc/diffsr_td3.sh create mode 100644 scripts/dmc/sdac.sh diff --git a/examples/online/config/dmc/algo/aca.yaml b/examples/online/config/dmc/algo/aca.yaml index bc5d564..b9e6818 100644 --- a/examples/online/config/dmc/algo/aca.yaml +++ b/examples/online/config/dmc/algo/aca.yaml @@ -34,8 +34,8 @@ algo: x_min: -1.0 x_max: 1.0 solver: ddpm - num_noises: 25 + num_noises: 20 linear: false ranking: true -# norm_obs: true +norm_obs: true diff --git a/examples/online/config/dmc/algo/ctrl_td3.yaml b/examples/online/config/dmc/algo/ctrlsr_td3.yaml similarity index 97% rename from examples/online/config/dmc/algo/ctrl_td3.yaml rename to examples/online/config/dmc/algo/ctrlsr_td3.yaml index 10a754d..33e840e 100644 --- a/examples/online/config/dmc/algo/ctrl_td3.yaml +++ b/examples/online/config/dmc/algo/ctrlsr_td3.yaml @@ -1,7 +1,7 @@ # @package _global_ algo: - name: ctrl_td3 + name: ctrlsr_td3 actor_update_freq: 1 target_update_freq: 1 discount: 0.99 diff --git a/examples/online/config/dmc/algo/diffsr_aca.yaml b/examples/online/config/dmc/algo/diffsr_aca.yaml new file mode 100644 index 0000000..9952f73 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_aca.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +algo: + name: diffsr_aca + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 20 + num_samples: 10 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_aca_sep.yaml b/examples/online/config/dmc/algo/diffsr_aca_sep.yaml new file mode 100644 index 0000000..1318f25 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_aca_sep.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +algo: + name: diffsr_aca_sep + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 20 + num_samples: 10 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_qsm.yaml b/examples/online/config/dmc/algo/diffsr_qsm.yaml new file mode 100644 index 0000000..e926893 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_qsm.yaml @@ -0,0 +1,47 @@ +# @package _global_ + +algo: + name: diffsr_qsm + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + # critic_hidden_dims: [512, 512, 512] # not used + critic_activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + critic_lr: 0.0003 + clip_grad_norm: null + + # below are params specific to ctrl_td3 + num_noises: 50 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 256 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + num_samples: 10 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_td3.yaml b/examples/online/config/dmc/algo/diffsr_td3.yaml new file mode 100644 index 0000000..eb7e4a7 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_td3.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +algo: + name: diffsr_td3 + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 50 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + +norm_obs: true diff --git a/examples/online/config/dmc/algo/sdac.yaml b/examples/online/config/dmc/algo/sdac.yaml new file mode 100644 index 0000000..0b68526 --- /dev/null +++ b/examples/online/config/dmc/algo/sdac.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +algo: + name: sdac + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.05 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/sdac.yaml b/examples/online/config/mujoco/algo/sdac.yaml index a4e8ff3..79833ae 100644 --- a/examples/online/config/mujoco/algo/sdac.yaml +++ b/examples/online/config/mujoco/algo/sdac.yaml @@ -3,12 +3,13 @@ algo: name: sdac critic_hidden_dims: [256, 256] + critic_activation: relu critic_lr: 0.0003 discount: 0.99 num_samples: 10 num_reverse_samples: 500 ema: 0.005 - temp: 0.2 + temp: 0.05 diffusion: time_dim: 64 mlp_hidden_dims: [256, 256] diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index 5d8c0fb..7e651cc 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -26,10 +26,10 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, - "ctrl_td3": CtrlTD3Agent, "qsm": QSMAgent, - "ctrl_qsm": CtrlQSMAgent, - "aca": ACAAgent, + "ctrlsr_td3": CtrlSRTD3Agent, + "diffsr_td3": DiffSRTD3Agent, + "diffsr_qsm": DiffSRQSMAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 4de6c7a..6e0aa19 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,6 +1,7 @@ from ..base import BaseAgent from .alac.alac import ALACAgent -from .ctrl import * +from .ctrlsr import * +from .diffsr import * from .dpmd import DPMDAgent from .idem import IDEMAgent from .ppo import PPOAgent @@ -19,11 +20,12 @@ "SDACAgent", "DPMDAgent", "PPOAgent", - "CtrlTD3Agent", + "CtrlSRTD3Agent", + "DiffSRTD3Agent", + "DiffSRQSMAgent", "QSMAgent", "IDEMAgent", "ALACAgent", - "CtrlTD3Agent", - "CtrlQSMAgent", "ACAAgent", + "DiffSRACAAgent", ] diff --git a/flowrl/agent/online/ctrl/__init__.py b/flowrl/agent/online/ctrl/__init__.py deleted file mode 100644 index 602a563..0000000 --- a/flowrl/agent/online/ctrl/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .ctrl_qsm import CtrlQSMAgent -from .ctrl_td3 import CtrlTD3Agent - -__all__ = [ - "CtrlTD3Agent", - "CtrlQSMAgent", -] diff --git a/flowrl/agent/online/ctrlsr/__init__.py b/flowrl/agent/online/ctrlsr/__init__.py new file mode 100644 index 0000000..81927fb --- /dev/null +++ b/flowrl/agent/online/ctrlsr/__init__.py @@ -0,0 +1,6 @@ +# from .ctrl_qsm import CtrlQSMAgent +from .ctrlsr_td3 import CtrlSRTD3Agent + +__all__ = [ + "CtrlSRTD3Agent", +] diff --git a/flowrl/agent/online/ctrl/ctrl_qsm.py b/flowrl/agent/online/ctrlsr/ctrl_qsm.py similarity index 100% rename from flowrl/agent/online/ctrl/ctrl_qsm.py rename to flowrl/agent/online/ctrlsr/ctrl_qsm.py diff --git a/flowrl/agent/online/ctrl/ctrl_td3.py b/flowrl/agent/online/ctrlsr/ctrlsr_td3.py similarity index 92% rename from flowrl/agent/online/ctrl/ctrl_td3.py rename to flowrl/agent/online/ctrlsr/ctrlsr_td3.py index f379b46..3a32143 100644 --- a/flowrl/agent/online/ctrl/ctrl_td3.py +++ b/flowrl/agent/online/ctrlsr/ctrlsr_td3.py @@ -5,9 +5,10 @@ import jax.numpy as jnp import optax -from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce +import flowrl.module.initialization as init +from flowrl.agent.online.ctrlsr.network import FactorizedNCE, update_factorized_nce from flowrl.agent.online.td3 import TD3Agent -from flowrl.config.online.mujoco.algo.ctrl.ctrl_td3 import CtrlTD3Config +from flowrl.config.online.mujoco.algo.ctrlsr.ctrlsr_td3 import CtrlSRTD3Config from flowrl.functional.ema import ema_update from flowrl.module.actor import SquashedDeterministicActor from flowrl.module.mlp import MLP @@ -90,15 +91,15 @@ def actor_loss_fn( return rng, new_actor, metrics -class CtrlTD3Agent(TD3Agent): +class CtrlSRTD3Agent(TD3Agent): """ - CTRL with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. + CTRL-SR with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. """ - name = "CtrlTD3Agent" + name = "CtrlSRTD3Agent" model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] - def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlSRTD3Config, seed: int): super().__init__(obs_dim, act_dim, cfg, seed) self.cfg = cfg @@ -156,6 +157,8 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): hidden_dims=cfg.actor_hidden_dims, layer_norm=cfg.layer_norm, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ), obs_dim=self.obs_dim, action_dim=self.act_dim, @@ -165,6 +168,8 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): hidden_dims=cfg.critic_hidden_dims, rff_dim=cfg.rff_dim, ensemble_size=2, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.actor = Model.create( actor_def, diff --git a/flowrl/agent/online/ctrl/network.py b/flowrl/agent/online/ctrlsr/network.py similarity index 64% rename from flowrl/agent/online/ctrl/network.py rename to flowrl/agent/online/ctrlsr/network.py index fad30e9..c023a15 100644 --- a/flowrl/agent/online/ctrl/network.py +++ b/flowrl/agent/online/ctrlsr/network.py @@ -5,9 +5,10 @@ import jax.numpy as jnp import optax +import flowrl.module.initialization as init from flowrl.flow.ddpm import get_noise_schedule from flowrl.functional.activation import l2_normalize, mish -from flowrl.module.mlp import ResidualMLP +from flowrl.module.mlp import MLP, ResidualMLP from flowrl.module.model import Model from flowrl.module.rff import RffReward from flowrl.module.time_embedding import PositionalEmbedding @@ -27,9 +28,13 @@ class FactorizedNCE(nn.Module): ranking: bool = False def setup(self): - self.mlp_t = nn.Sequential( - [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] - ) + MLP_torch_init = partial(MLP, kernel_init=init.pytorch_kernel_init, bias_init=init.pytorch_bias_init) + self.mlp_t = nn.Sequential([ + PositionalEmbedding(128), + MLP_torch_init(output_dim=256), + mish, + MLP_torch_init(output_dim=128) + ]) self.mlp_phi = ResidualMLP( self.phi_hidden_dims, self.feature_dim, @@ -37,6 +42,8 @@ def setup(self): activation=mish, layer_norm=True, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.mlp_mu = ResidualMLP( self.mu_hidden_dims, @@ -45,32 +52,27 @@ def setup(self): activation=mish, layer_norm=True, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.reward = RffReward( self.feature_dim, self.reward_hidden_dims, rff_dim=self.rff_dim, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) if self.num_noises > 0: self.use_noise_perturbation = True betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) - alphabars_prev = jnp.pad(alphabars[:-1], (1, 0), constant_values=1.0) - self.betas = betas[..., jnp.newaxis] - self.alphas = alphas[..., jnp.newaxis] - self.alphabars = alphabars[..., jnp.newaxis] - self.alphabars_prev = alphabars_prev[..., jnp.newaxis] + self.alphabars = alphabars else: self.use_noise_perturbation = False self.N = max(self.num_noises, 1) - if not self.ranking: - self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) - else: - self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) def forward_phi(self, s, a): x = jnp.concat([s, a], axis=-1) x = self.mlp_phi(x) - x = l2_normalize(x, group_size=None) return x def forward_mu(self, sp, t=None): @@ -78,57 +80,33 @@ def forward_mu(self, sp, t=None): t_ff = self.mlp_t(t) sp = jnp.concat([sp, t_ff], axis=-1) sp = self.mlp_mu(sp) + sp = jnp.tanh(sp) return sp def forward_reward(self, x: jnp.ndarray): # for z_phi return self.reward(x) - def forward_logits( - self, - rng: PRNGKey, - s: jnp.ndarray, - a: jnp.ndarray, - sp: jnp.ndarray, - z_phi: jnp.ndarray | None=None - ): + def __call__(self, rng, s, a, sp, training: bool=False): B, D = sp.shape rng, eps_rng = jax.random.split(rng, 2) - if z_phi is None: - z_phi = self.forward_phi(s, a) + z_phi = self.forward_phi(s, a) if self.use_noise_perturbation: sp = jnp.broadcast_to(sp, (self.N, B, D)) t = jnp.arange(self.num_noises) - t = jnp.repeat(t, B).reshape(self.N, B) + t = jnp.repeat(t, B).reshape(self.N, B, 1) alphabars = self.alphabars[t] eps = jax.random.normal(eps_rng, sp.shape) xt = jnp.sqrt(alphabars) * sp + jnp.sqrt(1-alphabars) * eps - t = jnp.expand_dims(t, -1) else: xt = jnp.expand_dims(sp, 0) t = None z_mu = self.forward_mu(xt, t) - z_phi = jnp.broadcast_to(z_phi, (self.N, B, self.feature_dim)) - logits = jax.lax.batch_matmul(z_phi, jnp.swapaxes(z_mu, -1, -2)) - logits = logits / jnp.exp(self.normalizer[:, None, None]) - return logits - - def forward_normalizer(self): - return self.normalizer - - def __call__( - self, - rng: PRNGKey, - s, - a, - sp, - ): - z_phi = self.forward_phi(s, a) - _ = self.forward_reward(z_phi) - _ = self.forward_logits(rng, s, a, sp, z_phi=z_phi) - - _ = self.forward_normalizer() - - return z_phi + logits = jax.lax.batch_matmul( + jnp.broadcast_to(z_phi, (self.N, B, self.feature_dim)), + jnp.swapaxes(z_mu, -1, -2) + ) + r_pred = self.forward_reward(z_phi) + return logits, r_pred, z_phi @partial(jax.jit, static_argnames=("ranking", "reward_coef")) @@ -140,46 +118,31 @@ def update_factorized_nce( reward_coef: float, ) -> Tuple[PRNGKey, Model, Metric]: B = batch.obs.shape[0] - rng, logits_rng = jax.random.split(rng) + rng, update_rng = jax.random.split(rng) if ranking: labels = jnp.arange(B) else: labels = jnp.eye(B) def loss_fn(nce_params: Param, dropout_rng: PRNGKey): - z_phi = nce.apply( + logits, r_pred, z_phi = nce.apply( {"params": nce_params}, - batch.obs, - batch.action, - method="forward_phi", - ) - logits = nce.apply( - {"params": nce_params}, - logits_rng, + update_rng, batch.obs, batch.action, batch.next_obs, - z_phi, - method="forward_logits", + training=True, + rngs={"dropout": dropout_rng}, ) - if ranking: model_loss = optax.softmax_cross_entropy_with_integer_labels( logits, jnp.broadcast_to(labels, (logits.shape[0], B)) ).mean(axis=-1) else: - normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") - eff_logits = logits + normalizer[:, None, None] - jnp.log(B) - model_loss = optax.sigmoid_binary_cross_entropy(eff_logits, labels).mean([-2, -1]) - r_pred = nce.apply( - {"params": nce_params}, - z_phi, - method="forward_reward", - ) - normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + raise NotImplementedError("non-ranking mode is not supported") reward_loss = jnp.mean((r_pred - batch.reward) ** 2) - nce_loss = model_loss.mean() + reward_coef * reward_loss + 0.000 * (logits**2).mean() + nce_loss = model_loss.mean() + reward_coef * reward_loss pos_logits = logits[ jnp.arange(logits.shape[0])[..., jnp.newaxis], @@ -207,9 +170,6 @@ def loss_fn(nce_params: Param, dropout_rng: PRNGKey): metrics.update({ f"misc/logits_gap_{i}": (pos_logits_per_noise[i] - neg_logits_per_noise[i]).mean() for i in checkpoints }) - metrics.update({ - f"misc/normalizer_{i}": jnp.exp(normalizer[i]) for i in checkpoints - }) return nce_loss, metrics new_nce, metrics = nce.apply_gradient(loss_fn) diff --git a/flowrl/agent/online/diffsr/__init__.py b/flowrl/agent/online/diffsr/__init__.py new file mode 100644 index 0000000..c20bd2c --- /dev/null +++ b/flowrl/agent/online/diffsr/__init__.py @@ -0,0 +1,7 @@ +from .diffsr_qsm import DiffSRQSMAgent +from .diffsr_td3 import DiffSRTD3Agent + +__all__ = [ + "DiffSRTD3Agent", + "DiffSRQSMAgent", +] diff --git a/flowrl/agent/online/diffsr/diffsr_qsm.py b/flowrl/agent/online/diffsr/diffsr_qsm.py new file mode 100644 index 0000000..830e947 --- /dev/null +++ b/flowrl/agent/online/diffsr/diffsr_qsm.py @@ -0,0 +1,342 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.diffsr.network import FactorizedDDPM, update_factorized_ddpm +from flowrl.agent.online.qsm import QSMAgent +from flowrl.config.online.mujoco.algo.diffsr import DiffSRQSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + feature = ddpm_target(obs_repeat, actions, method="forward_phi") + qs = critic(feature) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + ddpm_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample( + rng, + next_xT, + batch.next_obs, + training=False, + solver=solver, + ) + next_feature = ddpm_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + feature = ddpm_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + + +@partial(jax.jit, static_argnames=("temp")) +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: ContinuousDDPM, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: + feature = ddpm_target(obs, action, method="forward_phi") + q = critic_target(feature) + return q.min(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": diffusion_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(loss_fn) + return rng, new_actor, actor_metrics + +# @jax.jit +# def jit_compute_metrics( +# rng: PRNGKey, +# critic: Model, +# ddpm_target: Model, +# diffusion_value: Model, +# diffusion_actor: Model, +# batch: Batch, +# ) -> Tuple[PRNGKey, Metric]: +# B, S = batch.obs.shape +# A = batch.action.shape[-1] +# num_actions = 50 +# metrics = {} +# rng, action_rng = jax.random.split(rng) +# obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) +# action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + +# def get_critic(at, obs): +# t1 = jnp.ones((1, ), dtype=jnp.int32) +# ft = ddpm_target(obs, at, t1, method="forward_phi") +# q = critic(ft) +# return q.mean() +# all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( +# action_repeat.reshape(-1, A), +# obs_repeat.reshape(-1, S), +# ) +# all_critic = all_critic.reshape(B, num_actions, 1) +# all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) +# metrics.update({ +# f"q_std/critic": all_critic.std(axis=1).mean(), +# f"q_grad/critic": jnp.abs(all_critic_grad).mean(), +# }) + +# def get_value(at, obs, t): +# ft = ddpm_target(obs, at, t, method="forward_phi") +# q = diffusion_value(ft) +# return q.mean() +# for t in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): +# t_input = jnp.ones((B, num_actions, 1)) * t +# all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( +# action_repeat.reshape(-1, A), +# obs_repeat.reshape(-1, S), +# t_input.reshape(-1, 1), +# ) +# all_value = all_value.reshape(B, num_actions, 1) +# all_value_grad = all_value_grad.reshape(B, num_actions, -1) +# metrics.update({ +# f"q_std/value_{t}": all_value.std(axis=1).mean(), +# f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), +# }) +# return rng, metrics + +class DiffSRQSMAgent(QSMAgent): + """ + Diff-SR with QSM agent. + """ + + name = "DiffSRQSMAgent" + model_names = ["ddpm", "ddpm_target", "actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRQSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + self.rff_dim = cfg.rff_dim + self.actor_update_freq = cfg.actor_update_freq + self.target_update_freq = cfg.target_update_freq + self.temp = cfg.temp + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.ddpm_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic_target, + batch, + temp=self.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + # if self._n_training_steps % 2000 == 0: + # self.rng, metrics = jit_compute_metrics( + # self.rng, + # self.critic, + # self.ddpm_target, + # self.diffusion_value, + # self.diffusion_actor, + # batch, + # ) + # metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/diffsr/diffsr_td3.py b/flowrl/agent/online/diffsr/diffsr_td3.py new file mode 100644 index 0000000..335eaa3 --- /dev/null +++ b/flowrl/agent/online/diffsr/diffsr_td3.py @@ -0,0 +1,237 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +import flowrl.module.initialization as init +from flowrl.agent.online.diffsr.network import FactorizedDDPM, update_factorized_ddpm +from flowrl.agent.online.td3 import TD3Agent +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, metrics + + +class DiffSRTD3Agent(TD3Agent): + """ + Diff-SR with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. + """ + + name = "DiffSRTD3Agent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/diffsr/network.py b/flowrl/agent/online/diffsr/network.py new file mode 100644 index 0000000..c49c859 --- /dev/null +++ b/flowrl/agent/online/diffsr/network.py @@ -0,0 +1,139 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import mish +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class FactorizedDDPM(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + embed_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int + num_noises: int + + def setup(self): + MLP_torch_init = partial(MLP, kernel_init=init.pytorch_kernel_init, bias_init=init.pytorch_bias_init) + self.mlp_t = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + MLP_torch_init(output_dim=2*self.embed_dim), + mish, + MLP_torch_init(output_dim=self.embed_dim) + ]) + # self.mlp_s = nn.Sequential([ + # MLP_torch_init(output_dim=self.embed_dim*2), + # mish, + # MLP_torch_init(output_dim=self.embed_dim) + # ]) + # self.mlp_a = nn.Sequential([ + # MLP_torch_init(output_dim=self.embed_dim*2), + # mish, + # MLP_torch_init(output_dim=self.embed_dim) + # ]) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim*self.obs_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.reward = RffReward( + self.feature_dim, + self.reward_hidden_dims, + rff_dim=self.rff_dim, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) + self.alphabars = alphabars + + def forward_phi(self, s, a): + # s = self.mlp_s(s) + # a = self.mlp_a(a) + x = jnp.concat([s, a], axis=-1) + x = self.mlp_phi(x) + return x + + def forward_mu(self, sp, t): + t = self.mlp_t(t) + x = jnp.concat([sp, t], axis=-1) + x = self.mlp_mu(x) + return x.reshape(-1, self.feature_dim, self.obs_dim) + + def forward_reward(self, x: jnp.ndarray): + return self.reward(x) + + def __call__(self, rng, s, a, sp, training: bool=False): + rng, t_rng, eps_rng = jax.random.split(rng, 3) + t = jax.random.randint(t_rng, (s.shape[0], 1), 0, self.num_noises+1) + eps = jax.random.normal(eps_rng, sp.shape) + spt = jnp.sqrt(self.alphabars[t]) * sp + jnp.sqrt(1-self.alphabars[t]) * eps + z_phi = self.forward_phi(s, a) + z_mu = self.forward_mu(spt, t) + eps_pred = jax.lax.batch_matmul(z_phi[..., jnp.newaxis, :], z_mu)[..., 0, :] + r_pred = self.forward_reward(z_phi) + return eps, eps_pred, r_pred, z_phi + + +@partial(jax.jit, static_argnames=("reward_coef")) +def update_factorized_ddpm( + rng: PRNGKey, + ddpm: FactorizedDDPM, + batch: Batch, + reward_coef: float, +) -> Tuple[PRNGKey, FactorizedDDPM, Metric]: + B = batch.obs.shape[0] + rng, update_rng = jax.random.split(rng) + def loss_fn(ddpm_params: Param, dropout_rng: PRNGKey): + eps, eps_pred, r_pred, z_phi = ddpm.apply( + {"params": ddpm_params}, + update_rng, + batch.obs, + batch.action, + batch.next_obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + ddpm_loss = ((eps_pred - eps) ** 2).sum(axis=-1).mean() + reward_loss = ((r_pred - batch.reward) ** 2).mean() + loss = ddpm_loss + reward_coef * reward_loss + return loss, { + "loss/ddpm_loss": ddpm_loss, + "loss/reward_loss": reward_loss, + "misc/sp0_mean": batch.next_obs.mean(), + "misc/sp0_std": batch.next_obs.std(axis=0).mean(), + "misc/sp0_l1": jnp.abs(batch.next_obs).mean(), + "misc/eps_mean": eps_pred.mean(), + "misc/eps_l1": jnp.abs(eps_pred).mean(), + "misc/reward_mean": r_pred.mean(), + "misc/z_phi_l1": jnp.abs(z_phi).mean(), + } + + new_ddpm, metrics = ddpm.apply_gradient(loss_fn) + return rng, new_ddpm, metrics diff --git a/flowrl/agent/online/sdac.py b/flowrl/agent/online/sdac.py index 4e89413..2fea9b3 100644 --- a/flowrl/agent/online/sdac.py +++ b/flowrl/agent/online/sdac.py @@ -116,6 +116,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra "misc/weight_std": weights.std(0).mean(), "misc/weights_max": weights.max(0).mean(), "misc/weights_min": weights.min(0).mean(), + "misc/eps_estimation_l1": jnp.abs((weights * eps_reverse).sum(axis=0)).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) @@ -181,25 +182,15 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: SDACConfig, seed: int): t_schedule_n=1.0, optimizer=optax.adam(learning_rate=actor_lr), ) - # CHECK: is this really necessary, since we are not using the target actor for policy evaluation? - self.actor_target = ContinuousDDPM.create( - network=backbone_def, - rng=actor_rng, - inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), - x_dim=self.act_dim, - steps=cfg.diffusion.steps, - noise_schedule="cosine", - noise_schedule_params={}, - clip_sampler=cfg.diffusion.clip_sampler, - x_min=cfg.diffusion.x_min, - x_max=cfg.diffusion.x_max, - t_schedule_n=1.0, - ) # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] critic_def = EnsembleCritic( hidden_dims=cfg.critic_hidden_dims, - activation=jax.nn.relu, + activation=critic_activation, layer_norm=False, dropout=None, ensemble_size=2, diff --git a/flowrl/agent/online/unirep/__init__.py b/flowrl/agent/online/unirep/__init__.py index e789c67..fa740b6 100644 --- a/flowrl/agent/online/unirep/__init__.py +++ b/flowrl/agent/online/unirep/__init__.py @@ -1,5 +1,7 @@ from .aca import ACAAgent +from .diffsr.diffsr_td3 import DiffSRACAAgent __all__ = [ "ACAAgent", + "DiffSRACAAgent", ] diff --git a/flowrl/agent/online/unirep/aca.py b/flowrl/agent/online/unirep/aca.py index e23bdb6..5ec437d 100644 --- a/flowrl/agent/online/unirep/aca.py +++ b/flowrl/agent/online/unirep/aca.py @@ -10,11 +10,12 @@ from flowrl.config.online.mujoco.algo.unirep.aca import ACAConfig from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone from flowrl.flow.ddpm import DDPM, DDPMBackbone -from flowrl.functional.activation import atanh, l2_normalize, mish, tanh +from flowrl.functional.activation import l2_normalize, mish from flowrl.functional.ema import ema_update -from flowrl.module.critic import EnsembleCritic, EnsembleCriticT +from flowrl.module.critic import Critic, EnsembleCritic, EnsembleCriticT from flowrl.module.mlp import MLP from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic from flowrl.module.time_embedding import LearnableFourierEmbedding, PositionalEmbedding from flowrl.types import Batch, Metric, Param, PRNGKey @@ -43,13 +44,15 @@ def jit_sample_actions( else: # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") - qs = critic(obs_repeat, actions) - qs = qs.min(axis=0).reshape(B, num_samples) + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = nce_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) best_idx = qs.argmax(axis=-1) actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] return rng, actions - @partial(jax.jit, static_argnames=("deterministic", "exploration_noise")) def jit_td3_sample_action( rng: PRNGKey, @@ -81,26 +84,28 @@ def jit_update_critic( ) -> Tuple[PRNGKey, Model, Metric]: # q0 target B = batch.obs.shape[0] - t0 = jnp.ones((batch.obs.shape[0], 1)) + A = batch.action.shape[-1] + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) + a0 = batch.action + + rng, next_aT_rng = jax.random.split(rng) - next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + next_a0 = backup(batch.next_obs, training=False) + # next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) # rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) - next_a0 = backup(batch.next_obs) # next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") # q0_target = critic_target(next_f0) - q0_target = critic_target(batch.next_obs, next_a0) + next_f1 = nce_target(batch.next_obs, next_a0, t1, method="forward_phi") + f1 = nce_target(batch.obs, a0, t1, method="forward_phi") + # q0_target = critic_target(batch.next_obs, next_a0) + q0_target = critic_target(next_f1) q0_target = batch.reward + discount * (1 - batch.terminal) * q0_target.min(axis=0) - # qt target - a0 = batch.action - # f0 = nce_target(batch.obs, a0, t0, method="forward_phi") - qt_target = critic_target(batch.obs, a0).mean(axis=0) - # qt_target = critic(batch.obs, a0).mean(axis=0) # features # rng, at, t, eps = actor.add_noise(rng, a0) # weight_t = actor.alpha_hats[t] / (1-actor.alpha_hats[t]) - weight_t = 1.0 + # weight_t = 1.0 # ft = nce_target(batch.obs, at, t, method="forward_phi") # rng, t_rng, noise_rng = jax.random.split(rng, 3) # t = jax.random.randint(t_rng, (*a0.shape[:-1], 1), 0, actor.steps+1) @@ -109,106 +114,45 @@ def jit_update_critic( # at = jnp.sqrt(actor.alpha_hats[t]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t]) * eps def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + # f0 = nce_target(batch.obs, a0, t0, method="forward_phi") q0_pred = critic.apply( {"params": critic_params}, - batch.obs, - a0, - # t0, - # training=True, - rngs={"dropout": dropout_rng}, + f1, + # batch.obs, + # a0, ) - # qt_pred = value.apply( - # {"params": critic_params}, - # batch.obs, - # at, - # t, - # # training=True, - # rngs={"dropout": dropout_rng}, - # ) critic_loss = ( ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() - # + ((qt_pred - qt_target[jnp.newaxis, :])**2).mean() ) return critic_loss, { "loss/critic_loss": critic_loss, "misc/q0_mean": q0_pred.mean(), - # "misc/qt_mean": qt_pred.mean(), "misc/reward": batch.reward.mean(), "misc/next_action_l1": jnp.abs(next_a0).mean(), + "misc/q0_target": q0_target.mean(), } new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + # rng, at, t, eps = actor.add_noise(rng, a0) + rng, t_rng, noise_rng = jax.random.split(rng, 3) + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) + eps = jax.random.normal(noise_rng, a0.shape) + at = jnp.sqrt(actor.alpha_hats[t1]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t1]) * eps + ft = nce_target(batch.obs, at, t1, method="forward_phi") + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: - value_loss = 0 - for t in range(0, actor.steps+1): - t_input = jnp.ones((B, 1)) * t - noise_rng, dropout_rng = jax.random.split(dropout_rng) - eps = jax.random.normal(noise_rng, a0.shape) - at = jnp.sqrt(actor.alpha_hats[t]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t]) * eps - qt_pred = value.apply( - {"params": value_params}, - batch.obs, - at, - t_input, - training=True, - rngs={"dropout": dropout_rng}, - ) - value_loss += ((qt_pred - q0_target[:])**2).mean() + qt_pred = value.apply( + {"params": value_params}, + ft, + ) + value_loss = ((qt_pred - q0_target)**2).mean() return value_loss, { "loss/value_loss": value_loss, "misc/qt_mean": qt_pred.mean(), } new_value, value_metrics = value.apply_gradient(value_loss_fn) - # t_zero = jnp.zeros((B, 1)) - # v0_target = value_target(batch.next_obs, next_a0, t_zero) - # v0_target = v0_target.min(axis=0) - # v0_target = batch.reward + discount * (1-batch.terminal) * v0_target - - # def td_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: - # v0_pred = value.apply( - # {"params": value_params}, - # batch.obs, - # batch.action, - # t_zero, - # training=True, - # rngs={"dropout": dropout_rng}, - # ) - # td_loss = ((v0_pred - v0_target[jnp.newaxis, :])**2).mean() - # return td_loss, { - # "loss/td_loss": td_loss, - # } - # # value, td_metrics = value.apply_gradient(td_loss_fn) - # td_metrics = {} - - # rng, rng1, rng2, rng3 = jax.random.split(rng, 4) - # t = jax.random.randint(rng1, (B, 1), 0, actor.steps) - # eps = jax.random.normal(rng2, batch.action.shape) - # at = jnp.sqrt(actor.alpha_hats[t]) * batch.action + jnp.sqrt(1 - actor.alpha_hats[t]) * eps - # vt_target = value(batch.obs, batch.action, t_zero).mean(axis=0) - # def nontd_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: - # vt_pred = value.apply( - # {"params": value_params}, - # batch.obs, - # at, - # t, - # training=True, - # rngs={"dropout": dropout_rng}, - # ) - # nontd_loss = ((vt_pred - q0_target[jnp.newaxis, :])**2).mean() - # return nontd_loss, { - # "loss/nontd_loss": nontd_loss, - # } - # new_value, nontd_metrics = value.apply_gradient(nontd_loss_fn) - - # return rng, new_critic, new_value, { - # **critic_metrics, - # # **value_metrics, - # **td_metrics, - # **nontd_metrics, - # } - return rng, new_critic, new_value, { **critic_metrics, **value_metrics, @@ -219,6 +163,7 @@ def jit_compute_metrics( actor: Model, critic: Model, value: Model, + nce_target: Model, batch: Batch, ) -> Tuple[PRNGKey, Metric]: B, S = batch.obs.shape @@ -231,7 +176,9 @@ def jit_compute_metrics( action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) def get_critic(at, obs): - q = critic(obs, at) + t1 = jnp.ones((1, ), dtype=jnp.int32) + f1 = nce_target(obs, at, t1, method="forward_phi") + q = critic(f1) return q.mean() all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( action_repeat.reshape(-1, A), @@ -240,13 +187,13 @@ def get_critic(at, obs): all_critic = all_critic.reshape(B, num_actions, 1) all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) metrics.update({ - f"q_mean/critic": all_critic.mean(), f"q_std/critic": all_critic.std(axis=1).mean(), f"q_grad/critic": jnp.abs(all_critic_grad).mean(), }) def get_value(at, obs, t): - q = value(obs, at, t) + ft = nce_target(obs, at, t, method="forward_phi") + q = value(ft) return q.mean() for t in [0] + list(range(1, actor.steps+1, actor.steps//5)): @@ -259,7 +206,6 @@ def get_value(at, obs, t): all_value = all_value.reshape(B, num_actions, 1) all_value_grad = all_value_grad.reshape(B, num_actions, -1) metrics.update({ - f"q_mean/value_{t}": all_value.mean(), f"q_std/value_{t}": all_value.std(axis=1).mean(), f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), }) @@ -278,16 +224,17 @@ def jit_update_actor( temp: float, ) -> Tuple[PRNGKey, Model, Metric]: a0 = batch.action + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) rng, at, t, eps = actor.add_noise(rng, a0) - # alpha, sigma = actor.noise_schedule_func(t) sigma = jnp.sqrt(1 - actor.alpha_hats[t]) def get_q_value(at: jnp.ndarray, obs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: - # ft = nce_target(obs, at, t, method="forward_phi") - q = value_target(obs, at, t) - # q = critic_target(obs, at) - return q.mean(axis=0).mean() + t1 = jnp.ones(t.shape, dtype=jnp.int32) + ft = nce_target(obs, at, t1, method="forward_phi") + q = value_target(ft) + return q.mean() q_grad_fn = jax.vmap(jax.grad(get_q_value)) q_grad = q_grad_fn(at, batch.obs, t) + q_grad_l1 = jnp.abs(q_grad).mean() eps_estimation = - sigma * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) # eps_estimation = -sigma * l2_normalize(q_grad) / temp @@ -304,6 +251,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra return loss, { "loss/actor_loss": loss, "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/q_grad_l1": q_grad_l1, } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) @@ -314,11 +262,18 @@ def backup_loss_fn(backup_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar training=True, rngs={"dropout": dropout_rng}, ) - q = critic_target(batch.obs, new_action) - backup_loss = - q.mean() - return backup_loss, {} + f1 = nce_target(batch.obs, new_action, t1, method="forward_phi") + q = critic_target(f1) + loss = - q.mean() + return loss, { + "loss/backup_loss": loss, + } new_backup, backup_metrics = backup.apply_gradient(backup_loss_fn) - return rng, new_actor, new_backup, actor_metrics + + return rng, new_actor, new_backup, { + **actor_metrics, + **backup_metrics, + } class ACAAgent(BaseAgent): @@ -440,59 +395,54 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): "relu": jax.nn.relu, "elu": jax.nn.elu, }[cfg.critic_activation] - # critic_def = nn.Sequential([ - # nn.LayerNorm(), - # EnsembleCritic( - # hidden_dims=cfg.critic_hidden_dims, - # activation=critic_activation, - # layer_norm=True, - # dropout=None, - # ensemble_size=2, - # ) - # ]) - # self.critic = Model.create( - # critic_def, - # critic_rng, - # inputs=(jnp.ones((1, self.feature_dim))), - # optimizer=optax.adam(learning_rate=cfg.critic_lr), - # ) - # self.critic_target = Model.create( - # critic_def, - # critic_rng, - # inputs=(jnp.ones((1, self.feature_dim))), + # critic_def = EnsembleCritic( + # hidden_dims=[512, 512], + # activation=critic_activation, + # ensemble_size=2, + # layer_norm=True, # ) - critic_def = EnsembleCritic( - # time_embedding=time_embedding, - hidden_dims=[512, 512, 512], - activation=critic_activation, + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=[512,], + rff_dim=cfg.rff_dim, ensemble_size=2, - layer_norm=False, ) self.critic = Model.create( critic_def, critic_rng, - inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + inputs=(jnp.ones((1, self.feature_dim))), optimizer=optax.adam(learning_rate=cfg.critic_lr), ) self.critic_target = Model.create( critic_def, critic_rng, - inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + inputs=(jnp.ones((1, self.feature_dim))), ) - # value_def = EnsembleCriticT( - # time_embedding=time_embedding, - # hidden_dims=[512, 512, 512], - # activation=critic_activation, - # ensemble_size=2, - # layer_norm=True, - # ) - from flowrl.agent.online.unirep.network import ( - EnsembleACACritic, - EnsembleResidualCritic, - ResidualCritic, - SeparateCritic, + + value_def = Critic( + hidden_dims=[512, 512], + activation=critic_activation, + layer_norm=True, + ) + self.value = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.feature_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.value_target = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.feature_dim))), ) + # from flowrl.agent.online.unirep.network import ( + # EnsembleACACritic, + # EnsembleResidualCritic, + # ResidualCritic, + # SeparateCritic, + # ) + # value_def = EnsembleACACritic( # time_dim=16, # hidden_dims=[256,256,256], @@ -504,44 +454,36 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): # hidden_dims=[512, 512, 512], # activation=jax.nn.mish, # ) - value_def = SeparateCritic( - hidden_dims=[512, 512, 512], - activation=jax.nn.mish, - ensemble_size=cfg.diffusion.steps+1, - ) - self.value = Model.create( - value_def, - value_rng, - inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), - optimizer=optax.adam(learning_rate=cfg.critic_lr), - ) - self.value_target = Model.create( - value_def, - value_rng, - inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), - ) + # value_def = SeparateCritic( + # hidden_dims=[512, 512, 512], + # activation=jax.nn.mish, + # ensemble_size=cfg.diffusion.steps+1, + # ) + # self.value = Model.create( + # value_def, + # value_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + # optimizer=optax.adam(learning_rate=cfg.critic_lr), + # ) + # self.value_target = Model.create( + # value_def, + # value_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + # ) # define tracking variables self._n_training_steps = 0 def train_step(self, batch: Batch, step: int) -> Metric: metrics = {} - # batch = Batch( - # obs=batch.obs, - # action=atanh(batch.action, scale=5.0), - # next_obs=batch.next_obs, - # reward=batch.reward, - # terminal=batch.terminal, - # next_action=atanh(batch.next_action, scale=5.0), - # ) - # self.rng, self.nce, nce_metrics = update_factorized_nce( - # self.rng, - # self.nce, - # batch, - # self.ranking, - # self.reward_coef, - # ) - # metrics.update(nce_metrics) + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) self.rng, self.critic, self.value, critic_metrics = jit_update_critic( self.rng, self.critic, @@ -578,6 +520,7 @@ def train_step(self, batch: Batch, step: int) -> Metric: self.actor, self.critic, self.value, + self.nce_target, batch, ) @@ -607,7 +550,7 @@ def sample_actions( else: self.rng, action_rng = jax.random.split(self.rng) action = jit_td3_sample_action( - self.rng, + action_rng, self.backup, obs, deterministic, diff --git a/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py b/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py new file mode 100644 index 0000000..7bf711e --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py @@ -0,0 +1,495 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.td3 import TD3Agent +from flowrl.agent.online.unirep.diffsr.network import ( + FactorizedDDPM, + update_factorized_ddpm, +) +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.activation import l2_normalize +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = ddpm_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + diffusion_actor: Model, + diffusion_value: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, t1, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, t1, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + + new_value = [] + value_metrics = {} + for i in range(len(diffusion_value)): + rng, eps_rng = jax.random.split(rng) + eps = jax.random.normal(eps_rng, batch.action.shape) + t = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(i) + at = jnp.sqrt(diffusion_actor.alpha_hats[t]) * batch.action + jnp.sqrt(1-diffusion_actor.alpha_hats[t]) * eps + ft = ddpm_target(batch.obs, at, t, method="forward_phi") + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + qt_pred = diffusion_value[i].apply( + {"params": value_params}, + ft, + ) + value_loss = ((qt_pred - q_target) ** 2).mean() + return value_loss, { + f"loss/value_{i}_loss": value_loss, + } + this_value, this_metrics = diffusion_value[i].apply_gradient(value_loss_fn) + new_value.append(this_value) + value_metrics.update(this_metrics) + + return rng, new_critic, new_value, { + **metrics, + **value_metrics, + } + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + diffusion_actor, + diffusion_value, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, t1, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + + # rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + # sigma = jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) + rng, rng2 = jax.random.split(rng) + B, S = batch.obs.shape + A = batch.action.shape[-1] + t_repeat = jnp.arange(diffusion_actor.steps+1)[..., jnp.newaxis, jnp.newaxis].repeat(B, axis=1) + obs_repeat = batch.obs[jnp.newaxis, ...].repeat(diffusion_actor.steps+1, axis=0) + action_repeat = batch.action[jnp.newaxis, ...].repeat(diffusion_actor.steps+1, axis=0) + at_repeat = jnp.sqrt(diffusion_actor.alpha_hats[t_repeat]) * action_repeat + jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * jax.random.normal(rng2, action_repeat.shape) + q_grad = [] + for i in range(diffusion_actor.steps+1): + def get_q_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value[i](ft) + return q.mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad.append(q_grad_fn(at_repeat[i], obs_repeat[i], t_repeat[i])) + q_grad = jnp.stack(q_grad, axis=0) + # eps_estimation = -jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * q_grad + # eps_estimation = l2_normalize(eps_estimation) * (eps_estimation.shape[-1] ** 0.5) + # eps_estimation = eps_estimation / 0.2 + eps_estimation = - jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * q_grad / 0.1 / (jnp.abs(q_grad).mean() + 1e-6) + + def diffusion_loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = diffusion_actor.apply( + {"params": diffusion_params}, + at_repeat, + t_repeat, + condition=obs_repeat, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/diffusion_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + new_diffusion, diffusion_metrics = diffusion_actor.apply_gradient(diffusion_loss_fn) + return rng, new_actor, new_diffusion, { + **metrics, + **diffusion_metrics, + } + +@jax.jit +def jit_compute_metrics( + rng: PRNGKey, + critic: Model, + ddpm_target: Model, + diffusion_value: Model, + diffusion_actor: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + t1 = jnp.ones((1, ), dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = critic(ft) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + for i in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): + def get_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value[i](ft) + return q.mean() + t_repeat = jnp.ones((B, num_actions, 1)) * jnp.int32(i) + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_repeat.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/value_{i}": all_value.std(axis=1).mean(), + f"q_grad/value_{i}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + +class DiffSRACASepAgent(TD3Agent): + """ + Diff-SR with ACA agent. + """ + + name = "DiffSRACAAgent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + # critic_def = RffEnsembleCritic( + # feature_dim=self.feature_dim, + # hidden_dims=cfg.critic_hidden_dims, + # rff_dim=cfg.rff_dim, + # ensemble_size=2, + # ) + critic_def = EnsembleCritic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ensemble_size=2, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + # define the ACA value and actor + from flowrl.agent.online.unirep.diffsr.network import ResidualCritic + from flowrl.flow.ddpm import DDPM, DDPMBackbone + from flowrl.functional.activation import mish + from flowrl.module.critic import Critic + from flowrl.module.time_embedding import PositionalEmbedding + self.rng, diffusion_value_rng, diffusion_actor_rng = jax.random.split(self.rng, 3) + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + self.diffusion_actor = DDPM.create( + network=backbone_def, + rng=diffusion_actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="vp", + noise_schedule_params={}, + approx_postvar=False, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + value_def = Critic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ) + # value_def = ResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512], + # activation=jax.nn.elu, + # ) + self.diffusion_value = [] + for t in range(cfg.diffusion.steps+1): + this_rng, diffusion_value_rng = jax.random.split(diffusion_value_rng) + self.diffusion_value.append( + Model.create( + value_def, + this_rng, + inputs=(jnp.ones((1, self.feature_dim)), ), + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, self.diffusion_value, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + self.diffusion_actor, + self.diffusion_value, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, self.diffusion_actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + self.diffusion_actor, + self.diffusion_value, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.critic, + self.ddpm_target, + self.diffusion_value, + self.diffusion_actor, + batch, + ) + metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.diffusion_actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + else: + action = super().sample_actions(obs, deterministic, num_samples) + return action, {} + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/diffsr/diffsr_td3.py b/flowrl/agent/online/unirep/diffsr/diffsr_td3.py new file mode 100644 index 0000000..2e2cfb1 --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/diffsr_td3.py @@ -0,0 +1,476 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.td3 import TD3Agent +from flowrl.agent.online.unirep.diffsr.network import ( + FactorizedDDPM, + update_factorized_ddpm, +) +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = ddpm_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + diffusion_actor: Model, + diffusion_value: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, t1, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, t1, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + + rng, t_rng, eps_rng = jax.random.split(rng, 3) + # rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + # ft = ddpm_target(batch.obs, at, t, method="forward_phi") + t = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + at = jnp.sqrt(diffusion_actor.alpha_hats[t]) * batch.action + jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) * jax.random.normal(eps_rng, batch.action.shape) + ft = ddpm_target(batch.obs, at, t, method="forward_phi") + # snr = jnp.sqrt(diffusion_actor.alpha_hats[t] / (1 - diffusion_actor.alpha_hats[t])) + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + qt_pred = diffusion_value.apply( + {"params": value_params}, + ft, + ) + value_loss = (((qt_pred - q_target) ** 2)).mean() + return value_loss, { + "loss/value_loss": value_loss, + "misc/qt_mean": qt_pred.mean(), + } + new_value, value_metrics = diffusion_value.apply_gradient(value_loss_fn) + + return rng, new_critic, new_value, { + **metrics, + **value_metrics, + } + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + diffusion_actor, + diffusion_value, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, t1, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + + rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + sigma = jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) + def get_q_value(at, obs, t): + t1 = jnp.ones(t.shape, dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = diffusion_value(ft) + return q.mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs, t) + q_grad_l1 = jnp.abs(q_grad).mean() + eps_estimation = - sigma * q_grad / 0.1 / (jnp.abs(q_grad).mean() + 1e-6) + def diffusion_loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = diffusion_actor.apply( + {"params": diffusion_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/diffusion_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/q_grad_l1": q_grad_l1, + } + new_diffusion, diffusion_metrics = diffusion_actor.apply_gradient(diffusion_loss_fn) + return rng, new_actor, new_diffusion, { + **metrics, + **diffusion_metrics, + } + +@jax.jit +def jit_compute_metrics( + rng: PRNGKey, + critic: Model, + ddpm_target: Model, + diffusion_value: Model, + diffusion_actor: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + t1 = jnp.ones((1, ), dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = critic(ft) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + def get_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value(ft) + return q.mean() + for t in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): + t_input = jnp.ones((B, num_actions, 1)) * t + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_input.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/value_{t}": all_value.std(axis=1).mean(), + f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + +class DiffSRACAAgent(TD3Agent): + """ + Diff-SR with ACA agent. + """ + + name = "DiffSRACAAgent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + # critic_def = RffEnsembleCritic( + # feature_dim=self.feature_dim, + # hidden_dims=cfg.critic_hidden_dims, + # rff_dim=cfg.rff_dim, + # ensemble_size=2, + # ) + critic_def = EnsembleCritic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ensemble_size=2, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + # define the ACA value and actor + from flowrl.agent.online.unirep.diffsr.network import ResidualCritic + from flowrl.flow.ddpm import DDPM, DDPMBackbone + from flowrl.functional.activation import mish + from flowrl.module.critic import Critic + from flowrl.module.time_embedding import PositionalEmbedding + self.rng, diffusion_value_rng, diffusion_actor_rng = jax.random.split(self.rng, 3) + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + self.diffusion_actor = DDPM.create( + network=backbone_def, + rng=diffusion_actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="vp", + noise_schedule_params={}, + approx_postvar=False, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + value_def = Critic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ) + # value_def = ResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512], + # activation=jax.nn.elu, + # ) + self.diffusion_value = Model.create( + value_def, + diffusion_value_rng, + inputs=(jnp.ones((1, self.feature_dim)), ), + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, self.diffusion_value, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + self.diffusion_actor, + self.diffusion_value, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, self.diffusion_actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + self.diffusion_actor, + self.diffusion_value, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.critic, + self.ddpm_target, + self.diffusion_value, + self.diffusion_actor, + batch, + ) + metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.diffusion_actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + else: + action = super().sample_actions(obs, deterministic, num_samples) + return action, {} + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/diffsr/network.py b/flowrl/agent/online/unirep/diffsr/network.py new file mode 100644 index 0000000..d4139dc --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/network.py @@ -0,0 +1,161 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp + +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import l2_normalize, mish +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class ResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + + @nn.compact + def __call__(self, ft: jnp.ndarray, t: jnp.ndarray, training: bool=False): + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([item for item in [ft, ] if item is not None], axis=-1) + x = ResidualMLP( + hidden_dims=self.hidden_dims, + output_dim=1, + multiplier=1, + activation=self.activation, + layer_norm=True, + )(x, training) + return x + +class FactorizedDDPM(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + embed_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int + num_noises: int + + def setup(self): + self.mlp_t1 = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + nn.Dense(2*self.embed_dim), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_t2 = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + nn.Dense(2*self.embed_dim), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_s = nn.Sequential([ + nn.Dense(self.embed_dim*2), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_a = nn.Sequential([ + nn.Dense(self.embed_dim*2), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim*self.obs_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.reward = RffReward( + self.feature_dim, + self.reward_hidden_dims, + rff_dim=self.rff_dim, + ) + betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) + alphabars_prev = jnp.pad(alphabars[:-1], (1, 0), constant_values=1.0) + self.alphabars = alphabars + + def forward_phi(self, s, a, t): + s = self.mlp_s(s) + a = self.mlp_a(a) + t_ff = self.mlp_t1(t) + x = jnp.concat([s, a, t_ff], axis=-1) + x = self.mlp_phi(x) + return x + + def forward_mu(self, sp, t): + t = self.mlp_t2(t) + x = jnp.concat([sp, t], axis=-1) + x = self.mlp_mu(x) + return x.reshape(-1, self.feature_dim, self.obs_dim) + + def forward_reward(self, x: jnp.ndarray): + return self.reward(x) + + def __call__(self, rng, s, a, sp, training: bool=False): + rng, t1_rng, eps1_rng, t2_rng, eps2_rng = jax.random.split(rng, 5) + t1 = jax.random.randint(t1_rng, (s.shape[0], 1), 0, self.num_noises+1) + t2 = jax.random.randint(t2_rng, (s.shape[0], 1), 0, self.num_noises+1) + eps1 = jax.random.normal(eps1_rng, a.shape) + eps2 = jax.random.normal(eps2_rng, sp.shape) + at = jnp.sqrt(self.alphabars[t1]) * a + jnp.sqrt(1-self.alphabars[t1]) * eps1 + spt = jnp.sqrt(self.alphabars[t2]) * sp + jnp.sqrt(1-self.alphabars[t2]) * eps2 + z_phi = self.forward_phi(s, at, t1) + z_mu = self.forward_mu(spt, t2) + eps_pred = jax.lax.batch_matmul(z_phi[..., jnp.newaxis, :], z_mu)[..., 0, :] + r_pred = self.forward_reward(z_phi) + return eps2, eps_pred, r_pred + + +@partial(jax.jit, static_argnames=("reward_coef")) +def update_factorized_ddpm( + rng: PRNGKey, + ddpm: FactorizedDDPM, + batch: Batch, + reward_coef: float, +) -> Tuple[PRNGKey, FactorizedDDPM, Metric]: + B = batch.obs.shape[0] + rng, update_rng = jax.random.split(rng) + def loss_fn(ddpm_params: Param, dropout_rng: PRNGKey): + eps, eps_pred, r_pred = ddpm.apply( + {"params": ddpm_params}, + update_rng, + batch.obs, + batch.action, + batch.next_obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + ddpm_loss = ((eps_pred - eps) ** 2).mean() + reward_loss = ((r_pred - batch.reward) ** 2).mean() + loss = ddpm_loss + reward_coef * reward_loss + return loss, { + "loss/ddpm_loss": ddpm_loss, + "loss/reward_loss": reward_loss, + "misc/sp0_mean": batch.next_obs.mean(), + "misc/sp0_std": batch.next_obs.std(axis=0).mean(), + "misc/eps_mean": eps_pred.mean(), + "misc/reward_mean": r_pred.mean(), + } + + new_ddpm, metrics = ddpm.apply_gradient(loss_fn) + return rng, new_ddpm, metrics diff --git a/flowrl/agent/online/unirep/network.py b/flowrl/agent/online/unirep/network.py index e1dc431..2cb806b 100644 --- a/flowrl/agent/online/unirep/network.py +++ b/flowrl/agent/online/unirep/network.py @@ -12,7 +12,7 @@ from flowrl.module.mlp import MLP, ResidualMLP from flowrl.module.model import Model from flowrl.module.rff import RffReward -from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.module.time_embedding import LearnableFourierEmbedding, PositionalEmbedding from flowrl.types import * from flowrl.types import Sequence @@ -23,13 +23,13 @@ class ResidualCritic(nn.Module): activation: Callable = nn.relu @nn.compact - def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, t: jnp.ndarray, training: bool=False): + def __call__(self, ft: jnp.ndarray, t: jnp.ndarray, training: bool=False): t_ff = self.time_embedding()(t) t_ff = MLP( hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], activation=mish, )(t_ff) - x = jnp.concatenate([item for item in [obs, action, t_ff] if item is not None], axis=-1) + x = jnp.concatenate([item for item in [ft, t_ff] if item is not None], axis=-1) x = ResidualMLP( hidden_dims=self.hidden_dims, output_dim=1, @@ -68,9 +68,6 @@ def __call__( )(obs, action, t, training) return x -from flowrl.module.time_embedding import PositionalEmbedding - - class ACACritic(nn.Module): time_dim: int hidden_dims: Sequence[int] @@ -158,8 +155,11 @@ class FactorizedNCE(nn.Module): ranking: bool = False def setup(self): - self.mlp_t = nn.Sequential( - [LearnableFourierEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + self.mlp_t1 = nn.Sequential( + [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + ) + self.mlp_t2 = nn.Sequential( + [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] ) self.mlp_phi = ResidualMLP( self.phi_hidden_dims, @@ -177,20 +177,24 @@ def setup(self): layer_norm=True, dropout=None, ) - # self.reward = RffReward( - # self.feature_dim, - # self.reward_hidden_dims, - # rff_dim=self.rff_dim, - # ) - self.reward = Critic( - hidden_dims=self.reward_hidden_dims, - activation=nn.elu, - layer_norm=True, - dropout=None, + self.reward = RffReward( + self.feature_dim, + [512,], + rff_dim=self.rff_dim, ) + # self.reward = Critic( + # hidden_dims=self.reward_hidden_dims, + # activation=nn.elu, + # layer_norm=True, + # dropout=None, + # ) if self.num_noises > 0: self.use_noise_perturbation = True - self.noise_schedule_fn = cosine_noise_schedule + from flowrl.flow.ddpm import cosine_beta_schedule + betas = cosine_beta_schedule(T=self.num_noises) + betas = jnp.concatenate([jnp.zeros((1,)), betas]) + alphas = 1 - betas + self.alpha_hats = jnp.cumprod(alphas) else: self.use_noise_perturbation = False self.N = max(self.num_noises, 1) @@ -201,15 +205,18 @@ def setup(self): def forward_phi(self, s, at, t): x = jnp.concat([s, at], axis=-1) - if t is not None: - t_ff = self.mlp_t(t) - x = jnp.concat([x, t_ff], axis=-1) + t_ff = self.mlp_t1(t) + x = jnp.concat([x, t_ff], axis=-1) x = self.mlp_phi(x) x = l2_normalize(x, group_size=None) + # x = jnp.tanh(x) return x - def forward_mu(self, sp): + def forward_mu(self, sp, t): + t_ff = self.mlp_t2(t) + sp = jnp.concat([sp, t_ff], axis=-1) sp = self.mlp_mu(sp) + sp = jnp.tanh(sp) return sp def forward_reward(self, x: jnp.ndarray): # for z_phi @@ -221,27 +228,38 @@ def forward_logits( s: jnp.ndarray, a: jnp.ndarray, sp: jnp.ndarray, - z_mu: jnp.ndarray | None=None ): B, D = sp.shape - rng, t_rng, eps_rng = jax.random.split(rng, 3) - if z_mu is None: - z_mu = self.forward_mu(sp) + rng, t_rng, eps1_rng, eps2_rng = jax.random.split(rng, 4) if self.use_noise_perturbation: - s = jnp.broadcast_to(s, (self.N, B, s.shape[-1])) - a0 = jnp.broadcast_to(a, (self.N, B, a.shape[-1])) - t = jax.random.uniform(t_rng, (self.N,), dtype=jnp.float32) # check removing min val and max val is valid - t = jnp.repeat(t, B).reshape(self.N, B, 1) - eps = jax.random.normal(eps_rng, a0.shape) - alpha, sigma = self.noise_schedule_fn(t) - at = alpha * a0 + sigma * eps + + # perturb sp, (N, B, D) with the noise level shared in each N + sp0 = jnp.broadcast_to(sp, (self.N, B, sp.shape[-1])) + t_sp = jnp.arange(1, self.num_noises+1) + t_sp = jnp.repeat(t_sp, B).reshape(self.N, B, 1) + eps_sp = jax.random.normal(eps1_rng, sp0.shape) + alpha, sigma = jnp.sqrt(self.alpha_hats[t_sp]), jnp.sqrt(1 - self.alpha_hats[t_sp]) + spt = alpha * sp0 + sigma * eps_sp + + # perturb a, (N, B, D) with noise level independent across N + a0 = a + t_a = jax.random.randint(t_rng, (B, 1), 1, self.num_noises+1) + # t_a = jnp.ones((B, 1), dtype=jnp.int32) + eps_a = jax.random.normal(eps2_rng, a0.shape) + alpha, sigma = jnp.sqrt(self.alpha_hats[t_a]), jnp.sqrt(1 - self.alpha_hats[t_a]) + at = alpha * a0 + sigma * eps_a + # at = jnp.broadcast_to(at, (self.N, B, at.shape[-1])) + # t_a = jnp.broadcast_to(t_a, (self.N, B, 1)) else: s = jnp.expand_dims(s, 0) at = jnp.expand_dims(a, 0) t = None - z_phi = self.forward_phi(s, at, t) - z_mu = jnp.broadcast_to(z_mu, (self.N, B, self.feature_dim)) - logits = jax.lax.batch_matmul(z_phi, jnp.swapaxes(z_mu, -1, -2)) + z_phi = self.forward_phi(s, at, t_a) + z_mu = self.forward_mu(spt, t_sp) + logits = jax.lax.batch_matmul( + jnp.broadcast_to(z_phi, (self.N, B, z_phi.shape[-1])), + jnp.swapaxes(z_mu, -1, -2) + ) logits = logits / jnp.exp(self.normalizer[:, None, None]) rewards = self.forward_reward(z_phi) return logits, rewards @@ -278,14 +296,12 @@ def update_factorized_nce( labels = jnp.eye(B) def loss_fn(nce_params: Param, dropout_rng: PRNGKey): - z_mu = nce.apply({"params": nce_params}, batch.next_obs, method="forward_mu") logits, rewards = nce.apply( {"params": nce_params}, logits_rng, batch.obs, batch.action, batch.next_obs, - z_mu, method="forward_logits", ) diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index 8a72576..89a4940 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -2,7 +2,8 @@ from .algo.alac import ALACConfig from .algo.base import BaseAlgoConfig -from .algo.ctrl import * +from .algo.ctrlsr import * +from .algo.diffsr import * from .algo.dpmd import DPMDConfig from .algo.idem import IDEMConfig from .algo.qsm import QSMConfig @@ -30,9 +31,10 @@ "qsm": QSMConfig, "alac": ALACConfig, "idem": IDEMConfig, - "ctrl_td3": CtrlTD3Config, - "ctrl_qsm": CtrlQSMConfig, + "ctrlsr_td3": CtrlSRTD3Config, "aca": ACAConfig, + "diffsr_td3": DiffSRTD3Config, + "diffsr_qsm": DiffSRQSMConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/ctrl/__init__.py b/flowrl/config/online/mujoco/algo/ctrl/__init__.py deleted file mode 100644 index 7ed1456..0000000 --- a/flowrl/config/online/mujoco/algo/ctrl/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .ctrl_qsm import CtrlQSMConfig -from .ctrl_td3 import CtrlTD3Config - -__all__ = [ - "CtrlTD3Config", - "CtrlQSMConfig", -] diff --git a/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py b/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py new file mode 100644 index 0000000..0dad455 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py @@ -0,0 +1,5 @@ +from .ctrlsr_td3 import CtrlSRTD3Config + +__all__ = [ + "CtrlSRTD3Config", +] diff --git a/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py b/flowrl/config/online/mujoco/algo/ctrlsr/ctrl_qsm.py similarity index 100% rename from flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py rename to flowrl/config/online/mujoco/algo/ctrlsr/ctrl_qsm.py diff --git a/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py b/flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py similarity index 95% rename from flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py rename to flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py index 374d820..5c51a2a 100644 --- a/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py +++ b/flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py @@ -5,7 +5,7 @@ @dataclass -class CtrlTD3Config(BaseAlgoConfig): +class CtrlSRTD3Config(BaseAlgoConfig): name: str actor_update_freq: int target_update_freq: int diff --git a/flowrl/config/online/mujoco/algo/diffsr/__init__.py b/flowrl/config/online/mujoco/algo/diffsr/__init__.py new file mode 100644 index 0000000..177cbe6 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/__init__.py @@ -0,0 +1,7 @@ +from .diffsr_qsm import DiffSRQSMConfig +from .diffsr_td3 import DiffSRTD3Config + +__all__ = [ + "DiffSRTD3Config", + "DiffSRQSMConfig", +] diff --git a/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py b/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py new file mode 100644 index 0000000..db067f7 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig +from ..qsm import QSMDiffusionConfig + + +@dataclass +class DiffSRQSMConfig(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + # critic_hidden_dims: List[int] + critic_activation: str # not used + critic_ensemble_size: int + layer_norm: bool + critic_lr: float + clip_grad_norm: float | None + + num_noises: int + feature_dim: int + feature_lr: float + feature_ema: float + embed_dim: int + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ddpm_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float + + num_samples: int + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py b/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py new file mode 100644 index 0000000..6f68a2c --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig + + +@dataclass +class DiffSRTD3Config(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + actor_hidden_dims: List[int] + # critic_hidden_dims: List[int] + critic_ensemble_size: int + layer_norm: bool + actor_lr: float + critic_lr: float + clip_grad_norm: float | None + target_policy_noise: float + noise_clip: float + exploration_noise: float + + num_noises: int + feature_dim: int + feature_lr: float + feature_ema: float + embed_dim: int + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ddpm_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float diff --git a/flowrl/config/online/mujoco/algo/sdac.py b/flowrl/config/online/mujoco/algo/sdac.py index bf1f78e..c26e080 100644 --- a/flowrl/config/online/mujoco/algo/sdac.py +++ b/flowrl/config/online/mujoco/algo/sdac.py @@ -23,6 +23,7 @@ class SDACDiffusionConfig: class SDACConfig(BaseAlgoConfig): name: str critic_hidden_dims: List[int] + critic_activation: str critic_lr: float discount: float num_samples: int diff --git a/flowrl/module/actor.py b/flowrl/module/actor.py index 5add3dd..882143b 100644 --- a/flowrl/module/actor.py +++ b/flowrl/module/actor.py @@ -2,6 +2,7 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.types import * from flowrl.utils.distribution import TanhMultivariateNormalDiag @@ -12,6 +13,8 @@ class DeterministicActor(nn.Module): backbone: nn.Module obs_dim: int action_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -20,7 +23,7 @@ def __call__( training: bool = False, ) -> jnp.ndarray: x = self.backbone(obs, training) - x = MLP(output_dim=self.action_dim)(x) + x = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x @@ -28,6 +31,8 @@ class SquashedDeterministicActor(DeterministicActor): backbone: nn.Module obs_dim: int action_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -45,6 +50,8 @@ class GaussianActor(nn.Module): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -54,10 +61,10 @@ def __call__( ) -> jnp.ndarray: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) logstd = jnp.clip(logstd, self.logstd_min, self.logstd_max) distribution = distrax.MultivariateNormalDiag(mean, jnp.exp(logstd)) @@ -71,6 +78,8 @@ class SquashedGaussianActor(GaussianActor): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -80,10 +89,10 @@ def __call__( ) -> distrax.Distribution: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) logstd = jnp.clip(logstd, self.logstd_min, self.logstd_max) distribution = TanhMultivariateNormalDiag(mean, jnp.exp(logstd)) @@ -97,6 +106,8 @@ class TanhMeanGaussianActor(GaussianActor): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -106,10 +117,10 @@ def __call__( ) -> jnp.ndarray: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) # broadcast logstd to the shape of mean logstd = jnp.broadcast_to(logstd, mean.shape) diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index e99e2ca..f0f90d4 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -1,6 +1,7 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.functional.activation import mish from flowrl.types import * @@ -12,6 +13,8 @@ class Critic(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -30,6 +33,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training) return x @@ -40,6 +45,8 @@ class EnsembleCritic(nn.Module): layer_norm: bool = False dropout: Optional[float] = None ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -61,6 +68,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(obs, action, training) return x @@ -70,6 +79,8 @@ class CriticT(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -83,6 +94,8 @@ def __call__( t_ff = MLP( hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], activation=mish, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(t_ff) x = jnp.concatenate([item for item in [obs, action, t_ff] if item is not None], axis=-1) x = MLP( @@ -91,6 +104,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training) return x @@ -102,6 +117,8 @@ class EnsembleCriticT(nn.Module): layer_norm: bool = False dropout: Optional[float] = None ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -124,6 +141,8 @@ def __call__( hidden_dims=self.hidden_dims, activation=self.activation, layer_norm=self.layer_norm, - dropout=self.dropout + dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(obs, action, t, training) return x diff --git a/flowrl/module/initialization.py b/flowrl/module/initialization.py index 926e557..0b51721 100644 --- a/flowrl/module/initialization.py +++ b/flowrl/module/initialization.py @@ -1,5 +1,6 @@ import flax.linen as nn -from flax.linen.initializers import lecun_normal +import jax +from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init from jax import numpy as jnp from flowrl.types import * @@ -10,11 +11,11 @@ def orthogonal_init(scale: Optional[float] = None): scale = jnp.sqrt(2) return nn.initializers.orthogonal(scale) +def pytorch_kernel_init(): + return variance_scaling(scale=1/3, mode="fan_in", distribution="uniform") -def uniform_init(scale_final=None): - if scale_final is not None: - return nn.initializers.xavier_uniform(scale_final) - return nn.initializers.xavier_uniform() +def pytorch_bias_init(): + return lambda key, shape, dtype=jnp.float32: (jax.random.uniform(key, shape, dtype)*2-1) / jnp.sqrt(shape[0]) - -default_init = orthogonal_init +default_kernel_init = orthogonal_init +default_bias_init = zeros_init diff --git a/flowrl/module/mlp.py b/flowrl/module/mlp.py index 22c3454..f447c35 100644 --- a/flowrl/module/mlp.py +++ b/flowrl/module/mlp.py @@ -3,10 +3,9 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.types import * -from .initialization import default_init - class MLP(nn.Module): hidden_dims: Sequence[int] = field(default_factory=lambda: []) @@ -14,6 +13,8 @@ class MLP(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -22,14 +23,14 @@ def __call__( training: bool = False, ) -> jnp.ndarray: for i, size in enumerate(self.hidden_dims): - x = nn.Dense(size, kernel_init=default_init())(x) + x = nn.Dense(size, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) if self.layer_norm: x = nn.LayerNorm()(x) x = self.activation(x) if self.dropout and self.dropout > 0: x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) if self.output_dim > 0: - x = nn.Dense(self.output_dim, kernel_init=default_init())(x) + x = nn.Dense(self.output_dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) return x @@ -39,6 +40,8 @@ class ResidualLinear(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -51,12 +54,12 @@ def __call__( x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) if self.layer_norm: x = nn.LayerNorm()(x) - x = nn.Dense(self.dim * self.multiplier, kernel_init=default_init())(x) + x = nn.Dense(self.dim * self.multiplier, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) x = self.activation(x) - x = nn.Dense(self.dim, kernel_init=default_init())(x) + x = nn.Dense(self.dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) if residual.shape != x.shape: - residual = nn.Dense(self.dim, kernel_init=default_init())(residual) + residual = nn.Dense(self.dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(residual) return residual + x @@ -68,6 +71,8 @@ class ResidualMLP(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -76,10 +81,18 @@ def __call__( training: bool = False, ) -> jnp.ndarray: if len(self.hidden_dims) > 0: - x = nn.Dense(self.hidden_dims[0], kernel_init=default_init())(x) + x = nn.Dense(self.hidden_dims[0], kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) for i, size in enumerate(self.hidden_dims): - x = ResidualLinear(size, self.multiplier, self.activation, self.layer_norm, self.dropout)(x, training) + x = ResidualLinear( + size, + self.multiplier, + self.activation, + self.layer_norm, + self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(x, training) if self.output_dim > 0: x = self.activation(x) - x = nn.Dense(self.output_dim, kernel_init=default_init())(x) + x = nn.Dense(self.output_dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) return x diff --git a/flowrl/module/rff.py b/flowrl/module/rff.py index 373bf3b..53f3139 100644 --- a/flowrl/module/rff.py +++ b/flowrl/module/rff.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.module.mlp import MLP from flowrl.types import * @@ -10,6 +11,8 @@ class RffLayer(nn.Module): feature_dim: int rff_dim: int learnable: bool = True + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: @@ -17,7 +20,12 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: half = self.rff_dim // 2 if self.learnable: - x = MLP(hidden_dims=[], output_dim=half)(x) + x = MLP( + hidden_dims=[], + output_dim=half, + kernel_init=self.kernel_init, + bias_init=self.bias_init + )(x) else: noise = self.variable( "noise", @@ -32,17 +40,20 @@ class RffReward(nn.Module): feature_dim: int hidden_dims: list[int] rff_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: x = nn.LayerNorm()(x) - x = RffLayer(self.feature_dim, self.rff_dim, learnable=True)(x) - + x = RffLayer(self.feature_dim, self.rff_dim, learnable=True, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = MLP( hidden_dims=self.hidden_dims, output_dim=1, layer_norm=True, activation=nn.elu, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training=training) return x @@ -56,6 +67,8 @@ class RffEnsembleCritic(nn.Module): hidden_dims: Sequence[int] rff_dim: int ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x) -> jnp.ndarray: @@ -67,5 +80,5 @@ def __call__(self, x) -> jnp.ndarray: out_axes=0, axis_size=self.ensemble_size, ) - x = vmap_rff(self.feature_dim, self.hidden_dims, self.rff_dim)(x) + x = vmap_rff(self.feature_dim, self.hidden_dims, self.rff_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x diff --git a/flowrl/types.py b/flowrl/types.py index cdf9989..67f42b8 100644 --- a/flowrl/types.py +++ b/flowrl/types.py @@ -4,6 +4,7 @@ import flax import jax import jax.numpy as jnp +from flax.linen.initializers import Initializer from flax.training.train_state import TrainState PRNGKey = NewType("PRNGKey", jax.Array) @@ -15,4 +16,4 @@ ['obs', 'action', 'reward', 'terminal', 'next_obs', 'next_action'], ) -__all__ = ["Batch", "PRNGKey", "Param", "Shape", "Metric", "Optional", "Sequence", "Any", "Dict", "Callable", "Union", "Tuple"] +__all__ = ["Batch", "PRNGKey", "Param", "Shape", "Metric", "Optional", "Sequence", "Any", "Dict", "Callable", "Union", "Tuple", "Initializer"] diff --git a/scripts/dmc/aca.sh b/scripts/dmc/aca.sh index 39f1737..efb0576 100644 --- a/scripts/dmc/aca.sh +++ b/scripts/dmc/aca.sh @@ -1,6 +1,6 @@ # Specify which GPUs to use -GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use -SEEDS=(0 1) +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0) NUM_EACH_GPU=2 PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) @@ -38,9 +38,9 @@ TASKS=( SHARED_ARGS=( "algo=aca" - "algo.temp=0.05" + "algo.temp=0.1" # "algo.critic_activation=relu" - "log.tag=backup-temp0.05-sepcritic" + "log.tag=backup_repr_rff-tanh-nonorm-rep_multi-rl_onlyuse0-512x4" "log.project=flow-rl" "log.entity=lambda-rl" ) diff --git a/scripts/dmc/ctrl_td3.sh b/scripts/dmc/ctrlsr_td3.sh similarity index 65% rename from scripts/dmc/ctrl_td3.sh rename to scripts/dmc/ctrlsr_td3.sh index fd9e038..522335a 100644 --- a/scripts/dmc/ctrl_td3.sh +++ b/scripts/dmc/ctrlsr_td3.sh @@ -1,43 +1,43 @@ # Specify which GPUs to use GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use SEEDS=(0 1 2 3 4) -NUM_EACH_GPU=3 +NUM_EACH_GPU=2 PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) TASKS=( "acrobot-swingup" - "ball_in_cup-catch" - "cartpole-balance" - "cartpole-balance_sparse" - "cartpole-swingup" - "cartpole-swingup_sparse" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" "cheetah-run" "dog-run" "dog-stand" "dog-trot" "dog-walk" - "finger-spin" - "finger-turn_easy" - "finger-turn_hard" - "fish-swim" - "hopper-hop" - "hopper-stand" - "humanoid-run" - "humanoid-stand" - "humanoid-walk" - "pendulum-swingup" - "quadruped-run" - "quadruped-walk" - "reacher-easy" - "reacher-hard" - "walker-run" - "walker-stand" - "walker-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + # "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" ) SHARED_ARGS=( - "algo=ctrl_td3" + "algo=ctrlsr_td3" "log.tag=default" "log.project=flow-rl" "log.entity=lambda-rl" diff --git a/scripts/dmc/diffsr_aca.sh b/scripts/dmc/diffsr_aca.sh new file mode 100644 index 0000000..a7c35cb --- /dev/null +++ b/scripts/dmc/diffsr_aca.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0 1) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + # "dog-run" + "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_aca" + "log.tag=new-critic_only_1-actor_only_1-temp0.1-normal_critic" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/diffsr_aca_sep.sh b/scripts/dmc/diffsr_aca_sep.sh new file mode 100644 index 0000000..4a69522 --- /dev/null +++ b/scripts/dmc/diffsr_aca_sep.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0 1) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + # "dog-run" + "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_aca_sep" + "log.tag=sep-normeps-temp0.2" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/diffsr_td3.sh b/scripts/dmc/diffsr_td3.sh new file mode 100644 index 0000000..9557580 --- /dev/null +++ b/scripts/dmc/diffsr_td3.sh @@ -0,0 +1,70 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # # "pendulum-swingup" + # "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_td3" + # "algo.embed_dim=256" + # "algo.reward_coef=1.0" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/sdac.sh b/scripts/dmc/sdac.sh new file mode 100644 index 0000000..93d5dd3 --- /dev/null +++ b/scripts/dmc/sdac.sh @@ -0,0 +1,70 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + # "cheetah-run" + # "dog-run" + # "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=sdac" + "algo.temp=0.01" + "log.tag=fix_temp-0.01" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi