Skip to content

Commit f11755b

Browse files
[update]: ctrl new architecture and clean the code (#14)
* add layer norm * update: add metrics logging and regularizer * fix minor bugs * fix typo * before using catgorical Q * update: clean some code and finalize ctrl implementation --------- Co-authored-by: Edward Chen <echen333us@gmail.com>
1 parent 9dc014b commit f11755b

File tree

12 files changed

+154
-105
lines changed

12 files changed

+154
-105
lines changed

examples/online/config/dmc/algo/ctrl_td3.yaml

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,38 @@
22

33
algo:
44
name: ctrl_td3
5-
cls: ctrl_td3
6-
discount: 0.99
7-
8-
tau: 0.005 # hard update
95
actor_update_freq: 1
106
target_update_freq: 1
11-
7+
discount: 0.99
128
ema: 0.005
13-
14-
# critic_hidden_dims: [512, 512, 512] # this is actually not used
159
actor_hidden_dims: [512, 512, 512]
16-
17-
critic_lr: 0.0003
10+
# critic_hidden_dims: [512, 512, 512] # not used
11+
activation: elu # not used
12+
critic_ensemble_size: 2
13+
layer_norm: true
1814
actor_lr: 0.0003
19-
15+
critic_lr: 0.0003
16+
clip_grad_norm: null
2017
target_policy_noise: 0.2
2118
noise_clip: 0.3
2219
exploration_noise: 0.2
2320

2421
# below are params specific to ctrl_td3
2522
feature_dim: 512
2623
feature_lr: 0.0001
27-
feature_tau: 0.005
24+
feature_ema: 0.005
2825
phi_hidden_dims: [512, 512]
2926
mu_hidden_dims: [512, 512]
30-
ctrl_coef: 1.0
31-
reward_coef: 1.0
32-
feature_update_ratio: 1
3327
critic_hidden_dims: [512, ]
3428
reward_hidden_dims: [512, ]
3529
rff_dim: 1024
30+
ctrl_coef: 1.0
31+
reward_coef: 1.0
3632
back_critic_grad: false
3733
critic_coef: 1.0
38-
aug_batch_size: 512
3934

4035
num_noises: 25
4136
linear: false
42-
beta: 1.0 # not used
4337
ranking: true
44-
activation: elu
45-
layer_norm: true
46-
critic_ensemble_size: 2
47-
clip_grad_norm: null
4838

4939
norm_obs: true
50-
batch_size: 1024

examples/online/config/dmc/algo/td3.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
algo:
44
name: td3
5+
actor_update_freq: 1
6+
target_update_freq: 1
57
discount: 0.99
68
ema: 0.005
79
actor_hidden_dims: [512, 512, 512]
@@ -12,8 +14,6 @@ algo:
1214
actor_lr: 0.0003
1315
critic_lr: 0.0003
1416
clip_grad_norm: null
15-
actor_update_freq: 1
16-
target_update_freq: 1
1717
target_policy_noise: 0.2
1818
noise_clip: 0.3
1919
exploration_noise: 0.2

examples/online/main_dmc_offpolicy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
import jax.numpy as jnp
77
import numpy as np
88
import omegaconf
9-
import wandb
109
from omegaconf import OmegaConf
1110
from tqdm import tqdm, trange
1211

12+
import wandb
1313
from flowrl.agent.online import *
1414
from flowrl.config.online.mujoco import Config
1515
from flowrl.dataset.buffer.state import ReplayBuffer
@@ -26,7 +26,7 @@
2626
"td7": TD7Agent,
2727
"sdac": SDACAgent,
2828
"dpmd": DPMDAgent,
29-
"ctrl_td3": Ctrl_TD3_Agent,
29+
"ctrl_td3": CtrlTD3Agent,
3030
}
3131

3232
class OffPolicyTrainer():

flowrl/agent/online/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ..base import BaseAgent
2-
from .ctrl.ctrl import Ctrl_TD3_Agent
2+
from .ctrl.ctrl import CtrlTD3Agent
33
from .dpmd import DPMDAgent
44
from .ppo import PPOAgent
55
from .sac import SACAgent
@@ -15,5 +15,5 @@
1515
"SDACAgent",
1616
"DPMDAgent",
1717
"PPOAgent",
18-
"Ctrl_TD3_Agent",
18+
"CtrlTD3Agent",
1919
]

flowrl/agent/online/ctrl/ctrl.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from operator import attrgetter
32
from typing import Tuple
43

54
import jax
@@ -8,7 +7,7 @@
87

