Skip to content

Commit 65e9265

Browse files
authored
Merge branch 'master' into qsm_idem
2 parents 668b38f + 0af1dba commit 65e9265

File tree

15 files changed

+749
-12
lines changed

15 files changed

+749
-12
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# @package _global_
2+
3+
algo:
4+
name: ctrl_td3
5+
cls: ctrl_td3
6+
discount: 0.99
7+
8+
tau: 0.005 # hard update
9+
actor_update_freq: 1
10+
target_update_freq: 1
11+
12+
ema: 0.005
13+
14+
# critic_hidden_dims: [512, 512, 512] # this is actually not used
15+
actor_hidden_dims: [512, 512, 512]
16+
17+
critic_lr: 0.0003
18+
actor_lr: 0.0003
19+
20+
target_policy_noise: 0.2
21+
noise_clip: 0.3
22+
exploration_noise: 0.2
23+
24+
# below are params specific to ctrl_td3
25+
feature_dim: 512
26+
feature_lr: 0.0001
27+
feature_tau: 0.005
28+
phi_hidden_dims: [512, 512]
29+
mu_hidden_dims: [512, 512]
30+
ctrl_coef: 1.0
31+
reward_coef: 1.0
32+
feature_update_ratio: 1
33+
critic_hidden_dims: [512, ]
34+
reward_hidden_dims: [512, ]
35+
rff_dim: 1024
36+
back_critic_grad: false
37+
critic_coef: 1.0
38+
aug_batch_size: 512
39+
40+
num_noises: 25
41+
linear: false
42+
beta: 1.0 # not used
43+
ranking: true
44+
activation: elu
45+
layer_norm: true
46+
critic_ensemble_size: 2
47+
clip_grad_norm: null
48+
49+
norm_obs: true
50+
batch_size: 1024

examples/online/main_dmc_offpolicy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import jax.numpy as jnp
66
import numpy as np
77
import omegaconf
8+
import wandb
89
from omegaconf import OmegaConf
910
from tqdm import tqdm, trange
1011

11-
import wandb
1212
from flowrl.agent.online import *
1313
from flowrl.config.online.mujoco import Config
1414
from flowrl.dataset.buffer.state import ReplayBuffer
@@ -24,6 +24,7 @@
2424
"sdac": SDACAgent,
2525
"dpmd": DPMDAgent,
2626
"qsm": QSMAgent,
27+
"ctrl_td3": Ctrl_TD3_Agent,
2728
}
2829

2930
class OffPolicyTrainer():

examples/online/main_mujoco_offpolicy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import hydra
66
import numpy as np
77
import omegaconf
8+
import wandb
89
from omegaconf import OmegaConf
910
from tqdm import tqdm
1011

11-
import wandb
1212
from flowrl.agent.online import *
1313
from flowrl.agent.online.idem import IDEMAgent
1414
from flowrl.config.online.mujoco import Config

examples/online/main_mujoco_onpolicy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import hydra
77
import jax.numpy as jnp
88
import numpy as np
9+
import wandb
910
from omegaconf import OmegaConf
1011
from tqdm import tqdm
1112

12-
import wandb
1313
from flowrl.agent.online import *
1414
from flowrl.config.online.mujoco import Config
1515
from flowrl.types import *

flowrl/agent/online/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ..base import BaseAgent
2+
from .ctrl.ctrl import Ctrl_TD3_Agent
23
from .dpmd import DPMDAgent
34
from .idem import IDEMAgent
45
from .ppo import PPOAgent
@@ -18,4 +19,5 @@
1819
"PPOAgent",
1920
"QSMAgent",
2021
"IDEMAgent"
22+
"Ctrl_TD3_Agent",
2123
]

