1+ from topoloss import TopoLoss , LaplacianPyramid
2+ from topoloss .scheduler import TauScheduler , ChainedTauScheduler
3+ import pytest
4+ import torch
5+ import torch .nn as nn
6+
7+
8+ def make_topo_loss ():
9+ model = nn .Sequential (nn .Linear (30 , 20 ))
10+ tl = TopoLoss (
11+ losses = [
12+ LaplacianPyramid .from_layer (
13+ model = model , layer = model [0 ], scale = 1.0 , factor_h = 2.0 , factor_w = 2.0
14+ )
15+ ]
16+ )
17+ return model , tl
18+
19+
20+ @pytest .mark .parametrize ("num_steps" , [10 , 50 ])
21+ def test_linear_reaches_end_value (num_steps : int ):
22+ """Linear tau should equal end_value after num_steps steps."""
23+ _ , tl = make_topo_loss ()
24+ scheduler = TauScheduler (
25+ topo_loss = tl ,
26+ start_value = 0.0 ,
27+ end_value = 1.0 ,
28+ num_steps = num_steps ,
29+ mode = "linear" ,
30+ verbose = False ,
31+ )
32+ for _ in range (num_steps ):
33+ scheduler .step ()
34+
35+ assert abs (scheduler .get_current_tau () - 1.0 ) < 1e-5 , (
36+ f"Expected tau=1.0 after { num_steps } steps, got { scheduler .get_current_tau ():.6f} "
37+ )
38+
39+
40+ @pytest .mark .parametrize ("num_steps" , [10 , 50 ])
41+ def test_cosine_decay_reaches_end_value (num_steps : int ):
42+ """Cosine decay tau should equal end_value after num_steps steps (requires start > end)."""
43+ _ , tl = make_topo_loss ()
44+ scheduler = TauScheduler (
45+ topo_loss = tl ,
46+ start_value = 1.0 ,
47+ end_value = 0.0 ,
48+ num_steps = num_steps ,
49+ mode = "cosine_decay" ,
50+ verbose = False ,
51+ )
52+ for _ in range (num_steps ):
53+ scheduler .step ()
54+
55+ assert abs (scheduler .get_current_tau () - 0.0 ) < 1e-5 , (
56+ f"Expected tau=0.0 after { num_steps } steps, got { scheduler .get_current_tau ():.6f} "
57+ )
58+
59+
60+ @pytest .mark .parametrize ("mode,start,end" , [
61+ ("linear" , 0.3 , 1.0 ),
62+ ("cosine_decay" , 1.0 , 0.0 ), # cosine_decay requires start > end
63+ ])
64+ def test_scheduler_starts_at_start_value (mode : str , start : float , end : float ):
65+ """Tau should equal start_value before any steps."""
66+ _ , tl = make_topo_loss ()
67+ scheduler = TauScheduler (
68+ topo_loss = tl ,
69+ start_value = start ,
70+ end_value = end ,
71+ num_steps = 10 ,
72+ mode = mode ,
73+ verbose = False ,
74+ )
75+ assert abs (scheduler .get_current_tau () - start ) < 1e-5 , (
76+ f"Expected tau={ start } before any steps, got { scheduler .get_current_tau ():.6f} "
77+ )
78+
79+
80+ def test_linear_scheduler_is_monotone ():
81+ """Linear warmup should be strictly increasing."""
82+ _ , tl = make_topo_loss ()
83+ num_steps = 20
84+ scheduler = TauScheduler (
85+ topo_loss = tl ,
86+ start_value = 0.0 ,
87+ end_value = 1.0 ,
88+ num_steps = num_steps ,
89+ mode = "linear" ,
90+ verbose = False ,
91+ )
92+ taus = [scheduler .get_current_tau ()]
93+ for _ in range (num_steps ):
94+ scheduler .step ()
95+ taus .append (scheduler .get_current_tau ())
96+
97+ for i in range (len (taus ) - 1 ):
98+ assert taus [i ] <= taus [i + 1 ], (
99+ f"Linear scheduler not monotone at step { i } : { taus [i ]:.4f} > { taus [i + 1 ]:.4f} "
100+ )
101+
102+
103+ def test_cosine_decay_is_monotone ():
104+ """Cosine decay should be strictly decreasing."""
105+ _ , tl = make_topo_loss ()
106+ num_steps = 20
107+ scheduler = TauScheduler (
108+ topo_loss = tl ,
109+ start_value = 1.0 ,
110+ end_value = 0.0 ,
111+ num_steps = num_steps ,
112+ mode = "cosine_decay" ,
113+ verbose = False ,
114+ )
115+ taus = [scheduler .get_current_tau ()]
116+ for _ in range (num_steps ):
117+ scheduler .step ()
118+ taus .append (scheduler .get_current_tau ())
119+
120+ for i in range (len (taus ) - 1 ):
121+ assert taus [i ] >= taus [i + 1 ], (
122+ f"Cosine decay not monotone at step { i } : { taus [i ]:.4f} < { taus [i + 1 ]:.4f} "
123+ )
124+
125+
126+ def test_chained_scheduler_transitions ():
127+ """ChainedTauScheduler should run warmup then decay, hitting expected midpoint."""
128+ _ , tl = make_topo_loss ()
129+ num_steps = 100
130+ scheduler = ChainedTauScheduler (
131+ schedulers = [
132+ TauScheduler (
133+ topo_loss = tl ,
134+ start_value = 0.0 ,
135+ end_value = 1.0 ,
136+ num_steps = num_steps // 2 ,
137+ mode = "linear" ,
138+ verbose = False ,
139+ ),
140+ TauScheduler (
141+ topo_loss = tl ,
142+ start_value = 1.0 ,
143+ end_value = 0.0 ,
144+ num_steps = num_steps // 2 ,
145+ mode = "cosine_decay" ,
146+ verbose = False ,
147+ ),
148+ ]
149+ )
150+
151+ taus = []
152+ for _ in range (num_steps ):
153+ taus .append (scheduler .get_current_tau ())
154+ scheduler .step ()
155+
156+ # Should start near 0
157+ assert taus [0 ] < 0.1 , f"Expected tau near 0 at start, got { taus [0 ]:.4f} "
158+ # Should peak near 1 at the midpoint
159+ assert taus [num_steps // 2 - 1 ] > 0.9 , (
160+ f"Expected tau near 1 at midpoint, got { taus [num_steps // 2 - 1 ]:.4f} "
161+ )
162+ # Should end near 0
163+ assert taus [- 1 ] < 0.1 , f"Expected tau near 0 at end, got { taus [- 1 ]:.4f} "
164+
165+
166+ def test_scheduler_does_not_exceed_bounds ():
167+ """Linear tau should never go outside [start_value, end_value], even past num_steps."""
168+ _ , tl = make_topo_loss ()
169+ num_steps = 30
170+ start , end = 0.0 , 1.0
171+ scheduler = TauScheduler (
172+ topo_loss = tl ,
173+ start_value = start ,
174+ end_value = end ,
175+ num_steps = num_steps ,
176+ mode = "linear" ,
177+ verbose = False ,
178+ )
179+ for _ in range (num_steps + 5 ): # a few extra steps past the end
180+ tau = scheduler .get_current_tau ()
181+ assert start - 1e-5 <= tau <= end + 1e-5 , (
182+ f"Tau { tau :.4f} out of bounds [{ start } , { end } ]"
183+ )
184+ scheduler .step ()
0 commit comments