98
from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce
109
from flowrl.agent.online.td3 import TD3Agent
11-
from flowrl.config.online.mujoco.algo.ctrl_td3 import CTRL_TD3_Config
10+
from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config
1211
from flowrl.functional.ema import ema_update
1312
from flowrl.module.actor import SquashedDeterministicActor
1413
from flowrl.module.mlp import MLP
@@ -23,7 +22,6 @@ def update_critic(
2322
critic: Model,
2423
critic_target: Model,
2524
actor_target: Model,
26-
nce: Model,
2725
nce_target: Model,
2826
batch: Batch,
2927
discount: float,
@@ -42,7 +40,6 @@ def update_critic(
4240

4341
back_critic_grad = False
4442
if back_critic_grad:
45-
# this part will use feature
4643
raise NotImplementedError("no back critic grad exists")
4744

4845
feature = nce_target(batch.obs, batch.action, method="forward_phi")
@@ -53,7 +50,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
5350
feature,
5451
rngs={"dropout": dropout_rng},
5552
)
56-
# q_pred (2, 512, 1), q_target (512, 1)
5753
critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean()
5854
return critic_loss, {
5955
"loss/critic_loss": critic_loss,
@@ -69,7 +65,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
6965
def update_actor(
7066
rng: PRNGKey,
7167
actor: Model,
72-
nce: Model,
7368
nce_target: Model,
7469
critic: Model,
7570
batch: Batch,
@@ -95,23 +90,21 @@ def actor_loss_fn(
9590
return rng, new_actor, metrics
9691

9792

98-
class Ctrl_TD3_Agent(TD3Agent):
93+
class CtrlTD3Agent(TD3Agent):
9994
"""
100-
CTRL Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
95+
CTRL with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
10196
"""
10297

103-
name = "CTRLTD3Agent"
98+
name = "CtrlTD3Agent"
10499
model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"]
105100

106-
def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
101+
def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int):
107102
super().__init__(obs_dim, act_dim, cfg, seed)
108103
self.cfg = cfg
109104

110105
self.ctrl_coef = cfg.ctrl_coef
111106
self.critic_coef = cfg.critic_coef
112107

113-
self.aug_batch_size = cfg.aug_batch_size
114-
self.feature_tau = cfg.feature_tau
115108
self.linear = cfg.linear
116109
self.ranking = cfg.ranking
117110
self.feature_dim = cfg.feature_dim
@@ -120,10 +113,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
120113
self.rff_dim = cfg.rff_dim
121114

122115
# sanity checks for the hyper-parameters
123-
assert not self.linear, "Removing linear version for now"
116+
assert not self.linear, "linear mode is not supported yet"
124117

125118
# networks
126-
self.rng, nce_rng, actor_rng, critic_rng = jax.random.split(self.rng, 4)
119+
self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5)
127120
nce_def = FactorizedNCE(
128121
self.obs_dim,
129122
self.act_dim,
@@ -139,6 +132,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
139132
nce_def,
140133
nce_rng,
141134
inputs=(
135+
nce_init_rng,
142136
jnp.ones((1, self.obs_dim)),
143137
jnp.ones((1, self.act_dim)),
144138
jnp.ones((1, self.obs_dim)),
@@ -150,6 +144,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
150144
nce_def,
151145
nce_rng,
152146
inputs=(
147+
nce_init_rng,
153148
jnp.ones((1, self.obs_dim)),
154149
jnp.ones((1, self.act_dim)),
155150
jnp.ones((1, self.obs_dim)),
@@ -201,28 +196,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
201196
def train_step(self, batch: Batch, step: int) -> Metric:
202197
metrics = {}
203198

204-
split_index = batch.obs.shape[0] - self.aug_batch_size
205-
obs, action, next_obs, reward, terminal = [
206-
b[:split_index]
207-
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
208-
batch
209-
)
210-
]
211-
fobs, faction, fnext_obs, freward, fterminal = [
212-
b[split_index:]
213-
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
214-
batch
215-
)
216-
]
217-
rl_batch = Batch(obs, action, reward, terminal, next_obs, None)
218-
219199
self.rng, self.nce, nce_metrics = update_factorized_nce(
220200
self.rng,
221201
self.nce,
222-
fobs,
223-
faction,
224-
fnext_obs,
225-
freward,
202+
batch,
226203
self.ranking,
227204
self.reward_coef,
228205
)
@@ -233,9 +210,8 @@ def train_step(self, batch: Batch, step: int) -> Metric:
233210
self.critic,
234211
self.critic_target,
235212
self.actor_target,
236-
self.nce,
237213
self.nce_target,
238-
rl_batch,
214+
batch,
239215
discount=self.cfg.discount,
240216
target_policy_noise=self.target_policy_noise,
241217
noise_clip=self.noise_clip,
@@ -247,10 +223,9 @@ def train_step(self, batch: Batch, step: int) -> Metric:
247223
self.rng, self.actor, actor_metrics = update_actor(
248224
self.rng,
249225
self.actor,
250-
self.nce,
251226
self.nce_target,
252227
self.critic,
253-
rl_batch,
228+
batch,
254229
)
255230
metrics.update(actor_metrics)
256231

@@ -262,4 +237,4 @@ def train_step(self, batch: Batch, step: int) -> Metric:
262237

263238
def sync_target(self):
264239
super().sync_target()
265-
self.nce_target = ema_update(self.nce, self.nce_target, self.feature_tau)
240+
self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema)

0 commit comments

Comments
 (0)