Skip to content

Commit 7844900

Browse files
authored
Merge branch 'master' into qsm_idem
2 parents 5c6d871 + f11755b commit 7844900

File tree

15 files changed

+164
-105
lines changed

15 files changed

+164
-105
lines changed

examples/online/config/dmc/algo/ctrl_td3.yaml

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,38 @@
22

33
algo:
44
name: ctrl_td3
5-
cls: ctrl_td3
6-
discount: 0.99
7-
8-
tau: 0.005 # hard update
95
actor_update_freq: 1
106
target_update_freq: 1
11-
7+
discount: 0.99
128
ema: 0.005
13-
14-
# critic_hidden_dims: [512, 512, 512] # this is actually not used
159
actor_hidden_dims: [512, 512, 512]
16-
17-
critic_lr: 0.0003
10+
# critic_hidden_dims: [512, 512, 512] # not used
11+
activation: elu # not used
12+
critic_ensemble_size: 2
13+
layer_norm: true
1814
actor_lr: 0.0003
19-
15+
critic_lr: 0.0003
16+
clip_grad_norm: null
2017
target_policy_noise: 0.2
2118
noise_clip: 0.3
2219
exploration_noise: 0.2
2320

2421
# below are params specific to ctrl_td3
2522
feature_dim: 512
2623
feature_lr: 0.0001
27-
feature_tau: 0.005
24+
feature_ema: 0.005
2825
phi_hidden_dims: [512, 512]
2926
mu_hidden_dims: [512, 512]
30-
ctrl_coef: 1.0
31-
reward_coef: 1.0
32-
feature_update_ratio: 1
3327
critic_hidden_dims: [512, ]
3428
reward_hidden_dims: [512, ]
3529
rff_dim: 1024
30+
ctrl_coef: 1.0
31+
reward_coef: 1.0
3632
back_critic_grad: false
3733
critic_coef: 1.0
38-
aug_batch_size: 512
3934

4035
num_noises: 25
4136
linear: false
42-
beta: 1.0 # not used
4337
ranking: true
44-
activation: elu
45-
layer_norm: true
46-
critic_ensemble_size: 2
47-
clip_grad_norm: null
4838

4939
norm_obs: true
50-
batch_size: 1024

examples/online/config/dmc/algo/td3.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
algo:
44
name: td3
5+
actor_update_freq: 1
6+
target_update_freq: 1
57
discount: 0.99
68
ema: 0.005
79
actor_hidden_dims: [512, 512, 512]
@@ -12,8 +14,6 @@ algo:
1214
actor_lr: 0.0003
1315
critic_lr: 0.0003
1416
clip_grad_norm: null
15-
actor_update_freq: 1
16-
target_update_freq: 1
1717
target_policy_noise: 0.2
1818
noise_clip: 0.3
1919
exploration_noise: 0.2

examples/online/main_dmc_offpolicy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import gymnasium as gym
44
import hydra
5+
import jax
56
import jax.numpy as jnp
67
import numpy as np
78
import omegaconf
8-
import wandb
99
from omegaconf import OmegaConf
1010
from tqdm import tqdm, trange
1111

12+
import wandb
1213
from flowrl.agent.online import *
1314
from flowrl.config.online.mujoco import Config
1415
from flowrl.dataset.buffer.state import ReplayBuffer
@@ -17,14 +18,16 @@
1718
from flowrl.utils.logger import CompositeLogger
1819
from flowrl.utils.misc import set_seed_everywhere
1920

21+
jax.config.update("jax_default_matmul_precision", "float32")
22+
2023
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2124
"sac": SACAgent,
2225
"td3": TD3Agent,
2326
"td7": TD7Agent,
2427
"sdac": SDACAgent,
2528
"dpmd": DPMDAgent,
2629
"qsm": QSMAgent,
27-
"ctrl_td3": Ctrl_TD3_Agent,
30+
"ctrl_td3": CtrlTD3Agent,
2831
}
2932

3033
class OffPolicyTrainer():

examples/online/main_mujoco_offpolicy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gymnasium as gym
44
import gymnasium_robotics
55
import hydra
6+
import jax
67
import numpy as np
78
import omegaconf
89
from omegaconf import OmegaConf
@@ -16,6 +17,8 @@
1617
from flowrl.utils.logger import CompositeLogger
1718
from flowrl.utils.misc import set_seed_everywhere
1819

20+
jax.config.update("jax_default_matmul_precision", "float32")
21+
1922
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2023
"sac": SACAgent,
2124
"td3": TD3Agent,

examples/online/main_mujoco_onpolicy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import gymnasium as gym
55
import gymnasium_robotics
66
import hydra
7+
import jax
78
import jax.numpy as jnp
89
import numpy as np
910
import wandb
@@ -16,6 +17,8 @@
1617
from flowrl.utils.logger import CompositeLogger
1718
from flowrl.utils.misc import set_seed_everywhere
1819

20+
jax.config.update("jax_default_matmul_precision", "float32")
21+
1922
SUPPORTED_AGENTS: Dict[str, BaseAgent] = {
2023
"ppo": PPOAgent,
2124
}

flowrl/agent/online/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ..base import BaseAgent
22
from .alac.alac import ALACAgent
3-
from .ctrl.ctrl import Ctrl_TD3_Agent
3+
from .ctrl.ctrl import CtrlTD3Agent
44
from .dpmd import DPMDAgent
55
from .idem import IDEMAgent
66
from .ppo import PPOAgent
@@ -21,5 +21,5 @@
2121
"QSMAgent",
2222
"IDEMAgent",
2323
"ALACAgent",
24-
"Ctrl_TD3_Agent",
24+
"CtrlTD3Agent",
2525
]

