Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ Callbacks

Processing callback <callback/processing_callback.rst>
Optimizer callback <callback/optimizer_callback.rst>
Switch Scheduler <callback/switch_scheduler.rst>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep the Scheduler callback, so that in the future, we can add other callbacks there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This differs from #445, where we chose to name each file after the class it contains.

That said, your observation raises a valid point. My proposal is the following: within the callback directory, we create a dedicated scheduler subdirectory that, for now, will include only switch_scheduler.py, while leaving room for future additions. This prevents the file from becoming overly long as more scheduler-related callbacks are introduced.

The same structure would apply to optimizer_callback and processing_callback. In particular, the layout I have in mind is the following:

-- callback
-------- scheduler_callback
--------------- switch_scheduler.py
-------- optimizer_callback
--------------- switch_optimizer.py
--------- refinement
--------------- refinement_interface.py
--------------- r3_refinement.py
-------- processing_callback
--------------- metric_tracker.py
--------------- pina_progress_bar.py
--------------- normalizer_data_callback.py

Happy to hear your thoughts on this, @dario-coscia @FilippoOlivo @ndem0.

All these changes would be implemented in a dedicated PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-- callback
-------- scheduler
--------------- switch_scheduler.py
-------- optimizer
--------------- switch_optimizer.py
--------- refinement
--------------- refinement_interface.py
--------------- r3_refinement.py
-------- processing
---------------  metric_tracker.py
--------------- pina_progress_bar.py
--------------- normalizer_data_callback.py

I prefer something like this (just remove callback on the subfolders)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good! Maybe I would fuse scheduler and optimizer in optim directory (this is similar to from pina.optim import ...)

R3 Refinment callback <callback/refinement/r3_refinement.rst>
Refinment Interface callback <callback/refinement/refinement_interface.rst>
Normalizer callback <callback/normalizer_data_callback.rst>
Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/callback/switch_scheduler.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Switch Scheduler
=====================

.. currentmodule:: pina.callback.switch_scheduler
.. autoclass:: SwitchScheduler
:members:
:show-inheritance:
2 changes: 2 additions & 0 deletions pina/callback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

__all__ = [
"SwitchOptimizer",
"SwitchScheduler",
"MetricTracker",
"PINAProgressBar",
"R3Refinement",
Expand All @@ -12,3 +13,4 @@
from .processing_callback import MetricTracker, PINAProgressBar
from .refinement import R3Refinement
from .normalizer_data_callback import NormalizerDataCallback
from .switch_scheduler import SwitchScheduler
75 changes: 75 additions & 0 deletions pina/callback/switch_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Module for the SwitchScheduler callback."""

from lightning.pytorch.callbacks import Callback
from ..optim import TorchScheduler
from ..utils import check_consistency, check_positive_integer


class SwitchScheduler(Callback):
"""
Callback to switch scheduler during training.
"""

def __init__(self, new_schedulers, epoch_switch):
"""
This callback allows switching between different schedulers during
training, enabling the exploration of multiple optimization strategies
without interrupting the training process.

:param new_schedulers: The scheduler or list of schedulers to switch to.
Use a single scheduler for single-model solvers, or a list of
schedulers when working with multiple models.
:type new_schedulers: pina.optim.TorchScheduler |
list[pina.optim.TorchScheduler]
:param int epoch_switch: The epoch at which the scheduler switch occurs.
:raise AssertionError: If epoch_switch is less than 1.
:raise ValueError: If each scheduler in ``new_schedulers`` is not an
instance of :class:`pina.optim.TorchScheduler`.

Example:
>>> scheduler = TorchScheduler(
>>> torch.optim.lr_scheduler.StepLR, step_size=5
>>> )
>>> switch_callback = SwitchScheduler(
>>> new_schedulers=scheduler, epoch_switch=10
>>> )
"""
super().__init__()

# Check if epoch_switch is greater than 1
check_positive_integer(epoch_switch - 1, strict=True)

# If new_schedulers is not a list, convert it to a list
if not isinstance(new_schedulers, list):
new_schedulers = [new_schedulers]

# Check consistency
for scheduler in new_schedulers:
check_consistency(scheduler, TorchScheduler)

# Store the new schedulers and epoch switch
self._new_schedulers = new_schedulers
self._epoch_switch = epoch_switch

def on_train_epoch_start(self, trainer, __):
"""
Switch the scheduler at the start of the specified training epoch.

:param lightning.pytorch.Trainer trainer: The trainer object managing
the training process.
:param __: Placeholder argument (not used).
"""
# Check if the current epoch matches the switch epoch
if trainer.current_epoch == self._epoch_switch:
schedulers = []

# Hook the new schedulers to the model parameters
for idx, scheduler in enumerate(self._new_schedulers):
scheduler.hook(trainer.solver._pina_optimizers[idx])
schedulers.append(scheduler)

# Update the trainer's scheduler configs
trainer.lr_scheduler_configs[idx].scheduler = scheduler.instance

# Update the solver's schedulers
trainer.solver._pina_schedulers = schedulers
61 changes: 61 additions & 0 deletions tests/test_callback/test_switch_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import pytest

from pina.solver import PINN
from pina.trainer import Trainer
from pina.model import FeedForward
from pina.optim import TorchScheduler
from pina.callback import SwitchScheduler
from pina.problem.zoo import Poisson2DSquareProblem as Poisson


# Define the problem
problem = Poisson()
problem.discretise_domain(10)
model = FeedForward(len(problem.input_variables), len(problem.output_variables))

# Define the scheduler
scheduler = TorchScheduler(torch.optim.lr_scheduler.ConstantLR, factor=0.1)

# Initialize the solver
solver = PINN(problem=problem, model=model, scheduler=scheduler)

# Define new schedulers for testing
step = TorchScheduler(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.1)
exp = TorchScheduler(torch.optim.lr_scheduler.ExponentialLR, gamma=0.9)


@pytest.mark.parametrize("epoch_switch", [5, 10])
@pytest.mark.parametrize("new_sched", [step, exp])
def test_switch_scheduler_constructor(new_sched, epoch_switch):

# Constructor
SwitchScheduler(new_schedulers=new_sched, epoch_switch=epoch_switch)

# Should fail if epoch_switch is less than 1
with pytest.raises(AssertionError):
SwitchScheduler(new_schedulers=new_sched, epoch_switch=0)


@pytest.mark.parametrize("epoch_switch", [5, 10])
@pytest.mark.parametrize("new_sched", [step, exp])
def test_switch_scheduler_routine(new_sched, epoch_switch):

# Initialize the trainer
switch_sched_callback = SwitchScheduler(
new_schedulers=new_sched, epoch_switch=epoch_switch
)
trainer = Trainer(
solver=solver,
callbacks=switch_sched_callback,
accelerator="cpu",
max_epochs=epoch_switch + 2,
)
trainer.train()

# Check that the solver and trainer strategy schedulers have been updated
assert solver.scheduler.instance.__class__ == new_sched.instance.__class__
assert (
trainer.lr_scheduler_configs[0].scheduler.__class__
== new_sched.instance.__class__
)