@@ -15,30 +15,39 @@ class Reinforce:
1515 alpha : float = 0.05
1616 normalize_adv : bool = True
1717 baseline_fn : Optional [Callable [[object ], float ]] = None
18- seed : int | None = None
18+ seed : Optional [int ] = None
19+
1920 def __post_init__ (self ):
2021 self .rng = np .random .default_rng (self .seed )
22+
2123 def run_episode_discrete (self , env , policy , feature_fn : Callable [[object ], np .ndarray ]):
22- s = env .reset (); S ,A ,R ,L = [],[],[],[]; done = False
24+ s = env .reset ()
25+ S , A , R , L = [], [], [], []
26+ done = False
2327 while not done :
24- x = feature_fn (s ); a = policy .sample (x )
25- logp ,_ = policy .logprob_and_grad (x ,a )
28+ x = feature_fn (s )
29+ a = policy .sample (x )
30+ logp , _ = policy .logprob_and_grad (x , a )
2631 ns , r , done , _ = env .step (a )
27- S .append (s ); A .append (a ); R .append (r ); L .append (logp ); s = ns
28- return Trajectory (S ,A ,R ,L )
32+ S .append (s ); A .append (a ); R .append (r ); L .append (logp )
33+ s = ns
34+ return Trajectory (S , A , R , L )
35+
2936 def update_discrete (self , traj : Trajectory , policy , feature_fn : Callable [[object ], np .ndarray ]):
3037 G = returns_to_go (traj .rewards , self .gamma )
3138 if self .baseline_fn is not None :
32- b = np .array ([self .baseline_fn (s ) for s in traj .states ], dtype = float ); adv = G - b
39+ b = np .array ([self .baseline_fn (s ) for s in traj .states ], dtype = float )
40+ adv = G - b
3341 else :
3442 adv = G .copy ()
3543 if self .normalize_adv :
36- # Only standardize when there is variability; for 1-step episodes std==0 leads to zero updates.
3744 if len (adv ) >= 2 and np .std (adv ) > 1e-8 :
3845 adv = standardize (adv )
46+
3947 total_grad = np .zeros_like (policy .theta )
40- for s ,a ,adv_t in zip (traj .states , traj .actions , adv ):
41- x = feature_fn (s ); _ , grad = policy .logprob_and_grad (x ,a )
48+ for s , a , adv_t in zip (traj .states , traj .actions , adv ):
49+ x = feature_fn (s )
50+ _ , grad = policy .logprob_and_grad (x , a )
4251 total_grad += adv_t * grad
4352 policy .theta += self .alpha * total_grad
4453 return {"G" : G , "adv" : adv }
0 commit comments