Skip to content

Commit 668b38f

Browse files
committed
qsm scripts
1 parent fa3cb0d commit 668b38f

File tree

4 files changed

+93
-0
lines changed

4 files changed

+93
-0
lines changed
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# @package _global_
2+
3+
algo:
4+
name: qsm
5+
critic_hidden_dims: [512, 512, 512]
6+
critic_lr: 0.0003
7+
discount: 0.99
8+
num_samples: 10
9+
ema: 0.005
10+
temp: 0.2
11+
diffusion:
12+
time_dim: 64
13+
mlp_hidden_dims: [512, 512, 512]
14+
lr: 0.0003
15+
end_lr: null
16+
lr_decay_steps: null
17+
lr_decay_begin: null
18+
steps: 20
19+
clip_sampler: true
20+
x_min: -1.0
21+
x_max: 1.0
22+
solver: ddpm

examples/online/main_dmc_offpolicy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"td7": TD7Agent,
2424
"sdac": SDACAgent,
2525
"dpmd": DPMDAgent,
26+
"qsm": QSMAgent,
2627
}
2728

2829
class OffPolicyTrainer():

flowrl/agent/online/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ..base import BaseAgent
22
from .dpmd import DPMDAgent
3+
from .idem import IDEMAgent
34
from .ppo import PPOAgent
45
from .qsm import QSMAgent
56
from .sac import SACAgent
@@ -16,4 +17,5 @@
1617
"DPMDAgent",
1718
"PPOAgent",
1819
"QSMAgent",
20+
"IDEMAgent"
1921
]

scripts/dmc/qsm.sh

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Specify which GPUs to use
2+
GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use
3+
SEEDS=(0 1 2 3 4)
4+
NUM_EACH_GPU=3
5+
6+
PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]}))
7+
8+
TASKS=(
9+
"acrobot-swingup"
10+
"ball_in_cup-catch"
11+
"cartpole-balance"
12+
"cartpole-balance_sparse"
13+
"cartpole-swingup"
14+
"cartpole-swingup_sparse"
15+
"cheetah-run"
16+
"dog-run"
17+
"dog-stand"
18+
"dog-trot"
19+
"dog-walk"
20+
"finger-spin"
21+
"finger-turn_easy"
22+
"finger-turn_hard"
23+
"fish-swim"
24+
"hopper-hop"
25+
"hopper-stand"
26+
"humanoid-run"
27+
"humanoid-stand"
28+
"humanoid-walk"
29+
"pendulum-swingup"
30+
"quadruped-run"
31+
"quadruped-walk"
32+
"reacher-easy"
33+
"reacher-hard"
34+
"walker-run"
35+
"walker-stand"
36+
"walker-walk"
37+
)
38+
39+
SHARED_ARGS=(
40+
"algo=qsm"
41+
"log.tag=default"
42+
"log.project=flow-rl"
43+
"log.entity=lamda-rl"
44+
)
45+
46+
run_task() {
47+
task=$1
48+
seed=$2
49+
slot=$3
50+
num_gpus=${#GPUS[@]}
51+
device_idx=$((slot % num_gpus))
52+
device=${GPUS[$device_idx]}
53+
echo "Running $env $seed on GPU $device"
54+
command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}"
55+
if [ -n "$DRY_RUN" ]; then
56+
echo $command
57+
else
58+
echo $command
59+
$command
60+
fi
61+
}
62+
63+
. env_parallel.bash
64+
if [ -n "$DRY_RUN" ]; then
65+
env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
66+
else
67+
env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
68+
fi

0 commit comments

Comments
 (0)