Skip to content

Commit 448aae3

Browse files
committed
stash ctrl_qsm
1 parent f444f94 commit 448aae3

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

flowrl/agent/online/ctrl/ctrl_qsm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,6 @@ def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray:
107107
return q.min(axis=0).mean()
108108
q_grad_fn = jax.vmap(jax.grad(get_q_value))
109109
q_grad = q_grad_fn(at, batch.obs)
110-
q_grad = alpha1 * q_grad - alpha2 * at
111110
eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6)
112111

113112
def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]:
@@ -123,6 +122,7 @@ def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarra
123122
return loss, {
124123
"loss/actor_loss": loss,
125124
"misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(),
125+
"misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(),
126126
}
127127

128128
new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn)

flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,5 @@ class CtrlQSMConfig(BaseAlgoConfig):
3737
ranking: bool
3838

3939
num_samples: int
40+
temp: float
4041
diffusion: QSMDiffusionConfig

scripts/dmc/ctrl_qsm.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Specify which GPUs to use
2+
GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use
3+
SEEDS=(0 1 2 3)
4+
NUM_EACH_GPU=3
5+
6+
PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]}))
7+
8+
TASKS=(
9+
"acrobot-swingup"
10+
"ball_in_cup-catch"
11+
"cartpole-balance"
12+
"cartpole-balance_sparse"
13+
"cartpole-swingup"
14+
"cartpole-swingup_sparse"
15+
"cheetah-run"
16+
"dog-run"
17+
"dog-stand"
18+
"dog-trot"
19+
"dog-walk"
20+
"finger-spin"
21+
"finger-turn_easy"
22+
"finger-turn_hard"
23+
"fish-swim"
24+
"hopper-hop"
25+
"hopper-stand"
26+
"humanoid-run"
27+
"humanoid-stand"
28+
"humanoid-walk"
29+
"pendulum-swingup"
30+
"quadruped-run"
31+
"quadruped-walk"
32+
"reacher-easy"
33+
"reacher-hard"
34+
"walker-run"
35+
"walker-stand"
36+
"walker-walk"
37+
)
38+
39+
SHARED_ARGS=(
40+
"algo=ctrl_qsm"
41+
"log.tag=default"
42+
"log.project=flow-rl"
43+
"log.entity=lambda-rl"
44+
)
45+
46+
run_task() {
47+
task=$1
48+
seed=$2
49+
slot=$3
50+
num_gpus=${#GPUS[@]}
51+
device_idx=$((slot % num_gpus))
52+
device=${GPUS[$device_idx]}
53+
echo "Running $env $seed on GPU $device"
54+
command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}"
55+
if [ -n "$DRY_RUN" ]; then
56+
echo $command
57+
else
58+
echo $command
59+
$command
60+
fi
61+
}
62+
63+
. env_parallel.bash
64+
if [ -n "$DRY_RUN" ]; then
65+
env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
66+
else
67+
env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
68+
fi

0 commit comments

Comments
 (0)