-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_inv_dec_loss_train.py
More file actions
54 lines (47 loc) · 1.91 KB
/
plot_inv_dec_loss_train.py
File metadata and controls
54 lines (47 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import re
import matplotlib.pyplot as plt
import argparse
import os
os.makedirs("plots", exist_ok=True)
parser = argparse.ArgumentParser(description="Plot training loss curves for inverse decoder.")
parser.add_argument("--log_file", type=str, help="Path to the log file to parse.")
args = parser.parse_args()
log_file = args.log_file
job_id = re.search(r"inverse_decoder_(\d+)\.out", log_file).group(1)
losses = {"var": [], "rar": []}
val_losses = {"var": [], "rar": []}
current_family = None
with open(log_file, "r") as f:
for line in f:
fam_match = re.search(r"\[Training\] Starting Inverse Decoder for (\w+)", line)
if fam_match:
current_family = fam_match.group(1)
epoch_match = re.search(r"Epoch\s+(\d+):\s+Loss\s+([0-9.]+)", line)
val_match = re.search(r"Validation Loss:\s+([0-9.]+)", line)
if epoch_match and current_family:
epoch = int(epoch_match.group(1))
loss = float(epoch_match.group(2))
losses[current_family].append((epoch, loss))
if val_match and current_family:
val_loss = float(val_match.group(1))
# Use the same epoch as last training loss entry
if losses[current_family]:
val_losses[current_family].append((epoch, val_loss))
plt.figure(figsize=(12, 5))
families = ["rar", "var"]
for i, family in enumerate(families):
plt.subplot(1, 2, i+1)
if losses[family]:
epochs, loss_vals = zip(*losses[family])
plt.plot(epochs, loss_vals, marker="o", label="Train")
if val_losses[family]:
epochs, val_loss_vals = zip(*val_losses[family])
plt.plot(epochs, val_loss_vals, marker="x", linestyle="--", label="Val")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title(f"{family.upper()} Inverse Decoder Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f"plots/inverse_decoder_loss_curve_{job_id}.png")
plt.show()