Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f3df7f0
feat: experiment finetuning compatibility with peft
Kurokabe Dec 18, 2025
76158c9
fix: properly set the model device and dtype
Kurokabe Dec 19, 2025
94a71b4
fix: remove peft as a darts dependency
Kurokabe Jan 15, 2026
9057acb
wip: add transform callback for foundation model fine-tuning
Kurokabe Jan 16, 2026
34b08c1
fix: make lora finetuning work
Kurokabe Jan 16, 2026
eeb93b4
fix: improve callbacks to allow only saving the adapter
Kurokabe Jan 19, 2026
7fb10d6
Merge remote-tracking branch 'origin/master' into finetuning
Kurokabe Jan 30, 2026
76a8efa
feat: modify foundation model to integrate full and partial fine-tuning
Kurokabe Jan 30, 2026
8f7ef24
test: add unit tests for fine-tuning
Kurokabe Jan 30, 2026
260de97
documentation: update example notebook on finetuning
Kurokabe Jan 30, 2026
a20c943
documentation: update changelog
Kurokabe Jan 30, 2026
a568a3d
fix: update on_save_checkpoint of lora callback to avoid potential OO…
Kurokabe Jan 30, 2026
7c813c8
Merge branch 'master' into finetuning
dennisbader Feb 13, 2026
6fb2138
remove FoundationPLModule and enable finetuning for TorchForecastingM…
Kurokabe Feb 23, 2026
45ff4ae
Merge branch 'master' into finetuning
dennisbader Feb 23, 2026
01706ee
feat: use fnmatch instead of startswith for finetuning arguments
Kurokabe Feb 24, 2026
66fa1b2
doc: update example notebook to demonstrate finetuning for pytorch mo…
Kurokabe Feb 24, 2026
7b2c657
rename _setup_fine_tuning to _setup_finetuning and update notebook to…
Kurokabe Feb 24, 2026
b5a6c19
update unit test for foundation model with the new fine tuning logic
Kurokabe Feb 24, 2026
13fc4ab
Modify Chronos to use quantiles by default when fine-tuning
Kurokabe Feb 24, 2026
f38cfc8
Merge branch 'master' into finetuning
dennisbader Feb 27, 2026
9021e20
rename example notebook and add it to tests and docs
dennisbader Feb 27, 2026
9736f21
some refactoring
dennisbader Feb 27, 2026
0f77ac6
training specific quantile loss for foundation models and tests
dennisbader Feb 27, 2026
9987a02
update docs
dennisbader Feb 27, 2026
f93a217
update changelog entry
dennisbader Feb 27, 2026
70169ef
update notebook
dennisbader Feb 28, 2026
2891d0f
update notebook p2
dennisbader Feb 28, 2026
eb433c7
update notebook
dennisbader Mar 7, 2026
b68cebd
better test coverage
dennisbader Mar 7, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/merge.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb, 15-static-covariates.ipynb, 16-hierarchical-reconciliation.ipynb, 18-TiDE-examples.ipynb, 19-EnsembleModel-examples.ipynb, 20-SKLearnModel-examples.ipynb, 21-TSMixer-examples.ipynb, 22-anomaly-detection-examples.ipynb, 23-Conformal-Prediction-examples.ipynb, 24-SKLearnClassifierModel-examples.ipynb, 25-Chronos-2-examples.ipynb, 26-NeuralForecast-examples.ipynb]
example-name: [00-quickstart.ipynb, 01-multi-time-series-and-covariates.ipynb, 02-data-processing.ipynb, 03-FFT-examples.ipynb, 04-RNN-examples.ipynb, 05-TCN-examples.ipynb, 06-Transformer-examples.ipynb, 07-NBEATS-examples.ipynb, 08-DeepAR-examples.ipynb, 09-DeepTCN-examples.ipynb, 10-Kalman-filter-examples.ipynb, 11-GP-filter-examples.ipynb, 12-Dynamic-Time-Warping-example.ipynb, 13-TFT-examples.ipynb, 15-static-covariates.ipynb, 16-hierarchical-reconciliation.ipynb, 18-TiDE-examples.ipynb, 19-EnsembleModel-examples.ipynb, 20-SKLearnModel-examples.ipynb, 21-TSMixer-examples.ipynb, 22-anomaly-detection-examples.ipynb, 23-Conformal-Prediction-examples.ipynb, 24-SKLearnClassifierModel-examples.ipynb, 25-Chronos-2-examples.ipynb, 26-NeuralForecast-examples.ipynb, 27-Torch-and-Foundation-Model-Fine-Tuning-examples.ipynb]
steps:
- name: "Clone repository"
uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

