Skip to content

Commit ae7f36c

Browse files
add switch scheduler callback
1 parent 893d39b commit ae7f36c

File tree

5 files changed

+146
-0
lines changed

5 files changed

+146
-0
lines changed

docs/source/_rst/_code.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ Callbacks
256256

257257
Processing callback <callback/processing_callback.rst>
258258
Optimizer callback <callback/optimizer_callback.rst>
259+
Switch Scheduler <callback/switch_scheduler.rst>
259260
R3 Refinment callback <callback/refinement/r3_refinement.rst>
260261
Refinment Interface callback <callback/refinement/refinement_interface.rst>
261262
Normalizer callback <callback/normalizer_data_callback.rst>
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Switch Scheduler
2+
=====================
3+
4+
.. currentmodule:: pina.callback.switch_scheduler
5+
.. autoclass:: SwitchScheduler
6+
:members:
7+
:show-inheritance:

pina/callback/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
__all__ = [
44
"SwitchOptimizer",
5+
"SwitchScheduler",
56
"MetricTracker",
67
"PINAProgressBar",
78
"R3Refinement",
@@ -12,3 +13,4 @@
1213
from .processing_callback import MetricTracker, PINAProgressBar
1314
from .refinement import R3Refinement
1415
from .normalizer_data_callback import NormalizerDataCallback
16+
from .switch_scheduler import SwitchScheduler

pina/callback/switch_scheduler.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Module for the SwitchScheduler callback."""
2+
3+
from lightning.pytorch.callbacks import Callback
4+
from ..optim import TorchScheduler
5+
from ..utils import check_consistency, check_positive_integer
6+
7+
8+
class SwitchScheduler(Callback):
9+
"""
10+
Callback to switch scheduler during training.
11+
"""
12+
13+
def __init__(self, new_schedulers, epoch_switch):
14+
"""
15+
This callback allows switching between different schedulers during
16+
training, enabling the exploration of multiple optimization strategies
17+
without interrupting the training process.
18+
19+
:param new_schedulers: The scheduler or list of schedulers to switch to.
20+
Use a single scheduler for single-model solvers, or a list of
21+
schedulers when working with multiple models.
22+
:type new_schedulers: pina.optim.TorchScheduler |
23+
list[pina.optim.TorchScheduler]
24+
:param int epoch_switch: The epoch at which the scheduler switch occurs.
25+
:raise AssertionError: If epoch_switch is less than 1.
26+
:raise ValueError: If each scheduler in ``new_schedulers`` is not an
27+
instance of :class:`pina.optim.TorchScheduler`.
28+
29+
Example:
30+
>>> scheduler = TorchScheduler(
31+
>>> torch.optim.lr_scheduler.StepLR, step_size=5
32+
>>> )
33+
>>> switch_callback = SwitchScheduler(
34+
>>> new_schedulers=scheduler, epoch_switch=10
35+
>>> )
36+
"""
37+
super().__init__()
38+
39+
# Check if epoch_switch is greater than 1
40+
check_positive_integer(epoch_switch - 1, strict=True)
41+
42+
# If new_schedulers is not a list, convert it to a list
43+
if not isinstance(new_schedulers, list):
44+
new_schedulers = [new_schedulers]
45+
46+
# Check consistency
47+
for scheduler in new_schedulers:
48+
check_consistency(scheduler, TorchScheduler)
49+
50+
# Store the new schedulers and epoch switch
51+
self._new_schedulers = new_schedulers
52+
self._epoch_switch = epoch_switch
53+
54+
def on_train_epoch_start(self, trainer, __):
55+
"""
56+
Switch the scheduler at the start of the specified training epoch.
57+
58+
:param lightning.pytorch.Trainer trainer: The trainer object managing
59+
the training process.
60+
:param __: Placeholder argument (not used).
61+
"""
62+
# Check if the current epoch matches the switch epoch
63+
if trainer.current_epoch == self._epoch_switch:
64+
schedulers = []
65+
66+
# Hook the new schedulers to the model parameters
67+
for idx, scheduler in enumerate(self._new_schedulers):
68+
scheduler.hook(trainer.solver._pina_optimizers[idx])
69+
schedulers.append(scheduler)
70+
71+
# Update the trainer's scheduler configs
72+
trainer.lr_scheduler_configs[idx].scheduler = scheduler.instance
73+
74+
# Update the solver's schedulers
75+
trainer.solver._pina_schedulers = schedulers
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import torch
2+
import pytest
3+
4+
from pina.solver import PINN
5+
from pina.trainer import Trainer
6+
from pina.model import FeedForward
7+
from pina.optim import TorchScheduler
8+
from pina.callback import SwitchScheduler
9+
from pina.problem.zoo import Poisson2DSquareProblem as Poisson
10+
11+
12+
# Define the problem
13+
problem = Poisson()
14+
problem.discretise_domain(10)
15+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
16+
17+
# Define the scheduler
18+
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=0.1)
19+
20+
# Initialize the solver
21+
solver = PINN(problem=problem, model=model, scheduler=scheduler)
22+
23+
# Define new schedulers for testing
24+
step = TorchScheduler(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.1)
25+
exp = TorchScheduler(torch.optim.lr_scheduler.ExponentialLR, gamma=0.9)
26+
27+
28+
@pytest.mark.parametrize("epoch_switch", [5, 10])
29+
@pytest.mark.parametrize("new_sched", [step, exp])
30+
def test_switch_scheduler_constructor(new_sched, epoch_switch):
31+
32+
# Constructor
33+
SwitchScheduler(new_schedulers=new_sched, epoch_switch=epoch_switch)
34+
35+
# Should fail if epoch_switch is less than 1
36+
with pytest.raises(AssertionError):
37+
SwitchScheduler(new_schedulers=new_sched, epoch_switch=0)
38+
39+
40+
@pytest.mark.parametrize("epoch_switch", [5, 10])
41+
@pytest.mark.parametrize("new_sched", [step, exp])
42+
def test_switch_scheduler_routine(new_sched, epoch_switch):
43+
44+
# Initialize the trainer
45+
switch_sched_callback = SwitchScheduler(
46+
new_schedulers=new_sched, epoch_switch=epoch_switch
47+
)
48+
trainer = Trainer(
49+
solver=solver,
50+
callbacks=switch_sched_callback,
51+
accelerator="cpu",
52+
max_epochs=epoch_switch + 2,
53+
)
54+
trainer.train()
55+
56+
# Check that the solver and trainer strategy schedulers have been updated
57+
assert solver.scheduler.instance.__class__ == new_sched.instance.__class__
58+
assert (
59+
trainer.lr_scheduler_configs[0].scheduler.__class__
60+
== new_sched.instance.__class__
61+
)

0 commit comments

Comments
 (0)