Skip to content

Commit 40ec550

Browse files
committed
separate q network works
1 parent f148d68 commit 40ec550

File tree

10 files changed

+684
-108
lines changed

10 files changed

+684
-108
lines changed

examples/online/config/dmc/algo/aca.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,4 @@ algo:
3838
linear: false
3939
ranking: true
4040

41-
norm_obs: true
41+
# norm_obs: true
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# @package _global_
2+
3+
algo:
4+
name: aca
5+
target_update_freq: 1
6+
feature_dim: 512
7+
rff_dim: 1024
8+
critic_hidden_dims: [512, 512]
9+
reward_hidden_dims: [512, 512]
10+
phi_hidden_dims: [512, 512]
11+
mu_hidden_dims: [512, 512]
12+
ctrl_coef: 1.0
13+
reward_coef: 1.0
14+
critic_coef: 1.0
15+
critic_activation: elu # not used
16+
back_critic_grad: false
17+
feature_lr: 0.0001
18+
critic_lr: 0.0003
19+
discount: 0.99
20+
num_samples: 10
21+
ema: 0.005
22+
feature_ema: 0.005
23+
clip_grad_norm: null
24+
temp: 0.2
25+
diffusion:
26+
time_dim: 64
27+
mlp_hidden_dims: [512, 512, 512]
28+
lr: 0.0003
29+
end_lr: null
30+
lr_decay_steps: null
31+
lr_decay_begin: null
32+
steps: 20
33+
clip_sampler: true
34+
x_min: -5.0
35+
x_max: 5.0
36+
# solver: ddpm
37+
solver: ddpm
38+
num_noises: 25
39+
linear: false
40+
ranking: true
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# @package _global_
2+
3+
algo:
4+
name: qsm
5+
critic_hidden_dims: [512, 512, 512]
6+
critic_activation: elu
7+
critic_lr: 0.0003
8+
discount: 0.99
9+
num_samples: 10
10+
ema: 0.005
11+
temp: 0.2
12+
diffusion:
13+
time_dim: 64
14+
mlp_hidden_dims: [512, 512, 512]
15+
lr: 0.0003
16+
end_lr: null
17+
lr_decay_steps: null
18+
lr_decay_begin: null
19+
steps: 20
20+
clip_sampler: true
21+
x_min: -5.0
22+
x_max: 5.0
23+
solver: ddpm
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# @package _global_
2+
3+
algo:
4+
name: sdac
5+
critic_hidden_dims: [256, 256]
6+
critic_lr: 0.0003
7+
discount: 0.99
8+
num_samples: 10
9+
num_reverse_samples: 500
10+
ema: 0.005
11+
temp: 0.2
12+
diffusion:
13+
time_dim: 64
14+
mlp_hidden_dims: [256, 256]
15+
lr: 0.0003
16+
end_lr: null
17+
lr_decay_steps: null
18+
lr_decay_begin: null
19+
steps: 20
20+
clip_sampler: false
21+
x_min: -1.0
22+
x_max: 1.0
23+
solver: ddpm

examples/toy2d/main_toy2d.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import hydra
55
import omegaconf
6-
import wandb
76
from omegaconf import OmegaConf
87
from tqdm import trange
98

9+
import wandb
1010
from examples.toy2d.utils import compute_metrics, plot_data, plot_energy, plot_sample
1111
from flowrl.agent.offline import *
1212
from flowrl.agent.online import *
@@ -19,6 +19,9 @@
1919
SUPPORTED_AGENTS: Dict[str, Type[BaseAgent]] = {
2020
"bdpo": BDPOAgent,
2121
"dac": DACAgent,
22+
"qsm": QSMAgent,
23+
"sdac": SDACAgent,
24+
"aca": ACAAgent,
2225
}
2326

