Skip to content

Commit f444f94

Browse files
committed
update
1 parent 7844900 commit f444f94

File tree

14 files changed

+408
-8
lines changed

14 files changed

+408
-8
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# @package _global_
2+
3+
algo:
4+
name: ctrl_qsm
5+
actor_update_freq: 1
6+
target_update_freq: 1
7+
discount: 0.99
8+
ema: 0.005
9+
# critic_hidden_dims: [512, 512, 512] # not used
10+
critic_activation: elu # not used
11+
critic_ensemble_size: 2
12+
layer_norm: true
13+
critic_lr: 0.0003
14+
clip_grad_norm: null
15+
16+
# below are params specific to ctrl_td3
17+
feature_dim: 512
18+
feature_lr: 0.0001
19+
feature_ema: 0.005
20+
phi_hidden_dims: [512, 512]
21+
mu_hidden_dims: [512, 512]
22+
critic_hidden_dims: [512, ]
23+
reward_hidden_dims: [512, ]
24+
rff_dim: 1024
25+
ctrl_coef: 1.0
26+
reward_coef: 1.0
27+
back_critic_grad: false
28+
critic_coef: 1.0
29+
30+
num_noises: 25
31+
linear: false
32+
ranking: true
33+
34+
num_samples: 10
35+
temp: 0.1
36+
diffusion:
37+
time_dim: 64
38+
mlp_hidden_dims: [512, 512, 512]
39+
lr: 0.0003
40+
end_lr: null
41+
lr_decay_steps: null
42+
lr_decay_begin: null
43+
steps: 20
44+
clip_sampler: true
45+
x_min: -1.0
46+
x_max: 1.0
47+
solver: ddpm
48+
49+
norm_obs: true

examples/online/config/dmc/algo/qsm.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
algo:
44
name: qsm
55
critic_hidden_dims: [512, 512, 512]
6+
critic_activation: elu
67
critic_lr: 0.0003
78
discount: 0.99
89
num_samples: 10
910
ema: 0.005
10-
temp: 0.2
11+
temp: 0.1
1112
diffusion:
1213
time_dim: 64
1314
mlp_hidden_dims: [512, 512, 512]

examples/online/config/mujoco/algo/qsm.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
algo:
44
name: qsm
55
critic_hidden_dims: [512, 512]
6+
critic_activation: relu
67
critic_lr: 0.0003
78
discount: 0.99
89
num_samples: 10

examples/online/main_dmc_offpolicy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
"dpmd": DPMDAgent,
2929
"qsm": QSMAgent,
3030
"ctrl_td3": CtrlTD3Agent,
31+
"ctrl_qsm": CtrlQSMAgent,
3132
}
3233

3334
class OffPolicyTrainer():

