Skip to content

Commit d6a6976

Browse files
authored
Merge pull request #8 from IDEALLab/feature/seed_artifacts
Tag artifacts with seed
2 parents 88172d7 + b69f857 commit d6a6976

3 files changed

Lines changed: 11 additions & 10 deletions

File tree

engiopt/cgan_1d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,12 @@ def sample_designs(n_designs: int) -> th.Tensor:
291291

292292
th.save(ckpt_gen, "generator.pth")
293293
th.save(ckpt_disc, "discriminator.pth")
294-
artifact_gen = wandb.Artifact("generator", type="model")
294+
artifact_gen = wandb.Artifact(f"{args.algo}_generator", type="model")
295295
artifact_gen.add_file("generator.pth")
296-
artifact_disc = wandb.Artifact("discriminator", type="model")
296+
artifact_disc = wandb.Artifact(f"{args.algo}_discriminator", type="model")
297297
artifact_disc.add_file("discriminator.pth")
298298

299-
wandb.log_artifact(artifact_gen)
300-
wandb.log_artifact(artifact_disc)
299+
wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
300+
wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
301301

302302
wandb.finish()

engiopt/diffusion_1d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,9 @@ class Args:
175175
}
176176

177177
th.save(ckpt, "model.pth")
178-
artifact = wandb.Artifact("diffusion", type="model")
178+
artifact = wandb.Artifact(f"{args.algo}_model", type="model")
179179
artifact.add_file("model.pth")
180-
wandb.log_artifact(artifact)
180+
181+
wandb.log_artifact(artifact, aliases=[f"seed_{args.seed}"])
181182

182183
wandb.finish()

engiopt/gan_1d.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,12 @@ def forward(self, design: th.Tensor) -> th.Tensor: # noqa: D102
261261

262262
th.save(ckpt_gen, "generator.pth")
263263
th.save(ckpt_disc, "discriminator.pth")
264-
artifact_gen = wandb.Artifact("generator", type="model")
264+
artifact_gen = wandb.Artifact(f"{args.algo}_generator", type="model")
265265
artifact_gen.add_file("generator.pth")
266-
artifact_disc = wandb.Artifact("discriminator", type="model")
266+
artifact_disc = wandb.Artifact(f"{args.algo}_discriminator", type="model")
267267
artifact_disc.add_file("discriminator.pth")
268268

269-
wandb.log_artifact(artifact_gen)
270-
wandb.log_artifact(artifact_disc)
269+
wandb.log_artifact(artifact_gen, aliases=[f"seed_{args.seed}"])
270+
wandb.log_artifact(artifact_disc, aliases=[f"seed_{args.seed}"])
271271

272272
wandb.finish()

0 commit comments

Comments
 (0)