-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vae.py
More file actions
58 lines (45 loc) · 1.54 KB
/
train_vae.py
File metadata and controls
58 lines (45 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import annotations
import os
import shutil
from pathlib import Path
from backend.app.training.runner import train_run
def train() -> None:
os.makedirs("outputs", exist_ok=True)
run_dir = Path("outputs")
config = {
"model_type": "gmvae",
"epochs": 10,
"batch_size": 128,
"learning_rate": 1e-3,
"latent_dim": 64,
"capacity": 32,
"num_components": 10,
"kld_weight": 1.0,
"checkpoint_interval_epochs": 1,
"save_final_checkpoint": True,
"final_checkpoint_name": "final.pth",
"latest_recon_name": "recon_latest.png",
"dataset_name": "mnist",
}
print("Starting GM-VAE training...")
result = train_run(
config=config,
run_dir=run_dir,
emit=_print_event,
is_cancel_requested=lambda: False,
)
if result.get("status") != "completed":
raise RuntimeError(f"GM-VAE training did not complete: {result}")
# Preserve legacy output filename
final_ckpt = run_dir / "checkpoints" / "final.pth"
shutil.copyfile(final_ckpt, run_dir / "vae_model.pth")
print("Training complete. Model saved to outputs/vae_model.pth")
def _print_event(event_type: str, payload: dict) -> None:
if event_type == "train.metrics":
epoch = payload["epoch"]
total = payload["total_loss"]
recon = payload["recon_loss"]
kld = payload["kld_loss"]
print(f"Epoch [{epoch}] Total: {total:.4f} Recon: {recon:.4f} KLD: {kld:.4f}")
if __name__ == "__main__":
train()