2427
class Trainer():
@@ -30,14 +33,15 @@ def __init__(self, cfg: Config):
3033
log_dir="/".join([cfg.log.dir, cfg.algo.name, cfg.log.tag, cfg.task]),
3134
name="seed"+str(cfg.seed),
3235
logger_config={
33-
"TensorboardLogger": {"activate": True},
34-
"WandbLogger": {
35-
"activate": True,
36-
"config": OmegaConf.to_container(cfg),
37-
"settings": wandb.Settings(_disable_stats=True),
38-
"project": cfg.log.project,
39-
"entity": cfg.log.entity
40-
} if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False},
36+
"CsvLogger": {"activate": True},
37+
# "TensorboardLogger": {"activate": True},
38+
# "WandbLogger": {
39+
# "activate": True,
40+
# "config": OmegaConf.to_container(cfg),
41+
# "settings": wandb.Settings(_disable_stats=True),
42+
# "project": cfg.log.project,
43+
# "entity": cfg.log.entity
44+
# } if ("project" in cfg.log and "entity" in cfg.log) else {"activate": False},
4145
}
4246
)
4347
self.ckpt_save_dir = os.path.join(self.logger.log_dir, "ckpt")

examples/toy2d/utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from flowrl.agent.base import BaseAgent
1111
from flowrl.agent.offline import *
12+
from flowrl.agent.online import *
1213
from flowrl.dataset.toy2d import Toy2dDataset, inf_train_gen
1314

1415
SAMPLE_GRAPH_SIZE = 2000
@@ -108,6 +109,9 @@ def plot_energy(out_dir, task: str, agent: BaseAgent):
108109
vmin = e.min()
109110
vmax = e.max()
110111

112+
def default_plot():
113+
pass
114+
111115
def bdpo_plot():
112116
tt = [0, 1, 3, 5, 10, 20, 30, 40, 50]
113117
plt.figure(figsize=(30, 3.0))
@@ -168,12 +172,54 @@ def dac_plot():
168172
plt.close()
169173
tqdm.write(f"Saved value plot to {saveto}")
170174

175+
def aca_plot():
176+
tt = [0, 1, 3, 5, 10, 20]
177+
plt.figure(figsize=(20, 3.0))
178+
axes = []
179+
for i, t in enumerate(tt):
180+
plt.subplot(1, len(tt), i+1)
181+
if t == 0:
182+
model = agent.critic_target
183+
c = model(zero, id_matrix).mean(axis=0).reshape(90, 90)
184+
else:
185+
model = agent.value_target
186+
t_input = np.ones((90*90, 1)) * t
187+
c = model(zero, id_matrix, t_input).mean(axis=0).reshape(90, 90)
188+
plt.gca().set_aspect("equal", adjustable="box")
189+
plt.xlim(0, 89)
190+
plt.ylim(0, 89)
191+
if i == 0:
192+
mappable = plt.imshow(
193+
c, origin="lower", vmin=vmin, vmax=vmax,
194+
cmap="viridis", rasterized=True
195+
)
196+
plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4])
197+
else:
198+
plt.imshow(
199+
c, origin="lower", vmin=vmin, vmax=vmax,
200+
cmap="viridis", rasterized=True
201+
)
202+
plt.yticks(ticks=[5, 25, 45, 65, 85], labels=[None, None, None, None, None])
203+
204+
axes.append(plt.gca())
205+
plt.xticks(ticks=[5, 25, 45, 65, 85], labels=[-4, -2, 0, 2, 4])
206+
plt.title(f't={t}')
207+
plt.tight_layout()
208+
cbar = plt.gcf().colorbar(mappable, ax=axes, fraction=0.1, pad=0.02, aspect=12)
209+
plt.gcf().axes[-1].yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
210+
saveto = os.path.join(out_dir, "qt_space.png")
211+
plt.savefig(saveto, dpi=300)
212+
plt.close()
213+
tqdm.write(f"Saved value plot to {saveto}")
214+
171215
if isinstance(agent, BDPOAgent):
172216
bdpo_plot()
173217
elif isinstance(agent, DACAgent):
174218
dac_plot()
219+
elif isinstance(agent, ACAAgent):
220+
aca_plot()
175221
else:
176-
raise NotImplementedError(f"Plotting for {type(agent)} is not implemented")
222+
default_plot()
177223

178224

179225

0 commit comments

Comments
 (0)