File tree Expand file tree Collapse file tree 4 files changed +93
-0
lines changed
Expand file tree Collapse file tree 4 files changed +93
-0
lines changed Original file line number Diff line number Diff line change 1+ # @package _global_
2+
3+ algo :
4+ name : qsm
5+ critic_hidden_dims : [512, 512, 512]
6+ critic_lr : 0.0003
7+ discount : 0.99
8+ num_samples : 10
9+ ema : 0.005
10+ temp : 0.2
11+ diffusion :
12+ time_dim : 64
13+ mlp_hidden_dims : [512, 512, 512]
14+ lr : 0.0003
15+ end_lr : null
16+ lr_decay_steps : null
17+ lr_decay_begin : null
18+ steps : 20
19+ clip_sampler : true
20+ x_min : -1.0
21+ x_max : 1.0
22+ solver : ddpm
Original file line number Diff line number Diff line change 2323 "td7" : TD7Agent ,
2424 "sdac" : SDACAgent ,
2525 "dpmd" : DPMDAgent ,
26+ "qsm" : QSMAgent ,
2627}
2728
2829class OffPolicyTrainer ():
Original file line number Diff line number Diff line change 11from ..base import BaseAgent
22from .dpmd import DPMDAgent
3+ from .idem import IDEMAgent
34from .ppo import PPOAgent
45from .qsm import QSMAgent
56from .sac import SACAgent
1617 "DPMDAgent" ,
1718 "PPOAgent" ,
1819 "QSMAgent" ,
20+ "IDEMAgent"
1921]
Original file line number Diff line number Diff line change 1+ # Specify which GPUs to use
2+ GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use
3+ SEEDS=(0 1 2 3 4)
4+ NUM_EACH_GPU=3
5+
6+ PARALLEL=$(( NUM_EACH_GPU * ${# GPUS[@]} ))
7+
8+ TASKS=(
9+ " acrobot-swingup"
10+ " ball_in_cup-catch"
11+ " cartpole-balance"
12+ " cartpole-balance_sparse"
13+ " cartpole-swingup"
14+ " cartpole-swingup_sparse"
15+ " cheetah-run"
16+ " dog-run"
17+ " dog-stand"
18+ " dog-trot"
19+ " dog-walk"
20+ " finger-spin"
21+ " finger-turn_easy"
22+ " finger-turn_hard"
23+ " fish-swim"
24+ " hopper-hop"
25+ " hopper-stand"
26+ " humanoid-run"
27+ " humanoid-stand"
28+ " humanoid-walk"
29+ " pendulum-swingup"
30+ " quadruped-run"
31+ " quadruped-walk"
32+ " reacher-easy"
33+ " reacher-hard"
34+ " walker-run"
35+ " walker-stand"
36+ " walker-walk"
37+ )
38+
39+ SHARED_ARGS=(
40+ " algo=qsm"
41+ " log.tag=default"
42+ " log.project=flow-rl"
43+ " log.entity=lamda-rl"
44+ )
45+
46+ run_task () {
47+ task=$1
48+ seed=$2
49+ slot=$3
50+ num_gpus=${# GPUS[@]}
51+ device_idx=$(( slot % num_gpus))
52+ device=${GPUS[$device_idx]}
53+ echo " Running $env $seed on GPU $device "
54+ command=" python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]} "
55+ if [ -n " $DRY_RUN " ]; then
56+ echo $command
57+ else
58+ echo $command
59+ $command
60+ fi
61+ }
62+
63+ . env_parallel.bash
64+ if [ -n " $DRY_RUN " ]; then
65+ env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
66+ else
67+ env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
68+ fi
You can’t perform that action at this time.
0 commit comments