@@ -21,8 +21,7 @@ class RelationSolverState:
2121 keeping anchor (fixed) and optimizable positions separate internally
2222 while providing an interface for position lookups.
2323
24- Positions are always stored as (N, num_objects, 3) where N = num_envs
25- (N=1 for single-env).
24+ Positions are always stored as (batch_size, num_objects, 3).
2625 """
2726
2827 def __init__ (
@@ -52,7 +51,7 @@ def __init__(
5251 self ._obj_to_idx : dict [ObjectBase , int ] = {obj : i for i , obj in enumerate (objects )}
5352
5453 self ._device = device or torch .device ("cpu" )
55- self ._num_envs = len (initial_positions )
54+ self ._batch_size = len (initial_positions )
5655
5756 # Validate that every dict contains all objects before building the tensor.
5857 for d in initial_positions :
@@ -67,7 +66,7 @@ def __init__(
6766 self ._anchor_indices : set [int ] = {self ._obj_to_idx [obj ] for obj in self ._anchor_objects }
6867 # Anchors must be identical across envs (they are fixed reference points).
6968 for idx in self ._anchor_indices :
70- for env_idx in range (1 , self ._num_envs ):
69+ for env_idx in range (1 , self ._batch_size ):
7170 assert torch .allclose (all_positions [0 , idx ], all_positions [env_idx , idx ]), (
7271 f"Anchor '{ objects [idx ].name } ' has different positions across envs "
7372 f"(env 0: { all_positions [0 , idx ].tolist ()} , env { env_idx } : { all_positions [env_idx , idx ].tolist ()} )"
@@ -95,13 +94,13 @@ def __init__(
9594 self ._optimizable_positions = None
9695
9796 @property
98- def num_envs (self ) -> int :
99- """Number of environments (leading dimension N )."""
100- return self ._num_envs
97+ def batch_size (self ) -> int :
98+ """Number of independent position sets (leading dimension of position tensors )."""
99+ return self ._batch_size
101100
102101 @property
103102 def optimizable_positions (self ) -> torch .Tensor | None :
104- """Tensor of optimizable positions (N, num_opt , 3), or None if all objects are anchors.
103+ """Tensor of optimizable positions (batch_size, num_optimizable , 3), or None if all objects are anchors.
105104
106105 This is the tensor that should be passed to the optimizer.
107106 """
@@ -124,15 +123,15 @@ def get_position(self, obj: ObjectBase) -> torch.Tensor:
124123 obj: The object to get position for.
125124
126125 Returns:
127- Position tensor of shape (N , 3).
126+ Position tensor of shape (batch_size , 3).
128127
129128 Raises:
130129 KeyError: If object is not tracked by this state.
131130 RuntimeError: If requesting position for optimizable object when none exist.
132131 """
133132 idx = self ._obj_to_idx [obj ]
134133 if idx in self ._anchor_indices :
135- return self ._anchor_positions [idx ].unsqueeze (0 ).expand (self ._num_envs , 3 )
134+ return self ._anchor_positions [idx ].unsqueeze (0 ).expand (self ._batch_size , 3 )
136135 if self ._optimizable_positions is None :
137136 raise RuntimeError (f"No optimizable positions available for object '{ obj .name } '" )
138137 opt_idx = self ._global_to_opt_idx [idx ]
@@ -157,12 +156,12 @@ def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float
157156 pos_list = full .detach ().cpu ().tolist ()
158157 return [
159158 {obj : tuple (pos_list [env_idx ][obj_idx ]) for obj_idx , obj in enumerate (self ._all_objects )}
160- for env_idx in range (self ._num_envs )
159+ for env_idx in range (self ._batch_size )
161160 ]
162161
163162 def _reconstruct_all_positions (self ) -> torch .Tensor :
164- """Reconstruct a full (N , num_objects, 3) tensor from anchor and optimizable parts."""
165- full = self ._anchor_pos_tensor .expand (self ._num_envs , - 1 , - 1 ).clone ()
163+ """Reconstruct a full (batch_size , num_objects, 3) tensor from anchor and optimizable parts."""
164+ full = self ._anchor_pos_tensor .expand (self ._batch_size , - 1 , - 1 ).clone ()
166165 if self ._optimizable_positions is not None :
167166 full [:, self ._opt_idx_tensor , :] = self ._optimizable_positions
168167 return full
0 commit comments