flowrl/agent/online/ctrl/ctrl.py

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import partial
2-
from operator import attrgetter
32
from typing import Tuple
43

54
import jax
@@ -8,7 +7,7 @@
87

98
from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce
109
from flowrl.agent.online.td3 import TD3Agent
11-
from flowrl.config.online.mujoco.algo.ctrl_td3 import CTRL_TD3_Config
10+
from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config
1211
from flowrl.functional.ema import ema_update
1312
from flowrl.module.actor import SquashedDeterministicActor
1413
from flowrl.module.mlp import MLP
@@ -23,7 +22,6 @@ def update_critic(
2322
critic: Model,
2423
critic_target: Model,
2524
actor_target: Model,
26-
nce: Model,
2725
nce_target: Model,
2826
batch: Batch,
2927
discount: float,
@@ -42,7 +40,6 @@ def update_critic(
4240

4341
back_critic_grad = False
4442
if back_critic_grad:
45-
# this part will use feature
4643
raise NotImplementedError("no back critic grad exists")
4744

4845
feature = nce_target(batch.obs, batch.action, method="forward_phi")
@@ -53,7 +50,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
5350
feature,
5451
rngs={"dropout": dropout_rng},
5552
)
56-
# q_pred (2, 512, 1), q_target (512, 1)
5753
critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean()
5854
return critic_loss, {
5955
"loss/critic_loss": critic_loss,
@@ -69,7 +65,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
6965
def update_actor(
7066
rng: PRNGKey,
7167
actor: Model,
72-
nce: Model,
7368
nce_target: Model,
7469
critic: Model,
7570
batch: Batch,
@@ -95,23 +90,21 @@ def actor_loss_fn(
9590
return rng, new_actor, metrics
9691

9792

98-
class Ctrl_TD3_Agent(TD3Agent):
93+
class CtrlTD3Agent(TD3Agent):
9994
"""
100-
CTRL Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
95+
CTRL with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
10196
"""
10297

103-
name = "CTRLTD3Agent"
98+
name = "CtrlTD3Agent"
10499
model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"]
105100

106-
def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
101+
def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlTD3Config, seed: int):
107102
super().__init__(obs_dim, act_dim, cfg, seed)
108103
self.cfg = cfg
109104

110105
self.ctrl_coef = cfg.ctrl_coef
111106
self.critic_coef = cfg.critic_coef
112107

113-
self.aug_batch_size = cfg.aug_batch_size
114-
self.feature_tau = cfg.feature_tau
115108
self.linear = cfg.linear
116109
self.ranking = cfg.ranking
117110
self.feature_dim = cfg.feature_dim
@@ -120,10 +113,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
120113
self.rff_dim = cfg.rff_dim
121114

122115
# sanity checks for the hyper-parameters
123-
assert not self.linear, "Removing linear version for now"
116+
assert not self.linear, "linear mode is not supported yet"
124117

125118
# networks
126-
self.rng, nce_rng, actor_rng, critic_rng = jax.random.split(self.rng, 4)
119+
self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5)
127120
nce_def = FactorizedNCE(
128121
self.obs_dim,
129122
self.act_dim,
@@ -139,6 +132,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
139132
nce_def,
140133
nce_rng,
141134
inputs=(
135+
nce_init_rng,
142136
jnp.ones((1, self.obs_dim)),
143137
jnp.ones((1, self.act_dim)),
144138
jnp.ones((1, self.obs_dim)),
@@ -150,6 +144,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
150144
nce_def,
151145
nce_rng,
152146
inputs=(
147+
nce_init_rng,
153148
jnp.ones((1, self.obs_dim)),
154149
jnp.ones((1, self.act_dim)),
155150
jnp.ones((1, self.obs_dim)),
@@ -201,28 +196,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
201196
def train_step(self, batch: Batch, step: int) -> Metric:
202197
metrics = {}
203198

204-
split_index = batch.obs.shape[0] - self.aug_batch_size
205-
obs, action, next_obs, reward, terminal = [
206-
b[:split_index]
207-
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
208-
batch
209-
)
210-
]
211-
fobs, faction, fnext_obs, freward, fterminal = [
212-
b[split_index:]
213-
for b in attrgetter("obs", "action", "next_obs", "reward", "terminal")(
214-
batch
215-
)
216-
]
217-
rl_batch = Batch(obs, action, reward, terminal, next_obs, None)
218-
219199
self.rng, self.nce, nce_metrics = update_factorized_nce(
220200
self.rng,
221201
self.nce,
222-
fobs,
223-
faction,
224-
fnext_obs,
225-
freward,
202+
batch,
226203
self.ranking,
227204
self.reward_coef,
228205
)
@@ -233,9 +210,8 @@ def train_step(self, batch: Batch, step: int) -> Metric:
233210
self.critic,
234211
self.critic_target,
235212
self.actor_target,
236-
self.nce,
237213
self.nce_target,
238-
rl_batch,
214+
batch,
239215
discount=self.cfg.discount,
240216
target_policy_noise=self.target_policy_noise,
241217
noise_clip=self.noise_clip,
@@ -247,10 +223,9 @@ def train_step(self, batch: Batch, step: int) -> Metric:
247223
self.rng, self.actor, actor_metrics = update_actor(
248224
self.rng,
249225
self.actor,
250-
self.nce,
251226
self.nce_target,
252227
self.critic,
253-
rl_batch,
228+
batch,
254229
)
255230
metrics.update(actor_metrics)
256231

@@ -262,4 +237,4 @@ def train_step(self, batch: Batch, step: int) -> Metric:
262237

263238
def sync_target(self):
264239
super().sync_target()
265-
self.nce_target = ema_update(self.nce, self.nce_target, self.feature_tau)
240+
self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema)

0 commit comments

Comments
 (0)