@@ -8726,8 +8726,10 @@ def set_container(self, container):
87268726class ActionMask (Transform ):
87278727 """An adaptive action masker.
87288728
8729- This transform reads the mask from the input tensordict after the step is executed,
8730- and adapts the mask of the one-hot / categorical action spec.
8729+ This transform is useful to ensure that randomly generated actions
8730+ respect legal actions, by masking the action specs.
8731+ It reads the mask from the input tensordict after the step is executed,
8732+ and adapts the mask of the finite action spec.
87318733
87328734 .. note:: This transform will fail when used without an environment.
87338735
@@ -8773,8 +8775,6 @@ class ActionMask(Transform):
87738775 >>> base_env = MaskedEnv()
87748776 >>> env = TransformedEnv(base_env, ActionMask())
87758777 >>> r = env.rollout(10)
8776- >>> env = TransformedEnv(base_env, ActionMask())
8777- >>> r = env.rollout(10)
87788778 >>> r["action_mask"]
87798779 tensor([[ True, True, True, True],
87808780 [ True, True, False, True],
@@ -8810,45 +8810,29 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
88108810 raise RuntimeError (FORWARD_NOT_IMPLEMENTED .format (type (self )))
88118811
88128812 @property
8813- def action_spec (self ):
8814- action_spec = self .container .full_action_spec
8815- keys = self .container .action_keys
8816- if len (keys ) == 1 :
8817- action_spec = action_spec [keys [0 ]]
8818- else :
8819- raise ValueError (
8820- f"Too many action keys for { self .__class__ .__name__ } : { keys = } "
8821- )
8813+ def action_spec (self ) -> TensorSpec :
8814+ action_spec = self .container .full_action_spec [self .in_keys [0 ]]
88228815 if not isinstance (action_spec , self .ACCEPTED_SPECS ):
88238816 raise ValueError (
88248817 self .SPEC_TYPE_ERROR .format (self .ACCEPTED_SPECS , type (action_spec ))
88258818 )
88268819 return action_spec
88278820
88288821 def _call (self , next_tensordict : TensorDictBase ) -> TensorDictBase :
8829- parent = self .parent
8830- if parent is None :
8822+ if self .parent is None :
88318823 raise RuntimeError (
88328824 f"{ type (self )} .parent cannot be None: make sure this transform is executed within an environment."
88338825 )
8826+
88348827 mask = next_tensordict .get (self .in_keys [1 ])
8835- action_spec = self .action_spec
8836- action_spec . update_mask ( mask . to ( action_spec . device ))
8828+ self . action_spec . update_mask ( mask . to ( self .action_spec . device ))
8829+
88378830 return next_tensordict
88388831
88398832 def _reset (
88408833 self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
88418834 ) -> TensorDictBase :
8842- action_spec = self .action_spec
8843- mask = tensordict .get (self .in_keys [1 ], None )
8844- if mask is not None :
8845- mask = mask .to (action_spec .device )
8846- action_spec .update_mask (mask )
8847-
8848- # TODO: Check that this makes sense
8849- with _set_missing_tolerance (self , True ):
8850- tensordict_reset = self ._call (tensordict_reset )
8851- return tensordict_reset
8835+ return self ._call (tensordict_reset )
88528836
88538837
88548838class VecGymEnvTransform (Transform ):
0 commit comments