@@ -29,6 +29,7 @@ def __init__(
2929 self ,
3030 objects : list [ObjectBase ],
3131 initial_positions : list [dict [ObjectBase , tuple [float , float , float ]]],
32+ device : torch .device | None = None ,
3233 ):
3334 """Initialize optimization state.
3435
@@ -37,6 +38,7 @@ def __init__(
3738 object marked with IsAnchor() which serves as a fixed reference.
3839 initial_positions: List of dicts (one per env). Length 1 = single-env,
3940 length > 1 = batched.
41+ device: Torch device for all tensors. Defaults to CPU.
4042 """
4143 assert len (initial_positions ) >= 1 , "initial_positions must contain at least one dict."
4244 anchor_objects = get_anchor_objects (objects )
@@ -49,39 +51,47 @@ def __init__(
4951 # Build object-to-index mapping
5052 self ._obj_to_idx : dict [ObjectBase , int ] = {obj : i for i , obj in enumerate (objects )}
5153
52- # Extract positions from each env's dict
54+ self . _device = device or torch . device ( "cpu" )
5355 self ._num_envs = len (initial_positions )
54- positions_per_env = []
56+
57+ # Validate that every dict contains all objects before building the tensor.
5558 for d in initial_positions :
56- positions = []
5759 for obj in objects :
5860 assert obj in d , f"Missing initial position for { obj .name } "
59- positions .append (torch .tensor (d [obj ], dtype = torch .float32 ))
60- positions_per_env .append (positions )
61+
62+ # Build all positions as a single (N, num_objects, 3) tensor in one call.
63+ pos_nested = [[d [obj ] for obj in objects ] for d in initial_positions ]
64+ all_positions = torch .tensor (pos_nested , dtype = torch .float32 , device = self ._device )
6165
6266 # Separate anchor positions from optimizable positions
6367 self ._anchor_indices : set [int ] = {self ._obj_to_idx [obj ] for obj in self ._anchor_objects }
6468 # Anchors must be identical across envs (they are fixed reference points).
6569 for idx in self ._anchor_indices :
66- for e in range (1 , self ._num_envs ):
67- assert torch .allclose (positions_per_env [ 0 ][ idx ], positions_per_env [ e ][ idx ]), (
70+ for env_idx in range (1 , self ._num_envs ):
71+ assert torch .allclose (all_positions [ 0 , idx ], all_positions [ env_idx , idx ]), (
6872 f"Anchor '{ objects [idx ].name } ' has different positions across envs "
69- f"(env 0: { positions_per_env [ 0 ][ idx ].tolist ()} , env { e } : { positions_per_env [ e ][ idx ].tolist ()} )"
73+ f"(env 0: { all_positions [ 0 , idx ].tolist ()} , env { env_idx } : { all_positions [ env_idx , idx ].tolist ()} )"
7074 )
7175 self ._anchor_positions : dict [int , torch .Tensor ] = {
72- idx : positions_per_env [ 0 ][ idx ].clone () for idx in self ._anchor_indices
76+ idx : all_positions [ 0 , idx ].clone () for idx in self ._anchor_indices
7377 }
7478
75- # Build optimizable positions tensor (excludes all anchors)
76- # Always stored as (N, num_opt, 3) where N = num_envs
79+ # Pre-build anchor positions as (1, num_objects, 3) for fast _reconstruct_all_positions.
80+ self ._anchor_pos_tensor = torch .zeros (1 , len (objects ), 3 , dtype = torch .float32 , device = self ._device )
81+ for idx , pos in self ._anchor_positions .items ():
82+ self ._anchor_pos_tensor [0 , idx , :] = pos
83+
84+ # Build optimizable positions tensor by slicing from the full tensor.
7785 self ._optimizable_indices = [i for i in range (len (objects )) if i not in self ._anchor_indices ]
86+ self ._global_to_opt_idx : dict [int , int ] = {
87+ global_idx : opt_idx for opt_idx , global_idx in enumerate (self ._optimizable_indices )
88+ }
7889 if self ._optimizable_indices :
79- opt_tensors = [
80- torch .stack ([positions_per_env [e ][i ] for e in range (self ._num_envs )]) for i in self ._optimizable_indices
81- ]
82- self ._optimizable_positions = torch .stack (opt_tensors , dim = 1 ) # (N, num_opt, 3)
90+ self ._opt_idx_tensor = torch .tensor (self ._optimizable_indices , dtype = torch .long , device = self ._device )
91+ self ._optimizable_positions = all_positions [:, self ._opt_idx_tensor , :].clone ()
8392 self ._optimizable_positions .requires_grad = True
8493 else :
94+ self ._opt_idx_tensor = None
8595 self ._optimizable_positions = None
8696
8797 @property
@@ -125,7 +135,7 @@ def get_position(self, obj: ObjectBase) -> torch.Tensor:
125135 return self ._anchor_positions [idx ].unsqueeze (0 ).expand (self ._num_envs , 3 )
126136 if self ._optimizable_positions is None :
127137 raise RuntimeError (f"No optimizable positions available for object '{ obj .name } '" )
128- opt_idx = self ._optimizable_indices . index ( idx )
138+ opt_idx = self ._global_to_opt_idx [ idx ]
129139 return self ._optimizable_positions [:, opt_idx , :]
130140
131141 def get_all_positions_snapshot (self ) -> list [tuple [float , float , float ]]:
@@ -142,11 +152,17 @@ def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float
142152 Returns:
143153 List of dictionaries with object instances as keys and (x, y, z) tuples as values.
144154 """
145- out = []
146- for e in range (self ._num_envs ):
147- d : dict [ObjectBase , tuple [float , float , float ]] = {}
148- for obj in self ._all_objects :
149- pos = self .get_position (obj )[e ].detach ().tolist ()
150- d [obj ] = (pos [0 ], pos [1 ], pos [2 ])
151- out .append (d )
152- return out
155+ # Reconstruct the full (N, num_objects, 3) tensor and transfer to CPU in one call.
156+ full = self ._reconstruct_all_positions ()
157+ pos_list = full .detach ().cpu ().tolist ()
158+ return [
159+ {obj : tuple (pos_list [env_idx ][obj_idx ]) for obj_idx , obj in enumerate (self ._all_objects )}
160+ for env_idx in range (self ._num_envs )
161+ ]
162+
163+ def _reconstruct_all_positions (self ) -> torch .Tensor :
164+ """Reconstruct a full (N, num_objects, 3) tensor from anchor and optimizable parts."""
165+ full = self ._anchor_pos_tensor .expand (self ._num_envs , - 1 , - 1 ).clone ()
166+ if self ._optimizable_positions is not None :
167+ full [:, self ._opt_idx_tensor , :] = self ._optimizable_positions
168+ return full
0 commit comments