Skip to content

Commit 2e877a8

Browse files
committed
speed acceleration for batch solver
1 parent 973d45b commit 2e877a8

3 files changed

Lines changed: 50 additions & 31 deletions

File tree

isaaclab_arena/relations/relation_solver.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,24 +85,22 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) -
8585
loss = strategy.compute_loss(
8686
relation=relation,
8787
child_pos=child_pos,
88-
child_bbox=obj.get_bounding_box(),
88+
child_bbox=obj.get_bounding_box().to(device),
8989
)
9090
if debug:
9191
_print_unary_relation_debug(obj, relation, child_pos[0], loss.mean())
9292
# Handle binary relations (with parent) like On, NextTo
9393
elif isinstance(relation, Relation):
94-
# Build parent world bbox: anchors have a known fixed pose,
95-
# optimizable parents use the current solver position + local bbox.
9694
parent = relation.parent
9795
if parent in state.anchor_objects:
98-
parent_world_bbox = parent.get_world_bounding_box()
96+
parent_world_bbox = parent.get_world_bounding_box().to(device)
9997
else:
10098
parent_pos = state.get_position(parent)
101-
parent_world_bbox = parent.get_bounding_box().translated(parent_pos)
99+
parent_world_bbox = parent.get_bounding_box().to(device).translated(parent_pos)
102100
loss = strategy.compute_loss(
103101
relation=relation,
104102
child_pos=child_pos,
105-
child_bbox=obj.get_bounding_box(),
103+
child_bbox=obj.get_bounding_box().to(device),
106104
parent_world_bbox=parent_world_bbox,
107105
)
108106
if debug:
@@ -132,7 +130,8 @@ def solve(
132130
Returns:
133131
List of dicts (one per env) mapping objects to their solved (x, y, z) positions.
134132
"""
135-
state = RelationSolverState(objects, initial_positions)
133+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134+
state = RelationSolverState(objects, initial_positions, device=device)
136135

137136
if self.params.verbose:
138137
anchor_names = [obj.name for obj in state.anchor_objects]

isaaclab_arena/relations/relation_solver_state.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
self,
3030
objects: list[ObjectBase],
3131
initial_positions: list[dict[ObjectBase, tuple[float, float, float]]],
32+
device: torch.device | None = None,
3233
):
3334
"""Initialize optimization state.
3435
@@ -37,6 +38,7 @@ def __init__(
3738
object marked with IsAnchor() which serves as a fixed reference.
3839
initial_positions: List of dicts (one per env). Length 1 = single-env,
3940
length > 1 = batched.
41+
device: Torch device for all tensors. Defaults to CPU.
4042
"""
4143
assert len(initial_positions) >= 1, "initial_positions must contain at least one dict."
4244
anchor_objects = get_anchor_objects(objects)
@@ -49,39 +51,47 @@ def __init__(
4951
# Build object-to-index mapping
5052
self._obj_to_idx: dict[ObjectBase, int] = {obj: i for i, obj in enumerate(objects)}
5153

52-
# Extract positions from each env's dict
54+
self._device = device or torch.device("cpu")
5355
self._num_envs = len(initial_positions)
54-
positions_per_env = []
56+
57+
# Validate that every dict contains all objects before building the tensor.
5558
for d in initial_positions:
56-
positions = []
5759
for obj in objects:
5860
assert obj in d, f"Missing initial position for {obj.name}"
59-
positions.append(torch.tensor(d[obj], dtype=torch.float32))
60-
positions_per_env.append(positions)
61+
62+
# Build all positions as a single (N, num_objects, 3) tensor in one call.
63+
pos_nested = [[d[obj] for obj in objects] for d in initial_positions]
64+
all_positions = torch.tensor(pos_nested, dtype=torch.float32, device=self._device)
6165

6266
# Separate anchor positions from optimizable positions
6367
self._anchor_indices: set[int] = {self._obj_to_idx[obj] for obj in self._anchor_objects}
6468
# Anchors must be identical across envs (they are fixed reference points).
6569
for idx in self._anchor_indices:
66-
for e in range(1, self._num_envs):
67-
assert torch.allclose(positions_per_env[0][idx], positions_per_env[e][idx]), (
70+
for env_idx in range(1, self._num_envs):
71+
assert torch.allclose(all_positions[0, idx], all_positions[env_idx, idx]), (
6872
f"Anchor '{objects[idx].name}' has different positions across envs "
69-
f"(env 0: {positions_per_env[0][idx].tolist()}, env {e}: {positions_per_env[e][idx].tolist()})"
73+
f"(env 0: {all_positions[0, idx].tolist()}, env {env_idx}: {all_positions[env_idx, idx].tolist()})"
7074
)
7175
self._anchor_positions: dict[int, torch.Tensor] = {
72-
idx: positions_per_env[0][idx].clone() for idx in self._anchor_indices
76+
idx: all_positions[0, idx].clone() for idx in self._anchor_indices
7377
}
7478

75-
# Build optimizable positions tensor (excludes all anchors)
76-
# Always stored as (N, num_opt, 3) where N = num_envs
79+
# Pre-build anchor positions as (1, num_objects, 3) for fast _reconstruct_all_positions.
80+
self._anchor_pos_tensor = torch.zeros(1, len(objects), 3, dtype=torch.float32, device=self._device)
81+
for idx, pos in self._anchor_positions.items():
82+
self._anchor_pos_tensor[0, idx, :] = pos
83+
84+
# Build optimizable positions tensor by slicing from the full tensor.
7785
self._optimizable_indices = [i for i in range(len(objects)) if i not in self._anchor_indices]
86+
self._global_to_opt_idx: dict[int, int] = {
87+
global_idx: opt_idx for opt_idx, global_idx in enumerate(self._optimizable_indices)
88+
}
7889
if self._optimizable_indices:
79-
opt_tensors = [
80-
torch.stack([positions_per_env[e][i] for e in range(self._num_envs)]) for i in self._optimizable_indices
81-
]
82-
self._optimizable_positions = torch.stack(opt_tensors, dim=1) # (N, num_opt, 3)
90+
self._opt_idx_tensor = torch.tensor(self._optimizable_indices, dtype=torch.long, device=self._device)
91+
self._optimizable_positions = all_positions[:, self._opt_idx_tensor, :].clone()
8392
self._optimizable_positions.requires_grad = True
8493
else:
94+
self._opt_idx_tensor = None
8595
self._optimizable_positions = None
8696

8797
@property
@@ -125,7 +135,7 @@ def get_position(self, obj: ObjectBase) -> torch.Tensor:
125135
return self._anchor_positions[idx].unsqueeze(0).expand(self._num_envs, 3)
126136
if self._optimizable_positions is None:
127137
raise RuntimeError(f"No optimizable positions available for object '{obj.name}'")
128-
opt_idx = self._optimizable_indices.index(idx)
138+
opt_idx = self._global_to_opt_idx[idx]
129139
return self._optimizable_positions[:, opt_idx, :]
130140

131141
def get_all_positions_snapshot(self) -> list[tuple[float, float, float]]:
@@ -142,11 +152,17 @@ def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float
142152
Returns:
143153
List of dictionaries with object instances as keys and (x, y, z) tuples as values.
144154
"""
145-
out = []
146-
for e in range(self._num_envs):
147-
d: dict[ObjectBase, tuple[float, float, float]] = {}
148-
for obj in self._all_objects:
149-
pos = self.get_position(obj)[e].detach().tolist()
150-
d[obj] = (pos[0], pos[1], pos[2])
151-
out.append(d)
152-
return out
155+
# Reconstruct the full (N, num_objects, 3) tensor and transfer to CPU in one call.
156+
full = self._reconstruct_all_positions()
157+
pos_list = full.detach().cpu().tolist()
158+
return [
159+
{obj: tuple(pos_list[env_idx][obj_idx]) for obj_idx, obj in enumerate(self._all_objects)}
160+
for env_idx in range(self._num_envs)
161+
]
162+
163+
def _reconstruct_all_positions(self) -> torch.Tensor:
164+
"""Reconstruct a full (N, num_objects, 3) tensor from anchor and optimizable parts."""
165+
full = self._anchor_pos_tensor.expand(self._num_envs, -1, -1).clone()
166+
if self._optimizable_positions is not None:
167+
full[:, self._opt_idx_tensor, :] = self._optimizable_positions
168+
return full

isaaclab_arena/utils/bounding_box.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ def scaled(self, scale: tuple[float, float, float] | torch.Tensor) -> "AxisAlign
124124
scale = self._to_batched_tensor(scale)
125125
return AxisAlignedBoundingBox(min_point=self._min_point * scale, max_point=self._max_point * scale)
126126

127+
def to(self, device: torch.device) -> "AxisAlignedBoundingBox":
128+
"""Return a new bounding box with tensors on *device*."""
129+
return AxisAlignedBoundingBox(min_point=self._min_point.to(device), max_point=self._max_point.to(device))
130+
127131
def translated(self, offset: tuple[float, float, float] | torch.Tensor) -> "AxisAlignedBoundingBox":
128132
"""Return a new bounding box translated by an offset.
129133

0 commit comments

Comments
 (0)