-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
63 lines (56 loc) · 1.69 KB
/
utils.py
File metadata and controls
63 lines (56 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import json
import torch
import pandas as pd
import matplotlib.pyplot as plt
def load_json(file_path):
"""
Function for loading a JSON file
"""
with open(file_path, "r", encoding="utf-8") as f:
json_file = json.load(f)
return json_file
def save_json(text, filepath):
"""
Function for saving text to a JSON file
"""
with open(filepath, "w", encoding="utf-8") as f:
json.dump(text, f, ensure_ascii=False, indent=2)
def load_txt(filepath:str):
"""
Function for loading a txt file
"""
# Reading the selected data
with open(filepath, "r", encoding="utf-8") as f:
text = f.read()
return text
def save_data(out_src:str, text:str):
with open(out_src, "w", encoding="utf-8") as f:
f.write(text)
def plot_traing_result(csv_file):
"""
Function for reading and plotting traning result
"""
# read csv file
df = pd.read_csv(csv_file)
# === Plot Loss ===
plt.figure(figsize=(10, 5))
plt.plot(df["epoch"], df["train_loss"], marker="o", label="Train Loss")
plt.plot(df["epoch"], df["val_loss"], marker="o", label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Validation Loss")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()
# === Plot Perplexity ===
plt.figure(figsize=(10, 5))
plt.plot(df["epoch"], df["train_ppl"], marker="o", label="Train Perplexity")
plt.plot(df["epoch"], df["val_ppl"], marker="o", label="Validation Perplexity")
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.title("Training vs Validation Perplexity")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()