Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- `TFTExplainer` plotting improvements: [#3039](https://github.com/unit8co/darts/pull/3039) by [ReinerBRO](https://github.com/ReinerBRO).
- `plot_variable_selection()` now returns the matplotlib figures for downstream usage (saving, editing, ...). It returns a single figure when explaining a single TimeSeries. Otherwise, it returns a list of figures.
- `plot_variable_selection()` now accepts a `show_plot: bool = True` parameter that allows to suppress showing the plot.
- 🔴 `plot_attention()` now also returns the matplotlib figures for all explained series, instead of only the matplotlib axis for the last series.

**Fixed**

**Dependencies**
Expand Down
40 changes: 36 additions & 4 deletions darts/explainability/tft_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from torch import Tensor

from darts import TimeSeries
Expand Down Expand Up @@ -243,7 +244,8 @@ def plot_variable_selection(
expl_result: TFTExplainabilityResult,
fig_size=None,
max_nr_series: int = 5,
):
show_plot: bool = True,
) -> Figure | list[Figure]:
"""Plots the variable selection / feature importances of the `TFTModel` based on the input.
The figure includes three subplots:

Expand All @@ -260,6 +262,14 @@ def plot_variable_selection(
The size of the figure to be plotted.
max_nr_series
The maximum number of plots to show in case `expl_result` was computed on multiple series.
show_plot
Whether to show the plot.

Returns
-------
Figure | list[Figure]
The matplotlib figures used for plotting. Returns a single Figure when explaining a single series,
and a list of Figures when explaining multiple series.
"""
encoder_importance = expl_result.get_encoder_importance()
decoder_importance = expl_result.get_decoder_importance()
Expand All @@ -269,6 +279,7 @@ def plot_variable_selection(
decoder_importance = [decoder_importance]
static_covariates_importance = [static_covariates_importance]

plotted_figures: list[Figure] = []
uses_static_covariates = not static_covariates_importance[0].empty
for idx, (enc_imp, dec_imp, stc_imp) in enumerate(
zip(encoder_importance, decoder_importance, static_covariates_importance)
Expand All @@ -292,11 +303,17 @@ def plot_variable_selection(
ax=axes[2],
)
fig.tight_layout()
plt.show()
if show_plot:
plt.show()
plotted_figures.append(fig)

if idx + 1 == max_nr_series:
break

if len(plotted_figures) == 1:
return plotted_figures[0]
return plotted_figures

def plot_attention(
self,
expl_result: TFTExplainabilityResult,
Expand All @@ -305,7 +322,7 @@ def plot_attention(
ax: matplotlib.axes.Axes | None = None,
max_nr_series: int = 5,
show_plot: bool = True,
) -> matplotlib.axes.Axes:
) -> Figure | list[Figure]:
"""Plots the attention heads of the `TFTModel`.

Parameters
Expand All @@ -331,16 +348,26 @@ def plot_attention(
The maximum number of plots to show in case `expl_result` was computed on multiple series.
show_plot
Whether to show the plot.

Returns
-------
Figure | list[Figure]
The matplotlib figures used for plotting. Returns a single Figure when explaining a single series,
and a list of Figures when explaining multiple series.
"""
single_series = False
attentions = expl_result.get_explanation(component="attention")
if isinstance(attentions, TimeSeries):
attentions = [attentions]
single_series = True

plotted_figures: list[Figure] = []
for idx, attention in enumerate(attentions):
if ax is None or not single_series:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()

if show_index_as == "relative":
x_ticks = generate_index(
start=-self.model.input_chunk_length, end=self.n - 1
Expand Down Expand Up @@ -406,9 +433,14 @@ def plot_attention(
if show_plot:
plt.show()

plotted_figures.append(fig)

if idx + 1 == max_nr_series:
break
return ax

if len(plotted_figures) == 1:
return plotted_figures[0]
return plotted_figures

@property
def _encoder_importance(self) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/ad/test_anomaly_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def test_multivariate_ForecastingAnomalyModel(self):
)
np.testing.assert_array_almost_equal(auc_pr_from_scores, true_auc_pr, decimal=1)

def test_visualization(self):
def test_visualization(self, mpl_safe_plotting):
# test function show_anomalies() and show_anomalies_from_scores()
forecasting_anomaly_model = ForecastingAnomalyModel(
model=SKLearnModel(lags=10), scorer=Norm()
Expand Down
2 changes: 1 addition & 1 deletion darts/tests/ad/test_scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def expects_deterministic_input(self, scorer, **kwargs):
scorer.score_from_prediction(self.probabilistic, self.train)
)

def test_WassersteinScorer(self):
def test_WassersteinScorer(self, mpl_safe_plotting):
# Check parameters and inputs
self.component_wise_parameter(WassersteinScorer)
self.helper_window_parameter(WassersteinScorer)
Expand Down
10 changes: 10 additions & 0 deletions darts/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import shutil
import tempfile
from typing import Any
from unittest.mock import patch

import matplotlib.pyplot as plt
import pandas as pd
import pytest
from packaging import version
Expand Down Expand Up @@ -184,3 +186,11 @@ def tmpdir_fn():
os.chdir(cwd)
# remove temp dir
shutil.rmtree(temp_work_dir)


@pytest.fixture(scope="function")
def mpl_safe_plotting():
"""Patches plt.show() and closes all plots / figures from memory at the end of the test."""
with patch("matplotlib.pyplot.show") as patched_show:
yield patched_show
plt.close("all")
4 changes: 1 addition & 3 deletions darts/tests/dataprocessing/dtw/test_dtw.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -208,11 +207,10 @@ def test_nans(self):

dtw.dtw(series1, series2)

def test_plot(self):
def test_plot(self, mpl_safe_plotting):
align = dtw.dtw(self.series2, self.series1)
align.plot()
align.plot_alignment()
plt.close()

def test_multivariate(self):
n = 2
Expand Down
8 changes: 2 additions & 6 deletions darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
from datetime import date, timedelta

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -541,7 +540,7 @@ def test_explain_with_lags_covariates_series_older_timestamps_than_target(self):
# that at the start of the target series we have sufficient information to explain the prediction.
assert explanation.start_time() == self.target_ts.start_time()

def test_plot(self):
def test_plot(self, mpl_safe_plotting):
model_cls = LightGBMModel if LGBM_AVAILABLE else LinearRegressionModel
m_0 = model_cls(
lags=4,
Expand Down Expand Up @@ -576,7 +575,6 @@ def test_plot(self):
"power",
)
assert isinstance(fplot, shap.plots._force.BaseVisualizer)
plt.close()

# no component name -> multivariate error
with pytest.raises(ValueError):
Expand Down Expand Up @@ -641,7 +639,6 @@ def test_plot(self):
target_component="power",
)
assert isinstance(fplot, shap.plots._force.BaseVisualizer)
plt.close()

def test_feature_values_align_with_input(self):
model_cls = LightGBMModel if LGBM_AVAILABLE else LinearRegressionModel
Expand Down Expand Up @@ -839,7 +836,7 @@ def test_shapley_multiple_series_with_different_static_covs(self):
comps_out = explained_forecast[1]["price"].columns.tolist()
assert comps_out[-1] == "type_statcov_target_price"

def test_shap_regressor_component_specific_lags(self):
def test_shap_regressor_component_specific_lags(self, mpl_safe_plotting):
model = LinearRegressionModel(
lags={"price": [-3, -2], "power": [-1]},
output_chunk_length=1,
Expand Down Expand Up @@ -881,7 +878,6 @@ def test_shap_regressor_component_specific_lags(self):

# check that explain() can be called
explanation_results = shap_explain.explain()
plt.close()
for comp in ts.components:
comps_out = explanation_results.explained_forecasts[1][comp].columns
assert all(comps_out == expected_columns)
100 changes: 60 additions & 40 deletions darts/tests/explainability/test_tft_explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
from unittest.mock import patch

import matplotlib.pyplot as plt
import matplotlib.figure
import numpy as np
import pandas as pd
import pytest
Expand Down Expand Up @@ -301,20 +300,35 @@ def test_explainer_multiple_multivariate_series(self, test_case):
"or equal to the model's batch size (32)."
)

def test_variable_selection_explanation(self):
@pytest.mark.parametrize("n_series", [1, 2])
def test_variable_selection_explanation(self, n_series, mpl_safe_plotting):
"""Test variable selection (feature importance) explanation results and plotting."""
model = self.helper_create_model(use_encoders=True, add_relative_idx=True)
series, pc, fc = self.helper_get_input(series_option="multivariate")
model.fit(series, past_covariates=pc, future_covariates=fc)
explainer = TFTExplainer(model)
results = explainer.explain()
results = explainer.explain(
foreground_series=series if n_series == 1 else [series] * 2,
foreground_past_covariates=pc if n_series == 1 else [pc] * 2,
foreground_future_covariates=fc if n_series == 1 else [fc] * 2,
)

imps = results.get_feature_importances()
enc_imp = results.get_encoder_importance()
dec_imp = results.get_decoder_importance()
stc_imp = results.get_static_covariates_importance()
imps_direct = [enc_imp, dec_imp, stc_imp]

# check that all importances are the same across series (since the series have identical values)
if n_series > 1:
for imp in imps.values():
np.testing.assert_array_almost_equal(imp[0].values, imp[1].values)
for imp in imps_direct:
np.testing.assert_array_almost_equal(imp[0].values, imp[1].values)

imps = {k: v[0] for k, v in imps.items()}
imps_direct = [imp[0] for imp in imps_direct]

imp_names = [
"encoder_importance",
"decoder_importance",
Expand All @@ -339,6 +353,7 @@ def test_variable_selection_explanation(self):
index=[0],
)
# relaxed comparison because M1 chip gives slightly different results than intel chip
enc_imp = imps_direct[0]
assert ((enc_imp.round(decimals=1) - enc_expected).abs() <= 3).all().all()

dec_expected = pd.DataFrame(
Expand All @@ -350,19 +365,26 @@ def test_variable_selection_explanation(self):
},
index=[0],
)
dec_imp = imps_direct[1]
# relaxed comparison because M1 chip gives slightly different results than intel chip
assert ((dec_imp.round(decimals=1) - dec_expected).abs() <= 0.6).all().all()

stc_expected = pd.DataFrame(
{"num_statcov": 11.9, "cat_statcov": 88.1}, index=[0]
)
stc_imp = imps_direct[2]
# relaxed comparison because M1 chip gives slightly different results than intel chip
assert ((stc_imp.round(decimals=1) - stc_expected).abs() <= 0.1).all().all()

with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_variable_selection(results)
figs = explainer.plot_variable_selection(results)
if n_series == 1:
figs = [figs]
for fig in figs:
assert isinstance(fig, matplotlib.figure.Figure)
assert len(fig.get_axes()) == 3

def test_attention_explanation(self):
@pytest.mark.parametrize("n_series", [1, 2])
def test_attention_explanation(self, n_series, mpl_safe_plotting):
"""Test attention (feature importance) explanation results and plotting."""
# past attention (full_attention=False) on attends to values in the past relative to each horizon
# (look at the last 0 values in the array)
Expand Down Expand Up @@ -397,42 +419,40 @@ def test_attention_explanation(self):
series, pc, fc = self.helper_get_input(series_option="multivariate")
model.fit(series, past_covariates=pc, future_covariates=fc)
explainer = TFTExplainer(model)
results = explainer.explain()
results = explainer.explain(
foreground_series=series if n_series == 1 else [series] * 2,
foreground_past_covariates=pc if n_series == 1 else [pc] * 2,
foreground_future_covariates=fc if n_series == 1 else [fc] * 2,
)

att = results.get_attention()
attns = results.get_attention()
# relaxed comparison because M1 chip gives slightly different results than intel chip
assert np.all(np.abs(np.round(att.values(), decimals=1) - att_exp) <= 0.2)
assert att.columns.tolist() == ["horizon 1", "horizon 2"]
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="all", show_index_as="relative"
)
plt.close()
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="all", show_index_as="time"
)
plt.close()
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="time", show_index_as="relative"
)
plt.close()
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="time", show_index_as="time"
)
plt.close()
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="heatmap", show_index_as="relative"
)
plt.close()
with patch("matplotlib.pyplot.show") as _:
_ = explainer.plot_attention(
results, plot_type="heatmap", show_index_as="time"
if n_series == 1:
attns = [attns]

for att in attns:
assert np.all(
np.abs(np.round(att.values(), decimals=1) - att_exp) <= 0.2
)
plt.close()
assert att.columns.tolist() == ["horizon 1", "horizon 2"]

def _check_plot(n_figs_expected, n_axes_expected, **kwargs):
figs = explainer.plot_attention(results, **kwargs)
if n_figs_expected == 1:
figs = [figs]
for fig in figs:
assert isinstance(fig, matplotlib.figure.Figure)
assert isinstance(fig, matplotlib.figure.Figure)
assert len(fig.get_axes()) == n_axes_expected

# only a single axis should be plotted
_check_plot(n_series, 1, plot_type="all", show_index_as="relative")
_check_plot(n_series, 1, plot_type="all", show_index_as="time")
_check_plot(n_series, 1, plot_type="time", show_index_as="relative")
_check_plot(n_series, 1, plot_type="time", show_index_as="time")
# heatmap also plot colorbar axis
_check_plot(n_series, 2, plot_type="heatmap", show_index_as="relative")
_check_plot(n_series, 2, plot_type="heatmap", show_index_as="time")

def helper_create_model(
self, use_encoders=True, add_relative_idx=True, full_attention=False
Expand Down
Loading
Loading