- 🚀🚀 Added new forecasting model `NeuralForecastModel` to convert any of the 30+ NeuralForecast base model into a Darts `TorchForecastingModel`. This includes models such as NBEATSx, PatchTST, TimeXer, KAN, and many more. Like all Darts torch models, it supports univariate, multivariate, probabilistic forecasting, optimized backtesting and more. Depending on the base model, it also supports past, future, and static covariates. [#3002](https://github.com/unit8co/darts/pull/3002) by [Zhihao Dai](https://github.com/daidahao)
- Check out our new [NeuralForecastModel Notebook](https://unit8co.github.io/darts/examples/26-NeuralForecast-examples.html) for detailed examples. [#3026](https://github.com/unit8co/darts/pull/3026) by [Dennis Bader](https://github.com/dennisbader).
- 🚀🚀 Added support for fine-tuning to all `TorchForecastingModel` and `FoundationModel` (such as `Chronos2Model` and `TimesFM2p5Model`) via the new `enable_finetuning` parameter. Supports full training, and partial fine-tuning by selectively freezing or unfreezing layers by name pattern. [#2964](https://github.com/unit8co/darts/issues/2964) by [Alain Gysi](https://github.com/Kurokabe).
- Check out our new [Fine-Tuning Notebook](https://unit8co.github.io/darts/examples/27-Torch-and-Foundation-Model-Fine-Tuning-examples.html) for detailed examples.
- Created `darts.typing` to collect typical type annotation in one place. Introduced `TimeIndex` & `TimeSeriesLike` type aliases for improved readability & maintainability of the code. Commmon type annotations can be added to this file in the future. [#3021](https://github.com/unit8co/darts/pull/3021) by [Michel Zeller](https://github.com/mizeller)
- More fine-grained control over Reversible Instance Normalization for all torch models. Apart from the boolean trigger, parameter `use_reversible_instance_norm` now also supports setting the `RINorm` hyperparameters as a dictionary. [#3029](https://github.com/unit8co/darts/pull/3029) by [Zhihao Dai](https://github.com/daidahao).

Expand Down
4 changes: 1 addition & 3 deletions darts/models/components/huggingface_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@
from safetensors.torch import load_file

from darts.logging import get_logger, raise_log
from darts.models.forecasting.pl_forecasting_module import (
PLForecastingModule,
)
from darts.models.forecasting.pl_forecasting_module import PLForecastingModule

logger = get_logger(__name__)

Expand Down
12 changes: 12 additions & 0 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,18 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.

References
----------
Expand Down
71 changes: 54 additions & 17 deletions darts/models/forecasting/chronos2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

* `Chronos-2 Foundation Model Examples
<https://unit8co.github.io/darts/examples/25-Chronos-2-examples.html>`__
* `Fine-Tuning Examples
<https://unit8co.github.io/darts/examples/27-Torch-and-Foundation-Model-Fine-Tuning-examples.html>`__
"""

import math
Expand All @@ -23,14 +25,11 @@
_Patch,
_ResidualBlock,
)
from darts.models.components.huggingface_connector import (
HuggingFaceConnector,
)
from darts.models.forecasting.foundation_model import (
FoundationModel,
)
from darts.models.components.huggingface_connector import HuggingFaceConnector
from darts.models.forecasting.foundation_model import FoundationModel
from darts.models.forecasting.pl_forecasting_module import (
PLForecastingModule,
io_processor,
)
from darts.utils.data.torch_datasets.utils import PLModuleInput, TorchTrainingSample
from darts.utils.likelihood_models.torch import QuantileRegression
Expand Down Expand Up @@ -99,7 +98,8 @@ def __init__(
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
"""

# for fine-tuning, model should be trained on pre-trained quantiles
enable_finetuning = kwargs.pop("enable_finetuning", False)
super().__init__(**kwargs)
self.d_model = d_model
self.d_kv = d_kv
Expand Down Expand Up @@ -192,14 +192,23 @@ def __init__(
quantiles_tensor = torch.tensor(quantiles)
self.register_buffer("quantiles", quantiles_tensor, persistent=False)

# gather indices of user-specified quantiles
# gather indices of user-specified quantiles (used at prediction time)
user_quantiles: list[float] = (
self.likelihood.quantiles
if isinstance(self.likelihood, QuantileRegression)
else [0.5]
)
self.user_quantile_indices = [quantiles.index(q) for q in user_quantiles]

# during fine-tuning, train on ALL pre-trained quantiles to preserve the
# full distribution; prediction uses only user-specified quantiles
if enable_finetuning:
self._finetuning_likelihood = QuantileRegression(quantiles)
self._finetuning_quantile_indices = list(range(self.num_quantiles))
else:
self._finetuning_likelihood = None
self._finetuning_quantile_indices = None

self.output_patch_embedding = _ResidualBlock(
in_dim=self.d_model,
h_dim=self.d_ff,
Expand Down Expand Up @@ -461,6 +470,7 @@ def _forward(
# 3. Chronos-2 uses normalized values for loss computation, while Darts uses denormalized values.
# We need to think about how best to implement Chronos-2 `RINorm` in `io_processor()` without
# breaking existing behavior, while also allowing fine-tuning with normalized loss.
@io_processor
def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:
"""Chronos-2 model forward pass.

Expand Down Expand Up @@ -549,17 +559,26 @@ def forward(self, x_in: PLModuleInput, *args, **kwargs) -> Any:
# select only target variables
quantile_preds = quantile_preds[:, :, : self.n_targets, :]

# select only user-specified quantiles or median if deterministic
quantile_preds = quantile_preds[:, :, :, self.user_quantile_indices]
# during training (fine-tuning), output all pre-trained quantiles for loss;
# during prediction, output only user-specified quantiles
if self.training:
quantile_preds = quantile_preds[:, :, :, self._finetuning_quantile_indices]
else:
quantile_preds = quantile_preds[:, :, :, self.user_quantile_indices]

return quantile_preds

def _compute_loss(self, output, target, criterion, sample_weight):
if self.training:
# compute loss on pre-trained quantiles
return self._finetuning_likelihood.compute_loss(
output, target, sample_weight
)
else:
return super()._compute_loss(output, target, criterion, sample_weight)

class Chronos2Model(FoundationModel):
# Fine-tuning is turned off for now pending proper fine-tuning support
# and configuration.
_allows_finetuning = False

class Chronos2Model(FoundationModel):
def __init__(
self,
input_chunk_length: int,
Expand Down Expand Up @@ -607,6 +626,11 @@ def __init__(
below for details. It is recommended to call :func:`predict()` with ``predict_likelihood_parameters=True``
or ``num_samples >> 1`` to get meaningful results.

.. tip::
You can perform full or partial fine-tuning of the model by setting the ``enable_finetuning`` parameter.
Read more in the parameter description below and in the `Fine-Tuning Examples
<https://unit8co.github.io/darts/examples/27-Torch-and-Foundation-Model-Fine-Tuning-examples.html>`__.

Parameters
----------
input_chunk_length
Expand Down Expand Up @@ -635,6 +659,9 @@ def __init__(
[0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9,
0.95, 0.99].
Default: ``None``, which will make Chronos-2 deterministic (median quantile only).
When fine-tuning is enabled, the training loss is always computed on all pre-trained quantiles to
preserve the full distribution, regardless of the ``likelihood`` setting. The ``likelihood`` parameter
only affects prediction output.
hub_model_name
The model ID on HuggingFace Hub. Default: ``"amazon/chronos-2"``. Other available variants include
``"autogluon/chronos-2-small"`` and ``"autogluon/chronos-2-synth"``.
Expand Down Expand Up @@ -770,6 +797,18 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.

References
----------
Expand Down Expand Up @@ -810,8 +849,6 @@ def encode_year(idx):
[[1005.6928 ]]
[[1005.69617]]]

.. note::
Fine-tuning of Chronos-2 is not supported at the moment.
.. note::
Chronos-2 is licensed under the `Apache-2.0 License <https://github.com/amazon-science/chronos-forecasting/blob/main/LICENSE>`_,
copyright Amazon.com, Inc. or its affiliates. By using this model, you agree to the terms and conditions of
Expand Down Expand Up @@ -878,7 +915,7 @@ def encode_year(idx):
)

self.hf_connector = hf_connector
super().__init__(enable_finetuning=False, **kwargs)
super().__init__(**kwargs)

def _create_model(self, train_sample: TorchTrainingSample) -> PLForecastingModule:
pl_module_params = self.pl_module_params or {}
Expand Down
12 changes: 12 additions & 0 deletions darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,18 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.

References
----------
Expand Down
58 changes: 29 additions & 29 deletions darts/models/forecasting/foundation_model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fine-tuning the foundation models, we should make sure that during training we use a QuantileRegression(quantiles) with all quantiles that the original weights were trained on.

The user should still be able specify some different quantiles when creating the model with likelihood=QuantileRegression(other_quantiles). These quantile will only be used for prediction.

Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,14 @@

from abc import ABC

from darts.logging import get_logger, raise_log
from darts.models.forecasting.torch_forecasting_model import (
MixedCovariatesTorchModel,
)
from darts.logging import get_logger
from darts.models.forecasting.torch_forecasting_model import MixedCovariatesTorchModel

logger = get_logger(__name__)


class FoundationModel(MixedCovariatesTorchModel, ABC):
_allows_finetuning: bool = False

def __init__(
self,
enable_finetuning: bool = False,
**kwargs,
):
def __init__(self, **kwargs):
"""Foundation Forecasting Model with PyTorch Lightning backend.

This class is meant to be inherited to create a new foundation forecasting model.
Expand All @@ -46,11 +38,14 @@ def __init__(
instantiate a :class:`HuggingFaceConnector` and use its methods to load the model configuration
inside :func:`__init__()` and to load the model weights inside :func:`_create_model()`.


.. tip::
You can perform full or partial fine-tuning of the model by setting the ``enable_finetuning`` parameter.
Read more in the parameter description below and in the `Fine-Tuning Examples
<https://unit8co.github.io/darts/examples/27-Torch-and-Foundation-Model-Fine-Tuning-examples.html>`__.

Parameters
----------
enable_finetuning
Whether to enable fine-tuning of the foundation model. If set to ``True``, calling :func:`fit()` will
update the model weights. Default: ``False``.
batch_size
Number of time series (input and output sequences) used in each fine-tuning pass. Default: ``32``.
n_epochs
Expand Down Expand Up @@ -156,22 +151,33 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.
"""
# Set default fine-tuning to False for foundation models
if "enable_finetuning" not in self.model_params:
self.model_params["enable_finetuning"] = False

# initialize `TorchForecastingModel` base class
super().__init__(**self._extract_torch_model_params(**self.model_params))

# extract pytorch lightning module kwargs
self.pl_module_params = self._extract_pl_module_params(**self.model_params)

# validate and set fine-tuning flag
if enable_finetuning and not self._allows_finetuning:
raise_log(
ValueError(
f"Fine-tuning is not supported for {self.__class__.__name__}."
" Please set `enable_finetuning=False`."
),
logger,
)
# pass fine-tuning flag to the PLModule so it can set up training-specific
# quantile handling (separate from prediction-time likelihood)
if self.enable_finetuning:
self.pl_module_params["enable_finetuning"] = True

use_reversible_instance_norm: bool | dict = self.pl_module_params.get(
"use_reversible_instance_norm", False
Expand All @@ -193,9 +199,3 @@ def encode_year(idx):
self.pl_module_params["use_reversible_instance_norm"] = (
use_reversible_instance_norm
)

self._enable_finetuning = enable_finetuning

@property
def _requires_training(self) -> bool:
return self._enable_finetuning
12 changes: 12 additions & 0 deletions darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,18 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.

References
----------
Expand Down
12 changes: 12 additions & 0 deletions darts/models/forecasting/nf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,18 @@ def encode_year(idx):
show_warnings
whether to show warnings raised from PyTorch Lightning. Useful to detect potential issues of
your forecasting use case. Default: ``False``.
enable_finetuning
Enables model fine-tuning. Only effective if not ``None``.
If a bool, specifies whether to perform full fine-tuning / training (all parameters are updated) or keep
all parameters frozen. If a dict, specifies which parameters to fine-tune. Must only contain one key-value
record. Can be used to:

- Unfreeze specific parameters, while keeping everything else frozen:
``{"unfreeze": ["param.name.patterns.*"]}``
- Freeze specific parameters, while keeping everything else unfrozen:
``{"freeze": ["param.name.patterns.*"]}``

Default: ``None``.

References
----------
Expand Down
Loading