Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pina/_src/callback/refinement/base_refinement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module for the Base Refinement class."""

from pina._src.solver.pinn import PINN
from pina._src.solver.physics_informed_single_model_solver import (
PhysicsInformedSingleModelSolver,
)
from lightning.pytorch import Callback
from pina._src.core.utils import check_consistency, check_positive_integer
from pina._src.callback.refinement.refinement_interface import (
Expand Down Expand Up @@ -65,7 +67,7 @@ def on_train_start(self, trainer, solver):
'domain' attribute for sampling.
"""
# Check solver consistency
if not isinstance(solver, PINN):
if not isinstance(solver, PhysicsInformedSingleModelSolver):
raise RuntimeError(
"Refinement strategies require a physics-informed solver. "
f"Got '{type(solver).__name__}'."
Expand Down
6 changes: 3 additions & 3 deletions pina/_src/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import warnings
import torch
import lightning
from pina._src.solver.mixin.physics_informed_mixin import _PhysicsInformedMixin
from pina._src.solver.base_solver import BaseSolver
from pina._src.data.data_module import DataModule
from pina._src.solver.pinn import PINN
from pina._src.core.utils import (
check_consistency,
custom_warning_format,
Expand Down Expand Up @@ -132,8 +132,8 @@ def __init__(
f"Expected one of: {sorted(self._AVAIL_BATCHING_MODES)}."
)

# Set inference mode to false for PINN solvers to track gradients
if isinstance(solver, PINN):
# Set inference mode to false when usiing physics-informed mixin
if isinstance(solver, _PhysicsInformedMixin):
kwargs["inference_mode"] = False

# Set log_every_n_steps to 0 if batch_size is None, otherwise default
Expand Down
Empty file removed pina/_src/solver/__init__.py
Empty file.
47 changes: 47 additions & 0 deletions pina/_src/solver/autoregressive_ensemble_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pina._src.solver.mixin.autoregressive_mixin import _AutoregressiveMixin
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.solver.ensemble_solver import EnsembleSolver


class AutoregressiveEnsembleSolver(_AutoregressiveMixin, EnsembleSolver):
"""
Ensemble solver specialized for autoregressive conditions.
"""

# Accepted conditions types for this solver
accepted_conditions_types = (TimeSeriesCondition,)

def __init__(
self,
problem,
models,
optimizers=None,
schedulers=None,
weighting=None,
loss=None,
use_lt=True,
eps=0.0,
reset_weights_at_epoch_start=True,
kwargs=None,
):
"""
Initialization of the :class:`AutoregressiveEnsembleSolver` class.
"""
# Initialize the parent class
EnsembleSolver.__init__(
self,
problem=problem,
models=models,
optimizers=optimizers,
schedulers=schedulers,
weighting=weighting,
loss=loss,
use_lt=use_lt,
)

# Initialize the autoregressive components
self._init_autoregressive_components(
eps=eps,
reset_weights_at_epoch_start=reset_weights_at_epoch_start,
kwargs=kwargs,
)
48 changes: 48 additions & 0 deletions pina/_src/solver/autoregressive_single_model_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from pina._src.solver.mixin.autoregressive_mixin import _AutoregressiveMixin
from pina._src.condition.time_series_condition import TimeSeriesCondition
from pina._src.solver.single_model_solver import SingleModelSolver


class AutoregressiveSingleModelSolver(_AutoregressiveMixin, SingleModelSolver):
r"""
The autoregressive solver for learning dynamical systems.
"""

# Accepted conditions types for this solver
accepted_conditions_types = (TimeSeriesCondition,)

def __init__(
self,
problem,
model,
loss=None,
optimizer=None,
scheduler=None,
weighting=None,
use_lt=False,
eps=0.0,
reset_weights_at_epoch_start=True,
kwargs=None,
):
"""
Initialization of the :class:`AutoregressiveSingleModelSolver` class.
"""

# Initialize the parent class
SingleModelSolver.__init__(
self,
problem=problem,
model=model,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
loss=loss,
use_lt=use_lt,
)

# Initialize the autoregressive components
self._init_autoregressive_components(
eps=eps,
reset_weights_at_epoch_start=reset_weights_at_epoch_start,
kwargs=kwargs,
)
Loading
Loading