11from functools import partial
2- from operator import attrgetter
32from typing import Tuple
43
54import jax
87
98from flowrl .agent .online .ctrl .network import FactorizedNCE , update_factorized_nce
109from flowrl .agent .online .td3 import TD3Agent
11- from flowrl .config .online .mujoco .algo .ctrl_td3 import CTRL_TD3_Config
10+ from flowrl .config .online .mujoco .algo .ctrl_td3 import CtrlTD3Config
1211from flowrl .functional .ema import ema_update
1312from flowrl .module .actor import SquashedDeterministicActor
1413from flowrl .module .mlp import MLP
@@ -23,7 +22,6 @@ def update_critic(
2322 critic : Model ,
2423 critic_target : Model ,
2524 actor_target : Model ,
26- nce : Model ,
2725 nce_target : Model ,
2826 batch : Batch ,
2927 discount : float ,
@@ -42,7 +40,6 @@ def update_critic(
4240
4341 back_critic_grad = False
4442 if back_critic_grad :
45- # this part will use feature
4643 raise NotImplementedError ("no back critic grad exists" )
4744
4845 feature = nce_target (batch .obs , batch .action , method = "forward_phi" )
@@ -53,7 +50,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
5350 feature ,
5451 rngs = {"dropout" : dropout_rng },
5552 )
56- # q_pred (2, 512, 1), q_target (512, 1)
5753 critic_loss = critic_coef * ((q_pred - q_target [jnp .newaxis , :])** 2 ).sum (0 ).mean ()
5854 return critic_loss , {
5955 "loss/critic_loss" : critic_loss ,
@@ -69,7 +65,6 @@ def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndar
6965def update_actor (
7066 rng : PRNGKey ,
7167 actor : Model ,
72- nce : Model ,
7368 nce_target : Model ,
7469 critic : Model ,
7570 batch : Batch ,
@@ -95,23 +90,21 @@ def actor_loss_fn(
9590 return rng , new_actor , metrics
9691
9792
98- class Ctrl_TD3_Agent (TD3Agent ):
93+ class CtrlTD3Agent (TD3Agent ):
9994 """
100- CTRL Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
95+ CTRL with Twin Delayed Deep Deterministic Policy Gradient (TD3) agent.
10196 """
10297
103- name = "CTRLTD3Agent "
98+ name = "CtrlTD3Agent "
10499 model_names = ["nce" , "nce_target" , "actor" , "actor_target" , "critic" , "critic_target" ]
105100
106- def __init__ (self , obs_dim : int , act_dim : int , cfg : CTRL_TD3_Config , seed : int ):
101+ def __init__ (self , obs_dim : int , act_dim : int , cfg : CtrlTD3Config , seed : int ):
107102 super ().__init__ (obs_dim , act_dim , cfg , seed )
108103 self .cfg = cfg
109104
110105 self .ctrl_coef = cfg .ctrl_coef
111106 self .critic_coef = cfg .critic_coef
112107
113- self .aug_batch_size = cfg .aug_batch_size
114- self .feature_tau = cfg .feature_tau
115108 self .linear = cfg .linear
116109 self .ranking = cfg .ranking
117110 self .feature_dim = cfg .feature_dim
@@ -120,10 +113,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
120113 self .rff_dim = cfg .rff_dim
121114
122115 # sanity checks for the hyper-parameters
123- assert not self .linear , "Removing linear version for now "
116+ assert not self .linear , "linear mode is not supported yet "
124117
125118 # networks
126- self .rng , nce_rng , actor_rng , critic_rng = jax .random .split (self .rng , 4 )
119+ self .rng , nce_rng , nce_init_rng , actor_rng , critic_rng = jax .random .split (self .rng , 5 )
127120 nce_def = FactorizedNCE (
128121 self .obs_dim ,
129122 self .act_dim ,
@@ -139,6 +132,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
139132 nce_def ,
140133 nce_rng ,
141134 inputs = (
135+ nce_init_rng ,
142136 jnp .ones ((1 , self .obs_dim )),
143137 jnp .ones ((1 , self .act_dim )),
144138 jnp .ones ((1 , self .obs_dim )),
@@ -150,6 +144,7 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
150144 nce_def ,
151145 nce_rng ,
152146 inputs = (
147+ nce_init_rng ,
153148 jnp .ones ((1 , self .obs_dim )),
154149 jnp .ones ((1 , self .act_dim )),
155150 jnp .ones ((1 , self .obs_dim )),
@@ -201,28 +196,10 @@ def __init__(self, obs_dim: int, act_dim: int, cfg: CTRL_TD3_Config, seed: int):
201196 def train_step (self , batch : Batch , step : int ) -> Metric :
202197 metrics = {}
203198
204- split_index = batch .obs .shape [0 ] - self .aug_batch_size
205- obs , action , next_obs , reward , terminal = [
206- b [:split_index ]
207- for b in attrgetter ("obs" , "action" , "next_obs" , "reward" , "terminal" )(
208- batch
209- )
210- ]
211- fobs , faction , fnext_obs , freward , fterminal = [
212- b [split_index :]
213- for b in attrgetter ("obs" , "action" , "next_obs" , "reward" , "terminal" )(
214- batch
215- )
216- ]
217- rl_batch = Batch (obs , action , reward , terminal , next_obs , None )
218-
219199 self .rng , self .nce , nce_metrics = update_factorized_nce (
220200 self .rng ,
221201 self .nce ,
222- fobs ,
223- faction ,
224- fnext_obs ,
225- freward ,
202+ batch ,
226203 self .ranking ,
227204 self .reward_coef ,
228205 )
@@ -233,9 +210,8 @@ def train_step(self, batch: Batch, step: int) -> Metric:
233210 self .critic ,
234211 self .critic_target ,
235212 self .actor_target ,
236- self .nce ,
237213 self .nce_target ,
238- rl_batch ,
214+ batch ,
239215 discount = self .cfg .discount ,
240216 target_policy_noise = self .target_policy_noise ,
241217 noise_clip = self .noise_clip ,
@@ -247,10 +223,9 @@ def train_step(self, batch: Batch, step: int) -> Metric:
247223 self .rng , self .actor , actor_metrics = update_actor (
248224 self .rng ,
249225 self .actor ,
250- self .nce ,
251226 self .nce_target ,
252227 self .critic ,
253- rl_batch ,
228+ batch ,
254229 )
255230 metrics .update (actor_metrics )
256231
@@ -262,4 +237,4 @@ def train_step(self, batch: Batch, step: int) -> Metric:
262237
263238 def sync_target (self ):
264239 super ().sync_target ()
265- self .nce_target = ema_update (self .nce , self .nce_target , self .feature_tau )
240+ self .nce_target = ema_update (self .nce , self .nce_target , self .cfg . feature_ema )
0 commit comments