|
| 1 | +from functools import partial |
| 2 | +from typing import Tuple |
| 3 | + |
| 4 | +import jax |
| 5 | +import jax.numpy as jnp |
| 6 | +import optax |
| 7 | + |
| 8 | +from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce |
| 9 | +from flowrl.agent.online.qsm import QSMAgent |
| 10 | +from flowrl.config.online.mujoco.algo.ctrl.ctrl_qsm import CtrlQSMConfig |
| 11 | +from flowrl.flow.continuous_ddpm import ContinuousDDPM |
| 12 | +from flowrl.functional.ema import ema_update |
| 13 | +from flowrl.module.model import Model |
| 14 | +from flowrl.module.rff import RffEnsembleCritic |
| 15 | +from flowrl.types import Batch, Metric, Param, PRNGKey |
| 16 | + |
| 17 | + |
| 18 | +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) |
| 19 | +def jit_sample_actions( |
| 20 | + rng: PRNGKey, |
| 21 | + actor: ContinuousDDPM, |
| 22 | + nce_target: Model, |
| 23 | + critic: Model, |
| 24 | + obs: jnp.ndarray, |
| 25 | + training: bool, |
| 26 | + num_samples: int, |
| 27 | + solver: str, |
| 28 | +) -> Tuple[PRNGKey, jnp.ndarray]: |
| 29 | + assert len(obs.shape) == 2 |
| 30 | + B = obs.shape[0] |
| 31 | + rng, xT_rng = jax.random.split(rng) |
| 32 | + |
| 33 | + # sample |
| 34 | + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) |
| 35 | + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) |
| 36 | + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) |
| 37 | + if num_samples == 1: |
| 38 | + actions = actions[:, 0] |
| 39 | + else: |
| 40 | + feature = nce_target(obs_repeat, actions, method="forward_phi") |
| 41 | + qs = critic(feature) |
| 42 | + qs = qs.min(axis=0).reshape(B, num_samples) |
| 43 | + best_idx = qs.argmax(axis=-1) |
| 44 | + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] |
| 45 | + return rng, actions |
| 46 | + |
| 47 | +@partial(jax.jit, static_argnames=("discount", "solver")) |
| 48 | +def update_critic( |
| 49 | + rng: PRNGKey, |
| 50 | + critic: Model, |
| 51 | + critic_target: Model, |
| 52 | + actor: ContinuousDDPM, |
| 53 | + nce_target: Model, |
| 54 | + batch: Batch, |
| 55 | + discount: float, |
| 56 | + solver: str, |
| 57 | + critic_coef: float |
| 58 | +) -> Tuple[PRNGKey, Model, Metric]: |
| 59 | + rng, sample_rng = jax.random.split(rng) |
| 60 | + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) |
| 61 | + rng, next_action, _ = actor.sample( |
| 62 | + rng, |
| 63 | + next_xT, |
| 64 | + batch.next_obs, |
| 65 | + training=False, |
| 66 | + solver=solver, |
| 67 | + ) |
| 68 | + next_feature = nce_target(batch.next_obs, next_action, method="forward_phi") |
| 69 | + q_target = critic_target(next_feature).min(0) |
| 70 | + q_target = batch.reward + discount * (1 - batch.terminal) * q_target |
| 71 | + |
| 72 | + feature = nce_target(batch.obs, batch.action, method="forward_phi") |
| 73 | + |
| 74 | + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: |
| 75 | + q_pred = critic.apply( |
| 76 | + {"params": critic_params}, |
| 77 | + feature, |
| 78 | + rngs={"dropout": dropout_rng}, |
| 79 | + ) |
| 80 | + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() |
| 81 | + return critic_loss, { |
| 82 | + "loss/critic_loss": critic_loss, |
| 83 | + "misc/q_mean": q_pred.mean(), |
| 84 | + "misc/reward": batch.reward.mean(), |
| 85 | + } |
| 86 | + |
| 87 | + new_critic, metrics = critic.apply_gradient(critic_loss_fn) |
| 88 | + return rng, new_critic, metrics |
| 89 | + |
| 90 | +@partial(jax.jit, static_argnames=("temp")) |
| 91 | +def update_actor( |
| 92 | + rng: PRNGKey, |
| 93 | + actor: ContinuousDDPM, |
| 94 | + nce_target: Model, |
| 95 | + critic_target: Model, |
| 96 | + batch: Batch, |
| 97 | + temp: float, |
| 98 | +) -> Tuple[PRNGKey, Model, Metric]: |
| 99 | + |
| 100 | + a0 = batch.action |
| 101 | + rng, at, t, eps = actor.add_noise(rng, a0) |
| 102 | + alpha1, alpha2 = actor.noise_schedule_func(t) |
| 103 | + |
| 104 | + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: |
| 105 | + feature = nce_target(obs, action, method="forward_phi") |
| 106 | + q = critic_target(feature) |
| 107 | + return q.min(axis=0).mean() |
| 108 | + q_grad_fn = jax.vmap(jax.grad(get_q_value)) |
| 109 | + q_grad = q_grad_fn(at, batch.obs) |
| 110 | + q_grad = alpha1 * q_grad - alpha2 * at |
| 111 | + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) |
| 112 | + |
| 113 | + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: |
| 114 | + eps_pred = actor.apply( |
| 115 | + {"params": actor_params}, |
| 116 | + at, |
| 117 | + t, |
| 118 | + condition=batch.obs, |
| 119 | + training=True, |
| 120 | + rngs={"dropout": dropout_rng}, |
| 121 | + ) |
| 122 | + loss = ((eps_pred - eps_estimation) ** 2).mean() |
| 123 | + return loss, { |
| 124 | + "loss/actor_loss": loss, |
| 125 | + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), |
| 126 | + } |
| 127 | + |
| 128 | + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) |
| 129 | + return rng, new_actor, actor_metrics |
| 130 | + |
| 131 | + |
| 132 | +class CtrlQSMAgent(QSMAgent): |
| 133 | + """ |
| 134 | + CTRL with Q Score Matching (QSM) agent. |
| 135 | + """ |
| 136 | + |
| 137 | + name = "CtrlQSMAgent" |
| 138 | + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] |
| 139 | + |
| 140 | + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlQSMConfig, seed: int): |
| 141 | + super().__init__(obs_dim, act_dim, cfg, seed) |
| 142 | + self.cfg = cfg |
| 143 | + |
| 144 | + self.ctrl_coef = cfg.ctrl_coef |
| 145 | + self.critic_coef = cfg.critic_coef |
| 146 | + |
| 147 | + self.linear = cfg.linear |
| 148 | + self.ranking = cfg.ranking |
| 149 | + self.feature_dim = cfg.feature_dim |
| 150 | + self.num_noises = cfg.num_noises |
| 151 | + self.reward_coef = cfg.reward_coef |
| 152 | + self.rff_dim = cfg.rff_dim |
| 153 | + self.actor_update_freq = cfg.actor_update_freq |
| 154 | + self.target_update_freq = cfg.target_update_freq |
| 155 | + |
| 156 | + |
| 157 | + # sanity checks for the hyper-parameters |
| 158 | + assert not self.linear, "linear mode is not supported yet" |
| 159 | + |
| 160 | + # networks |
| 161 | + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) |
| 162 | + nce_def = FactorizedNCE( |
| 163 | + self.obs_dim, |
| 164 | + self.act_dim, |
| 165 | + self.feature_dim, |
| 166 | + cfg.phi_hidden_dims, |
| 167 | + cfg.mu_hidden_dims, |
| 168 | + cfg.reward_hidden_dims, |
| 169 | + cfg.rff_dim, |
| 170 | + cfg.num_noises, |
| 171 | + self.ranking, |
| 172 | + ) |
| 173 | + self.nce = Model.create( |
| 174 | + nce_def, |
| 175 | + nce_rng, |
| 176 | + inputs=( |
| 177 | + nce_init_rng, |
| 178 | + jnp.ones((1, self.obs_dim)), |
| 179 | + jnp.ones((1, self.act_dim)), |
| 180 | + jnp.ones((1, self.obs_dim)), |
| 181 | + ), |
| 182 | + optimizer=optax.adam(learning_rate=cfg.feature_lr), |
| 183 | + clip_grad_norm=cfg.clip_grad_norm, |
| 184 | + ) |
| 185 | + self.nce_target = Model.create( |
| 186 | + nce_def, |
| 187 | + nce_rng, |
| 188 | + inputs=( |
| 189 | + nce_init_rng, |
| 190 | + jnp.ones((1, self.obs_dim)), |
| 191 | + jnp.ones((1, self.act_dim)), |
| 192 | + jnp.ones((1, self.obs_dim)), |
| 193 | + ), |
| 194 | + ) |
| 195 | + |
| 196 | + critic_def = RffEnsembleCritic( |
| 197 | + feature_dim=self.feature_dim, |
| 198 | + hidden_dims=cfg.critic_hidden_dims, |
| 199 | + rff_dim=cfg.rff_dim, |
| 200 | + ensemble_size=2, |
| 201 | + ) |
| 202 | + self.critic = Model.create( |
| 203 | + critic_def, |
| 204 | + critic_rng, |
| 205 | + inputs=(jnp.ones((1, self.feature_dim)),), |
| 206 | + optimizer=optax.adam(learning_rate=cfg.critic_lr), |
| 207 | + clip_grad_norm=cfg.clip_grad_norm, |
| 208 | + ) |
| 209 | + self.critic_target = Model.create( |
| 210 | + critic_def, |
| 211 | + critic_rng, |
| 212 | + inputs=(jnp.ones((1, self.feature_dim)),), |
| 213 | + ) |
| 214 | + |
| 215 | + self._n_training_steps = 0 |
| 216 | + |
| 217 | + def train_step(self, batch: Batch, step: int) -> Metric: |
| 218 | + metrics = {} |
| 219 | + |
| 220 | + self.rng, self.nce, nce_metrics = update_factorized_nce( |
| 221 | + self.rng, |
| 222 | + self.nce, |
| 223 | + batch, |
| 224 | + self.ranking, |
| 225 | + self.reward_coef, |
| 226 | + ) |
| 227 | + metrics.update(nce_metrics) |
| 228 | + |
| 229 | + self.rng, self.critic, critic_metrics = update_critic( |
| 230 | + self.rng, |
| 231 | + self.critic, |
| 232 | + self.critic_target, |
| 233 | + self.actor, |
| 234 | + self.nce_target, |
| 235 | + batch, |
| 236 | + discount=self.cfg.discount, |
| 237 | + solver=self.cfg.diffusion.solver, |
| 238 | + critic_coef=self.critic_coef, |
| 239 | + ) |
| 240 | + metrics.update(critic_metrics) |
| 241 | + |
| 242 | + if self._n_training_steps % self.actor_update_freq == 0: |
| 243 | + self.rng, self.actor, actor_metrics = update_actor( |
| 244 | + self.rng, |
| 245 | + self.actor, |
| 246 | + self.nce_target, |
| 247 | + self.critic_target, |
| 248 | + batch, |
| 249 | + temp=self.cfg.temp, |
| 250 | + ) |
| 251 | + metrics.update(actor_metrics) |
| 252 | + |
| 253 | + if self._n_training_steps % self.target_update_freq == 0: |
| 254 | + self.sync_target() |
| 255 | + |
| 256 | + self._n_training_steps += 1 |
| 257 | + return metrics |
| 258 | + |
| 259 | + def sample_actions( |
| 260 | + self, |
| 261 | + obs: jnp.ndarray, |
| 262 | + deterministic: bool = True, |
| 263 | + num_samples: int = 1, |
| 264 | + ) -> Tuple[jnp.ndarray, Metric]: |
| 265 | + # if deterministic is true, sample cfg.num_samples actions and select the best one |
| 266 | + # if not, sample 1 action |
| 267 | + if deterministic: |
| 268 | + num_samples = self.cfg.num_samples |
| 269 | + else: |
| 270 | + num_samples = 1 |
| 271 | + self.rng, action = jit_sample_actions( |
| 272 | + self.rng, |
| 273 | + self.actor, |
| 274 | + self.nce_target, |
| 275 | + self.critic, |
| 276 | + obs, |
| 277 | + training=False, |
| 278 | + num_samples=num_samples, |
| 279 | + solver=self.cfg.diffusion.solver, |
| 280 | + ) |
| 281 | + if not deterministic: |
| 282 | + action = action + 0.1 * jax.random.normal(self.rng, action.shape) |
| 283 | + return action, {} |
| 284 | + |
| 285 | + def sync_target(self): |
| 286 | + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) |
| 287 | + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) |
0 commit comments