Skip to content

Commit da8e12a

Browse files
committed
rename N to avoid abbreviation
1 parent b152c0c commit da8e12a

3 files changed

Lines changed: 21 additions & 19 deletions

File tree

isaaclab_arena/relations/object_placer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def place(
119119
num_results = num_envs if result_per_env else 1
120120
num_candidates = self.params.max_placement_attempts * num_results
121121

122-
initial_positions: list[dict] = []
122+
initial_positions: list[dict[ObjectBase, tuple[float, float, float]]] = []
123123
for candidate_idx in range(num_candidates):
124124
if generator is not None:
125125
generator.manual_seed(self.params.placement_seed + candidate_idx)
@@ -149,7 +149,9 @@ def place(
149149
f" {total_valid} valid, selected best {num_results} ({n_valid} valid)"
150150
)
151151

152-
final_per_env: list[dict] = [candidate.positions for candidate in selected]
152+
final_per_env: list[dict[ObjectBase, tuple[float, float, float]]] = [
153+
candidate.positions for candidate in selected
154+
]
153155
results_per_env = [
154156
PlacementResult(
155157
success=candidate.is_valid,

isaaclab_arena/relations/relation_solver.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def _compute_total_loss(self, state: RelationSolverState, debug: bool = False) -
6868
debug: If True, print detailed loss breakdown.
6969
7070
Returns:
71-
Scalar loss tensor (mean over envs). Per-env loss stored in _last_loss_per_env.
71+
Scalar loss tensor (mean over environments).
7272
"""
73-
N = state.num_envs
73+
batch_size = state.batch_size
7474
device = state.optimizable_positions.device if state.optimizable_positions is not None else None
75-
total_loss = torch.zeros(N, device=device, dtype=torch.float32)
75+
total_loss = torch.zeros(batch_size, device=device, dtype=torch.float32)
7676

7777
# Compute loss from all spatial relations using strategies
7878
for obj in state.optimizable_objects:
@@ -145,6 +145,7 @@ def solve(
145145
if self.params.verbose:
146146
print("No optimizable objects, skipping solver.")
147147
self._last_loss_history = [0.0]
148+
self._last_loss_per_env = torch.zeros(state.batch_size)
148149
self._last_position_history = [state.get_all_positions_snapshot()]
149150
return state.get_final_positions()
150151

@@ -203,7 +204,7 @@ def last_loss_history(self) -> list[float]:
203204

204205
@property
205206
def last_loss_per_env(self) -> torch.Tensor | None:
206-
"""Per-env loss (N,) from the last solve() call."""
207+
"""Per-candidate loss tensor of shape (batch_size,) from the last solve() call."""
207208
return self._last_loss_per_env
208209

209210
@property

isaaclab_arena/relations/relation_solver_state.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ class RelationSolverState:
2121
keeping anchor (fixed) and optimizable positions separate internally
2222
while providing an interface for position lookups.
2323
24-
Positions are always stored as (N, num_objects, 3) where N = num_envs
25-
(N=1 for single-env).
24+
Positions are always stored as (batch_size, num_objects, 3).
2625
"""
2726

2827
def __init__(
@@ -52,7 +51,7 @@ def __init__(
5251
self._obj_to_idx: dict[ObjectBase, int] = {obj: i for i, obj in enumerate(objects)}
5352

5453
self._device = device or torch.device("cpu")
55-
self._num_envs = len(initial_positions)
54+
self._batch_size = len(initial_positions)
5655

5756
# Validate that every dict contains all objects before building the tensor.
5857
for d in initial_positions:
@@ -67,7 +66,7 @@ def __init__(
6766
self._anchor_indices: set[int] = {self._obj_to_idx[obj] for obj in self._anchor_objects}
6867
# Anchors must be identical across envs (they are fixed reference points).
6968
for idx in self._anchor_indices:
70-
for env_idx in range(1, self._num_envs):
69+
for env_idx in range(1, self._batch_size):
7170
assert torch.allclose(all_positions[0, idx], all_positions[env_idx, idx]), (
7271
f"Anchor '{objects[idx].name}' has different positions across envs "
7372
f"(env 0: {all_positions[0, idx].tolist()}, env {env_idx}: {all_positions[env_idx, idx].tolist()})"
@@ -95,13 +94,13 @@ def __init__(
9594
self._optimizable_positions = None
9695

9796
@property
98-
def num_envs(self) -> int:
99-
"""Number of environments (leading dimension N)."""
100-
return self._num_envs
97+
def batch_size(self) -> int:
98+
"""Number of independent position sets (leading dimension of position tensors)."""
99+
return self._batch_size
101100

102101
@property
103102
def optimizable_positions(self) -> torch.Tensor | None:
104-
"""Tensor of optimizable positions (N, num_opt, 3), or None if all objects are anchors.
103+
"""Tensor of optimizable positions (batch_size, num_optimizable, 3), or None if all objects are anchors.
105104
106105
This is the tensor that should be passed to the optimizer.
107106
"""
@@ -124,15 +123,15 @@ def get_position(self, obj: ObjectBase) -> torch.Tensor:
124123
obj: The object to get position for.
125124
126125
Returns:
127-
Position tensor of shape (N, 3).
126+
Position tensor of shape (batch_size, 3).
128127
129128
Raises:
130129
KeyError: If object is not tracked by this state.
131130
RuntimeError: If requesting position for optimizable object when none exist.
132131
"""
133132
idx = self._obj_to_idx[obj]
134133
if idx in self._anchor_indices:
135-
return self._anchor_positions[idx].unsqueeze(0).expand(self._num_envs, 3)
134+
return self._anchor_positions[idx].unsqueeze(0).expand(self._batch_size, 3)
136135
if self._optimizable_positions is None:
137136
raise RuntimeError(f"No optimizable positions available for object '{obj.name}'")
138137
opt_idx = self._global_to_opt_idx[idx]
@@ -157,12 +156,12 @@ def get_final_positions(self) -> list[dict[ObjectBase, tuple[float, float, float
157156
pos_list = full.detach().cpu().tolist()
158157
return [
159158
{obj: tuple(pos_list[env_idx][obj_idx]) for obj_idx, obj in enumerate(self._all_objects)}
160-
for env_idx in range(self._num_envs)
159+
for env_idx in range(self._batch_size)
161160
]
162161

163162
def _reconstruct_all_positions(self) -> torch.Tensor:
164-
"""Reconstruct a full (N, num_objects, 3) tensor from anchor and optimizable parts."""
165-
full = self._anchor_pos_tensor.expand(self._num_envs, -1, -1).clone()
163+
"""Reconstruct a full (batch_size, num_objects, 3) tensor from anchor and optimizable parts."""
164+
full = self._anchor_pos_tensor.expand(self._batch_size, -1, -1).clone()
166165
if self._optimizable_positions is not None:
167166
full[:, self._opt_idx_tensor, :] = self._optimizable_positions
168167
return full

0 commit comments

Comments
 (0)