Skip to content

Commit 7a73e69

Browse files
committed
add tests for scheduler
1 parent d8ee610 commit 7a73e69

1 file changed

Lines changed: 184 additions & 0 deletions

File tree

tests/test_loss_scheduler.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from topoloss import TopoLoss, LaplacianPyramid
2+
from topoloss.scheduler import TauScheduler, ChainedTauScheduler
3+
import pytest
4+
import torch
5+
import torch.nn as nn
6+
7+
8+
def make_topo_loss():
9+
model = nn.Sequential(nn.Linear(30, 20))
10+
tl = TopoLoss(
11+
losses=[
12+
LaplacianPyramid.from_layer(
13+
model=model, layer=model[0], scale=1.0, factor_h=2.0, factor_w=2.0
14+
)
15+
]
16+
)
17+
return model, tl
18+
19+
20+
@pytest.mark.parametrize("num_steps", [10, 50])
21+
def test_linear_reaches_end_value(num_steps: int):
22+
"""Linear tau should equal end_value after num_steps steps."""
23+
_, tl = make_topo_loss()
24+
scheduler = TauScheduler(
25+
topo_loss=tl,
26+
start_value=0.0,
27+
end_value=1.0,
28+
num_steps=num_steps,
29+
mode="linear",
30+
verbose=False,
31+
)
32+
for _ in range(num_steps):
33+
scheduler.step()
34+
35+
assert abs(scheduler.get_current_tau() - 1.0) < 1e-5, (
36+
f"Expected tau=1.0 after {num_steps} steps, got {scheduler.get_current_tau():.6f}"
37+
)
38+
39+
40+
@pytest.mark.parametrize("num_steps", [10, 50])
41+
def test_cosine_decay_reaches_end_value(num_steps: int):
42+
"""Cosine decay tau should equal end_value after num_steps steps (requires start > end)."""
43+
_, tl = make_topo_loss()
44+
scheduler = TauScheduler(
45+
topo_loss=tl,
46+
start_value=1.0,
47+
end_value=0.0,
48+
num_steps=num_steps,
49+
mode="cosine_decay",
50+
verbose=False,
51+
)
52+
for _ in range(num_steps):
53+
scheduler.step()
54+
55+
assert abs(scheduler.get_current_tau() - 0.0) < 1e-5, (
56+
f"Expected tau=0.0 after {num_steps} steps, got {scheduler.get_current_tau():.6f}"
57+
)
58+
59+
60+
@pytest.mark.parametrize("mode,start,end", [
61+
("linear", 0.3, 1.0),
62+
("cosine_decay", 1.0, 0.0), # cosine_decay requires start > end
63+
])
64+
def test_scheduler_starts_at_start_value(mode: str, start: float, end: float):
65+
"""Tau should equal start_value before any steps."""
66+
_, tl = make_topo_loss()
67+
scheduler = TauScheduler(
68+
topo_loss=tl,
69+
start_value=start,
70+
end_value=end,
71+
num_steps=10,
72+
mode=mode,
73+
verbose=False,
74+
)
75+
assert abs(scheduler.get_current_tau() - start) < 1e-5, (
76+
f"Expected tau={start} before any steps, got {scheduler.get_current_tau():.6f}"
77+
)
78+
79+
80+
def test_linear_scheduler_is_monotone():
81+
"""Linear warmup should be strictly increasing."""
82+
_, tl = make_topo_loss()
83+
num_steps = 20
84+
scheduler = TauScheduler(
85+
topo_loss=tl,
86+
start_value=0.0,
87+
end_value=1.0,
88+
num_steps=num_steps,
89+
mode="linear",
90+
verbose=False,
91+
)
92+
taus = [scheduler.get_current_tau()]
93+
for _ in range(num_steps):
94+
scheduler.step()
95+
taus.append(scheduler.get_current_tau())
96+
97+
for i in range(len(taus) - 1):
98+
assert taus[i] <= taus[i + 1], (
99+
f"Linear scheduler not monotone at step {i}: {taus[i]:.4f} > {taus[i+1]:.4f}"
100+
)
101+
102+
103+
def test_cosine_decay_is_monotone():
104+
"""Cosine decay should be strictly decreasing."""
105+
_, tl = make_topo_loss()
106+
num_steps = 20
107+
scheduler = TauScheduler(
108+
topo_loss=tl,
109+
start_value=1.0,
110+
end_value=0.0,
111+
num_steps=num_steps,
112+
mode="cosine_decay",
113+
verbose=False,
114+
)
115+
taus = [scheduler.get_current_tau()]
116+
for _ in range(num_steps):
117+
scheduler.step()
118+
taus.append(scheduler.get_current_tau())
119+
120+
for i in range(len(taus) - 1):
121+
assert taus[i] >= taus[i + 1], (
122+
f"Cosine decay not monotone at step {i}: {taus[i]:.4f} < {taus[i+1]:.4f}"
123+
)
124+
125+
126+
def test_chained_scheduler_transitions():
127+
"""ChainedTauScheduler should run warmup then decay, hitting expected midpoint."""
128+
_, tl = make_topo_loss()
129+
num_steps = 100
130+
scheduler = ChainedTauScheduler(
131+
schedulers=[
132+
TauScheduler(
133+
topo_loss=tl,
134+
start_value=0.0,
135+
end_value=1.0,
136+
num_steps=num_steps // 2,
137+
mode="linear",
138+
verbose=False,
139+
),
140+
TauScheduler(
141+
topo_loss=tl,
142+
start_value=1.0,
143+
end_value=0.0,
144+
num_steps=num_steps // 2,
145+
mode="cosine_decay",
146+
verbose=False,
147+
),
148+
]
149+
)
150+
151+
taus = []
152+
for _ in range(num_steps):
153+
taus.append(scheduler.get_current_tau())
154+
scheduler.step()
155+
156+
# Should start near 0
157+
assert taus[0] < 0.1, f"Expected tau near 0 at start, got {taus[0]:.4f}"
158+
# Should peak near 1 at the midpoint
159+
assert taus[num_steps // 2 - 1] > 0.9, (
160+
f"Expected tau near 1 at midpoint, got {taus[num_steps // 2 - 1]:.4f}"
161+
)
162+
# Should end near 0
163+
assert taus[-1] < 0.1, f"Expected tau near 0 at end, got {taus[-1]:.4f}"
164+
165+
166+
def test_scheduler_does_not_exceed_bounds():
167+
"""Linear tau should never go outside [start_value, end_value], even past num_steps."""
168+
_, tl = make_topo_loss()
169+
num_steps = 30
170+
start, end = 0.0, 1.0
171+
scheduler = TauScheduler(
172+
topo_loss=tl,
173+
start_value=start,
174+
end_value=end,
175+
num_steps=num_steps,
176+
mode="linear",
177+
verbose=False,
178+
)
179+
for _ in range(num_steps + 5): # a few extra steps past the end
180+
tau = scheduler.get_current_tau()
181+
assert start - 1e-5 <= tau <= end + 1e-5, (
182+
f"Tau {tau:.4f} out of bounds [{start}, {end}]"
183+
)
184+
scheduler.step()

0 commit comments

Comments
 (0)