Skip to content

Commit 40fca6d

Browse files
committed
enable loading and saving from jsons
1 parent 7a73e69 commit 40fca6d

1 file changed

Lines changed: 62 additions & 1 deletion

File tree

topoloss/scheduler.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
import math
2+
import json
3+
import os
4+
25
from .core import TopoLoss
36

47
valid_modes = ["linear", "cosine_decay"]
@@ -66,6 +69,32 @@ def step(self, current_step: int = None):
6669
print(f"[TauScheduler/{self.mode}] layer: {loss.layer_name} | {loss.scale:.4f} -> {value:.4f} | step {current_step}/{self.num_steps}")
6770
loss.scale = value
6871

72+
def save_json(self, filename: str):
73+
data = {
74+
"start_value": self.start_value,
75+
"end_value": self.end_value,
76+
"num_steps": self.num_steps,
77+
"mode": self.mode,
78+
}
79+
with open(filename, "w") as f:
80+
json.dump(data, f, indent=4)
81+
82+
83+
@classmethod
84+
def from_json(cls, filename: str, topo_loss: TopoLoss, verbose=False):
85+
assert isinstance(topo_loss, TopoLoss), f"topo_loss must be an instance of TopoLoss but got {type(topo_loss)}"
86+
assert os.path.exists(filename), f"File not found: {filename}"
87+
with open(filename, "r") as f:
88+
data = json.load(f)
89+
return cls(
90+
topo_loss=topo_loss,
91+
start_value=data["start_value"],
92+
end_value=data["end_value"],
93+
num_steps=data["num_steps"],
94+
mode=data["mode"],
95+
verbose=verbose,
96+
)
97+
6998

7099
class ChainedTauScheduler:
71100
"""Chains multiple TauSchedulers sequentially, mirroring PyTorch's ChainedScheduler pattern."""
@@ -96,4 +125,36 @@ def get_current_tau(self):
96125
if remaining < scheduler.num_steps:
97126
return scheduler.compute_value(remaining)
98127
remaining -= scheduler.num_steps
99-
return self.schedulers[-1].compute_value(self.schedulers[-1].num_steps)
128+
return self.schedulers[-1].compute_value(self.schedulers[-1].num_steps)
129+
130+
def save_json(self, filename: str):
131+
data = [
132+
{
133+
"start_value": scheduler.start_value,
134+
"end_value": scheduler.end_value,
135+
"num_steps": scheduler.num_steps,
136+
"mode": scheduler.mode,
137+
}
138+
for scheduler in self.schedulers
139+
]
140+
with open(filename, "w") as f:
141+
json.dump(data, f, indent=4)
142+
143+
@classmethod
144+
def from_json(cls, filename: str, topo_loss: TopoLoss, verbose=False):
145+
assert os.path.exists(filename), f"File not found: {filename}"
146+
with open(filename, "r") as f:
147+
data = json.load(f)
148+
schedulers = []
149+
for item in data:
150+
schedulers.append(
151+
TauScheduler(
152+
topo_loss=topo_loss,
153+
start_value=item["start_value"],
154+
end_value=item["end_value"],
155+
num_steps=item["num_steps"],
156+
mode=item["mode"],
157+
verbose=verbose,
158+
)
159+
)
160+
return cls(schedulers=schedulers)

0 commit comments

Comments
 (0)