Skip to content

Commit fd680a5

Browse files
committed
set RL-DAS logic to follow strictly the one of tis original version
1 parent 5a8925b commit fd680a5

8 files changed

Lines changed: 1036 additions & 231 deletions

File tree

agents/rl_das/env.py

Lines changed: 209 additions & 215 deletions
Large diffs are not rendered by default.

agents/rl_das/network.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
77
Architecture (following Guo et al. 2024):
88
9-
Input: flat obs = [features_6d | best_move_0_dim | worst_move_0_dim | ...]
9+
Input: flat obs = [features_9d | best_move_0_dim | worst_move_0_dim | ...]
1010
1111
For each of the 2*n_opt movement blocks:
1212
embedder_k : Linear(dim, 64) -> ReLU -> Linear(64, 1) -> ReLU
1313
14-
backbone_input = cat(features_6d, *[emb_k(move_k) for k]) shape: (6+2*n_opt,)
15-
backbone : Linear(6+2*n_opt, 64) -> Tanh -> Linear(64, 16) -> Tanh
14+
backbone_input = cat(features_9d, *[emb_k(move_k) for k]) shape: (9+2*n_opt,)
15+
backbone : Linear(9+2*n_opt, 64) -> Tanh -> Linear(64, 16) -> Tanh
1616
1717
Actor head : Linear(16, n_opt) -> Softmax
1818
Critic head : Linear(16, 1)
@@ -42,7 +42,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4242
class _RLDASBackbone(nn.Module):
4343
"""Shared feature extractor used by both Actor and Critic."""
4444

45-
N_FEATURES = 6 # must match env.RLDASEnv.N_FEATURES
45+
N_FEATURES = 9 # must match env.RLDASEnv.N_FEATURES
4646

4747
def __init__(self, dim: int, n_opt: int) -> None:
4848
super().__init__()

0 commit comments

Comments
 (0)