flowrl/agent/online/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ..base import BaseAgent
22
from .alac.alac import ALACAgent
3-
from .ctrl.ctrl import CtrlTD3Agent
3+
from .ctrl import *
44
from .dpmd import DPMDAgent
55
from .idem import IDEMAgent
66
from .ppo import PPOAgent
@@ -22,4 +22,5 @@
2222
"IDEMAgent",
2323
"ALACAgent",
2424
"CtrlTD3Agent",
25+
"CtrlQSMAgent",
2526
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .ctrl_qsm import CtrlQSMAgent
2+
from .ctrl_td3 import CtrlTD3Agent
3+
4+
__all__ = [
5+
"CtrlTD3Agent",
6+
"CtrlQSMAgent",
7+
]
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from functools import partial
2+
from typing import Tuple
3+
4+
import jax
5+
import jax.numpy as jnp
6+
import optax
7+
8+
from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce
9+
from flowrl.agent.online.qsm import QSMAgent
10+
from flowrl.config.online.mujoco.algo.ctrl.ctrl_qsm import CtrlQSMConfig
11+
from flowrl.flow.continuous_ddpm import ContinuousDDPM
12+
from flowrl.functional.ema import ema_update
13+
from flowrl.module.model import Model
14+
from flowrl.module.rff import RffEnsembleCritic
15+
from flowrl.types import Batch, Metric, Param, PRNGKey
16+
17+
18+
@partial(jax.jit, static_argnames=("training", "num_samples", "solver"))
19+
def jit_sample_actions(
20+
rng: PRNGKey,
21+
actor: ContinuousDDPM,
22+
nce_target: Model,
23+
critic: Model,
24+
obs: jnp.ndarray,
25+
training: bool,
26+
num_samples: int,
27+
solver: str,
28+
) -> Tuple[PRNGKey, jnp.ndarray]:
29+
assert len(obs.shape) == 2
30+
B = obs.shape[0]
31+
rng, xT_rng = jax.random.split(rng)
32+
33+
# sample
34+
obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2)
35+
xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim))
36+
rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver)
37+
if num_samples == 1:
38+
actions = actions[:, 0]
39+
else:
40+
feature = nce_target(obs_repeat, actions, method="forward_phi")
41+
qs = critic(feature)
42+
qs = qs.min(axis=0).reshape(B, num_samples)
43+
best_idx = qs.argmax(axis=-1)
44+
actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx]
45+
return rng, actions
46+
47+
@partial(jax.jit, static_argnames=("discount", "solver"))
48+
def update_critic(
49+
rng: PRNGKey,
50+
critic: Model,
51+
critic_target: Model,
52+
actor: ContinuousDDPM,
53+
nce_target: Model,
54+
batch: Batch,
55+
discount: float,
56+
solver: str,
57+
critic_coef: float
58+
) -> Tuple[PRNGKey, Model, Metric]:
59+
rng, sample_rng = jax.random.split(rng)
60+
next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim))
61+
rng, next_action, _ = actor.sample(
62+
rng,
63+
next_xT,
64+
batch.next_obs,
65+
training=False,
66+
solver=solver,
67+
)
68+
next_feature = nce_target(batch.next_obs, next_action, method="forward_phi")
69+
q_target = critic_target(next_feature).min(0)
70+
q_target = batch.reward + discount * (1 - batch.terminal) * q_target
71+
72+
feature = nce_target(batch.obs, batch.action, method="forward_phi")
73+
74+
def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]:
75+
q_pred = critic.apply(
76+
{"params": critic_params},
77+
feature,
78+
rngs={"dropout": dropout_rng},
79+
)
80+
critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean()
81+
return critic_loss, {
82+
"loss/critic_loss": critic_loss,
83+
"misc/q_mean": q_pred.mean(),
84+
"misc/reward": batch.reward.mean(),
85+
}
86+
87+
new_critic, metrics = critic.apply_gradient(critic_loss_fn)
88+
return rng, new_critic, metrics
89+
90+
@partial(jax.jit, static_argnames=("temp"))
91+
def update_actor(
92+
rng: PRNGKey,
93+
actor: ContinuousDDPM,
94+
nce_target: Model,
95+
critic_target: Model,
96+
batch: Batch,
97+
temp: float,
98+
) -> Tuple[PRNGKey, Model, Metric]:
99+
100+
a0 = batch.action
101+
rng, at, t, eps = actor.add_noise(rng, a0)
102+
alpha1, alpha2 = actor.noise_schedule_func(t)
103+
104+
def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray:
105+
feature = nce_target(obs, action, method="forward_phi")
106+
q = critic_target(feature)
107+
return q.min(axis=0).mean()
108+
q_grad_fn = jax.vmap(jax.grad(get_q_value))
109+
q_grad = q_grad_fn(at, batch.obs)
110+
q_grad = alpha1 * q_grad - alpha2 * at
111+
eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6)
112+
113+
def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]:
114+
eps_pred = actor.apply(
115+
{"params": actor_params},
116+
at,
117+
t,
118+
condition=batch.obs,
119+
training=True,
120+
rngs={"dropout": dropout_rng},
121+
)
122+
loss = ((eps_pred - eps_estimation) ** 2).mean()
123+
return loss, {
124+
"loss/actor_loss": loss,
125+
"misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(),
126+
}
127+
128+
new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn)
129+
return rng, new_actor, actor_metrics
130+
131+
132+
class CtrlQSMAgent(QSMAgent):
133+
"""
134+
CTRL with Q Score Matching (QSM) agent.
135+
"""
136+
137+
name = "CtrlQSMAgent"
138+
model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"]
139+
140+
def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlQSMConfig, seed: int):
141+
super().__init__(obs_dim, act_dim, cfg, seed)
142+
self.cfg = cfg
143+
144+
self.ctrl_coef = cfg.ctrl_coef
145+
self.critic_coef = cfg.critic_coef
146+
147+
self.linear = cfg.linear
148+
self.ranking = cfg.ranking
149+
self.feature_dim = cfg.feature_dim
150+
self.num_noises = cfg.num_noises
151+
self.reward_coef = cfg.reward_coef
152+
self.rff_dim = cfg.rff_dim
153+
self.actor_update_freq = cfg.actor_update_freq
154+
self.target_update_freq = cfg.target_update_freq
155+
156+
157+
# sanity checks for the hyper-parameters
158+
assert not self.linear, "linear mode is not supported yet"
159+
160+
# networks
161+
self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5)
162+
nce_def = FactorizedNCE(
163+
self.obs_dim,
164+
self.act_dim,
165+
self.feature_dim,
166+
cfg.phi_hidden_dims,
167+
cfg.mu_hidden_dims,
168+
cfg.reward_hidden_dims,
169+
cfg.rff_dim,
170+
cfg.num_noises,
171+
self.ranking,
172+
)
173+
self.nce = Model.create(
174+
nce_def,
175+
nce_rng,
176+
inputs=(
177+
nce_init_rng,
178+
jnp.ones((1, self.obs_dim)),
179+
jnp.ones((1, self.act_dim)),
180+
jnp.ones((1, self.obs_dim)),
181+
),
182+
optimizer=optax.adam(learning_rate=cfg.feature_lr),
183+
clip_grad_norm=cfg.clip_grad_norm,
184+
)
185+
self.nce_target = Model.create(
186+
nce_def,
187+
nce_rng,
188+
inputs=(
189+
nce_init_rng,
190+
jnp.ones((1, self.obs_dim)),
191+
jnp.ones((1, self.act_dim)),
192+
jnp.ones((1, self.obs_dim)),
193+
),
194+
)
195+
196+
critic_def = RffEnsembleCritic(
197+
feature_dim=self.feature_dim,
198+
hidden_dims=cfg.critic_hidden_dims,
199+
rff_dim=cfg.rff_dim,
200+
ensemble_size=2,
201+
)
202+
self.critic = Model.create(
203+
critic_def,
204+
critic_rng,
205+
inputs=(jnp.ones((1, self.feature_dim)),),
206+
optimizer=optax.adam(learning_rate=cfg.critic_lr),
207+
clip_grad_norm=cfg.clip_grad_norm,
208+
)
209+
self.critic_target = Model.create(
210+
critic_def,
211+
critic_rng,
212+
inputs=(jnp.ones((1, self.feature_dim)),),
213+
)
214+
215+
self._n_training_steps = 0
216+
217+
def train_step(self, batch: Batch, step: int) -> Metric:
218+
metrics = {}
219+
220+
self.rng, self.nce, nce_metrics = update_factorized_nce(
221+
self.rng,
222+
self.nce,
223+
batch,
224+
self.ranking,
225+
self.reward_coef,
226+
)
227+
metrics.update(nce_metrics)
228+
229+
self.rng, self.critic, critic_metrics = update_critic(
230+
self.rng,
231+
self.critic,
232+
self.critic_target,
233+
self.actor,
234+
self.nce_target,
235+
batch,
236+
discount=self.cfg.discount,
237+
solver=self.cfg.diffusion.solver,
238+
critic_coef=self.critic_coef,
239+
)
240+
metrics.update(critic_metrics)
241+
242+
if self._n_training_steps % self.actor_update_freq == 0:
243+
self.rng, self.actor, actor_metrics = update_actor(
244+
self.rng,
245+
self.actor,
246+
self.nce_target,
247+
self.critic_target,
248+
batch,
249+
temp=self.cfg.temp,
250+
)
251+
metrics.update(actor_metrics)
252+
253+
if self._n_training_steps % self.target_update_freq == 0:
254+
self.sync_target()
255+
256+
self._n_training_steps += 1
257+
return metrics
258+
259+
def sample_actions(
260+
self,
261+
obs: jnp.ndarray,
262+
deterministic: bool = True,
263+
num_samples: int = 1,
264+
) -> Tuple[jnp.ndarray, Metric]:
265+
# if deterministic is true, sample cfg.num_samples actions and select the best one
266+
# if not, sample 1 action
267+
if deterministic:
268+
num_samples = self.cfg.num_samples
269+
else:
270+
num_samples = 1
271+
self.rng, action = jit_sample_actions(
272+
self.rng,
273+
self.actor,
274+
self.nce_target,
275+
self.critic,
276+
obs,
277+
training=False,
278+
num_samples=num_samples,
279+
solver=self.cfg.diffusion.solver,
280+
)
281+
if not deterministic:
282+
action = action + 0.1 * jax.random.normal(self.rng, action.shape)
283+
return action, {}
284+
285+
def sync_target(self):
286+
self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema)
287+
self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce
99
from flowrl.agent.online.td3 import TD3Agent
10-
from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config
10+
from flowrl.config.online.mujoco.algo.ctrl.ctrl_td3 import CtrlTD3Config
1111
from flowrl.functional.ema import ema_update
1212
from flowrl.module.actor import SquashedDeterministicActor
1313
from flowrl.module.mlp import MLP

0 commit comments

Comments
 (0)