diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index f23817376..9962ebebb 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -256,6 +256,7 @@ Callbacks Processing callback Optimizer callback + Switch Scheduler R3 Refinment callback Refinment Interface callback Normalizer callback diff --git a/docs/source/_rst/callback/switch_scheduler.rst b/docs/source/_rst/callback/switch_scheduler.rst new file mode 100644 index 000000000..0e69ef0fb --- /dev/null +++ b/docs/source/_rst/callback/switch_scheduler.rst @@ -0,0 +1,7 @@ +Switch Scheduler +===================== + +.. currentmodule:: pina.callback.switch_scheduler +.. autoclass:: SwitchScheduler + :members: + :show-inheritance: \ No newline at end of file diff --git a/pina/callback/__init__.py b/pina/callback/__init__.py index f71a89f91..f2057257e 100644 --- a/pina/callback/__init__.py +++ b/pina/callback/__init__.py @@ -2,6 +2,7 @@ __all__ = [ "SwitchOptimizer", + "SwitchScheduler", "MetricTracker", "PINAProgressBar", "R3Refinement", @@ -12,3 +13,4 @@ from .processing_callback import MetricTracker, PINAProgressBar from .refinement import R3Refinement from .normalizer_data_callback import NormalizerDataCallback +from .switch_scheduler import SwitchScheduler diff --git a/pina/callback/switch_scheduler.py b/pina/callback/switch_scheduler.py new file mode 100644 index 000000000..22ae8bd08 --- /dev/null +++ b/pina/callback/switch_scheduler.py @@ -0,0 +1,75 @@ +"""Module for the SwitchScheduler callback.""" + +from lightning.pytorch.callbacks import Callback +from ..optim import TorchScheduler +from ..utils import check_consistency, check_positive_integer + + +class SwitchScheduler(Callback): + """ + Callback to switch scheduler during training. + """ + + def __init__(self, new_schedulers, epoch_switch): + """ + This callback allows switching between different schedulers during + training, enabling the exploration of multiple optimization strategies + without interrupting the training process. + + :param new_schedulers: The scheduler or list of schedulers to switch to. + Use a single scheduler for single-model solvers, or a list of + schedulers when working with multiple models. + :type new_schedulers: pina.optim.TorchScheduler | + list[pina.optim.TorchScheduler] + :param int epoch_switch: The epoch at which the scheduler switch occurs. + :raise AssertionError: If epoch_switch is less than 1. + :raise ValueError: If each scheduler in ``new_schedulers`` is not an + instance of :class:`pina.optim.TorchScheduler`. + + Example: + >>> scheduler = TorchScheduler( + >>> torch.optim.lr_scheduler.StepLR, step_size=5 + >>> ) + >>> switch_callback = SwitchScheduler( + >>> new_schedulers=scheduler, epoch_switch=10 + >>> ) + """ + super().__init__() + + # Check if epoch_switch is greater than 1 + check_positive_integer(epoch_switch - 1, strict=True) + + # If new_schedulers is not a list, convert it to a list + if not isinstance(new_schedulers, list): + new_schedulers = [new_schedulers] + + # Check consistency + for scheduler in new_schedulers: + check_consistency(scheduler, TorchScheduler) + + # Store the new schedulers and epoch switch + self._new_schedulers = new_schedulers + self._epoch_switch = epoch_switch + + def on_train_epoch_start(self, trainer, __): + """ + Switch the scheduler at the start of the specified training epoch. + + :param lightning.pytorch.Trainer trainer: The trainer object managing + the training process. + :param __: Placeholder argument (not used). + """ + # Check if the current epoch matches the switch epoch + if trainer.current_epoch == self._epoch_switch: + schedulers = [] + + # Hook the new schedulers to the model parameters + for idx, scheduler in enumerate(self._new_schedulers): + scheduler.hook(trainer.solver._pina_optimizers[idx]) + schedulers.append(scheduler) + + # Update the trainer's scheduler configs + trainer.lr_scheduler_configs[idx].scheduler = scheduler.instance + + # Update the solver's schedulers + trainer.solver._pina_schedulers = schedulers diff --git a/tests/test_callback/test_switch_scheduler.py b/tests/test_callback/test_switch_scheduler.py new file mode 100644 index 000000000..df91f0c59 --- /dev/null +++ b/tests/test_callback/test_switch_scheduler.py @@ -0,0 +1,61 @@ +import torch +import pytest + +from pina.solver import PINN +from pina.trainer import Trainer +from pina.model import FeedForward +from pina.optim import TorchScheduler +from pina.callback import SwitchScheduler +from pina.problem.zoo import Poisson2DSquareProblem as Poisson + + +# Define the problem +problem = Poisson() +problem.discretise_domain(10) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) + +# Define the scheduler +scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=0.1) + +# Initialize the solver +solver = PINN(problem=problem, model=model, scheduler=scheduler) + +# Define new schedulers for testing +step = TorchScheduler(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.1) +exp = TorchScheduler(torch.optim.lr_scheduler.ExponentialLR, gamma=0.9) + + +@pytest.mark.parametrize("epoch_switch", [5, 10]) +@pytest.mark.parametrize("new_sched", [step, exp]) +def test_switch_scheduler_constructor(new_sched, epoch_switch): + + # Constructor + SwitchScheduler(new_schedulers=new_sched, epoch_switch=epoch_switch) + + # Should fail if epoch_switch is less than 1 + with pytest.raises(AssertionError): + SwitchScheduler(new_schedulers=new_sched, epoch_switch=0) + + +@pytest.mark.parametrize("epoch_switch", [5, 10]) +@pytest.mark.parametrize("new_sched", [step, exp]) +def test_switch_scheduler_routine(new_sched, epoch_switch): + + # Initialize the trainer + switch_sched_callback = SwitchScheduler( + new_schedulers=new_sched, epoch_switch=epoch_switch + ) + trainer = Trainer( + solver=solver, + callbacks=switch_sched_callback, + accelerator="cpu", + max_epochs=epoch_switch + 2, + ) + trainer.train() + + # Check that the solver and trainer strategy schedulers have been updated + assert solver.scheduler.instance.__class__ == new_sched.instance.__class__ + assert ( + trainer.lr_scheduler_configs[0].scheduler.__class__ + == new_sched.instance.__class__ + )