diff --git a/examples/online/config/dmc/algo/aca.yaml b/examples/online/config/dmc/algo/aca.yaml new file mode 100644 index 0000000..b9e6818 --- /dev/null +++ b/examples/online/config/dmc/algo/aca.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +algo: + name: aca + target_update_freq: 1 + feature_dim: 512 + rff_dim: 1024 + critic_hidden_dims: [512, 512] + reward_hidden_dims: [512, 512] + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + ctrl_coef: 1.0 + reward_coef: 1.0 + critic_coef: 1.0 + critic_activation: elu # not used + back_critic_grad: false + feature_lr: 0.0001 + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + feature_ema: 0.005 + clip_grad_norm: null + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + num_noises: 20 + linear: false + ranking: true + +norm_obs: true diff --git a/examples/online/config/dmc/algo/ctrl_qsm.yaml b/examples/online/config/dmc/algo/ctrl_qsm.yaml new file mode 100644 index 0000000..af85061 --- /dev/null +++ b/examples/online/config/dmc/algo/ctrl_qsm.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +algo: + name: ctrl_qsm + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + # critic_hidden_dims: [512, 512, 512] # not used + critic_activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + critic_lr: 0.0003 + clip_grad_norm: null + + # below are params specific to ctrl_td3 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ctrl_coef: 1.0 + reward_coef: 1.0 + back_critic_grad: false + critic_coef: 1.0 + + num_noises: 25 + linear: false + ranking: true + + num_samples: 10 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/ctrl_td3.yaml b/examples/online/config/dmc/algo/ctrlsr_td3.yaml similarity index 97% rename from examples/online/config/dmc/algo/ctrl_td3.yaml rename to examples/online/config/dmc/algo/ctrlsr_td3.yaml index 10a754d..33e840e 100644 --- a/examples/online/config/dmc/algo/ctrl_td3.yaml +++ b/examples/online/config/dmc/algo/ctrlsr_td3.yaml @@ -1,7 +1,7 @@ # @package _global_ algo: - name: ctrl_td3 + name: ctrlsr_td3 actor_update_freq: 1 target_update_freq: 1 discount: 0.99 diff --git a/examples/online/config/dmc/algo/diffsr_aca.yaml b/examples/online/config/dmc/algo/diffsr_aca.yaml new file mode 100644 index 0000000..9952f73 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_aca.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +algo: + name: diffsr_aca + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 20 + num_samples: 10 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_aca_sep.yaml b/examples/online/config/dmc/algo/diffsr_aca_sep.yaml new file mode 100644 index 0000000..1318f25 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_aca_sep.yaml @@ -0,0 +1,51 @@ +# @package _global_ + +algo: + name: diffsr_aca_sep + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 20 + num_samples: 10 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_qsm.yaml b/examples/online/config/dmc/algo/diffsr_qsm.yaml new file mode 100644 index 0000000..e926893 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_qsm.yaml @@ -0,0 +1,47 @@ +# @package _global_ + +algo: + name: diffsr_qsm + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + # critic_hidden_dims: [512, 512, 512] # not used + critic_activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + critic_lr: 0.0003 + clip_grad_norm: null + + # below are params specific to ctrl_td3 + num_noises: 50 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 256 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + + num_samples: 10 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/diffsr_td3.yaml b/examples/online/config/dmc/algo/diffsr_td3.yaml new file mode 100644 index 0000000..eb7e4a7 --- /dev/null +++ b/examples/online/config/dmc/algo/diffsr_td3.yaml @@ -0,0 +1,37 @@ +# @package _global_ + +algo: + name: diffsr_td3 + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + actor_hidden_dims: [512, 512, 512] + # critic_hidden_dims: [512, 512, 512] # not used + activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + actor_lr: 0.0003 + critic_lr: 0.0003 + clip_grad_norm: null + target_policy_noise: 0.2 + noise_clip: 0.3 + exploration_noise: 0.2 + + # below are params specific to diffsr_td3 + num_noises: 50 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + embed_dim: 128 + phi_hidden_dims: [512, 512, 512] + mu_hidden_dims: [512, 512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ddpm_coef: 1.0 + reward_coef: 0.1 + back_critic_grad: false + critic_coef: 1.0 + +norm_obs: true diff --git a/examples/online/config/dmc/algo/qsm.yaml b/examples/online/config/dmc/algo/qsm.yaml new file mode 100644 index 0000000..12b9624 --- /dev/null +++ b/examples/online/config/dmc/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/dmc/algo/sdac.yaml b/examples/online/config/dmc/algo/sdac.yaml new file mode 100644 index 0000000..0b68526 --- /dev/null +++ b/examples/online/config/dmc/algo/sdac.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +algo: + name: sdac + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.05 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/alac.yaml b/examples/online/config/mujoco/algo/alac.yaml new file mode 100644 index 0000000..94fe9e3 --- /dev/null +++ b/examples/online/config/mujoco/algo/alac.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +algo: + name: alac + discount: 0.99 + num_samples: 10 + ema: 0.005 + ld: + resnet: false + activation: relu + ensemble_size: 2 + time_dim: 64 + hidden_dims: [512, 512] + cond_hidden_dims: [128, 128] + steps: 20 + step_size: 0.05 + noise_scale: 1.0 + noise_schedule: "none" + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + epsilon: 0.001 + lr: 0.0003 + clip_grad_norm: null diff --git a/examples/online/config/mujoco/algo/idem.yaml b/examples/online/config/mujoco/algo/idem.yaml new file mode 100644 index 0000000..f4a9b36 --- /dev/null +++ b/examples/online/config/mujoco/algo/idem.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: idem + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/qsm.yaml b/examples/online/config/mujoco/algo/qsm.yaml new file mode 100644 index 0000000..c646df7 --- /dev/null +++ b/examples/online/config/mujoco/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512] + critic_activation: relu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/sdac.yaml b/examples/online/config/mujoco/algo/sdac.yaml index a4e8ff3..79833ae 100644 --- a/examples/online/config/mujoco/algo/sdac.yaml +++ b/examples/online/config/mujoco/algo/sdac.yaml @@ -3,12 +3,13 @@ algo: name: sdac critic_hidden_dims: [256, 256] + critic_activation: relu critic_lr: 0.0003 discount: 0.99 num_samples: 10 num_reverse_samples: 500 ema: 0.005 - temp: 0.2 + temp: 0.05 diffusion: time_dim: 64 mlp_hidden_dims: [256, 256] diff --git a/examples/online/config/mujoco/config.yaml b/examples/online/config/mujoco/config.yaml index f662225..7658820 100644 --- a/examples/online/config/mujoco/config.yaml +++ b/examples/online/config/mujoco/config.yaml @@ -27,7 +27,6 @@ random_frames: 5_000 eval_frames: 10_000 log_frames: 1_000 lap_reset_frames: 250 -eval_episodes: 10 log: dir: logs tag: debug diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index 0f8e59f..7e651cc 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -26,7 +26,10 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, - "ctrl_td3": CtrlTD3Agent, + "qsm": QSMAgent, + "ctrlsr_td3": CtrlSRTD3Agent, + "diffsr_td3": DiffSRTD3Agent, + "diffsr_qsm": DiffSRQSMAgent, } class OffPolicyTrainer(): diff --git a/examples/online/main_mujoco_offpolicy.py b/examples/online/main_mujoco_offpolicy.py index f408493..4432994 100644 --- a/examples/online/main_mujoco_offpolicy.py +++ b/examples/online/main_mujoco_offpolicy.py @@ -6,10 +6,10 @@ import jax import numpy as np import omegaconf -import wandb from omegaconf import OmegaConf from tqdm import tqdm +import wandb from flowrl.agent.online import * from flowrl.config.online.mujoco import Config from flowrl.dataset.buffer.state import ReplayBuffer @@ -25,6 +25,9 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, + "qsm": QSMAgent, + "idem": IDEMAgent, + "alac": ALACAgent, } class OffPolicyTrainer(): diff --git a/examples/toy2d/config/algo/aca.yaml b/examples/toy2d/config/algo/aca.yaml new file mode 100644 index 0000000..4484f3c --- /dev/null +++ b/examples/toy2d/config/algo/aca.yaml @@ -0,0 +1,40 @@ +# @package _global_ + +algo: + name: aca + target_update_freq: 1 + feature_dim: 512 + rff_dim: 1024 + critic_hidden_dims: [512, 512] + reward_hidden_dims: [512, 512] + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + ctrl_coef: 1.0 + reward_coef: 1.0 + critic_coef: 1.0 + critic_activation: elu # not used + back_critic_grad: false + feature_lr: 0.0001 + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + feature_ema: 0.005 + clip_grad_norm: null + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -5.0 + x_max: 5.0 + # solver: ddpm + solver: ddpm + num_noises: 25 + linear: false + ranking: true diff --git a/examples/toy2d/config/algo/qsm.yaml b/examples/toy2d/config/algo/qsm.yaml new file mode 100644 index 0000000..816a84b --- /dev/null +++ b/examples/toy2d/config/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -5.0 + x_max: 5.0 + solver: ddpm diff --git a/examples/toy2d/config/algo/sdac.yaml b/examples/toy2d/config/algo/sdac.yaml new file mode 100644 index 0000000..3505de9 --- /dev/null +++ b/examples/toy2d/config/algo/sdac.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: sdac + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: false + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/toy2d/main_toy2d.py b/examples/toy2d/main_toy2d.py index f2f3881..3863d8c 100644 --- a/examples/toy2d/main_toy2d.py +++ b/examples/toy2d/main_toy2d.py @@ -3,10 +3,10 @@ import hydra import omegaconf -import wandb from omegaconf import OmegaConf from tqdm import trange +import wandb from examples.toy2d.utils import compute_metrics, plot_data, plot_energy, plot_sample from flowrl.agent.offline import * from flowrl.agent.online import * @@ -19,6 +19,9 @@ SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = { "bdpo": BDPOAgent, "dac": DACAgent, + "qsm": QSMAgent, + "sdac": SDACAgent, + "aca": ACAAgent, } class Trainer(): @@ -30,14 +33,15 @@ def __init__(self, cfg: Config): log_dir="/".join([cfg.log.dir, cfg.algo.name, cfg.log.tag, cfg.task]), name="seed"+str(cfg.seed), logger_config={ - "TensorboardLogger": {"activate": True}, - "WandbLogger": { - "activate": True, - "config": OmegaConf.to_container(cfg), - "settings": wandb.Settings(_disable_stats=True), - "project": cfg.log.project, - "entity": cfg.log.entity - } if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False}, + "CsvLogger": {"activate": True}, + # "TensorboardLogger": {"activate": True}, + # "WandbLogger": { + # "activate": True, + # "config": OmegaConf.to_container(cfg), + # "settings": wandb.Settings(_disable_stats=True), + # "project": cfg.log.project, + # "entity": cfg.log.entity + # } if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False}, } ) self.ckpt_save_dir = os.path.join(self.logger.log_dir, "ckpt") diff --git a/examples/toy2d/utils.py b/examples/toy2d/utils.py index 6b38344..8445f0c 100644 --- a/examples/toy2d/utils.py +++ b/examples/toy2d/utils.py @@ -9,6 +9,7 @@ from flowrl.agent.base import BaseAgent from flowrl.agent.offline import * +from flowrl.agent.online import * from flowrl.dataset.toy2d import Toy2dDataset, inf_train_gen SAMPLE_GRAPH_SIZE = 2000 @@ -108,6 +109,9 @@ def plot_energy(out_dir, task: str, agent: BaseAgent): vmin = e.min() vmax = e.max() + def default_plot(): + pass + def bdpo_plot(): tt = [0, 1, 3, 5, 10, 20, 30, 40, 50] plt.figure(figsize=(30, 3.0)) @@ -168,12 +172,54 @@ def dac_plot(): plt.close() tqdm.write(f"Saved value plot to {saveto}") + def aca_plot(): + tt = [0, 1, 3, 5, 10, 20] + plt.figure(figsize=(20, 3.0)) + axes = [] + for i, t in enumerate(tt): + plt.subplot(1, len(tt), i+1) + if t == 0: + model = agent.critic_target + c = model(zero, id_matrix).mean(axis=0).reshape(90, 90) + else: + model = agent.value_target + t_input = np.ones((90*90, 1)) * t + c = model(zero, id_matrix, t_input).mean(axis=0).reshape(90, 90) + plt.gca().set_aspect("equal", adjustable="box") + plt.xlim(0, 89) + plt.ylim(0, 89) + if i == 0: + mappable = plt.imshow( + c, origin="lower", vmin=vmin, vmax=vmax, + cmap="viridis", rasterized=True + ) + plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4]) + else: + plt.imshow( + c, origin="lower", vmin=vmin, vmax=vmax, + cmap="viridis", rasterized=True + ) + plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[None, None, None, None, None]) + + axes.append(plt.gca()) + plt.xticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4]) + plt.title(f't={t}') + plt.tight_layout() + cbar = plt.gcf().colorbar(mappable, ax=axes, fraction=0.1, pad=0.02, aspect=12) + plt.gcf().axes[-1].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f')) + saveto = os.path.join(out_dir, "qt_space.png") + plt.savefig(saveto, dpi=300) + plt.close() + tqdm.write(f"Saved value plot to {saveto}") + if isinstance(agent, BDPOAgent): bdpo_plot() elif isinstance(agent, DACAgent): dac_plot() + elif isinstance(agent, ACAAgent): + aca_plot() else: - raise NotImplementedError(f"Plotting for {type(agent)} is not implemented") + default_plot() diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 9041320..6e0aa19 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,11 +1,16 @@ from ..base import BaseAgent -from .ctrl.ctrl import CtrlTD3Agent +from .alac.alac import ALACAgent +from .ctrlsr import * +from .diffsr import * from .dpmd import DPMDAgent +from .idem import IDEMAgent from .ppo import PPOAgent +from .qsm import QSMAgent from .sac import SACAgent from .sdac import SDACAgent from .td3 import TD3Agent from .td7.td7 import TD7Agent +from .unirep import * __all__ = [ "BaseAgent", @@ -15,5 +20,12 @@ "SDACAgent", "DPMDAgent", "PPOAgent", - "CtrlTD3Agent", + "CtrlSRTD3Agent", + "DiffSRTD3Agent", + "DiffSRQSMAgent", + "QSMAgent", + "IDEMAgent", + "ALACAgent", + "ACAAgent", + "DiffSRACAAgent", ] diff --git a/flowrl/agent/online/alac/alac.py b/flowrl/agent/online/alac/alac.py new file mode 100644 index 0000000..45bbe3c --- /dev/null +++ b/flowrl/agent/online/alac/alac.py @@ -0,0 +1,339 @@ +from functools import partial +from typing import Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.alac.network import EnsembleEnergyNet +from flowrl.config.online.mujoco.algo.alac import ALACConfig +from flowrl.flow.langevin_dynamics import AnnealedLangevinDynamics +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples")) +def jit_sample_actions( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + ld: AnnealedLangevinDynamics, + obs, + training: bool, + num_samples: int +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, x_init_rng = jax.random.split(rng) + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + x_init = jax.random.normal(x_init_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, x_init, obs_repeat, training=training) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = ld(actions, t=jnp.zeros((B, num_samples, 1), dtype=jnp.float32), condition=obs_repeat) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "ema")) +def jit_update_ld( + rng: PRNGKey, + ld: AnnealedLangevinDynamics, + ld_target: AnnealedLangevinDynamics, + actor: AnnealedLangevinDynamics, + batch: Batch, + discount: float, + ema: float, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, AnnealedLangevinDynamics, Metric]: + B, A = batch.action.shape[0], batch.action.shape[1] + feed_t = jnp.zeros((B, 1), dtype=jnp.float32) + + rng, next_xT_rng = jax.random.split(rng) + # next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], ld.x_dim)) + next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + # rng, next_action, history = ld_target.sample( + # rng, + # next_action_init, + # batch.next_obs, + # training=False, + # ) + rng, next_action, history = actor.sample( + rng, + next_action_init, + batch.next_obs, + training=False, + ) + q_target = ld_target(next_action, feed_t, batch.next_obs, training=False) + # q_target = ld_target(batch.next_obs, next_action, training=False) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def ld_loss_fn(params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = ld.apply( + {"params": params}, + batch.action, + t=feed_t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + # q_pred = ld.apply( + # {"params": params}, + # batch.obs, + # batch.action, + # training=True, + # rngs={"dropout": dropout_rng}, + # ) + ld_loss = ((q_pred - q_target[jnp.newaxis, :])**2).mean() + return ld_loss, { + "loss/ld_loss": ld_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + # "misc/q_grad_l1": jnp.abs(history[1]).mean(), + } + + new_ld, ld_metrics = ld.apply_gradient(ld_loss_fn) + new_ld_target = ema_update(new_ld, ld_target, ema) + + # record energy + # num_checkpoints = 5 + # stepsize_checkpoint = ld.steps // num_checkpoints + # energy_history = history[2][jnp.arange(0, ld.steps, stepsize_checkpoint)] + # energy_history = energy_history.mean(axis=[-2, -1]) + # ld_metrics.update({ + # f"info/energy_step{i}": energy for i, energy in enumerate(energy_history) + # }) + + return rng, new_ld, new_ld_target, ld_metrics + + +@partial(jax.jit, static_argnames=()) +def jit_update_actor( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + critic_target: AnnealedLangevinDynamics, + batch: Batch, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, Metric]: + x0 = batch.action + rng, xt, t, eps = actor.add_noise(rng, x0) + # rng, t_rng, noise_rng = jax.random.split(rng, 3) + # t = jax.random.uniform(t_rng, (*x0.shape[:-1], 1), dtype=jnp.float32, minval=actor.t_diffusion[0], maxval=actor.t_diffusion[1]) + # eps = jax.random.normal(noise_rng, x0.shape, dtype=jnp.float32) + alpha, sigma = actor.noise_schedule_func(t) + xt = alpha * x0 + sigma * eps + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(a, None, condition=s).min(axis=0).mean())) + # q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(xt, batch.obs) + q_grad = alpha * q_grad - sigma * xt + eps_estimation = sigma * q_grad / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + xt, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class ALACAgent(BaseAgent): + """ + Annealed Langevin Dynamics Actor-Critic (ALAC) agent. + """ + name = "ALACAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ALACConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, ld_rng = jax.random.split(self.rng, 2) + + # define the critic + # from flowrl.module.critic import EnsembleCritic + # from flowrl.module.model import Model + # critic_def = EnsembleCritic( + # hidden_dims=cfg.ld.hidden_dims, + # activation=jax.nn.relu, + # layer_norm=False, + # dropout=None, + # ensemble_size=2, + # ) + # self.ld = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + # self.ld_target = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # ) + + mlp_impl = ResidualMLP if cfg.ld.resnet else MLP + activation = {"mish": mish, "relu": jax.nn.relu}[cfg.ld.activation] + energy_def = EnsembleEnergyNet( + mlp_impl=mlp_impl, + hidden_dims=cfg.ld.hidden_dims, + output_dim=1, + activation=activation, + layer_norm=False, + dropout=None, + ensemble_size=cfg.ld.ensemble_size, + # time_embedding=partial(LearnableFourierEmbedding, output_dim=cfg.ld.time_dim), + time_embedding=None, + # cond_embedding=partial(MLP, hidden_dims=cfg.ld.cond_hidden_dims, activation=activation), + cond_embedding=None, + ) + self.ld = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + self.ld_target = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + ) + + # DEBUG define the actor + from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone + self.rng, actor_rng = jax.random.split(self.rng, 2) + time_embedding = partial[LearnableFourierEmbedding](LearnableFourierEmbedding, output_dim=cfg.ld.time_dim) + cond_embedding = partial(MLP, hidden_dims=[128, 128], activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.ld.hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + # self.actor = ContinuousDDPM.create( + # network=backbone_def, + # rng=actor_rng, + # inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + # x_dim=self.act_dim, + # steps=cfg.ld.steps, + # noise_schedule="cosine", + # noise_schedule_params={}, + # clip_sampler=cfg.ld.clip_sampler, + # x_min=cfg.ld.x_min, + # x_max=cfg.ld.x_max, + # t_schedule_n=1.0, + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + self.actor = AnnealedLangevinDynamics.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=True, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.ld, self.ld_target, ld_metrics = jit_update_ld( + self.rng, + self.ld, + self.ld_target, + self.actor, + batch, + self.cfg.discount, + self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.ld_target, + batch, + ) + + self._n_training_steps += 1 + return {**ld_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + # self.ld, + self.actor, + self.ld, + obs, + training=False, + num_samples=num_samples, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} diff --git a/flowrl/agent/online/alac/network.py b/flowrl/agent/online/alac/network.py new file mode 100644 index 0000000..6f9e05b --- /dev/null +++ b/flowrl/agent/online/alac/network.py @@ -0,0 +1,87 @@ +import flax.linen as nn +import jax.numpy as jnp + +from flowrl.functional.activation import mish +from flowrl.module.mlp import MLP +from flowrl.types import * + + +class EnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + if condition is not None: + if self.cond_embedding is not None: + condition = self.cond_embedding()(condition, training=training) + else: + condition = condition + x = jnp.concatenate([x, condition], axis=-1) + if self.time_embedding is not None: + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([x, t_ff], axis=-1) + x = self.mlp_impl( + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + )(x, training) + return x + + +class EnsembleEnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + ensemble_size: int = 2 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + vmap_energy_net = nn.vmap( + EnergyNet, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + x = vmap_energy_net( + mlp_impl=self.mlp_impl, + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + cond_embedding=self.cond_embedding, + time_embedding=self.time_embedding, + )(x, t, condition, training) + return x diff --git a/flowrl/agent/online/ctrlsr/__init__.py b/flowrl/agent/online/ctrlsr/__init__.py new file mode 100644 index 0000000..81927fb --- /dev/null +++ b/flowrl/agent/online/ctrlsr/__init__.py @@ -0,0 +1,6 @@ +# from .ctrl_qsm import CtrlQSMAgent +from .ctrlsr_td3 import CtrlSRTD3Agent + +__all__ = [ + "CtrlSRTD3Agent", +] diff --git a/flowrl/agent/online/ctrlsr/ctrl_qsm.py b/flowrl/agent/online/ctrlsr/ctrl_qsm.py new file mode 100644 index 0000000..2960ae1 --- /dev/null +++ b/flowrl/agent/online/ctrlsr/ctrl_qsm.py @@ -0,0 +1,287 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce +from flowrl.agent.online.qsm import QSMAgent +from flowrl.config.online.mujoco.algo.ctrl.ctrl_qsm import CtrlQSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM +from flowrl.functional.ema import ema_update +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + feature = nce_target(obs_repeat, actions, method="forward_phi") + qs = critic(feature) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample( + rng, + next_xT, + batch.next_obs, + training=False, + solver=solver, + ) + next_feature = nce_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + feature = nce_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + +@partial(jax.jit, static_argnames=("temp")) +def update_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: + feature = nce_target(obs, action, method="forward_phi") + q = critic_target(feature) + return q.min(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class CtrlQSMAgent(QSMAgent): + """ + CTRL with Q Score Matching (QSM) agent. + """ + + name = "CtrlQSMAgent" + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlQSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ctrl_coef = cfg.ctrl_coef + self.critic_coef = cfg.critic_coef + + self.linear = cfg.linear + self.ranking = cfg.ranking + self.feature_dim = cfg.feature_dim + self.num_noises = cfg.num_noises + self.reward_coef = cfg.reward_coef + self.rff_dim = cfg.rff_dim + self.actor_update_freq = cfg.actor_update_freq + self.target_update_freq = cfg.target_update_freq + + + # sanity checks for the hyper-parameters + assert not self.linear, "linear mode is not supported yet" + + # networks + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.nce_target, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.nce_target, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/ctrl/ctrl.py b/flowrl/agent/online/ctrlsr/ctrlsr_td3.py similarity index 92% rename from flowrl/agent/online/ctrl/ctrl.py rename to flowrl/agent/online/ctrlsr/ctrlsr_td3.py index 186c4d9..3a32143 100644 --- a/flowrl/agent/online/ctrl/ctrl.py +++ b/flowrl/agent/online/ctrlsr/ctrlsr_td3.py @@ -5,9 +5,10 @@ import jax.numpy as jnp import optax -from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce +import flowrl.module.initialization as init +from flowrl.agent.online.ctrlsr.network import FactorizedNCE, update_factorized_nce from flowrl.agent.online.td3 import TD3Agent -from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config +from flowrl.config.online.mujoco.algo.ctrlsr.ctrlsr_td3 import CtrlSRTD3Config from flowrl.functional.ema import ema_update from flowrl.module.actor import SquashedDeterministicActor from flowrl.module.mlp import MLP @@ -90,15 +91,15 @@ def actor_loss_fn( return rng, new_actor, metrics -class CtrlTD3Agent(TD3Agent): +class CtrlSRTD3Agent(TD3Agent): """ - CTRL with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. + CTRL-SR with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. """ - name = "CtrlTD3Agent" + name = "CtrlSRTD3Agent" model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] - def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlSRTD3Config, seed: int): super().__init__(obs_dim, act_dim, cfg, seed) self.cfg = cfg @@ -156,6 +157,8 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): hidden_dims=cfg.actor_hidden_dims, layer_norm=cfg.layer_norm, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ), obs_dim=self.obs_dim, action_dim=self.act_dim, @@ -165,6 +168,8 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int): hidden_dims=cfg.critic_hidden_dims, rff_dim=cfg.rff_dim, ensemble_size=2, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.actor = Model.create( actor_def, diff --git a/flowrl/agent/online/ctrl/network.py b/flowrl/agent/online/ctrlsr/network.py similarity index 64% rename from flowrl/agent/online/ctrl/network.py rename to flowrl/agent/online/ctrlsr/network.py index fad30e9..c023a15 100644 --- a/flowrl/agent/online/ctrl/network.py +++ b/flowrl/agent/online/ctrlsr/network.py @@ -5,9 +5,10 @@ import jax.numpy as jnp import optax +import flowrl.module.initialization as init from flowrl.flow.ddpm import get_noise_schedule from flowrl.functional.activation import l2_normalize, mish -from flowrl.module.mlp import ResidualMLP +from flowrl.module.mlp import MLP, ResidualMLP from flowrl.module.model import Model from flowrl.module.rff import RffReward from flowrl.module.time_embedding import PositionalEmbedding @@ -27,9 +28,13 @@ class FactorizedNCE(nn.Module): ranking: bool = False def setup(self): - self.mlp_t = nn.Sequential( - [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] - ) + MLP_torch_init = partial(MLP, kernel_init=init.pytorch_kernel_init, bias_init=init.pytorch_bias_init) + self.mlp_t = nn.Sequential([ + PositionalEmbedding(128), + MLP_torch_init(output_dim=256), + mish, + MLP_torch_init(output_dim=128) + ]) self.mlp_phi = ResidualMLP( self.phi_hidden_dims, self.feature_dim, @@ -37,6 +42,8 @@ def setup(self): activation=mish, layer_norm=True, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.mlp_mu = ResidualMLP( self.mu_hidden_dims, @@ -45,32 +52,27 @@ def setup(self): activation=mish, layer_norm=True, dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) self.reward = RffReward( self.feature_dim, self.reward_hidden_dims, rff_dim=self.rff_dim, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, ) if self.num_noises > 0: self.use_noise_perturbation = True betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) - alphabars_prev = jnp.pad(alphabars[:-1], (1, 0), constant_values=1.0) - self.betas = betas[..., jnp.newaxis] - self.alphas = alphas[..., jnp.newaxis] - self.alphabars = alphabars[..., jnp.newaxis] - self.alphabars_prev = alphabars_prev[..., jnp.newaxis] + self.alphabars = alphabars else: self.use_noise_perturbation = False self.N = max(self.num_noises, 1) - if not self.ranking: - self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) - else: - self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) def forward_phi(self, s, a): x = jnp.concat([s, a], axis=-1) x = self.mlp_phi(x) - x = l2_normalize(x, group_size=None) return x def forward_mu(self, sp, t=None): @@ -78,57 +80,33 @@ def forward_mu(self, sp, t=None): t_ff = self.mlp_t(t) sp = jnp.concat([sp, t_ff], axis=-1) sp = self.mlp_mu(sp) + sp = jnp.tanh(sp) return sp def forward_reward(self, x: jnp.ndarray): # for z_phi return self.reward(x) - def forward_logits( - self, - rng: PRNGKey, - s: jnp.ndarray, - a: jnp.ndarray, - sp: jnp.ndarray, - z_phi: jnp.ndarray | None=None - ): + def __call__(self, rng, s, a, sp, training: bool=False): B, D = sp.shape rng, eps_rng = jax.random.split(rng, 2) - if z_phi is None: - z_phi = self.forward_phi(s, a) + z_phi = self.forward_phi(s, a) if self.use_noise_perturbation: sp = jnp.broadcast_to(sp, (self.N, B, D)) t = jnp.arange(self.num_noises) - t = jnp.repeat(t, B).reshape(self.N, B) + t = jnp.repeat(t, B).reshape(self.N, B, 1) alphabars = self.alphabars[t] eps = jax.random.normal(eps_rng, sp.shape) xt = jnp.sqrt(alphabars) * sp + jnp.sqrt(1-alphabars) * eps - t = jnp.expand_dims(t, -1) else: xt = jnp.expand_dims(sp, 0) t = None z_mu = self.forward_mu(xt, t) - z_phi = jnp.broadcast_to(z_phi, (self.N, B, self.feature_dim)) - logits = jax.lax.batch_matmul(z_phi, jnp.swapaxes(z_mu, -1, -2)) - logits = logits / jnp.exp(self.normalizer[:, None, None]) - return logits - - def forward_normalizer(self): - return self.normalizer - - def __call__( - self, - rng: PRNGKey, - s, - a, - sp, - ): - z_phi = self.forward_phi(s, a) - _ = self.forward_reward(z_phi) - _ = self.forward_logits(rng, s, a, sp, z_phi=z_phi) - - _ = self.forward_normalizer() - - return z_phi + logits = jax.lax.batch_matmul( + jnp.broadcast_to(z_phi, (self.N, B, self.feature_dim)), + jnp.swapaxes(z_mu, -1, -2) + ) + r_pred = self.forward_reward(z_phi) + return logits, r_pred, z_phi @partial(jax.jit, static_argnames=("ranking", "reward_coef")) @@ -140,46 +118,31 @@ def update_factorized_nce( reward_coef: float, ) -> Tuple[PRNGKey, Model, Metric]: B = batch.obs.shape[0] - rng, logits_rng = jax.random.split(rng) + rng, update_rng = jax.random.split(rng) if ranking: labels = jnp.arange(B) else: labels = jnp.eye(B) def loss_fn(nce_params: Param, dropout_rng: PRNGKey): - z_phi = nce.apply( + logits, r_pred, z_phi = nce.apply( {"params": nce_params}, - batch.obs, - batch.action, - method="forward_phi", - ) - logits = nce.apply( - {"params": nce_params}, - logits_rng, + update_rng, batch.obs, batch.action, batch.next_obs, - z_phi, - method="forward_logits", + training=True, + rngs={"dropout": dropout_rng}, ) - if ranking: model_loss = optax.softmax_cross_entropy_with_integer_labels( logits, jnp.broadcast_to(labels, (logits.shape[0], B)) ).mean(axis=-1) else: - normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") - eff_logits = logits + normalizer[:, None, None] - jnp.log(B) - model_loss = optax.sigmoid_binary_cross_entropy(eff_logits, labels).mean([-2, -1]) - r_pred = nce.apply( - {"params": nce_params}, - z_phi, - method="forward_reward", - ) - normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + raise NotImplementedError("non-ranking mode is not supported") reward_loss = jnp.mean((r_pred - batch.reward) ** 2) - nce_loss = model_loss.mean() + reward_coef * reward_loss + 0.000 * (logits**2).mean() + nce_loss = model_loss.mean() + reward_coef * reward_loss pos_logits = logits[ jnp.arange(logits.shape[0])[..., jnp.newaxis], @@ -207,9 +170,6 @@ def loss_fn(nce_params: Param, dropout_rng: PRNGKey): metrics.update({ f"misc/logits_gap_{i}": (pos_logits_per_noise[i] - neg_logits_per_noise[i]).mean() for i in checkpoints }) - metrics.update({ - f"misc/normalizer_{i}": jnp.exp(normalizer[i]) for i in checkpoints - }) return nce_loss, metrics new_nce, metrics = nce.apply_gradient(loss_fn) diff --git a/flowrl/agent/online/diffsr/__init__.py b/flowrl/agent/online/diffsr/__init__.py new file mode 100644 index 0000000..c20bd2c --- /dev/null +++ b/flowrl/agent/online/diffsr/__init__.py @@ -0,0 +1,7 @@ +from .diffsr_qsm import DiffSRQSMAgent +from .diffsr_td3 import DiffSRTD3Agent + +__all__ = [ + "DiffSRTD3Agent", + "DiffSRQSMAgent", +] diff --git a/flowrl/agent/online/diffsr/diffsr_qsm.py b/flowrl/agent/online/diffsr/diffsr_qsm.py new file mode 100644 index 0000000..830e947 --- /dev/null +++ b/flowrl/agent/online/diffsr/diffsr_qsm.py @@ -0,0 +1,342 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.diffsr.network import FactorizedDDPM, update_factorized_ddpm +from flowrl.agent.online.qsm import QSMAgent +from flowrl.config.online.mujoco.algo.diffsr import DiffSRQSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + feature = ddpm_target(obs_repeat, actions, method="forward_phi") + qs = critic(feature) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + ddpm_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample( + rng, + next_xT, + batch.next_obs, + training=False, + solver=solver, + ) + next_feature = ddpm_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + feature = ddpm_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + + +@partial(jax.jit, static_argnames=("temp")) +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: ContinuousDDPM, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: + feature = ddpm_target(obs, action, method="forward_phi") + q = critic_target(feature) + return q.min(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": diffusion_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(loss_fn) + return rng, new_actor, actor_metrics + +# @jax.jit +# def jit_compute_metrics( +# rng: PRNGKey, +# critic: Model, +# ddpm_target: Model, +# diffusion_value: Model, +# diffusion_actor: Model, +# batch: Batch, +# ) -> Tuple[PRNGKey, Metric]: +# B, S = batch.obs.shape +# A = batch.action.shape[-1] +# num_actions = 50 +# metrics = {} +# rng, action_rng = jax.random.split(rng) +# obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) +# action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + +# def get_critic(at, obs): +# t1 = jnp.ones((1, ), dtype=jnp.int32) +# ft = ddpm_target(obs, at, t1, method="forward_phi") +# q = critic(ft) +# return q.mean() +# all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( +# action_repeat.reshape(-1, A), +# obs_repeat.reshape(-1, S), +# ) +# all_critic = all_critic.reshape(B, num_actions, 1) +# all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) +# metrics.update({ +# f"q_std/critic": all_critic.std(axis=1).mean(), +# f"q_grad/critic": jnp.abs(all_critic_grad).mean(), +# }) + +# def get_value(at, obs, t): +# ft = ddpm_target(obs, at, t, method="forward_phi") +# q = diffusion_value(ft) +# return q.mean() +# for t in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): +# t_input = jnp.ones((B, num_actions, 1)) * t +# all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( +# action_repeat.reshape(-1, A), +# obs_repeat.reshape(-1, S), +# t_input.reshape(-1, 1), +# ) +# all_value = all_value.reshape(B, num_actions, 1) +# all_value_grad = all_value_grad.reshape(B, num_actions, -1) +# metrics.update({ +# f"q_std/value_{t}": all_value.std(axis=1).mean(), +# f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), +# }) +# return rng, metrics + +class DiffSRQSMAgent(QSMAgent): + """ + Diff-SR with QSM agent. + """ + + name = "DiffSRQSMAgent" + model_names = ["ddpm", "ddpm_target", "actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRQSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + self.rff_dim = cfg.rff_dim + self.actor_update_freq = cfg.actor_update_freq + self.target_update_freq = cfg.target_update_freq + self.temp = cfg.temp + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.ddpm_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic_target, + batch, + temp=self.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + # if self._n_training_steps % 2000 == 0: + # self.rng, metrics = jit_compute_metrics( + # self.rng, + # self.critic, + # self.ddpm_target, + # self.diffusion_value, + # self.diffusion_actor, + # batch, + # ) + # metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/diffsr/diffsr_td3.py b/flowrl/agent/online/diffsr/diffsr_td3.py new file mode 100644 index 0000000..335eaa3 --- /dev/null +++ b/flowrl/agent/online/diffsr/diffsr_td3.py @@ -0,0 +1,237 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +import flowrl.module.initialization as init +from flowrl.agent.online.diffsr.network import FactorizedDDPM, update_factorized_ddpm +from flowrl.agent.online.td3 import TD3Agent +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, metrics + + +class DiffSRTD3Agent(TD3Agent): + """ + Diff-SR with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent. + """ + + name = "DiffSRTD3Agent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/diffsr/network.py b/flowrl/agent/online/diffsr/network.py new file mode 100644 index 0000000..c49c859 --- /dev/null +++ b/flowrl/agent/online/diffsr/network.py @@ -0,0 +1,139 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import mish +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class FactorizedDDPM(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + embed_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int + num_noises: int + + def setup(self): + MLP_torch_init = partial(MLP, kernel_init=init.pytorch_kernel_init, bias_init=init.pytorch_bias_init) + self.mlp_t = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + MLP_torch_init(output_dim=2*self.embed_dim), + mish, + MLP_torch_init(output_dim=self.embed_dim) + ]) + # self.mlp_s = nn.Sequential([ + # MLP_torch_init(output_dim=self.embed_dim*2), + # mish, + # MLP_torch_init(output_dim=self.embed_dim) + # ]) + # self.mlp_a = nn.Sequential([ + # MLP_torch_init(output_dim=self.embed_dim*2), + # mish, + # MLP_torch_init(output_dim=self.embed_dim) + # ]) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim*self.obs_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + self.reward = RffReward( + self.feature_dim, + self.reward_hidden_dims, + rff_dim=self.rff_dim, + kernel_init=init.pytorch_kernel_init, + bias_init=init.pytorch_bias_init, + ) + betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) + self.alphabars = alphabars + + def forward_phi(self, s, a): + # s = self.mlp_s(s) + # a = self.mlp_a(a) + x = jnp.concat([s, a], axis=-1) + x = self.mlp_phi(x) + return x + + def forward_mu(self, sp, t): + t = self.mlp_t(t) + x = jnp.concat([sp, t], axis=-1) + x = self.mlp_mu(x) + return x.reshape(-1, self.feature_dim, self.obs_dim) + + def forward_reward(self, x: jnp.ndarray): + return self.reward(x) + + def __call__(self, rng, s, a, sp, training: bool=False): + rng, t_rng, eps_rng = jax.random.split(rng, 3) + t = jax.random.randint(t_rng, (s.shape[0], 1), 0, self.num_noises+1) + eps = jax.random.normal(eps_rng, sp.shape) + spt = jnp.sqrt(self.alphabars[t]) * sp + jnp.sqrt(1-self.alphabars[t]) * eps + z_phi = self.forward_phi(s, a) + z_mu = self.forward_mu(spt, t) + eps_pred = jax.lax.batch_matmul(z_phi[..., jnp.newaxis, :], z_mu)[..., 0, :] + r_pred = self.forward_reward(z_phi) + return eps, eps_pred, r_pred, z_phi + + +@partial(jax.jit, static_argnames=("reward_coef")) +def update_factorized_ddpm( + rng: PRNGKey, + ddpm: FactorizedDDPM, + batch: Batch, + reward_coef: float, +) -> Tuple[PRNGKey, FactorizedDDPM, Metric]: + B = batch.obs.shape[0] + rng, update_rng = jax.random.split(rng) + def loss_fn(ddpm_params: Param, dropout_rng: PRNGKey): + eps, eps_pred, r_pred, z_phi = ddpm.apply( + {"params": ddpm_params}, + update_rng, + batch.obs, + batch.action, + batch.next_obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + ddpm_loss = ((eps_pred - eps) ** 2).sum(axis=-1).mean() + reward_loss = ((r_pred - batch.reward) ** 2).mean() + loss = ddpm_loss + reward_coef * reward_loss + return loss, { + "loss/ddpm_loss": ddpm_loss, + "loss/reward_loss": reward_loss, + "misc/sp0_mean": batch.next_obs.mean(), + "misc/sp0_std": batch.next_obs.std(axis=0).mean(), + "misc/sp0_l1": jnp.abs(batch.next_obs).mean(), + "misc/eps_mean": eps_pred.mean(), + "misc/eps_l1": jnp.abs(eps_pred).mean(), + "misc/reward_mean": r_pred.mean(), + "misc/z_phi_l1": jnp.abs(z_phi).mean(), + } + + new_ddpm, metrics = ddpm.apply_gradient(loss_fn) + return rng, new_ddpm, metrics diff --git a/flowrl/agent/online/idem.py b/flowrl/agent/online/idem.py new file mode 100644 index 0000000..4fdfe41 --- /dev/null +++ b/flowrl/agent/online/idem.py @@ -0,0 +1,105 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.qsm import QSMAgent, jit_update_qsm_critic +from flowrl.config.online.mujoco.algo.idem import IDEMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + +jit_update_idem_critic = jit_update_qsm_critic + +@partial(jax.jit, static_argnames=("num_reverse_samples", "temp",)) +def jit_update_idem_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + num_reverse_samples: int, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + obs_repeat = batch.obs[jnp.newaxis, ...].repeat(num_reverse_samples, axis=0) + + rng, tnormal_rng, clipped_rng = jax.random.split(rng, 3) + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + lower_bound = - 1.0 / alpha2 * at - alpha1 / alpha2 + upper_bound = - 1.0 / alpha2 * at + alpha1 / alpha2 + tnormal_noise = jax.random.truncated_normal(tnormal_rng, lower_bound, upper_bound, (num_reverse_samples, *at.shape)) + normal_noise = jax.random.normal(clipped_rng, (num_reverse_samples, *at.shape)) + normal_noise_clipped = jnp.clip(normal_noise, lower_bound, upper_bound) + eps_reverse = jnp.where(jnp.isnan(tnormal_noise), normal_noise_clipped, tnormal_noise) + a0_hat = 1 / alpha1 * at + alpha2 / alpha1 * eps_reverse + + q_value_and_grad_fn = jax.vmap( + jax.vmap( + jax.value_and_grad(lambda a, s: critic_target(s, a).min(axis=0).mean()), + ) + ) + q_value, q_grad = q_value_and_grad_fn(a0_hat, obs_repeat) + q_grad = q_grad / temp + weight = jax.nn.softmax(q_value / temp, axis=0) + eps_estimation = - (alpha2 / alpha1) * jnp.sum(weight[:, :, jnp.newaxis] * q_grad, axis=0) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/weights": weight.mean(), + "misc/weight_std": weight.std(0).mean(), + "misc/weight_max": weight.max(0).mean(), + "misc/weight_min": weight.min(0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class IDEMAgent(QSMAgent): + """ + Iterative Denoising Energy Matching (iDEM) Agent. + """ + name = "IDEMAgent" + model_names = ["actor", "critic", "critic_target"] + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_idem_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_idem_actor( + self.rng, + self.actor, + self.critic_target, + batch, + num_reverse_samples=self.cfg.num_reverse_samples, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py new file mode 100644 index 0000000..0a4e0f6 --- /dev/null +++ b/flowrl/agent/online/qsm.py @@ -0,0 +1,246 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.config.online.mujoco.algo.qsm import QSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = critic(obs_repeat, actions) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver", "ema")) +def jit_update_qsm_critic( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + critic_target: Model, + batch: Batch, + discount: float, + solver: str, + ema: float, +) -> Tuple[PRNGKey, Model, Model, Metric]: + rng, next_xT_rng = jax.random.split(rng) + next_xT = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample(rng, next_xT, batch.next_obs, training=False, solver=solver) + q_target = critic_target(batch.next_obs, next_action) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q = critic.apply( + {"params": critic_params}, + batch.obs, + batch.action, + training=True, + rngs={"dropout": dropout_rng}, + ) + critic_loss = ((q - q_target[jnp.newaxis, :])**2).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q.mean(), + "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_action).mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + + new_critic_target = ema_update(new_critic, critic_target, ema) + return rng, new_critic, new_critic_target, critic_metrics + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_qsm_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class QSMAgent(BaseAgent): + """ + Q Score Matching (QSM) agent and beyond. + """ + name = "QSMAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, actor_rng, critic_rng = jax.random.split(self.rng, 3) + + # define the actor + time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + self.actor = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] + critic_def = EnsembleCritic( + hidden_dims=cfg.critic_hidden_dims, + activation=critic_activation, + layer_norm=False, + dropout=None, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_qsm_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_qsm_actor( + self.rng, + self.actor, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} diff --git a/flowrl/agent/online/sdac.py b/flowrl/agent/online/sdac.py index 4e89413..2fea9b3 100644 --- a/flowrl/agent/online/sdac.py +++ b/flowrl/agent/online/sdac.py @@ -116,6 +116,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra "misc/weight_std": weights.std(0).mean(), "misc/weights_max": weights.max(0).mean(), "misc/weights_min": weights.min(0).mean(), + "misc/eps_estimation_l1": jnp.abs((weights * eps_reverse).sum(axis=0)).mean(), } new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) @@ -181,25 +182,15 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: SDACConfig, seed: int): t_schedule_n=1.0, optimizer=optax.adam(learning_rate=actor_lr), ) - # CHECK: is this really necessary, since we are not using the target actor for policy evaluation? - self.actor_target = ContinuousDDPM.create( - network=backbone_def, - rng=actor_rng, - inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), - x_dim=self.act_dim, - steps=cfg.diffusion.steps, - noise_schedule="cosine", - noise_schedule_params={}, - clip_sampler=cfg.diffusion.clip_sampler, - x_min=cfg.diffusion.x_min, - x_max=cfg.diffusion.x_max, - t_schedule_n=1.0, - ) # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] critic_def = EnsembleCritic( hidden_dims=cfg.critic_hidden_dims, - activation=jax.nn.relu, + activation=critic_activation, layer_norm=False, dropout=None, ensemble_size=2, diff --git a/flowrl/agent/online/unirep/__init__.py b/flowrl/agent/online/unirep/__init__.py new file mode 100644 index 0000000..fa740b6 --- /dev/null +++ b/flowrl/agent/online/unirep/__init__.py @@ -0,0 +1,7 @@ +from .aca import ACAAgent +from .diffsr.diffsr_td3 import DiffSRACAAgent + +__all__ = [ + "ACAAgent", + "DiffSRACAAgent", +] diff --git a/flowrl/agent/online/unirep/aca.py b/flowrl/agent/online/unirep/aca.py new file mode 100644 index 0000000..5ec437d --- /dev/null +++ b/flowrl/agent/online/unirep/aca.py @@ -0,0 +1,564 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.unirep.network import FactorizedNCE, update_factorized_nce +from flowrl.config.online.mujoco.algo.unirep.aca import ACAConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.flow.ddpm import DDPM, DDPMBackbone +from flowrl.functional.activation import l2_normalize, mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import Critic, EnsembleCritic, EnsembleCriticT +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.module.time_embedding import LearnableFourierEmbedding, PositionalEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + nce_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = nce_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("deterministic", "exploration_noise")) +def jit_td3_sample_action( + rng: PRNGKey, + actor: Model, + obs: jnp.ndarray, + deterministic: bool, + exploration_noise: float, +) -> jnp.ndarray: + action = actor(obs, training=False) + if not deterministic: + action = action + exploration_noise * jax.random.normal(rng, action.shape) + action = jnp.clip(action, -1.0, 1.0) + return action + +@partial(jax.jit, static_argnames=("discount", "solver", "critic_coef")) +def jit_update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + value: Model, + value_target: Model, + actor: Model, + backup: Model, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + # q0 target + B = batch.obs.shape[0] + A = batch.action.shape[-1] + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) + a0 = batch.action + + + rng, next_aT_rng = jax.random.split(rng) + next_a0 = backup(batch.next_obs, training=False) + # next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + # rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) + # next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") + # q0_target = critic_target(next_f0) + next_f1 = nce_target(batch.next_obs, next_a0, t1, method="forward_phi") + f1 = nce_target(batch.obs, a0, t1, method="forward_phi") + # q0_target = critic_target(batch.next_obs, next_a0) + q0_target = critic_target(next_f1) + q0_target = batch.reward + discount * (1 - batch.terminal) * q0_target.min(axis=0) + + + # features + # rng, at, t, eps = actor.add_noise(rng, a0) + # weight_t = actor.alpha_hats[t] / (1-actor.alpha_hats[t]) + # weight_t = 1.0 + # ft = nce_target(batch.obs, at, t, method="forward_phi") + # rng, t_rng, noise_rng = jax.random.split(rng, 3) + # t = jax.random.randint(t_rng, (*a0.shape[:-1], 1), 0, actor.steps+1) + # t = jnp.ones((*a0.shape[:-1], 1), dtype=jnp.int32) + # eps = jax.random.normal(noise_rng, a0.shape) + # at = jnp.sqrt(actor.alpha_hats[t]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t]) * eps + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + # f0 = nce_target(batch.obs, a0, t0, method="forward_phi") + q0_pred = critic.apply( + {"params": critic_params}, + f1, + # batch.obs, + # a0, + ) + critic_loss = ( + ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() + ) + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q0_mean": q0_pred.mean(), + "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_a0).mean(), + "misc/q0_target": q0_target.mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + + # rng, at, t, eps = actor.add_noise(rng, a0) + rng, t_rng, noise_rng = jax.random.split(rng, 3) + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) + eps = jax.random.normal(noise_rng, a0.shape) + at = jnp.sqrt(actor.alpha_hats[t1]) * a0 + jnp.sqrt(1 - actor.alpha_hats[t1]) * eps + ft = nce_target(batch.obs, at, t1, method="forward_phi") + + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + qt_pred = value.apply( + {"params": value_params}, + ft, + ) + value_loss = ((qt_pred - q0_target)**2).mean() + return value_loss, { + "loss/value_loss": value_loss, + "misc/qt_mean": qt_pred.mean(), + } + new_value, value_metrics = value.apply_gradient(value_loss_fn) + + return rng, new_critic, new_value, { + **critic_metrics, + **value_metrics, + } + +def jit_compute_metrics( + rng: PRNGKey, + actor: Model, + critic: Model, + value: Model, + nce_target: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = batch.action[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + t1 = jnp.ones((1, ), dtype=jnp.int32) + f1 = nce_target(obs, at, t1, method="forward_phi") + q = critic(f1) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + def get_value(at, obs, t): + ft = nce_target(obs, at, t, method="forward_phi") + q = value(ft) + return q.mean() + + for t in [0] + list(range(1, actor.steps+1, actor.steps//5)): + t_input = jnp.ones((B, num_actions, 1)) * t + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_input.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/value_{t}": all_value.std(axis=1).mean(), + f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_actor( + rng: PRNGKey, + actor: Model, + backup: Model, + nce_target: Model, + critic_target: Model, + value_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + a0 = batch.action + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) + rng, at, t, eps = actor.add_noise(rng, a0) + sigma = jnp.sqrt(1 - actor.alpha_hats[t]) + def get_q_value(at: jnp.ndarray, obs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: + t1 = jnp.ones(t.shape, dtype=jnp.int32) + ft = nce_target(obs, at, t1, method="forward_phi") + q = value_target(ft) + return q.mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs, t) + q_grad_l1 = jnp.abs(q_grad).mean() + eps_estimation = - sigma * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + # eps_estimation = -sigma * l2_normalize(q_grad) / temp + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/q_grad_l1": q_grad_l1, + } + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + + def backup_loss_fn(backup_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + new_action = backup.apply( + {"params": backup_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + f1 = nce_target(batch.obs, new_action, t1, method="forward_phi") + q = critic_target(f1) + loss = - q.mean() + return loss, { + "loss/backup_loss": loss, + } + new_backup, backup_metrics = backup.apply_gradient(backup_loss_fn) + + return rng, new_actor, new_backup, { + **actor_metrics, + **backup_metrics, + } + + +class ACAAgent(BaseAgent): + """ + ACA (Actor-Critic with Actor) agent. + """ + name = "ACAAgent" + model_names = ["nce", "nce_target", "actor", "critic", "critic_target", "value", "value_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.feature_dim = cfg.feature_dim + self.ranking = cfg.ranking + self.linear = cfg.linear + self.reward_coef = cfg.reward_coef + self.critic_coef = cfg.critic_coef + + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng, value_rng = jax.random.split(self.rng, 6) + + # define the nce + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + # define the actor + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + self.actor = DDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + approx_postvar=False, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the backup actor + from flowrl.module.actor import SquashedDeterministicActor + backup_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=[512,512,512], + layer_norm=True, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + self.backup = Model.create( + backup_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] + # critic_def = EnsembleCritic( + # hidden_dims=[512, 512], + # activation=critic_activation, + # ensemble_size=2, + # layer_norm=True, + # ) + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=[512,], + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + ) + + value_def = Critic( + hidden_dims=[512, 512], + activation=critic_activation, + layer_norm=True, + ) + self.value = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.feature_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.value_target = Model.create( + value_def, + value_rng, + inputs=(jnp.ones((1, self.feature_dim))), + ) + + # from flowrl.agent.online.unirep.network import ( + # EnsembleACACritic, + # EnsembleResidualCritic, + # ResidualCritic, + # SeparateCritic, + # ) + + # value_def = EnsembleACACritic( + # time_dim=16, + # hidden_dims=[256,256,256], + # activation=jax.nn.mish, + # ensemble_size=2, + # ) + # value_def = EnsembleResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512, 512], + # activation=jax.nn.mish, + # ) + # value_def = SeparateCritic( + # hidden_dims=[512, 512, 512], + # activation=jax.nn.mish, + # ensemble_size=cfg.diffusion.steps+1, + # ) + # self.value = Model.create( + # value_def, + # value_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + # optimizer=optax.adam(learning_rate=cfg.critic_lr), + # ) + # self.value_target = Model.create( + # value_def, + # value_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim)), jnp.ones((1, 1))), + # ) + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + self.rng, self.critic, self.value, critic_metrics = jit_update_critic( + self.rng, + self.critic, + self.critic_target, + self.value, + self.value_target, + self.actor, + self.backup, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + self.rng, self.actor, self.backup, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.backup, + self.nce_target, + self.critic_target, + self.value_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.cfg.target_update_freq == 0: + self.sync_target() + + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.actor, + self.critic, + self.value, + self.nce_target, + batch, + ) + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.nce_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + else: + self.rng, action_rng = jax.random.split(self.rng) + action = jit_td3_sample_action( + action_rng, + self.backup, + obs, + deterministic, + exploration_noise=0.2, + ) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.value_target = ema_update(self.value, self.value_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py b/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py new file mode 100644 index 0000000..7bf711e --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/diffsr_aca_sep.py @@ -0,0 +1,495 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.td3 import TD3Agent +from flowrl.agent.online.unirep.diffsr.network import ( + FactorizedDDPM, + update_factorized_ddpm, +) +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.activation import l2_normalize +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = ddpm_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + diffusion_actor: Model, + diffusion_value: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, t1, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, t1, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + + new_value = [] + value_metrics = {} + for i in range(len(diffusion_value)): + rng, eps_rng = jax.random.split(rng) + eps = jax.random.normal(eps_rng, batch.action.shape) + t = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(i) + at = jnp.sqrt(diffusion_actor.alpha_hats[t]) * batch.action + jnp.sqrt(1-diffusion_actor.alpha_hats[t]) * eps + ft = ddpm_target(batch.obs, at, t, method="forward_phi") + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + qt_pred = diffusion_value[i].apply( + {"params": value_params}, + ft, + ) + value_loss = ((qt_pred - q_target) ** 2).mean() + return value_loss, { + f"loss/value_{i}_loss": value_loss, + } + this_value, this_metrics = diffusion_value[i].apply_gradient(value_loss_fn) + new_value.append(this_value) + value_metrics.update(this_metrics) + + return rng, new_critic, new_value, { + **metrics, + **value_metrics, + } + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + diffusion_actor, + diffusion_value, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, t1, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + + # rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + # sigma = jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) + rng, rng2 = jax.random.split(rng) + B, S = batch.obs.shape + A = batch.action.shape[-1] + t_repeat = jnp.arange(diffusion_actor.steps+1)[..., jnp.newaxis, jnp.newaxis].repeat(B, axis=1) + obs_repeat = batch.obs[jnp.newaxis, ...].repeat(diffusion_actor.steps+1, axis=0) + action_repeat = batch.action[jnp.newaxis, ...].repeat(diffusion_actor.steps+1, axis=0) + at_repeat = jnp.sqrt(diffusion_actor.alpha_hats[t_repeat]) * action_repeat + jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * jax.random.normal(rng2, action_repeat.shape) + q_grad = [] + for i in range(diffusion_actor.steps+1): + def get_q_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value[i](ft) + return q.mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad.append(q_grad_fn(at_repeat[i], obs_repeat[i], t_repeat[i])) + q_grad = jnp.stack(q_grad, axis=0) + # eps_estimation = -jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * q_grad + # eps_estimation = l2_normalize(eps_estimation) * (eps_estimation.shape[-1] ** 0.5) + # eps_estimation = eps_estimation / 0.2 + eps_estimation = - jnp.sqrt(1 - diffusion_actor.alpha_hats[t_repeat]) * q_grad / 0.1 / (jnp.abs(q_grad).mean() + 1e-6) + + def diffusion_loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = diffusion_actor.apply( + {"params": diffusion_params}, + at_repeat, + t_repeat, + condition=obs_repeat, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/diffusion_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + new_diffusion, diffusion_metrics = diffusion_actor.apply_gradient(diffusion_loss_fn) + return rng, new_actor, new_diffusion, { + **metrics, + **diffusion_metrics, + } + +@jax.jit +def jit_compute_metrics( + rng: PRNGKey, + critic: Model, + ddpm_target: Model, + diffusion_value: Model, + diffusion_actor: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + t1 = jnp.ones((1, ), dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = critic(ft) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + for i in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): + def get_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value[i](ft) + return q.mean() + t_repeat = jnp.ones((B, num_actions, 1)) * jnp.int32(i) + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_repeat.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/value_{i}": all_value.std(axis=1).mean(), + f"q_grad/value_{i}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + +class DiffSRACASepAgent(TD3Agent): + """ + Diff-SR with ACA agent. + """ + + name = "DiffSRACAAgent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + # critic_def = RffEnsembleCritic( + # feature_dim=self.feature_dim, + # hidden_dims=cfg.critic_hidden_dims, + # rff_dim=cfg.rff_dim, + # ensemble_size=2, + # ) + critic_def = EnsembleCritic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ensemble_size=2, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + # define the ACA value and actor + from flowrl.agent.online.unirep.diffsr.network import ResidualCritic + from flowrl.flow.ddpm import DDPM, DDPMBackbone + from flowrl.functional.activation import mish + from flowrl.module.critic import Critic + from flowrl.module.time_embedding import PositionalEmbedding + self.rng, diffusion_value_rng, diffusion_actor_rng = jax.random.split(self.rng, 3) + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + self.diffusion_actor = DDPM.create( + network=backbone_def, + rng=diffusion_actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="vp", + noise_schedule_params={}, + approx_postvar=False, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + value_def = Critic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ) + # value_def = ResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512], + # activation=jax.nn.elu, + # ) + self.diffusion_value = [] + for t in range(cfg.diffusion.steps+1): + this_rng, diffusion_value_rng = jax.random.split(diffusion_value_rng) + self.diffusion_value.append( + Model.create( + value_def, + this_rng, + inputs=(jnp.ones((1, self.feature_dim)), ), + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, self.diffusion_value, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + self.diffusion_actor, + self.diffusion_value, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, self.diffusion_actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + self.diffusion_actor, + self.diffusion_value, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.critic, + self.ddpm_target, + self.diffusion_value, + self.diffusion_actor, + batch, + ) + metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.diffusion_actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + else: + action = super().sample_actions(obs, deterministic, num_samples) + return action, {} + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/diffsr/diffsr_td3.py b/flowrl/agent/online/unirep/diffsr/diffsr_td3.py new file mode 100644 index 0000000..2e2cfb1 --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/diffsr_td3.py @@ -0,0 +1,476 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.td3 import TD3Agent +from flowrl.agent.online.unirep.diffsr.network import ( + FactorizedDDPM, + update_factorized_ddpm, +) +from flowrl.config.online.mujoco.algo.diffsr import DiffSRTD3Config +from flowrl.functional.ema import ema_update +from flowrl.module.actor import SquashedDeterministicActor +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: Model, + critic: Model, + ddpm_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + # t0 = jnp.zeros((obs_repeat.shape[0], num_samples, 1)) + # f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + # qs = critic(obs_repeat, actions) + t1 = jnp.ones((obs_repeat.shape[0], num_samples, 1), dtype=jnp.int32) + f1 = ddpm_target(obs_repeat, actions, t1, method="forward_phi") + qs = critic(f1).min(axis=0).reshape(B, num_samples) + # qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + + +@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor_target: Model, + ddpm_target: Model, + diffusion_actor: Model, + diffusion_value: Model, + batch: Batch, + discount: float, + target_policy_noise: float, + noise_clip: float, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + rng, sample_rng = jax.random.split(rng) + noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise + noise = jnp.clip(noise, -noise_clip, noise_clip) + next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0) + + next_feature = ddpm_target(batch.next_obs, next_action, t1, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + back_critic_grad = False + if back_critic_grad: + raise NotImplementedError("no back critic grad exists") + + feature = ddpm_target(batch.obs, batch.action, t1, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + + rng, t_rng, eps_rng = jax.random.split(rng, 3) + # rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + # ft = ddpm_target(batch.obs, at, t, method="forward_phi") + t = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + at = jnp.sqrt(diffusion_actor.alpha_hats[t]) * batch.action + jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) * jax.random.normal(eps_rng, batch.action.shape) + ft = ddpm_target(batch.obs, at, t, method="forward_phi") + # snr = jnp.sqrt(diffusion_actor.alpha_hats[t] / (1 - diffusion_actor.alpha_hats[t])) + def value_loss_fn(value_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + qt_pred = diffusion_value.apply( + {"params": value_params}, + ft, + ) + value_loss = (((qt_pred - q_target) ** 2)).mean() + return value_loss, { + "loss/value_loss": value_loss, + "misc/qt_mean": qt_pred.mean(), + } + new_value, value_metrics = diffusion_value.apply_gradient(value_loss_fn) + + return rng, new_critic, new_value, { + **metrics, + **value_metrics, + } + + +@jax.jit +def update_actor( + rng: PRNGKey, + actor: Model, + ddpm_target: Model, + critic: Model, + diffusion_actor, + diffusion_value, + batch: Batch, +) -> Tuple[PRNGKey, Model, Metric]: + t1 = jnp.ones((batch.obs.shape[0], 1), dtype=jnp.int32) * jnp.int32(1) + def actor_loss_fn( + actor_params: Param, dropout_rng: PRNGKey + ) -> Tuple[jnp.ndarray, Metric]: + new_action = actor.apply( + {"params": actor_params}, + batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + new_feature = ddpm_target(batch.obs, new_action, t1, method="forward_phi") + q = critic(new_feature) + actor_loss = - q.mean() + + return actor_loss, { + "loss/actor_loss": actor_loss, + } + + new_actor, metrics = actor.apply_gradient(actor_loss_fn) + + rng, at, t, eps = diffusion_actor.add_noise(rng, batch.action) + sigma = jnp.sqrt(1 - diffusion_actor.alpha_hats[t]) + def get_q_value(at, obs, t): + t1 = jnp.ones(t.shape, dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = diffusion_value(ft) + return q.mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs, t) + q_grad_l1 = jnp.abs(q_grad).mean() + eps_estimation = - sigma * q_grad / 0.1 / (jnp.abs(q_grad).mean() + 1e-6) + def diffusion_loss_fn(diffusion_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = diffusion_actor.apply( + {"params": diffusion_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/diffusion_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/q_grad_l1": q_grad_l1, + } + new_diffusion, diffusion_metrics = diffusion_actor.apply_gradient(diffusion_loss_fn) + return rng, new_actor, new_diffusion, { + **metrics, + **diffusion_metrics, + } + +@jax.jit +def jit_compute_metrics( + rng: PRNGKey, + critic: Model, + ddpm_target: Model, + diffusion_value: Model, + diffusion_actor: Model, + batch: Batch, +) -> Tuple[PRNGKey, Metric]: + B, S = batch.obs.shape + A = batch.action.shape[-1] + num_actions = 50 + metrics = {} + rng, action_rng = jax.random.split(rng) + obs_repeat = batch.obs[..., jnp.newaxis, :].repeat(num_actions, axis=-2) + action_repeat = jax.random.uniform(action_rng, (B, num_actions, A), minval=-1.0, maxval=1.0) + + def get_critic(at, obs): + t1 = jnp.ones((1, ), dtype=jnp.int32) + ft = ddpm_target(obs, at, t1, method="forward_phi") + q = critic(ft) + return q.mean() + all_critic, all_critic_grad = jax.vmap(jax.value_and_grad(get_critic))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + ) + all_critic = all_critic.reshape(B, num_actions, 1) + all_critic_grad = all_critic_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/critic": all_critic.std(axis=1).mean(), + f"q_grad/critic": jnp.abs(all_critic_grad).mean(), + }) + + def get_value(at, obs, t): + ft = ddpm_target(obs, at, t, method="forward_phi") + q = diffusion_value(ft) + return q.mean() + for t in [0] + list(range(1, diffusion_actor.steps+1, diffusion_actor.steps//5)): + t_input = jnp.ones((B, num_actions, 1)) * t + all_value, all_value_grad = jax.vmap(jax.value_and_grad(get_value))( + action_repeat.reshape(-1, A), + obs_repeat.reshape(-1, S), + t_input.reshape(-1, 1), + ) + all_value = all_value.reshape(B, num_actions, 1) + all_value_grad = all_value_grad.reshape(B, num_actions, -1) + metrics.update({ + f"q_std/value_{t}": all_value.std(axis=1).mean(), + f"q_grad/value_{t}": jnp.abs(all_value_grad).mean(), + }) + return rng, metrics + +class DiffSRACAAgent(TD3Agent): + """ + Diff-SR with ACA agent. + """ + + name = "DiffSRACAAgent" + model_names = ["ddpm", "ddpm_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: DiffSRTD3Config, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ddpm_coef = cfg.ddpm_coef + self.critic_coef = cfg.critic_coef + self.reward_coef = cfg.reward_coef + self.num_noises = cfg.num_noises + self.feature_dim = cfg.feature_dim + + # networks + self.rng, ddpm_rng, ddpm_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + ddpm_def = FactorizedDDPM( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.embed_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + ) + self.ddpm = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.ddpm_target = Model.create( + ddpm_def, + ddpm_rng, + inputs=( + ddpm_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + actor_def = SquashedDeterministicActor( + backbone=MLP( + hidden_dims=cfg.actor_hidden_dims, + layer_norm=cfg.layer_norm, + dropout=None, + ), + obs_dim=self.obs_dim, + action_dim=self.act_dim, + ) + # critic_def = RffEnsembleCritic( + # feature_dim=self.feature_dim, + # hidden_dims=cfg.critic_hidden_dims, + # rff_dim=cfg.rff_dim, + # ensemble_size=2, + # ) + critic_def = EnsembleCritic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ensemble_size=2, + ) + self.actor = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + optimizer=optax.adam(learning_rate=cfg.actor_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.actor_target = Model.create( + actor_def, + actor_rng, + inputs=(jnp.ones((1, self.obs_dim)),), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + # define the ACA value and actor + from flowrl.agent.online.unirep.diffsr.network import ResidualCritic + from flowrl.flow.ddpm import DDPM, DDPMBackbone + from flowrl.functional.activation import mish + from flowrl.module.critic import Critic + from flowrl.module.time_embedding import PositionalEmbedding + self.rng, diffusion_value_rng, diffusion_actor_rng = jax.random.split(self.rng, 3) + time_embedding = partial(PositionalEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = DDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + self.diffusion_actor = DDPM.create( + network=backbone_def, + rng=diffusion_actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="vp", + noise_schedule_params={}, + approx_postvar=False, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + value_def = Critic( + hidden_dims=[512, 512, 512], + activation=jax.nn.elu, + layer_norm=True, + ) + # value_def = ResidualCritic( + # time_embedding=time_embedding, + # hidden_dims=[512, 512], + # activation=jax.nn.elu, + # ) + self.diffusion_value = Model.create( + value_def, + diffusion_value_rng, + inputs=(jnp.ones((1, self.feature_dim)), ), + optimizer=optax.adam(learning_rate=cfg.diffusion.lr), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.ddpm, ddpm_metrics = update_factorized_ddpm( + self.rng, + self.ddpm, + batch, + self.reward_coef, + ) + metrics.update(ddpm_metrics) + + self.rng, self.critic, self.diffusion_value, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor_target, + self.ddpm_target, + self.diffusion_actor, + self.diffusion_value, + batch, + discount=self.cfg.discount, + target_policy_noise=self.target_policy_noise, + noise_clip=self.noise_clip, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, self.diffusion_actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.ddpm_target, + self.critic, + self.diffusion_actor, + self.diffusion_value, + batch, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + if self._n_training_steps % 2000 == 0: + self.rng, metrics = jit_compute_metrics( + self.rng, + self.critic, + self.ddpm_target, + self.diffusion_value, + self.diffusion_actor, + batch, + ) + metrics.update(metrics) + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + if deterministic: + num_samples = self.cfg.num_samples + self.rng, action = jit_sample_actions( + self.rng, + self.diffusion_actor, + self.critic, + self.ddpm_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + else: + action = super().sample_actions(obs, deterministic, num_samples) + return action, {} + + def sync_target(self): + super().sync_target() + self.ddpm_target = ema_update(self.ddpm, self.ddpm_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/diffsr/network.py b/flowrl/agent/online/unirep/diffsr/network.py new file mode 100644 index 0000000..d4139dc --- /dev/null +++ b/flowrl/agent/online/unirep/diffsr/network.py @@ -0,0 +1,161 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp + +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import l2_normalize, mish +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import PositionalEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class ResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + + @nn.compact + def __call__(self, ft: jnp.ndarray, t: jnp.ndarray, training: bool=False): + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([item for item in [ft, ] if item is not None], axis=-1) + x = ResidualMLP( + hidden_dims=self.hidden_dims, + output_dim=1, + multiplier=1, + activation=self.activation, + layer_norm=True, + )(x, training) + return x + +class FactorizedDDPM(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + embed_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int + num_noises: int + + def setup(self): + self.mlp_t1 = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + nn.Dense(2*self.embed_dim), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_t2 = nn.Sequential([ + PositionalEmbedding(self.embed_dim), + nn.Dense(2*self.embed_dim), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_s = nn.Sequential([ + nn.Dense(self.embed_dim*2), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_a = nn.Sequential([ + nn.Dense(self.embed_dim*2), + mish, + nn.Dense(self.embed_dim) + ]) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim*self.obs_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.reward = RffReward( + self.feature_dim, + self.reward_hidden_dims, + rff_dim=self.rff_dim, + ) + betas, alphas, alphabars = get_noise_schedule("vp", self.num_noises) + alphabars_prev = jnp.pad(alphabars[:-1], (1, 0), constant_values=1.0) + self.alphabars = alphabars + + def forward_phi(self, s, a, t): + s = self.mlp_s(s) + a = self.mlp_a(a) + t_ff = self.mlp_t1(t) + x = jnp.concat([s, a, t_ff], axis=-1) + x = self.mlp_phi(x) + return x + + def forward_mu(self, sp, t): + t = self.mlp_t2(t) + x = jnp.concat([sp, t], axis=-1) + x = self.mlp_mu(x) + return x.reshape(-1, self.feature_dim, self.obs_dim) + + def forward_reward(self, x: jnp.ndarray): + return self.reward(x) + + def __call__(self, rng, s, a, sp, training: bool=False): + rng, t1_rng, eps1_rng, t2_rng, eps2_rng = jax.random.split(rng, 5) + t1 = jax.random.randint(t1_rng, (s.shape[0], 1), 0, self.num_noises+1) + t2 = jax.random.randint(t2_rng, (s.shape[0], 1), 0, self.num_noises+1) + eps1 = jax.random.normal(eps1_rng, a.shape) + eps2 = jax.random.normal(eps2_rng, sp.shape) + at = jnp.sqrt(self.alphabars[t1]) * a + jnp.sqrt(1-self.alphabars[t1]) * eps1 + spt = jnp.sqrt(self.alphabars[t2]) * sp + jnp.sqrt(1-self.alphabars[t2]) * eps2 + z_phi = self.forward_phi(s, at, t1) + z_mu = self.forward_mu(spt, t2) + eps_pred = jax.lax.batch_matmul(z_phi[..., jnp.newaxis, :], z_mu)[..., 0, :] + r_pred = self.forward_reward(z_phi) + return eps2, eps_pred, r_pred + + +@partial(jax.jit, static_argnames=("reward_coef")) +def update_factorized_ddpm( + rng: PRNGKey, + ddpm: FactorizedDDPM, + batch: Batch, + reward_coef: float, +) -> Tuple[PRNGKey, FactorizedDDPM, Metric]: + B = batch.obs.shape[0] + rng, update_rng = jax.random.split(rng) + def loss_fn(ddpm_params: Param, dropout_rng: PRNGKey): + eps, eps_pred, r_pred = ddpm.apply( + {"params": ddpm_params}, + update_rng, + batch.obs, + batch.action, + batch.next_obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + ddpm_loss = ((eps_pred - eps) ** 2).mean() + reward_loss = ((r_pred - batch.reward) ** 2).mean() + loss = ddpm_loss + reward_coef * reward_loss + return loss, { + "loss/ddpm_loss": ddpm_loss, + "loss/reward_loss": reward_loss, + "misc/sp0_mean": batch.next_obs.mean(), + "misc/sp0_std": batch.next_obs.std(axis=0).mean(), + "misc/eps_mean": eps_pred.mean(), + "misc/reward_mean": r_pred.mean(), + } + + new_ddpm, metrics = ddpm.apply_gradient(loss_fn) + return rng, new_ddpm, metrics diff --git a/flowrl/agent/online/unirep/network.py b/flowrl/agent/online/unirep/network.py new file mode 100644 index 0000000..2cb806b --- /dev/null +++ b/flowrl/agent/online/unirep/network.py @@ -0,0 +1,353 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import l2_normalize, mish +from flowrl.module.critic import Critic +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.model import Model +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import LearnableFourierEmbedding, PositionalEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class ResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + + @nn.compact + def __call__(self, ft: jnp.ndarray, t: jnp.ndarray, training: bool=False): + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([item for item in [ft, t_ff] if item is not None], axis=-1) + x = ResidualMLP( + hidden_dims=self.hidden_dims, + output_dim=1, + multiplier=1, + activation=self.activation, + layer_norm=True, + )(x, training) + return x + +class EnsembleResidualCritic(nn.Module): + time_embedding: nn.Module + hidden_dims: Sequence[int] + activation: Callable = nn.relu + ensemble_size: int = 2 + + @nn.compact + def __call__( + self, + obs: Optional[jnp.ndarray] = None, + action: Optional[jnp.ndarray] = None, + t: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + vmap_critic = nn.vmap( + ResidualCritic, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + x = vmap_critic( + time_embedding=self.time_embedding, + hidden_dims=self.hidden_dims, + activation=self.activation, + )(obs, action, t, training) + return x + +class ACACritic(nn.Module): + time_dim: int + hidden_dims: Sequence[int] + activation: Callable + + @nn.compact + def __call__(self, obs, action, t, training=False): + t_ff = PositionalEmbedding(self.time_dim)(t) + t_ff = nn.Dense(2*self.time_dim)(t_ff) + t_ff = self.activation(t_ff) + t_ff = nn.Dense(self.time_dim)(t_ff) + x = jnp.concatenate([obs, action, t_ff], axis=-1) + return MLP( + hidden_dims=self.hidden_dims, + output_dim=1, + activation=self.activation, + layer_norm=False + )(x, training) + +class EnsembleACACritic(nn.Module): + time_dim: int + hidden_dims: Sequence[int] + activation: Callable + ensemble_size: int + + @nn.compact + def __call__(self, obs, action, t, training=False): + vmap_critic = nn.vmap( + ACACritic, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + return vmap_critic( + time_dim=self.time_dim, + hidden_dims=self.hidden_dims, + activation=self.activation, + )(obs, action, t, training) + + +class SeparateCritic(nn.Module): + hidden_dims: Sequence[int] + activation: Callable + ensemble_size: int + + @nn.compact + def __call__(self, obs, action, t, training=False): + vmap_critic = nn.vmap( + MLP, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=-1, + axis_size=self.ensemble_size + ) + x = jnp.concatenate([obs, action], axis=-1) + out = vmap_critic( + hidden_dims=self.hidden_dims, + output_dim=1, + activation=self.activation, + layer_norm=False + )(x, training) + out = out.reshape(*out.shape[:-2], -1) + # Using jnp.take_along_axis for batched index selection (broadcasting as needed) + # out: (E, B, T, 1), need to select on axis=2 using t_indices[b] for each batch + # We assume batch dim is axis=1, time axis=2 + out = jnp.take_along_axis( + out, + t.astype(jnp.int32), + axis=-1 + ) + return out + +class FactorizedNCE(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int = 0 + num_noises: int = 0 + ranking: bool = False + + def setup(self): + self.mlp_t1 = nn.Sequential( + [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + ) + self.mlp_t2 = nn.Sequential( + [PositionalEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + ) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.reward = RffReward( + self.feature_dim, + [512,], + rff_dim=self.rff_dim, + ) + # self.reward = Critic( + # hidden_dims=self.reward_hidden_dims, + # activation=nn.elu, + # layer_norm=True, + # dropout=None, + # ) + if self.num_noises > 0: + self.use_noise_perturbation = True + from flowrl.flow.ddpm import cosine_beta_schedule + betas = cosine_beta_schedule(T=self.num_noises) + betas = jnp.concatenate([jnp.zeros((1,)), betas]) + alphas = 1 - betas + self.alpha_hats = jnp.cumprod(alphas) + else: + self.use_noise_perturbation = False + self.N = max(self.num_noises, 1) + if not self.ranking: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + else: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + + def forward_phi(self, s, at, t): + x = jnp.concat([s, at], axis=-1) + t_ff = self.mlp_t1(t) + x = jnp.concat([x, t_ff], axis=-1) + x = self.mlp_phi(x) + x = l2_normalize(x, group_size=None) + # x = jnp.tanh(x) + return x + + def forward_mu(self, sp, t): + t_ff = self.mlp_t2(t) + sp = jnp.concat([sp, t_ff], axis=-1) + sp = self.mlp_mu(sp) + sp = jnp.tanh(sp) + return sp + + def forward_reward(self, x: jnp.ndarray): # for z_phi + return self.reward(x) + + def forward_logits( + self, + rng: PRNGKey, + s: jnp.ndarray, + a: jnp.ndarray, + sp: jnp.ndarray, + ): + B, D = sp.shape + rng, t_rng, eps1_rng, eps2_rng = jax.random.split(rng, 4) + if self.use_noise_perturbation: + + # perturb sp, (N, B, D) with the noise level shared in each N + sp0 = jnp.broadcast_to(sp, (self.N, B, sp.shape[-1])) + t_sp = jnp.arange(1, self.num_noises+1) + t_sp = jnp.repeat(t_sp, B).reshape(self.N, B, 1) + eps_sp = jax.random.normal(eps1_rng, sp0.shape) + alpha, sigma = jnp.sqrt(self.alpha_hats[t_sp]), jnp.sqrt(1 - self.alpha_hats[t_sp]) + spt = alpha * sp0 + sigma * eps_sp + + # perturb a, (N, B, D) with noise level independent across N + a0 = a + t_a = jax.random.randint(t_rng, (B, 1), 1, self.num_noises+1) + # t_a = jnp.ones((B, 1), dtype=jnp.int32) + eps_a = jax.random.normal(eps2_rng, a0.shape) + alpha, sigma = jnp.sqrt(self.alpha_hats[t_a]), jnp.sqrt(1 - self.alpha_hats[t_a]) + at = alpha * a0 + sigma * eps_a + # at = jnp.broadcast_to(at, (self.N, B, at.shape[-1])) + # t_a = jnp.broadcast_to(t_a, (self.N, B, 1)) + else: + s = jnp.expand_dims(s, 0) + at = jnp.expand_dims(a, 0) + t = None + z_phi = self.forward_phi(s, at, t_a) + z_mu = self.forward_mu(spt, t_sp) + logits = jax.lax.batch_matmul( + jnp.broadcast_to(z_phi, (self.N, B, z_phi.shape[-1])), + jnp.swapaxes(z_mu, -1, -2) + ) + logits = logits / jnp.exp(self.normalizer[:, None, None]) + rewards = self.forward_reward(z_phi) + return logits, rewards + + def forward_normalizer(self): + return self.normalizer + + def __call__( + self, + rng: PRNGKey, + s, + a, + sp, + ): + logits, rewards = self.forward_logits(rng, s, a, sp) + _ = self.forward_normalizer() + + return logits, rewards + + +@partial(jax.jit, static_argnames=("ranking", "reward_coef")) +def update_factorized_nce( + rng: PRNGKey, + nce: Model, + batch: Batch, + ranking: bool, + reward_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + B = batch.obs.shape[0] + rng, logits_rng = jax.random.split(rng) + if ranking: + labels = jnp.arange(B) + else: + labels = jnp.eye(B) + + def loss_fn(nce_params: Param, dropout_rng: PRNGKey): + logits, rewards = nce.apply( + {"params": nce_params}, + logits_rng, + batch.obs, + batch.action, + batch.next_obs, + method="forward_logits", + ) + + if ranking: + model_loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.broadcast_to(labels, (logits.shape[0], B)) + ).mean(axis=-1) + else: + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + eff_logits = logits + normalizer[:, None, None] - jnp.log(B) + model_loss = optax.sigmoid_binary_cross_entropy(eff_logits, labels).mean([-2, -1]) + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + rewards_target = jnp.broadcast_to(batch.reward, rewards.shape) + reward_loss = jnp.mean((rewards - rewards_target) ** 2) + + nce_loss = model_loss.mean() + reward_coef * reward_loss + 0.000 * (logits**2).mean() + + pos_logits = logits[ + jnp.arange(logits.shape[0])[..., jnp.newaxis], + jnp.arange(logits.shape[1]), + jnp.arange(logits.shape[2])[jnp.newaxis, ...].repeat(logits.shape[0], axis=0) + ] + pos_logits_per_noise = pos_logits.mean(axis=-1) + neg_logits = (logits.sum(axis=-1) - pos_logits) / (logits.shape[-1] - 1) + neg_logits_per_noise = neg_logits.mean(axis=-1) + metrics = { + "loss/nce_loss": nce_loss, + "loss/model_loss": model_loss.mean(), + "loss/reward_loss": reward_loss, + "misc/obs_mean": batch.obs.mean(), + "misc/obs_std": batch.obs.std(axis=0).mean(), + } + checkpoints = list(range(0, logits.shape[0], logits.shape[0]//5)) + [logits.shape[0]-1] + metrics.update({ + f"misc/positive_logits_{i}": pos_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/negative_logits_{i}": neg_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/logits_gap_{i}": (pos_logits_per_noise[i] - neg_logits_per_noise[i]).mean() for i in checkpoints + }) + metrics.update({ + f"misc/normalizer_{i}": jnp.exp(normalizer[i]) for i in checkpoints + }) + return nce_loss, metrics + + new_nce, metrics = nce.apply_gradient(loss_fn) + return rng, new_nce, metrics diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index e775d25..89a4940 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -1,12 +1,17 @@ from hydra.core.config_store import ConfigStore +from .algo.alac import ALACConfig from .algo.base import BaseAlgoConfig -from .algo.ctrl_td3 import CtrlTD3Config +from .algo.ctrlsr import * +from .algo.diffsr import * from .algo.dpmd import DPMDConfig +from .algo.idem import IDEMConfig +from .algo.qsm import QSMConfig from .algo.sac import SACConfig from .algo.sdac import SDACConfig from .algo.td3 import TD3Config from .algo.td7 import TD7Config +from .algo.unirep import * from .config import Config, LogConfig _DEF_SUFFIX = "_cfg_def" @@ -23,7 +28,13 @@ "td3": TD3Config, "td7": TD7Config, "dpmd": DPMDConfig, - "ctrl": CtrlTD3Config, + "qsm": QSMConfig, + "alac": ALACConfig, + "idem": IDEMConfig, + "ctrlsr_td3": CtrlSRTD3Config, + "aca": ACAConfig, + "diffsr_td3": DiffSRTD3Config, + "diffsr_qsm": DiffSRQSMConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/alac.py b/flowrl/config/online/mujoco/algo/alac.py new file mode 100644 index 0000000..b07cc24 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/alac.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class ALACLangevinDynamicsConfig: + resnet: bool + activation: str + ensemble_size: int + time_dim: int + hidden_dims: List[int] + cond_hidden_dims: List[int] + steps: int + step_size: float + noise_scale: float + noise_schedule: str + clip_sampler: bool + x_min: float + x_max: float + epsilon: float + lr: float + clip_grad_norm: float | None + +@dataclass +class ALACConfig(BaseAlgoConfig): + name: str + discount: float + ema: float + num_samples: int + ld: ALACLangevinDynamicsConfig diff --git a/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py b/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py new file mode 100644 index 0000000..0dad455 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrlsr/__init__.py @@ -0,0 +1,5 @@ +from .ctrlsr_td3 import CtrlSRTD3Config + +__all__ = [ + "CtrlSRTD3Config", +] diff --git a/flowrl/config/online/mujoco/algo/ctrlsr/ctrl_qsm.py b/flowrl/config/online/mujoco/algo/ctrlsr/ctrl_qsm.py new file mode 100644 index 0000000..402aa92 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrlsr/ctrl_qsm.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig +from ..qsm import QSMDiffusionConfig + + +@dataclass +class CtrlQSMConfig(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + # critic_hidden_dims: List[int] + critic_activation: str # not used + critic_ensemble_size: int + layer_norm: bool + critic_lr: float + clip_grad_norm: float | None + + feature_dim: int + feature_lr: float + feature_ema: float + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ctrl_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float + + num_noises: int + linear: bool + ranking: bool + + num_samples: int + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/ctrl_td3.py b/flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py similarity index 91% rename from flowrl/config/online/mujoco/algo/ctrl_td3.py rename to flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py index 667496a..5c51a2a 100644 --- a/flowrl/config/online/mujoco/algo/ctrl_td3.py +++ b/flowrl/config/online/mujoco/algo/ctrlsr/ctrlsr_td3.py @@ -1,11 +1,11 @@ from dataclasses import dataclass from typing import List -from .base import BaseAlgoConfig +from ..base import BaseAlgoConfig @dataclass -class CtrlTD3Config(BaseAlgoConfig): +class CtrlSRTD3Config(BaseAlgoConfig): name: str actor_update_freq: int target_update_freq: int diff --git a/flowrl/config/online/mujoco/algo/diffsr/__init__.py b/flowrl/config/online/mujoco/algo/diffsr/__init__.py new file mode 100644 index 0000000..177cbe6 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/__init__.py @@ -0,0 +1,7 @@ +from .diffsr_qsm import DiffSRQSMConfig +from .diffsr_td3 import DiffSRTD3Config + +__all__ = [ + "DiffSRTD3Config", + "DiffSRQSMConfig", +] diff --git a/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py b/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py new file mode 100644 index 0000000..db067f7 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/diffsr_qsm.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig +from ..qsm import QSMDiffusionConfig + + +@dataclass +class DiffSRQSMConfig(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + # critic_hidden_dims: List[int] + critic_activation: str # not used + critic_ensemble_size: int + layer_norm: bool + critic_lr: float + clip_grad_norm: float | None + + num_noises: int + feature_dim: int + feature_lr: float + feature_ema: float + embed_dim: int + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ddpm_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float + + num_samples: int + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py b/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py new file mode 100644 index 0000000..6f68a2c --- /dev/null +++ b/flowrl/config/online/mujoco/algo/diffsr/diffsr_td3.py @@ -0,0 +1,38 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig + + +@dataclass +class DiffSRTD3Config(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + actor_hidden_dims: List[int] + # critic_hidden_dims: List[int] + critic_ensemble_size: int + layer_norm: bool + actor_lr: float + critic_lr: float + clip_grad_norm: float | None + target_policy_noise: float + noise_clip: float + exploration_noise: float + + num_noises: int + feature_dim: int + feature_lr: float + feature_ema: float + embed_dim: int + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ddpm_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float diff --git a/flowrl/config/online/mujoco/algo/idem.py b/flowrl/config/online/mujoco/algo/idem.py new file mode 100644 index 0000000..c46e7b6 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/idem.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class IDEMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class IDEMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_lr: float + discount: float + num_samples: int + num_reverse_samples: int + ema: float + temp: float + diffusion: IDEMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/qsm.py b/flowrl/config/online/mujoco/algo/qsm.py new file mode 100644 index 0000000..8c02f0f --- /dev/null +++ b/flowrl/config/online/mujoco/algo/qsm.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class QSMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class QSMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_activation: str + critic_lr: float + discount: float + num_samples: int + ema: float + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/sdac.py b/flowrl/config/online/mujoco/algo/sdac.py index bf1f78e..c26e080 100644 --- a/flowrl/config/online/mujoco/algo/sdac.py +++ b/flowrl/config/online/mujoco/algo/sdac.py @@ -23,6 +23,7 @@ class SDACDiffusionConfig: class SDACConfig(BaseAlgoConfig): name: str critic_hidden_dims: List[int] + critic_activation: str critic_lr: float discount: float num_samples: int diff --git a/flowrl/config/online/mujoco/algo/unirep/__init__.py b/flowrl/config/online/mujoco/algo/unirep/__init__.py new file mode 100644 index 0000000..dafaab0 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/__init__.py @@ -0,0 +1,5 @@ +from .aca import ACAConfig + +__all__ = [ + "ACAConfig", +] diff --git a/flowrl/config/online/mujoco/algo/unirep/aca.py b/flowrl/config/online/mujoco/algo/unirep/aca.py new file mode 100644 index 0000000..3d68d4a --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/aca.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig + + +@dataclass +class ACADiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class ACAConfig(BaseAlgoConfig): + name: str + target_update_freq: int + feature_dim: int + rff_dim: int + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + ctrl_coef: float + reward_coef: float + critic_coef: float + critic_activation: str + back_critic_grad: bool + feature_lr: float + critic_lr: float + discount: float + num_samples: int + ema: float + feature_ema: float + clip_grad_norm: float | None + temp: float + diffusion: ACADiffusionConfig + + num_noises: int + linear: bool + ranking: bool diff --git a/flowrl/dataset/toy2d.py b/flowrl/dataset/toy2d.py index 649ccd4..ba93b76 100644 --- a/flowrl/dataset/toy2d.py +++ b/flowrl/dataset/toy2d.py @@ -4,6 +4,7 @@ import numpy as np import sklearn import sklearn.datasets +from scipy.stats import multivariate_normal from sklearn.utils import shuffle as util_shuffle from flowrl.types import Batch @@ -142,13 +143,38 @@ def inf_train_gen(data, batch_size=200): x = np.random.rand(batch_size) * 5 - 2.5 y = np.sin(x) * 2.5 return np.stack((x, y), 1) + + elif data == "8gaussiansmix": + scale = 3.5 + centers = [ + (0, 1), + (-1. / np.sqrt(2), 1. / np.sqrt(2)), + (-1, 0), + (-1. / np.sqrt(2), -1. / np.sqrt(2)), + (0, -1), + (1. / np.sqrt(2), -1. / np.sqrt(2)), + (1, 0), + (1. / np.sqrt(2), 1. / np.sqrt(2)), + ] + weights = [8, 7, 6, 5, 4, 3, 2, 1] + + + centers = [(scale * x, scale * y) for x, y in centers] + cov = 1.5**2 + x = np.random.rand(batch_size, 2) * 10 - 5 + energy = [ + multivariate_normal.pdf(x, center, cov) + for center in centers + ] + energy = sum([weights[i] * energy[i] for i in range(8)]) + return x, energy[:, None] else: assert False class Toy2dDataset(object): def __init__(self, task: str, data_size: int=1000000, scan: bool=True): - assert task in ["swissroll", "8gaussians", "moons", "rings", "checkerboard", "2spirals"] + assert task in ["swissroll", "8gaussians", "moons", "rings", "checkerboard", "2spirals", "8gaussiansmix"] self.task = task self.data_size = data_size self.scan = scan diff --git a/flowrl/flow/langevin_dynamics.py b/flowrl/flow/langevin_dynamics.py new file mode 100644 index 0000000..81fe14b --- /dev/null +++ b/flowrl/flow/langevin_dynamics.py @@ -0,0 +1,293 @@ +from functools import partial +from typing import Callable, Optional, Sequence, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.struct import PyTreeNode, dataclass, field +from flax.training.train_state import TrainState + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule, linear_noise_schedule +from flowrl.module.model import Model +from flowrl.types import * + +# ======= Langevin Dynamics Sampling ======= + +@dataclass +class LangevinDynamics(Model): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool = True, + steps: int = 100, + step_size: float = 0.01, + noise_scale: float = 1.0, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + optimizer: Optional[optax.GradientTransformation] = None, + clip_grad_norm: float = None + ) -> 'LangevinDynamics': + ret = super().create(network, rng, inputs, optimizer, clip_grad_norm) + + return ret.replace( + x_dim=x_dim, + grad_prediction=grad_prediction, + steps=steps, + step_size=step_size, + noise_scale=noise_scale, + clip_sampler=clip_sampler, + x_min=x_min, + x_max=x_max, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + i: int, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = i * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros_like((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + step_size = step_size or self.step_size + noise_scale = noise_scale or self.noise_scale + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, noise_rng, dropout_rng_ = jax.random.split(rng_, 3) + + grad, energy = self.compute_grad(xt, i, condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + + xt_1 = xt + step_size * grad + if self.clip_sampler: + xt_1 = jnp.clip(xt_1, self.x_min, self.x_max) + noise = jax.random.normal(noise_rng, xt_1.shape, dtype=jnp.float32) + xt_1 += (i>1) * jnp.sqrt(2 * step_size * noise_scale) * noise + + return (rng_, xt_1), (xt, grad, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history + + +@dataclass +class AnnealedLangevinDynamics(LangevinDynamics): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + t_schedule_n: float = field(pytree_node=False, default=None) + t_diffusion: Tuple[float, float] = field(pytree_node=False, default=None) + noise_schedule_func: Callable = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool, + steps: int, + step_size: float, + noise_scale: float, + noise_schedule: str, + noise_schedule_params: Optional[Dict]=None, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + t_schedule_n: float=1.0, + epsilon: float=0.001, + optimizer: Optional[optax.GradientTransformation]=None, + clip_grad_norm: float=None + ) -> 'AnnealedLangevinDynamics': + ret = super().create( + network, + rng, + inputs, + x_dim, + grad_prediction, + steps, + step_size, + noise_scale, + clip_sampler, + x_min, + x_max, + optimizer, + clip_grad_norm, + ) + + if noise_schedule_params is None: + noise_schedule_params = {} + if noise_schedule == "cosine": + t_diffusion = [epsilon, 0.9946] + else: + t_diffusion = [epsilon, 1.0] + if noise_schedule == "linear": + noise_schedule_func = partial(linear_noise_schedule, **noise_schedule_params) + elif noise_schedule == "cosine": + noise_schedule_func = partial(cosine_noise_schedule, **noise_schedule_params) + elif noise_schedule == "none": + noise_schedule_func = lambda t: (jnp.ones_like(t), jnp.zeros_like(t)) + else: + raise NotImplementedError(f"Unsupported noise schedule: {noise_schedule}") + + return ret.replace( + t_schedule_n=t_schedule_n, + t_diffusion=t_diffusion, + noise_schedule_func=noise_schedule_func, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + t: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = t * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + # alpha, sigma = self.noise_schedule_func(t) + # grad = alpha * grad - sigma * x + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + def add_noise(self, rng: PRNGKey, x: jnp.ndarray) -> Tuple[PRNGKey, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + rng, t_rng, noise_rng = jax.random.split(rng, 3) + t = jax.random.uniform(t_rng, (*x.shape[:-1], 1), dtype=jnp.float32, minval=self.t_diffusion[0], maxval=self.t_diffusion[1]) + alpha, sigma = self.noise_schedule_func(t) + eps = jax.random.normal(noise_rng, x.shape, dtype=jnp.float32) + xt = alpha * x + sigma * eps + return rng, xt, t, eps + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + # step_size = step_size or self.step_size + # noise_scale = noise_scale or self.noise_scale + t_schedule_n = 1.0 + from flowrl.flow.continuous_ddpm import quad_t_schedule + ts = quad_t_schedule(steps, n=t_schedule_n, tmin=self.t_diffusion[0], tmax=self.t_diffusion[1]) + alpha_hats = self.noise_schedule_func(ts)[0] ** 2 + alphas = alpha_hats[1:] / alpha_hats[:-1] + alphas = jnp.concat([jnp.ones((1, )), alphas], axis=0) + betas = 1 - alphas + alpha1, alpha2 = self.noise_schedule_func(ts) + + t_proto = jnp.ones((*x_init.shape[:-1], 1), dtype=jnp.int32) + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, dropout_rng_, key_ = jax.random.split(rng_, 3) + input_t = t_proto * ts[i] + + q_grad, energy = self.compute_grad(xt, ts[i], condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + eps_theta = q_grad + + x0_hat = (xt - jnp.sqrt(1 - alpha_hats[i]) * eps_theta) / jnp.sqrt(alpha_hats[i]) + x0_hat = jnp.clip(x0_hat, self.x_min, self.x_max) if self.clip_sampler else x0_hat + + mean_coef1 = jnp.sqrt(alpha_hats[i-1]) * betas[i] / (1 - alpha_hats[i]) + mean_coef2 = jnp.sqrt(alphas[i]) * (1 - alpha_hats[i-1]) / (1 - alpha_hats[i]) + xt_1 = mean_coef1 * x0_hat + mean_coef2 * xt + xt_1 += (i>1) * jnp.sqrt(betas[i]) * jax.random.normal(key_, xt_1.shape) + + return (rng_, xt_1), (xt, eps_theta, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history diff --git a/flowrl/module/actor.py b/flowrl/module/actor.py index 5add3dd..882143b 100644 --- a/flowrl/module/actor.py +++ b/flowrl/module/actor.py @@ -2,6 +2,7 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.types import * from flowrl.utils.distribution import TanhMultivariateNormalDiag @@ -12,6 +13,8 @@ class DeterministicActor(nn.Module): backbone: nn.Module obs_dim: int action_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -20,7 +23,7 @@ def __call__( training: bool = False, ) -> jnp.ndarray: x = self.backbone(obs, training) - x = MLP(output_dim=self.action_dim)(x) + x = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x @@ -28,6 +31,8 @@ class SquashedDeterministicActor(DeterministicActor): backbone: nn.Module obs_dim: int action_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -45,6 +50,8 @@ class GaussianActor(nn.Module): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -54,10 +61,10 @@ def __call__( ) -> jnp.ndarray: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) logstd = jnp.clip(logstd, self.logstd_min, self.logstd_max) distribution = distrax.MultivariateNormalDiag(mean, jnp.exp(logstd)) @@ -71,6 +78,8 @@ class SquashedGaussianActor(GaussianActor): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -80,10 +89,10 @@ def __call__( ) -> distrax.Distribution: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) logstd = jnp.clip(logstd, self.logstd_min, self.logstd_max) distribution = TanhMultivariateNormalDiag(mean, jnp.exp(logstd)) @@ -97,6 +106,8 @@ class TanhMeanGaussianActor(GaussianActor): conditional_logstd: bool = False logstd_min: float = -20.0 logstd_max: float = 2.0 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -106,10 +117,10 @@ def __call__( ) -> jnp.ndarray: x = self.backbone(obs, training) if self.conditional_logstd: - mean_logstd = MLP(output_dim=2*self.action_dim)(x) + mean_logstd = MLP(output_dim=2*self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) mean, logstd = jnp.split(mean_logstd, 2, axis=-1) else: - mean = MLP(output_dim=self.action_dim)(x) + mean = MLP(output_dim=self.action_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) logstd = self.param("logstd", nn.initializers.zeros, (self.action_dim,)) # broadcast logstd to the shape of mean logstd = jnp.broadcast_to(logstd, mean.shape) diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index e99e2ca..f0f90d4 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -1,6 +1,7 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.functional.activation import mish from flowrl.types import * @@ -12,6 +13,8 @@ class Critic(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -30,6 +33,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training) return x @@ -40,6 +45,8 @@ class EnsembleCritic(nn.Module): layer_norm: bool = False dropout: Optional[float] = None ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -61,6 +68,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(obs, action, training) return x @@ -70,6 +79,8 @@ class CriticT(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -83,6 +94,8 @@ def __call__( t_ff = MLP( hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], activation=mish, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(t_ff) x = jnp.concatenate([item for item in [obs, action, t_ff] if item is not None], axis=-1) x = MLP( @@ -91,6 +104,8 @@ def __call__( activation=self.activation, layer_norm=self.layer_norm, dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training) return x @@ -102,6 +117,8 @@ class EnsembleCriticT(nn.Module): layer_norm: bool = False dropout: Optional[float] = None ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -124,6 +141,8 @@ def __call__( hidden_dims=self.hidden_dims, activation=self.activation, layer_norm=self.layer_norm, - dropout=self.dropout + dropout=self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(obs, action, t, training) return x diff --git a/flowrl/module/initialization.py b/flowrl/module/initialization.py index 926e557..0b51721 100644 --- a/flowrl/module/initialization.py +++ b/flowrl/module/initialization.py @@ -1,5 +1,6 @@ import flax.linen as nn -from flax.linen.initializers import lecun_normal +import jax +from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init from jax import numpy as jnp from flowrl.types import * @@ -10,11 +11,11 @@ def orthogonal_init(scale: Optional[float] = None): scale = jnp.sqrt(2) return nn.initializers.orthogonal(scale) +def pytorch_kernel_init(): + return variance_scaling(scale=1/3, mode="fan_in", distribution="uniform") -def uniform_init(scale_final=None): - if scale_final is not None: - return nn.initializers.xavier_uniform(scale_final) - return nn.initializers.xavier_uniform() +def pytorch_bias_init(): + return lambda key, shape, dtype=jnp.float32: (jax.random.uniform(key, shape, dtype)*2-1) / jnp.sqrt(shape[0]) - -default_init = orthogonal_init +default_kernel_init = orthogonal_init +default_bias_init = zeros_init diff --git a/flowrl/module/mlp.py b/flowrl/module/mlp.py index 22c3454..f447c35 100644 --- a/flowrl/module/mlp.py +++ b/flowrl/module/mlp.py @@ -3,10 +3,9 @@ import flax.linen as nn import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.types import * -from .initialization import default_init - class MLP(nn.Module): hidden_dims: Sequence[int] = field(default_factory=lambda: []) @@ -14,6 +13,8 @@ class MLP(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -22,14 +23,14 @@ def __call__( training: bool = False, ) -> jnp.ndarray: for i, size in enumerate(self.hidden_dims): - x = nn.Dense(size, kernel_init=default_init())(x) + x = nn.Dense(size, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) if self.layer_norm: x = nn.LayerNorm()(x) x = self.activation(x) if self.dropout and self.dropout > 0: x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) if self.output_dim > 0: - x = nn.Dense(self.output_dim, kernel_init=default_init())(x) + x = nn.Dense(self.output_dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) return x @@ -39,6 +40,8 @@ class ResidualLinear(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -51,12 +54,12 @@ def __call__( x = nn.Dropout(rate=self.dropout)(x, deterministic=not training) if self.layer_norm: x = nn.LayerNorm()(x) - x = nn.Dense(self.dim * self.multiplier, kernel_init=default_init())(x) + x = nn.Dense(self.dim * self.multiplier, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) x = self.activation(x) - x = nn.Dense(self.dim, kernel_init=default_init())(x) + x = nn.Dense(self.dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) if residual.shape != x.shape: - residual = nn.Dense(self.dim, kernel_init=default_init())(residual) + residual = nn.Dense(self.dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(residual) return residual + x @@ -68,6 +71,8 @@ class ResidualMLP(nn.Module): activation: Callable = nn.relu layer_norm: bool = False dropout: Optional[float] = None + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__( @@ -76,10 +81,18 @@ def __call__( training: bool = False, ) -> jnp.ndarray: if len(self.hidden_dims) > 0: - x = nn.Dense(self.hidden_dims[0], kernel_init=default_init())(x) + x = nn.Dense(self.hidden_dims[0], kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) for i, size in enumerate(self.hidden_dims): - x = ResidualLinear(size, self.multiplier, self.activation, self.layer_norm, self.dropout)(x, training) + x = ResidualLinear( + size, + self.multiplier, + self.activation, + self.layer_norm, + self.dropout, + kernel_init=self.kernel_init, + bias_init=self.bias_init, + )(x, training) if self.output_dim > 0: x = self.activation(x) - x = nn.Dense(self.output_dim, kernel_init=default_init())(x) + x = nn.Dense(self.output_dim, kernel_init=self.kernel_init(), bias_init=self.bias_init())(x) return x diff --git a/flowrl/module/rff.py b/flowrl/module/rff.py index 373bf3b..53f3139 100644 --- a/flowrl/module/rff.py +++ b/flowrl/module/rff.py @@ -2,6 +2,7 @@ import jax import jax.numpy as jnp +import flowrl.module.initialization as init from flowrl.module.mlp import MLP from flowrl.types import * @@ -10,6 +11,8 @@ class RffLayer(nn.Module): feature_dim: int rff_dim: int learnable: bool = True + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x: jnp.ndarray) -> jnp.ndarray: @@ -17,7 +20,12 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: half = self.rff_dim // 2 if self.learnable: - x = MLP(hidden_dims=[], output_dim=half)(x) + x = MLP( + hidden_dims=[], + output_dim=half, + kernel_init=self.kernel_init, + bias_init=self.bias_init + )(x) else: noise = self.variable( "noise", @@ -32,17 +40,20 @@ class RffReward(nn.Module): feature_dim: int hidden_dims: list[int] rff_dim: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x: jnp.ndarray, training: bool = False) -> jnp.ndarray: x = nn.LayerNorm()(x) - x = RffLayer(self.feature_dim, self.rff_dim, learnable=True)(x) - + x = RffLayer(self.feature_dim, self.rff_dim, learnable=True, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = MLP( hidden_dims=self.hidden_dims, output_dim=1, layer_norm=True, activation=nn.elu, + kernel_init=self.kernel_init, + bias_init=self.bias_init, )(x, training=training) return x @@ -56,6 +67,8 @@ class RffEnsembleCritic(nn.Module): hidden_dims: Sequence[int] rff_dim: int ensemble_size: int = 2 + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init @nn.compact def __call__(self, x) -> jnp.ndarray: @@ -67,5 +80,5 @@ def __call__(self, x) -> jnp.ndarray: out_axes=0, axis_size=self.ensemble_size, ) - x = vmap_rff(self.feature_dim, self.hidden_dims, self.rff_dim)(x) + x = vmap_rff(self.feature_dim, self.hidden_dims, self.rff_dim, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x diff --git a/flowrl/types.py b/flowrl/types.py index cdf9989..67f42b8 100644 --- a/flowrl/types.py +++ b/flowrl/types.py @@ -4,6 +4,7 @@ import flax import jax import jax.numpy as jnp +from flax.linen.initializers import Initializer from flax.training.train_state import TrainState PRNGKey = NewType("PRNGKey", jax.Array) @@ -15,4 +16,4 @@ ['obs', 'action', 'reward', 'terminal', 'next_obs', 'next_action'], ) -__all__ = ["Batch", "PRNGKey", "Param", "Shape", "Metric", "Optional", "Sequence", "Any", "Dict", "Callable", "Union", "Tuple"] +__all__ = ["Batch", "PRNGKey", "Param", "Shape", "Metric", "Optional", "Sequence", "Any", "Dict", "Callable", "Union", "Tuple", "Initializer"] diff --git a/scripts/dmc/aca.sh b/scripts/dmc/aca.sh new file mode 100644 index 0000000..efb0576 --- /dev/null +++ b/scripts/dmc/aca.sh @@ -0,0 +1,70 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + # "dog-run" + "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=aca" + "algo.temp=0.1" + # "algo.critic_activation=relu" + "log.tag=backup_repr_rff-tanh-nonorm-rep_multi-rl_onlyuse0-512x4" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/ctrl_td3.sh b/scripts/dmc/ctrl_qsm.sh similarity index 97% rename from scripts/dmc/ctrl_td3.sh rename to scripts/dmc/ctrl_qsm.sh index fd9e038..34729d6 100644 --- a/scripts/dmc/ctrl_td3.sh +++ b/scripts/dmc/ctrl_qsm.sh @@ -1,6 +1,6 @@ # Specify which GPUs to use GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use -SEEDS=(0 1 2 3 4) +SEEDS=(0 1 2 3) NUM_EACH_GPU=3 PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) @@ -37,7 +37,7 @@ TASKS=( ) SHARED_ARGS=( - "algo=ctrl_td3" + "algo=ctrl_qsm" "log.tag=default" "log.project=flow-rl" "log.entity=lambda-rl" diff --git a/scripts/dmc/ctrlsr_td3.sh b/scripts/dmc/ctrlsr_td3.sh new file mode 100644 index 0000000..522335a --- /dev/null +++ b/scripts/dmc/ctrlsr_td3.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + # "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=ctrlsr_td3" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/diffsr_aca.sh b/scripts/dmc/diffsr_aca.sh new file mode 100644 index 0000000..a7c35cb --- /dev/null +++ b/scripts/dmc/diffsr_aca.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0 1) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + # "dog-run" + "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_aca" + "log.tag=new-critic_only_1-actor_only_1-temp0.1-normal_critic" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/diffsr_aca_sep.sh b/scripts/dmc/diffsr_aca_sep.sh new file mode 100644 index 0000000..4a69522 --- /dev/null +++ b/scripts/dmc/diffsr_aca_sep.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3) # Modify this array to specify which GPUs to use +SEEDS=(0 1) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + # "dog-run" + "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_aca_sep" + "log.tag=sep-normeps-temp0.2" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/diffsr_td3.sh b/scripts/dmc/diffsr_td3.sh new file mode 100644 index 0000000..9557580 --- /dev/null +++ b/scripts/dmc/diffsr_td3.sh @@ -0,0 +1,70 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # # "pendulum-swingup" + # "quadruped-run" + # "quadruped-walk" + # "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=diffsr_td3" + # "algo.embed_dim=256" + # "algo.reward_coef=1.0" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/qsm.sh b/scripts/dmc/qsm.sh new file mode 100644 index 0000000..fea0999 --- /dev/null +++ b/scripts/dmc/qsm.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/sdac.sh b/scripts/dmc/sdac.sh new file mode 100644 index 0000000..93d5dd3 --- /dev/null +++ b/scripts/dmc/sdac.sh @@ -0,0 +1,70 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=2 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + # "acrobot-swingup" + # "ball_in_cup-catch" + # "cartpole-balance" + # "cartpole-balance_sparse" + # "cartpole-swingup" + # "cartpole-swingup_sparse" + # "cheetah-run" + # "dog-run" + # "dog-stand" + # "dog-trot" + # "dog-walk" + # "finger-spin" + # "finger-turn_easy" + # "finger-turn_hard" + # "fish-swim" + # "hopper-hop" + # "hopper-stand" + # "humanoid-run" + # "humanoid-stand" + # "humanoid-walk" + # "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + # "reacher-hard" + # "walker-run" + # "walker-stand" + # "walker-walk" +) + +SHARED_ARGS=( + "algo=sdac" + "algo.temp=0.01" + "log.tag=fix_temp-0.01" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/mujoco/alac.sh b/scripts/mujoco/alac.sh new file mode 100644 index 0000000..2232a3a --- /dev/null +++ b/scripts/mujoco/alac.sh @@ -0,0 +1,60 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + # "Hopper-v5" + # "HumanoidStandup-v5" + "Humanoid-v5" + # "InvertedDoublePendulum-v5" + # "InvertedPendulum-v5" + # "Pusher-v5" + # "Reacher-v5" + # "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=alac" + # "algo.ld.step_size=0.1" + # "algo.ld.noise_scale=0.01" + # "algo.ld.steps=50" + # "log.tag=noise_none-stepsize0.1-noise0.01-steps50-no_last_noise" + "algo.ld.activation=relu" + "algo.ld.steps=20" + "algo.ld.noise_schedule=cosine" + "log.tag=use_ld_but_actually_diffusion-decay_q-temp0.5" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/mujoco/qsm.sh b/scripts/mujoco/qsm.sh new file mode 100644 index 0000000..04023c8 --- /dev/null +++ b/scripts/mujoco/qsm.sh @@ -0,0 +1,53 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + "Hopper-v5" + "HumanoidStandup-v5" + "Humanoid-v5" + "InvertedDoublePendulum-v5" + "InvertedPendulum-v5" + "Pusher-v5" + "Reacher-v5" + "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi