@@ -104,6 +104,9 @@ class PPOLoss(LossModule):
104104 * **Scalar**: one value applied to the summed entropy of every action head.
105105 * **Mapping** ``{head_name: coef}`` gives an individual coefficient for each action-head's entropy.
106106 Defaults to ``0.01``.
107+ log_explained_variance (bool, optional): if ``True``, the explained variance of the critic
108+ predictions w.r.t. value targets will be computed and logged as ``"explained_variance"``.
109+ This can help monitor critic quality during training. Best possible score is 1.0, lower values are worse. Defaults to ``True``.
107110 critic_coef (scalar, optional): critic loss multiplier when computing the total
108111 loss. Defaults to ``1.0``. Set ``critic_coef`` to ``None`` to exclude the value
109112 loss from the forward outputs.
@@ -349,6 +352,7 @@ def __init__(
349352 entropy_bonus : bool = True ,
350353 samples_mc_entropy : int = 1 ,
351354 entropy_coeff : float | Mapping [str , float ] = 0.01 ,
355+ log_explained_variance : bool = True ,
352356 critic_coef : float | None = None ,
353357 loss_critic_type : str = "smooth_l1" ,
354358 normalize_advantage : bool = False ,
@@ -413,6 +417,7 @@ def __init__(
413417 self .critic_network_params = None
414418 self .target_critic_network_params = None
415419
420+ self .log_explained_variance = log_explained_variance
416421 self .samples_mc_entropy = samples_mc_entropy
417422 self .entropy_bonus = entropy_bonus
418423 self .separate_losses = separate_losses
@@ -745,6 +750,16 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
745750 self .loss_critic_type ,
746751 )
747752
753+ explained_variance = None
754+ if self .log_explained_variance :
755+ with torch .no_grad (): # <‑‑ break grad‐flow
756+ tgt = target_return .detach ()
757+ pred = state_value .detach ()
758+ eps = torch .finfo (tgt .dtype ).eps
759+ resid = torch .var (tgt - pred , unbiased = False , dim = 0 )
760+ total = torch .var (tgt , unbiased = False , dim = 0 )
761+ explained_variance = 1.0 - resid / (total + eps )
762+
748763 self ._clear_weakrefs (
749764 tensordict ,
750765 "actor_network_params" ,
@@ -753,8 +768,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
753768 "target_critic_network_params" ,
754769 )
755770 if self ._has_critic :
756- return self .critic_coef * loss_value , clip_fraction
757- return loss_value , clip_fraction
771+ return self .critic_coef * loss_value , clip_fraction , explained_variance
772+ return loss_value , clip_fraction , explained_variance
758773
759774 @property
760775 @_cache_values
@@ -804,10 +819,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
804819 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
805820 td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
806821 if self ._has_critic :
807- loss_critic , value_clip_fraction = self .loss_critic (tensordict )
822+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict )
808823 td_out .set ("loss_critic" , loss_critic )
809824 if value_clip_fraction is not None :
810825 td_out .set ("value_clip_fraction" , value_clip_fraction )
826+ if explained_variance is not None :
827+ td_out .set ("explained_variance" , explained_variance )
811828 td_out = td_out .named_apply (
812829 lambda name , value : _reduce (value , reduction = self .reduction ).squeeze (- 1 )
813830 if name .startswith ("loss_" )
@@ -1172,10 +1189,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
11721189 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
11731190 td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
11741191 if self ._has_critic :
1175- loss_critic , value_clip_fraction = self .loss_critic (tensordict )
1192+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict )
11761193 td_out .set ("loss_critic" , loss_critic )
11771194 if value_clip_fraction is not None :
11781195 td_out .set ("value_clip_fraction" , value_clip_fraction )
1196+ if explained_variance is not None :
1197+ td_out .set ("explained_variance" , explained_variance )
11791198
11801199 td_out .set ("ESS" , _reduce (ess , self .reduction ) / batch )
11811200 td_out = td_out .named_apply (
@@ -1518,10 +1537,12 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
15181537 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
15191538 td_out .set ("loss_entropy" , self ._weighted_loss_entropy (entropy ))
15201539 if self ._has_critic :
1521- loss_critic , value_clip_fraction = self .loss_critic (tensordict_copy )
1540+ loss_critic , value_clip_fraction , explained_variance = self .loss_critic (tensordict_copy )
15221541 td_out .set ("loss_critic" , loss_critic )
15231542 if value_clip_fraction is not None :
15241543 td_out .set ("value_clip_fraction" , value_clip_fraction )
1544+ if explained_variance is not None :
1545+ td_out .set ("explained_variance" , explained_variance )
15251546 td_out = td_out .named_apply (
15261547 lambda name , value : _reduce (value , reduction = self .reduction ).squeeze (- 1 )
15271548 if name .startswith ("loss_" )
0 commit comments