11import math
2+ import json
3+ import os
4+
25from .core import TopoLoss
36
47valid_modes = ["linear" , "cosine_decay" ]
@@ -66,6 +69,32 @@ def step(self, current_step: int = None):
6669 print (f"[TauScheduler/{ self .mode } ] layer: { loss .layer_name } | { loss .scale :.4f} -> { value :.4f} | step { current_step } /{ self .num_steps } " )
6770 loss .scale = value
6871
72+ def save_json (self , filename : str ):
73+ data = {
74+ "start_value" : self .start_value ,
75+ "end_value" : self .end_value ,
76+ "num_steps" : self .num_steps ,
77+ "mode" : self .mode ,
78+ }
79+ with open (filename , "w" ) as f :
80+ json .dump (data , f , indent = 4 )
81+
82+
83+ @classmethod
84+ def from_json (cls , filename : str , topo_loss : TopoLoss , verbose = False ):
85+ assert isinstance (topo_loss , TopoLoss ), f"topo_loss must be an instance of TopoLoss but got { type (topo_loss )} "
86+ assert os .path .exists (filename ), f"File not found: { filename } "
87+ with open (filename , "r" ) as f :
88+ data = json .load (f )
89+ return cls (
90+ topo_loss = topo_loss ,
91+ start_value = data ["start_value" ],
92+ end_value = data ["end_value" ],
93+ num_steps = data ["num_steps" ],
94+ mode = data ["mode" ],
95+ verbose = verbose ,
96+ )
97+
6998
7099class ChainedTauScheduler :
71100 """Chains multiple TauSchedulers sequentially, mirroring PyTorch's ChainedScheduler pattern."""
@@ -96,4 +125,36 @@ def get_current_tau(self):
96125 if remaining < scheduler .num_steps :
97126 return scheduler .compute_value (remaining )
98127 remaining -= scheduler .num_steps
99- return self .schedulers [- 1 ].compute_value (self .schedulers [- 1 ].num_steps )
128+ return self .schedulers [- 1 ].compute_value (self .schedulers [- 1 ].num_steps )
129+
130+ def save_json (self , filename : str ):
131+ data = [
132+ {
133+ "start_value" : scheduler .start_value ,
134+ "end_value" : scheduler .end_value ,
135+ "num_steps" : scheduler .num_steps ,
136+ "mode" : scheduler .mode ,
137+ }
138+ for scheduler in self .schedulers
139+ ]
140+ with open (filename , "w" ) as f :
141+ json .dump (data , f , indent = 4 )
142+
143+ @classmethod
144+ def from_json (cls , filename : str , topo_loss : TopoLoss , verbose = False ):
145+ assert os .path .exists (filename ), f"File not found: { filename } "
146+ with open (filename , "r" ) as f :
147+ data = json .load (f )
148+ schedulers = []
149+ for item in data :
150+ schedulers .append (
151+ TauScheduler (
152+ topo_loss = topo_loss ,
153+ start_value = item ["start_value" ],
154+ end_value = item ["end_value" ],
155+ num_steps = item ["num_steps" ],
156+ mode = item ["mode" ],
157+ verbose = verbose ,
158+ )
159+ )
160+ return cls (schedulers = schedulers )
0 commit comments