flowrl/agent/online/ctrl/ctrl.py

Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
from functools import partial
2+
from operator import attrgetter
3+
from typing import Tuple
4+
5+
import jax
6+
import jax.numpy as jnp
7+
import optax
8+
9+
from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce
10+
from flowrl.agent.online.td3 import TD3Agent
11+
from flowrl.config.online.mujoco.algo.ctrl_td3 import CTRL_TD3_Config
12+
from flowrl.functional.ema import ema_update
13+
from flowrl.module.actor import SquashedDeterministicActor
14+
from flowrl.module.mlp import MLP
15+
from flowrl.module.model import Model
16+
from flowrl.module.rff import RffEnsembleCritic
17+
from flowrl.types import Batch, Metric, Param, PRNGKey
18+
19+
20+
@partial(jax.jit, static_argnames=("discount", "target_policy_noise", "noise_clip"))
21+
def update_critic(
22+
rng: PRNGKey,
23+
critic: Model,
24+
critic_target: Model,
25+
actor_target: Model,
26+
nce: Model,
27+
nce_target: Model,
28+
batch: Batch,
29+
discount: float,
30+
target_policy_noise: float,
31+
noise_clip: float,
32+
critic_coef: float
33+
) -> Tuple[PRNGKey, Model, Metric]:
34+
rng, sample_rng = jax.random.split(rng)
35+
noise = jax.random.normal(sample_rng, batch.action.shape) * target_policy_noise
36+
noise = jnp.clip(noise, -noise_clip, noise_clip)
37+
next_action = jnp.clip(actor_target(batch.next_obs) + noise, -1.0, 1.0)
38+
39+
next_feature = nce_target(batch.next_obs, next_action, method="forward_phi")
40+
q_target = critic_target(next_feature).min(0)
41+
q_target = batch.reward + discount * (1 - batch.terminal) * q_target
42+
43+
back_critic_grad = False
44+
if back_critic_grad:
45+
# this part will use feature
46+
raise NotImplementedError("no back critic grad exists")
47+
48+
feature = nce_target(batch.obs, batch.action, method="forward_phi")
49+
50+
def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]:
51+
q_pred = critic.apply(
52+
{"params": critic_params},
53+
feature,
54+
rngs={"dropout": dropout_rng},
55+
)
56+
# q_pred (2, 512, 1), q_target (512, 1)
57+
critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean()
58+
return critic_loss, {
59+
"loss/critic_loss": critic_loss,
60+
"misc/q_mean": q_pred.mean(),
61+
"misc/reward": batch.reward.mean(),
62+
}
63+
64+
new_critic, metrics = critic.apply_gradient(critic_loss_fn)
65+
return rng, new_critic, metrics
66+
67+
68+
@jax.jit
69+
def update_actor(
70+
rng: PRNGKey,
71+
actor: Model,
72+
nce: Model,
73+
nce_target: Model,
74+
critic: Model,
75+
batch: Batch,
76+
) -> Tuple[PRNGKey, Model, Metric]:
77+
def actor_loss_fn(
78+
actor_params: Param, dropout_rng: PRNGKey
79+
) -> Tuple[jnp.ndarray, Metric]:
80+
new_action = actor.apply(
81+
{"params": actor_params},
82+
batch.obs,
83+
training=True,
84+
rngs={"dropout": dropout_rng},
85+
)
86+
new_feature = nce_target(batch.obs, new_action, method="forward_phi")
87+
q = critic(new_feature)
88+
actor_loss = - q.mean()
89+
90+
return actor_loss, {
91+
"loss/actor_loss": actor_loss,
92+
}
93+
94+
new_actor, metrics = actor.apply_gradient(actor_loss_fn)
95+
return rng, new_actor, metrics
96+
97+
98+
class Ctrl_TD3_Agent(TD3Agent):
99+
"""
100+
CTRL Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
101+
"""
102+
103+
name = "CTRLTD3Agent"
104+
model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"]
105+
106+
def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
107+
super().__init__(obs_dim, act_dim, cfg, seed)
108+
self.cfg = cfg
109+
110+
self.ctrl_coef = cfg.ctrl_coef
111+
self.critic_coef = cfg.critic_coef
112+
113+
self.aug_batch_size = cfg.aug_batch_size
114+
self.feature_tau = cfg.feature_tau
115+
self.linear = cfg.linear
116+
self.ranking = cfg.ranking
117+
self.feature_dim = cfg.feature_dim
118+
self.num_noises = cfg.num_noises
119+
self.reward_coef = cfg.reward_coef
120+
self.rff_dim = cfg.rff_dim
121+
122+
# sanity checks for the hyper-parameters
123+
assert not self.linear, "Removing linear version for now"
124+
125+
# networks
126+
self.rng, nce_rng, actor_rng, critic_rng = jax.random.split(self.rng, 4)
127+
nce_def = FactorizedNCE(
128+
self.obs_dim,
129+
self.act_dim,
130+
self.feature_dim,
131+
cfg.phi_hidden_dims,
132+
cfg.mu_hidden_dims,
133+
cfg.reward_hidden_dims,
134+
cfg.rff_dim,
135+
cfg.num_noises,
136+
self.ranking,
137+
)
138+
self.nce = Model.create(
139+
nce_def,
140+
nce_rng,
141+
inputs=(
142+
jnp.ones((1, self.obs_dim)),
143+
jnp.ones((1, self.act_dim)),
144+
jnp.ones((1, self.obs_dim)),
145+
),
146+
optimizer=optax.adam(learning_rate=cfg.feature_lr),
147+
clip_grad_norm=cfg.clip_grad_norm,
148+
)
149+
self.nce_target = Model.create(
150+
nce_def,
151+
nce_rng,
152+
inputs=(
153+
jnp.ones((1, self.obs_dim)),
154+
jnp.ones((1, self.act_dim)),
155+
jnp.ones((1, self.obs_dim)),
156+
),
157+
)
158+
159+
actor_def = SquashedDeterministicActor(
160+
backbone=MLP(
161+
hidden_dims=cfg.actor_hidden_dims,
162+
layer_norm=cfg.layer_norm,
163+
dropout=None,
164+
),
165+
obs_dim=self.obs_dim,
166+
action_dim=self.act_dim,
167+
)
168+
critic_def = RffEnsembleCritic(
169+
feature_dim=self.feature_dim,
170+
hidden_dims=cfg.critic_hidden_dims,
171+
rff_dim=cfg.rff_dim,
172+
ensemble_size=2,
173+
)
174+
self.actor = Model.create(
175+
actor_def,
176+
actor_rng,
177+
inputs=(jnp.ones((1, self.obs_dim)),),
178+
optimizer=optax.adam(learning_rate=cfg.actor_lr),
179+
clip_grad_norm=cfg.clip_grad_norm,
180+
)
181+
self.critic = Model.create(
182+
critic_def,
183+
critic_rng,
184+
inputs=(jnp.ones((1, self.feature_dim)),),
185+
optimizer=optax.adam(learning_rate=cfg.critic_lr),
186+
clip_grad_norm=cfg.clip_grad_norm,
187+
)
188+
self.actor_target = Model.create(
189+
actor_def,
190+
actor_rng,
191+
inputs=(jnp.ones((1, self.obs_dim)),),
192+
)
193+
self.critic_target = Model.create(
194+
critic_def,
195+
critic_rng,
196+
inputs=(jnp.ones((1, self.feature_dim)),),
197+
)
198+
199+
self._n_training_steps = 0
200+
201+
def train_step(self, batch: Batch, step: int) -> Metric:
202+
metrics = {}
203+
204+
split_index = batch.obs.shape[0] - self.aug_batch_size
205+
obs, action, next_obs, reward, terminal = [
206+
b[:split_index]
207+
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
208+
batch
209+
)
210+
]
211+
fobs, faction, fnext_obs, freward, fterminal = [
212+
b[split_index:]
213+
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
214+
batch
215+
)
216+
]
217+
rl_batch = Batch(obs, action, reward, terminal, next_obs, None)
218+
219+
self.rng, self.nce, nce_metrics = update_factorized_nce(
220+
self.rng,
221+
self.nce,
222+
fobs,
223+
faction,
224+
fnext_obs,
225+
freward,
226+
self.ranking,
227+
self.reward_coef,
228+
)
229+
metrics.update(nce_metrics)
230+
231+
self.rng, self.critic, critic_metrics = update_critic(
232+
self.rng,
233+
self.critic,
234+
self.critic_target,
235+
self.actor_target,
236+
self.nce,
237+
self.nce_target,
238+
rl_batch,
239+
discount=self.cfg.discount,
240+
target_policy_noise=self.target_policy_noise,
241+
noise_clip=self.noise_clip,
242+
critic_coef=self.critic_coef,
243+
)
244+
metrics.update(critic_metrics)
245+
246+
if self._n_training_steps % self.actor_update_freq == 0:
247+
self.rng, self.actor, actor_metrics = update_actor(
248+
self.rng,
249+
self.actor,
250+
self.nce,
251+
self.nce_target,
252+
self.critic,
253+
rl_batch,
254+
)
255+
metrics.update(actor_metrics)
256+
257+
if self._n_training_steps % self.target_update_freq == 0:
258+
self.sync_target()
259+
260+
self._n_training_steps += 1
261+
return metrics
262+
263+
def sync_target(self):
264+
super().sync_target()
265+
self.nce_target = ema_update(self.nce, self.nce_target, self.feature_tau)

0 commit comments

Comments
 (0)