diff --git a/recipes/tc_tracking/README.md b/recipes/tc_tracking/README.md index 12d3a3840..6c353d67a 100644 --- a/recipes/tc_tracking/README.md +++ b/recipes/tc_tracking/README.md @@ -43,13 +43,12 @@ process. - [2.1 Generate Ensemble](#21-generate-ensemble) - [2.2 Reproduce Individual Ensemble Members](#22-reproduce-individual-ensemble-members) - [2.3 Extract Reference Tracks from ERA5](#23-extract-reference-tracks-from-era5) -3. [Visualisation](#3-visualisation) *(coming soon)* +3. [Visualisation](#3-visualisation) 4. [TempestExtremes Integration](#4-tempestextremes-integration) 5. [Example Workflow](#5-example-workflow) - [5.1 Extract Baseline](#51-extract-baseline-optional) - [5.2 Produce Ensemble Forecasts](#52-produce-ensemble-forecasts) - [5.3 Analyse Tracks](#53-analyse-tracks) - *(coming soon)* - [5.4 Reproduce Interesting Members](#54-reproduce-interesting-members-to-extract-fields) ## 1. Setting up the Environment @@ -570,29 +569,36 @@ and a warning is logged. ## 3. Visualisation -> [!Note] -> Visualisation tools will be available in a future update. - - -
-Preview - -Two Jupyter notebooks are provided in `./plotting` for -analysing and visualising tropical cyclone tracking results: +Two [JupyText](https://jupytext.readthedocs.io/) notebook scripts are +provided in `./plotting` for analysing and visualising tropical cyclone +tracking results: -- **`tracks_slayground.ipynb`**: Ensemble track analysis +- **`tracks_slayground_notebook.py`**: Ensemble track analysis including spaghetti plots (trajectory visualisation), absolute and relative intensity metrics (wind speed, MSLP), comparisons against ERA5 reference tracks and IBTrACS observations, extreme value statistics, and error moment analysis over lead time. -- **`plot_tracks_n_fields.ipynb`**: Create animated +- **`plot_tracks_n_fields_notebook.py`**: Create animated visualisations of storm tracks overlaid on atmospheric field data. -
- +Both scripts can be run as plain Python files or converted to Jupyter +notebooks via JupyText. From the recipe root: + +```bash +cd plotting +jupytext --to notebook tracks_slayground_notebook.py +jupytext --to notebook plot_tracks_n_fields_notebook.py +jupyter notebook tracks_slayground.ipynb +``` + +`./plotting/` must be the working directory because the notebooks import +their helpers (`analyse_n_plot`, `plotting_helpers`, `data_handling`) by +bare module name. See `./plotting/README.md` for the full layout, including +the `analyse_n_plot.py` batch entry point for running the analysis across +many storms at once. ## 4. TempestExtremes Integration @@ -708,13 +714,6 @@ trajectories. ### 5.3 Analyse Tracks -> [!Note] -> Visualisation tools will be available in a future update. - - -
-Preview - Visualise the results using the notebook `plotting/tracks_slayground.ipynb`. @@ -739,9 +738,6 @@ tru_track_dir = '/path/to/outputs_reference_tracks' # tru_track_dir = '/path/to/test/aux_data' ``` -
- - ### 5.4 Reproduce Interesting Members to Extract Fields Suppose that after conducting the above analysis you want diff --git a/recipes/tc_tracking/plotting/.gitignore b/recipes/tc_tracking/plotting/.gitignore new file mode 100644 index 000000000..ab7c38f24 --- /dev/null +++ b/recipes/tc_tracking/plotting/.gitignore @@ -0,0 +1,2 @@ +plots/ +*.ipynb diff --git a/recipes/tc_tracking/plotting/README.md b/recipes/tc_tracking/plotting/README.md new file mode 100644 index 000000000..b17b615c3 --- /dev/null +++ b/recipes/tc_tracking/plotting/README.md @@ -0,0 +1,50 @@ +# Analysing and Plotting TC Tracks + +## Notebooks + +- **`plot_tracks_n_fields_notebook.py`** – Plotting tracks and fields for + individual ensemble members +- **`tracks_slayground_notebook.py`** – Analysing and plotting complete + tracks from a full ensemble run for a case study on a given storm + +Both scripts are [JupyText](https://jupytext.readthedocs.io/) Python files +that can be run directly or converted to Jupyter notebooks: + +```bash +# Convert to a Jupyter notebook +jupytext --to notebook plot_tracks_n_fields_notebook.py +jupytext --to notebook tracks_slayground_notebook.py +``` + +> [!Note] +> The notebooks and the modules in this directory use bare module names +> (`from analyse_n_plot import ...`, `from plotting_helpers import ...`). +> Run them with `plotting/` as the working directory so Python can resolve +> those imports: +> +> ```bash +> cd recipes/tc_tracking/plotting +> jupyter notebook tracks_slayground.ipynb +> ``` + +## Scripts and Library Modules + +- **`analyse_n_plot.py`** – Batch entry point. Drives + `analyse_individual_storms` (one plot set per storm) and + `analyse_ensemble_of_storms` (error metrics aggregated across many + storms). Run with `python analyse_n_plot.py` after editing the storm + selection and paths near the bottom of the file. +- **`data_handling.py`** – Library: track ingestion, matching against the + reference, ensemble averaging on the sphere, and lead-time error + metrics. Imported by the notebooks; not intended to be run directly. +- **`plotting_helpers.py`** – Library: the individual plotting routines + (spaghetti, intensities over time, histograms, error metrics). Also + imported by the notebooks; not intended to be run directly. + +## Additional Information + +- Each notebook specifies at the beginning what data is required and how to + produce it using the TC tracking pipeline. +- All plotting and analysis routines take a `time_step` keyword argument + defaulting to 6 h, matching the stock FCN3 and AIFS-ENS configurations. + Override it if you run the upstream pipeline at a different cadence. diff --git a/recipes/tc_tracking/plotting/analyse_n_plot.py b/recipes/tc_tracking/plotting/analyse_n_plot.py new file mode 100644 index 000000000..d690fadd4 --- /dev/null +++ b/recipes/tc_tracking/plotting/analyse_n_plot.py @@ -0,0 +1,471 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any + +import numpy as np +import pandas as pd +from data_handling import ( + compute_averages_of_errors_over_lead_time, + extract_tracks, + extract_tracks_from_file, + get_ensemble_averages, + match_tracks, +) +from plotting_helpers import ( + plot_errors_over_lead_time, + plot_extreme_extremes_histograms, + plot_ib_era5, + plot_over_time, + plot_relative_over_time, + plot_spaghetti, +) +from tqdm import tqdm + +_DEFAULT_TIME_STEP = np.timedelta64(6, "h") + + +def load_tracks( + case: str, + pred_track_dir: str, + tru_track_dir: str, + out_dir: str | None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> tuple[pd.DataFrame, list[dict[str, Any]], dict[str, Any], int, str | None]: + """Load predicted and reference tracks for a named storm. + + Parameters + ---------- + case : str + Storm identifier in the format ``{name}_{year}_{basin}``. + pred_track_dir : str + Directory containing predicted track CSV files. + tru_track_dir : str + Directory containing the reference track CSV file. + out_dir : str | None + Base output directory for plots. A case-specific sub-directory + is created automatically. + time_step : np.timedelta64, optional + Model time step, by default 6 h. + + Returns + ------- + tuple + ``(tru_track, pred_tracks, ens_mean, n_members, out_dir)`` + """ + tru_track = extract_tracks_from_file( + os.path.join(tru_track_dir, f"reference_track_{case}.csv") + ) + tru_track["dist"] = np.zeros(len(tru_track)) + + pred_tracks = extract_tracks(in_dir=os.path.join(pred_track_dir)) + n_members = len(pred_tracks) + + pred_tracks = match_tracks(pred_tracks, tru_track) + + if out_dir: + out_dir = os.path.join(out_dir, case) + os.makedirs(out_dir, exist_ok=True) + + ens_mean = get_ensemble_averages( + pred_tracks=pred_tracks, tru_track=tru_track, time_step=time_step + ) + + return tru_track, pred_tracks, ens_mean, n_members, out_dir + + +def analyse_individual_storms( + cases: str | list[str], + pred_track_dir: str, + tru_track_dir: str, + out_path: str | None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> None: + """Generate a full set of analysis plots for each storm individually. + + Parameters + ---------- + cases : str | list[str] + Storm identifier(s). + pred_track_dir : str + Directory containing predicted track CSV files. + tru_track_dir : str + Directory containing reference track CSV files. + out_path : str | None + Base output directory for plots. + time_step : np.timedelta64, optional + Model time step, by default 6 h. + """ + if isinstance(cases, str): + cases = [cases] + + for case in tqdm(cases): + tru_track, pred_tracks, ens_mean, n_members, out_dir = load_tracks( + case=case, + pred_track_dir=pred_track_dir, + tru_track_dir=tru_track_dir, + out_dir=out_path, + time_step=time_step, + ) + + plot_spaghetti( + true_track=tru_track, + pred_tracks=pred_tracks, + ensemble_mean=ens_mean["mean"], + case=case, + n_members=n_members, + out_dir=out_dir, + ) + + plot_over_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + n_members=n_members, + out_dir=out_dir, + time_step=time_step, + ) + + plot_relative_over_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + n_members=n_members, + out_dir=out_dir, + time_step=time_step, + ) + + plot_ib_era5( + tru_track=tru_track, + case=case, + variables=["msl", "wind_speed"], + out_dir=out_dir, + ) + + plot_extreme_extremes_histograms( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + out_dir=out_dir, + ) + + err_dict, _ = compute_averages_of_errors_over_lead_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + variables=["wind_speed", "msl", "dist"], + ) + + plot_errors_over_lead_time( + err_dict=err_dict, + case=case, + ic=pred_tracks[0]["ic"], + n_members=n_members, + n_tracks=len(pred_tracks), + out_dir=out_dir, + time_step=time_step, + ) + + +def stack_metrics(err_dict: dict[str, dict[str, np.ndarray]]) -> np.ndarray: + """Stack per-variable, per-metric arrays into a single 3-D array. + + Parameters + ---------- + err_dict : dict[str, dict[str, np.ndarray]] + Per-variable error metrics. + + Returns + ------- + np.ndarray + Array of shape ``(n_vars, n_metrics, lead_time)``. + """ + var_errs = [] + for var in err_dict.keys(): + metrics = np.stack([err_dict[var][metric] for metric in err_dict[var]], axis=0) + var_errs.append(metrics) + + return np.stack(var_errs, axis=0) + + +def stack_cases(storm_metrics: dict[str, Any], max_len: int) -> dict[str, Any]: + """Pad per-storm metric and weight arrays to *max_len* and stack them. + + Parameters + ---------- + storm_metrics : dict[str, Any] + Accumulator with ``"data"`` (list of 3-D arrays) and ``"weights"`` + (list of 1-D arrays). + max_len : int + Target lead-time dimension (shorter storms are NaN-padded for + ``"data"`` and zero-padded for ``"weights"``). + + Returns + ------- + dict[str, Any] + Updated *storm_metrics* with ``"data"`` as a 4-D ``np.ndarray`` + of shape ``(n_cases, n_vars, n_metrics, max_len)`` and + ``"weights"`` as a 2-D ``np.ndarray`` of shape + ``(n_cases, max_len)``. + """ + for ii in range(len(storm_metrics["case"])): + storm_metrics["data"][ii] = np.pad( + storm_metrics["data"][ii], + pad_width=( + (0, 0), + (0, 0), + (0, max_len - storm_metrics["data"][ii].shape[-1]), + ), + mode="constant", + constant_values=np.nan, + ) + storm_metrics["weights"][ii] = np.pad( + storm_metrics["weights"][ii], + pad_width=(0, max_len - storm_metrics["weights"][ii].shape[-1]), + mode="constant", + constant_values=0, + ) + + storm_metrics["data"] = np.stack(storm_metrics["data"], axis=0) + storm_metrics["weights"] = np.stack(storm_metrics["weights"], axis=0).astype(int) + + should_shape = ( + len(storm_metrics["case"]), + len(storm_metrics["var"]), + len(storm_metrics["metric"]), + max_len, + ) + if storm_metrics["data"].shape != should_shape: + raise ValueError( + f"shapes not matching when stacking cases: " + f'{storm_metrics["data"].shape=} {should_shape=}' + ) + + return storm_metrics + + +def get_individual_storm_metrics( + cases: list[str], + pred_track_dir: str, + tru_track_dir: str, + out_path: str | None, + variables: list[str] | None = None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> tuple[dict[str, Any], int, dict[str, Any], dict[str, Any]]: + """Compute per-storm error metrics and collect ensemble averages and extremes. + + Parameters + ---------- + cases : list[str] + Storm identifiers. + pred_track_dir : str + Directory containing predicted track CSV files. + tru_track_dir : str + Directory containing reference track CSV files. + out_path : str | None + Base output directory for plots. + variables : list[str] | None, optional + Variable names to evaluate, by default + ``["wind_speed", "msl", "dist"]`` + time_step : np.timedelta64, optional + Model time step used for the lead-time axis, by default 6 h. + + Returns + ------- + tuple[dict[str, Any], int, dict[str, Any], dict[str, Any]] + ``(storm_metrics, max_len, ensemble_averages, extremes)`` + """ + if variables is None: + variables = ["wind_speed", "msl", "dist"] + + storm_metrics: dict[str, Any] = { + "case": [], + "var": None, + "metric": None, + "lead time": None, + "data": [], + "weights": [], + } + max_len: int = 0 + ensemble_averages: dict[str, Any] = {} + extremes: dict[str, Any] = {} + for case in tqdm(cases, desc="loading storm data"): + tru_track, pred_tracks, ens_mean, _n_members, _out_dir = load_tracks( + case=case, + pred_track_dir=pred_track_dir, + tru_track_dir=tru_track_dir, + out_dir=out_path, + time_step=time_step, + ) + ensemble_averages[case] = ens_mean + + err_dict, _max_len = compute_averages_of_errors_over_lead_time( + pred_tracks=pred_tracks, tru_track=tru_track, variables=variables + ) + + # Strip per-storm extremes and ensemble-member counts from the metric + # axis so only the real error metrics remain. ``n_members`` is the + # same array for every variable, so we read it once per storm. + extremes[case] = {} + storm_weights: np.ndarray | None = None + for var in variables: + extremes[case][var] = {} + for ext, npfun in zip(["min", "max"], [np.nanmin, np.nanmax]): + extremes[case][var][ext + "_pred"] = err_dict[var].pop(ext) + extremes[case][var][ext + "_tru"] = npfun(tru_track[var]) + counts = err_dict[var].pop("n_members") + if storm_weights is None: + storm_weights = np.nan_to_num(counts, nan=0).astype(int) + + max_len = max(max_len, _max_len) + storm_metrics["case"].append(case) + storm_metrics["data"].append(stack_metrics(err_dict)) + storm_metrics["weights"].append(storm_weights) + + storm_metrics["var"] = list(err_dict.keys()) + storm_metrics["metric"] = list(err_dict[list(err_dict.keys())[0]].keys()) + storm_metrics["lead time"] = np.arange(max_len) * time_step + + return storm_metrics, max_len, ensemble_averages, extremes + + +def reduce_over_all_storms( + storm_metrics: dict[str, Any], +) -> dict[str, Any]: + """Average error metrics across all storms. + + Parameters + ---------- + storm_metrics : dict[str, Any] + Stacked storm metrics with ``"weights"`` key. + + Returns + ------- + dict[str, Any] + Per-variable metrics averaged over storms, plus aggregate + ``"n_members"`` counts. + """ + ensemble_metrics: dict[str, Any] = {} + for var in storm_metrics["var"]: + ensemble_metrics[var] = {} + var_idx = storm_metrics["var"].index(var) + for metric in storm_metrics["metric"]: + met_idx = storm_metrics["metric"].index(metric) + ensemble_metrics[var][metric] = np.nanmean( + storm_metrics["data"][:, var_idx, met_idx, :], axis=0 + ) + + ensemble_metrics["n_members"] = np.sum(storm_metrics["weights"], axis=0) + + return ensemble_metrics + + +def analyse_ensemble_of_storms( + cases: list[str], + pred_track_dir: str, + tru_track_dir: str, + out_path: str | None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> dict[str, Any]: + """Compute and aggregate error metrics across multiple storms. + + Parameters + ---------- + cases : list[str] + Storm identifiers. + pred_track_dir : str + Directory containing predicted track CSV files. + tru_track_dir : str + Directory containing reference track CSV files. + out_path : str | None + Base output directory for plots. + time_step : np.timedelta64, optional + Model time step used for the lead-time axis, by default 6 h. + + Returns + ------- + dict[str, Any] + Ensemble-aggregated error metrics. + """ + storm_metrics, max_len, _ens_means, _extremes = get_individual_storm_metrics( + cases, pred_track_dir, tru_track_dir, out_path, time_step=time_step + ) + + storm_metrics = stack_cases(storm_metrics, max_len) + + ensemble_metrics = reduce_over_all_storms(storm_metrics) + + return ensemble_metrics + + +def analyse_n_plot_tracks() -> None: + """Entry point for batch analysis of multiple storm cases.""" + cases = [ + "amphan_2020_north_indian", # 00 + "beryl_2024_north_atlantic", # 01 + "debbie_2017_southern_pacific", # 02 + "dorian_2019_north_atlantic", # 03 + "harvey_2017_north_atlantic", # 04 + "hato_2017_west_pacific", # 05 + "helene_2024_north_atlantic", # 06 + "ian_2022_north_atlantic", # 07 + "iota_2020_north_atlantic", # 08 + "irma_2017_north_atlantic", # 09 + "lan_2017_west_pacific", # 10 + "lee_2023_north_atlantic", # 11 + "lorenzo_2019_north_atlantic", # 12 + "maria_2017_north_atlantic", # 13 + "mawar_2023_west_pacific", # 14 + "michael_2018_north_atlantic", # 15 + "milton_2024_north_atlantic", # 16 + "ophelia_2017_north_atlantic", # 17 + "yagi_2024_west_pacific", # 18 + ] + + # case_selection = list(range(len(cases))) + case_selection = [6, 13] + individual_storms = False + ensemble_of_storms = True + time_step = _DEFAULT_TIME_STEP + + pred_track_dir = "/path/to/predictions/cyclone_tracks_te" + tru_track_dir = "/path/to/reference_tracks" + out_dir = "./plots" + + if individual_storms: + analyse_individual_storms( + cases=[cases[ii] for ii in case_selection], + pred_track_dir=pred_track_dir, + tru_track_dir=tru_track_dir, + out_path=out_dir, + time_step=time_step, + ) + + if ensemble_of_storms: + analyse_ensemble_of_storms( + cases=[cases[ii] for ii in case_selection], + pred_track_dir=pred_track_dir, + tru_track_dir=tru_track_dir, + out_path=out_dir, + time_step=time_step, + ) + + +if __name__ == "__main__": + analyse_n_plot_tracks() diff --git a/recipes/tc_tracking/plotting/data_handling.py b/recipes/tc_tracking/plotting/data_handling.py new file mode 100644 index 000000000..150479ec1 --- /dev/null +++ b/recipes/tc_tracking/plotting/data_handling.py @@ -0,0 +1,587 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import os +import sys +from typing import Any + +import numpy as np +import pandas as pd +from loguru import logger + +# Make ``src/`` importable when the plotting modules are run from +# ``recipes/tc_tracking/plotting/`` (the conventional working directory for +# the notebooks). Mirrors how ``tc_hunt.py`` exposes ``src`` to the rest of +# the recipe. +_RECIPE_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _RECIPE_ROOT not in sys.path: + sys.path.insert(0, _RECIPE_ROOT) + +from src.tc_hunt_utils import EARTH_RADIUS_M, great_circle_distance # noqa: E402 + +_DEFAULT_TIME_STEP = np.timedelta64(6, "h") + + +def merge_tracks_by_time(track: pd.DataFrame, tru_track: pd.DataFrame) -> pd.DataFrame: + """Left-join a predicted track onto a reference track by time. + + Parameters + ---------- + track : pd.DataFrame + Predicted track. + tru_track : pd.DataFrame + Reference (true) track. + + Returns + ------- + pd.DataFrame + Merged frame with ``_tru`` suffixes on reference columns, + clipped to the time range of the reference track. + """ + merged_track = pd.merge( + track, tru_track, on="time", how="left", suffixes=("", "_tru") + ) + + merged_track = merged_track[merged_track["time"] <= tru_track["time"].max()] + + return merged_track + + +def add_track_distance(track: pd.DataFrame, tru_track: pd.DataFrame) -> pd.DataFrame: + """Augment *track* with a ``dist`` column measuring great-circle distance to *tru_track*. + + Parameters + ---------- + track : pd.DataFrame + Predicted track containing ``lat`` and ``lon`` columns. + tru_track : pd.DataFrame + Reference track containing ``lat`` and ``lon`` columns. + + Returns + ------- + pd.DataFrame + Copy of *track* with an additional ``dist`` column (metres). + """ + merged_track = merge_tracks_by_time(track, tru_track)[ + ["time", "lat", "lon", "lat_tru", "lon_tru"] + ] + + dist = great_circle_distance( + merged_track["lat"], + merged_track["lon"], + merged_track["lat_tru"], + merged_track["lon_tru"], + ) + + merged_track["dist"] = dist + + track = pd.merge( + track, merged_track[["time", "dist"]], on="time", how="left", suffixes=("", "") + ) + + return track + + +def match_tracks( + pred_tracks: list[dict[str, Any]], + true_track: pd.DataFrame, + max_dist: float = 300000, +) -> list[dict[str, Any]]: + """Match predicted tracks to a reference track by proximity at first overlap. + + A predicted track is considered a match if its first position is within + ``max_dist`` metres of the reference track at the same time step. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + List of prediction dicts, each containing ``"ic"``, ``"member"``, + and ``"tracks"`` (a DataFrame). + true_track : pd.DataFrame + Reference track with ``lat_ib`` and ``lon_ib`` columns for + IBTrACS positions. + max_dist : float, optional + Maximum great-circle distance (in metres) between the first + predicted position and the reference position to count as a + match, by default 300000 (300 km). + + Returns + ------- + list[dict[str, Any]] + Subset of matched tracks, each augmented with ``"first_match"``, + ``"initial_dist"``, and a ``dist`` column on the track DataFrame. + """ + matched_tracks: list[dict[str, Any]] = [] + min_seen, max_seen = float("inf"), float("-inf") + + for _pred_track_dict in pred_tracks: + _pred_tracks = _pred_track_dict["tracks"] + + if len(_pred_tracks) == 0: + continue + + n_tracks = _pred_tracks["track_id"].iloc[-1] + 1 + + for ii in range(n_tracks): + track = _pred_tracks.loc[_pred_tracks["track_id"] == ii].copy() + + lat_pred = track["lat"].iloc[0] + lon_pred = track["lon"].iloc[0] + + time_mask = true_track["time"] == track["time"].iloc[0] + if not time_mask.any(): + continue + + ref_row = true_track.loc[time_mask] + lat_true = ref_row["lat_ib"].item() + lon_true = ref_row["lon_ib"].item() + dist = great_circle_distance(lat_pred, lon_pred, lat_true, lon_true) + + if dist <= max_dist: + min_seen, max_seen = min(min_seen, dist), max(max_seen, dist) + + track = add_track_distance(track, true_track) + + matched_tracks.append( + { + "ic": _pred_track_dict["ic"], + "member": _pred_track_dict["member"], + "first_match": track["time"].iloc[0], + "initial_dist": dist, + "tracks": track, + } + ) + break + + if matched_tracks: + logger.info( + f"matched {len(matched_tracks)} out of {len(pred_tracks)} tracks, " + f"with distances ranging from {min_seen/1000:.1f} to " + f"{max_seen/1000:.1f} km" + ) + else: + logger.info(f"matched 0 out of {len(pred_tracks)} tracks") + + return matched_tracks + + +def extract_tracks_from_file(csv_file: str) -> pd.DataFrame: + """Read a TempestExtremes track CSV and convert date columns to a single ``time`` column. + + Parameters + ---------- + csv_file : str + Path to a CSV file produced by TempestExtremes ``StitchNodes``. + + Returns + ------- + pd.DataFrame + Track data with a ``time`` column (datetime64) prepended. + """ + tracks = pd.read_csv(csv_file, sep=",") + tracks.columns = tracks.columns.str.strip() + + times = pd.to_datetime(tracks[["year", "month", "day", "hour"]].astype(int)) + + tracks.drop(columns=["year", "month", "day", "hour"], inplace=True) + if "i" in tracks.columns: + tracks.drop(columns=["i", "j"], inplace=True) + + tracks.insert(0, "time", times) + + return tracks + + +def extract_tracks(in_dir: str) -> list[dict[str, Any]]: + """Load all track CSV files from a directory. + + Parameters + ---------- + in_dir : str + Directory containing track CSV files whose names encode the + initial condition timestamp, member ID, and random seed. + + Returns + ------- + list[dict[str, Any]] + One dict per file with keys ``"ic"`` (Timestamp), ``"member"`` + (int), and ``"tracks"`` (DataFrame). + """ + tracks: list[dict[str, Any]] = [] + files = glob.glob(f"{in_dir}/*.csv") + files.sort() + + for csv_file in files: + _tracks = extract_tracks_from_file(csv_file) + + mem = int(csv_file.split("_mem_")[-1].split("_seed_")[0]) + ic = pd.to_datetime(csv_file.split("_mem_")[0][-19:]) + + tracks.append({"ic": ic, "member": mem, "tracks": _tracks}) + + return tracks + + +def compute_mae(tru_vars: np.ndarray, pred_vars: np.ndarray) -> np.ndarray: + """Compute mean absolute error along the first axis, ignoring NaNs.""" + return np.nanmean(np.abs(tru_vars - pred_vars), axis=0) + + +def compute_mse(tru_vars: np.ndarray, pred_vars: np.ndarray) -> np.ndarray: + """Compute mean squared error along the first axis, ignoring NaNs.""" + return np.nanmean((tru_vars - pred_vars) ** 2, axis=0) + + +def compute_variance(arr: np.ndarray) -> np.ndarray: + """Compute variance along the first axis, ignoring NaNs.""" + return np.nanvar(arr, axis=0) + + +def remove_trailing_nans(merged_track: pd.DataFrame, var: str) -> pd.DataFrame: + """Trim rows after the last time step where both predicted and true values are present. + + Parameters + ---------- + merged_track : pd.DataFrame + Merged track with ``var`` and ``var_tru`` columns. + var : str + Variable name (the true column is ``var + "_tru"``). + + Returns + ------- + pd.DataFrame + Truncated frame. + """ + either_nans = np.logical_or( + merged_track[var + "_tru"].isna(), merged_track[var].isna() + ) + cut_off = np.where(~either_nans)[0][-1] + + return merged_track.iloc[: cut_off + 1] + + +def rebase_by_lead_time( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + variables: list[str], +) -> tuple[dict[str, dict[str, list]], int]: + """Align predicted and true values by lead time for error computation. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + variables : list[str] + Variable names to extract. + + Returns + ------- + tuple[dict[str, dict[str, list]], int] + Per-variable dict of ``{"pred": [...], "tru": [...]}`` lists + and the maximum lead-time length across all tracks. + """ + err_dict: dict[str, dict[str, list]] = {} + for var in variables: + err_dict[var] = {"pred": [], "tru": []} + + max_len = 0 + for track in pred_tracks: + merged_track = merge_tracks_by_time(track["tracks"], tru_track) + merged_track = remove_trailing_nans(merged_track, "msl") + + max_len = max(max_len, len(merged_track)) + + for var in err_dict.keys(): + err_dict[var]["pred"].append(merged_track[var]) + err_dict[var]["tru"].append(merged_track[var + "_tru"]) + + return err_dict, max_len + + +def compute_error_metrics( + err_dict: dict[str, dict[str, list]], max_len: int +) -> dict[str, dict[str, np.ndarray]]: + """Compute MAE, MSE, variance, extremes and member counts from aligned predictions. + + Parameters + ---------- + err_dict : dict[str, dict[str, list]] + Per-variable dict of ``{"pred": [...], "tru": [...]}`` lists + as returned by :func:`rebase_by_lead_time`. + max_len : int + Maximum lead-time length used for NaN-padding. + + Returns + ------- + dict[str, dict[str, np.ndarray]] + Per-variable dict of computed metrics. + """ + for var in err_dict.keys(): + pred_vars = err_dict[var]["pred"] + tru_vars = err_dict[var]["tru"] + + counts = np.zeros(max_len, dtype=int) + for ii in range(len(pred_vars)): + counts[: len(pred_vars[ii])] += 1 + + pred_vars[ii] = np.pad( + pred_vars[ii], + (0, max_len - len(pred_vars[ii])), + mode="constant", + constant_values=np.nan, + ) + tru_vars[ii] = np.pad( + tru_vars[ii], + (0, max_len - len(tru_vars[ii])), + mode="constant", + constant_values=np.nan, + ) + + pred_vars, tru_vars = np.array(pred_vars), np.array(tru_vars) + + err_dict[var] = { + "mae": compute_mae(tru_vars, pred_vars), + "mse": compute_mse(tru_vars, pred_vars), + "variance": compute_variance(pred_vars), + "max": np.nanmax(pred_vars, axis=-1), + "min": np.nanmin(pred_vars, axis=-1), + "n_members": counts, + } + + return err_dict + + +def compute_averages_of_errors_over_lead_time( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + variables: list[str], +) -> tuple[dict[str, dict[str, np.ndarray]], int]: + """Compute error metrics averaged over ensemble members as a function of lead time. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + variables : list[str] + Variable names to evaluate. + + Returns + ------- + tuple[dict[str, dict[str, np.ndarray]], int] + Per-variable error metrics and the maximum lead-time length. + """ + err_dict, max_len = rebase_by_lead_time(pred_tracks, tru_track, variables) + + err_dict = compute_error_metrics(err_dict, max_len) + + return err_dict, max_len + + +def lat_lon_to_xyz( + lat: float | np.ndarray, + lon: float | np.ndarray, + radius: float = EARTH_RADIUS_M, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Convert latitude/longitude to 3-D Cartesian coordinates. + + Parameters + ---------- + lat : float or np.ndarray + Latitude(s) in degrees (range [-90, 90]). + lon : float or np.ndarray + Longitude(s) in degrees (range [0, 360)). + radius : float, optional + Sphere radius in metres, by default ``EARTH_RADIUS_M`` (6 371 km). + + Returns + ------- + tuple[np.ndarray, np.ndarray, np.ndarray] + ``(x, y, z)`` Cartesian coordinates. + """ + lat_rad, lon_rad = np.radians(lat), np.radians(lon) + + xx = radius * np.cos(lat_rad) * np.cos(lon_rad) + yy = radius * np.cos(lat_rad) * np.sin(lon_rad) + zz = radius * np.sin(lat_rad) + + return xx, yy, zz + + +def xyz_to_lat_lon( + xx: float | np.ndarray, + yy: float | np.ndarray, + zz: float | np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Convert 3-D Cartesian coordinates back to latitude/longitude. + + Parameters + ---------- + xx : float or np.ndarray + X coordinate(s). + yy : float or np.ndarray + Y coordinate(s). + zz : float or np.ndarray + Z coordinate(s). + + Returns + ------- + tuple[np.ndarray, np.ndarray] + ``(lat, lon)`` in degrees. Longitude is in [0, 360). + """ + radius = np.sqrt(xx**2 + yy**2 + zz**2) + + lat_rad = np.arcsin(zz / (radius + 1e-9)) + + lon_rad = np.arctan2(yy, xx) + + lat = np.degrees(lat_rad) + lon = (np.degrees(lon_rad) + 360) % 360 + + return lat, lon + + +def cartesian_to_spherical_track( + stats: dict[str, Any], + tru_track: pd.DataFrame, + frame_of_reference: pd.DataFrame, +) -> dict[str, Any]: + """Convert Cartesian ensemble-mean position back to spherical and compute distance. + + Replaces the ``x, y, z`` entries in ``stats["mean"]`` and + ``stats["variance"]`` with ``lat``, ``lon``, and ``dist``. + + Parameters + ---------- + stats : dict[str, Any] + Ensemble statistics dict (modified in place). + tru_track : pd.DataFrame + Reference track. + frame_of_reference : pd.DataFrame + Single-column ``time`` frame covering all lead times. + + Returns + ------- + dict[str, Any] + Updated *stats*. + """ + mean_lat, mean_lon = xyz_to_lat_lon( + stats["mean"]["x"], stats["mean"]["y"], stats["mean"]["z"] + ) + + for var in ["x", "y", "z"]: + for metric in ["mean", "variance"]: + del stats[metric][var] + + stats["mean"]["lat"] = mean_lat + stats["mean"]["lon"] = mean_lon + + tru_cont = pd.merge(frame_of_reference, tru_track, on="time", how="left") + + dist = great_circle_distance( + tru_cont["lat"], tru_cont["lon"], stats["mean"]["lat"], stats["mean"]["lon"] + ) + + stats["mean"]["dist"] = np.asarray(dist, dtype=float) + + return stats + + +def get_ensemble_averages( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + variables: list[str] | None = None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> dict[str, Any]: + """Compute ensemble mean and variance on the sphere. + + Averaging is performed in Cartesian space to avoid artefacts near + the antimeridian, then converted back to lat/lon. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + variables : list[str] | None, optional + Variables to average (must include ``x``, ``y``, ``z`` for the + Cartesian round-trip), by default + ``["msl", "wind_speed", "x", "y", "z"]`` + time_step : np.timedelta64, optional + Spacing of the output time axis, by default 6 h. + + Returns + ------- + dict[str, Any] + Dict with keys ``"time"``, ``"n_members"``, ``"mean"``, and + ``"variance"``. + """ + if variables is None: + variables = ["msl", "wind_speed", "x", "y", "z"] + + stats: dict[str, Any] = { + "time": None, + "n_members": None, + "mean": {var: [] for var in variables}, + "variance": {var: [] for var in variables}, + } + + last_time = pred_tracks[0]["ic"] + for track in pred_tracks: + last_time = max(last_time, track["tracks"]["time"].values[-1]) + + all_times = np.arange(pred_tracks[0]["ic"], last_time, time_step) + stats["time"] = all_times + + frame_of_reference = pd.DataFrame( + data=all_times, index=np.arange(len(all_times)), columns=["time"] + ) + + for track in pred_tracks: + xx, yy, zz = lat_lon_to_xyz(track["tracks"]["lat"], track["tracks"]["lon"]) + track["tracks"]["x"] = xx + track["tracks"]["y"] = yy + track["tracks"]["z"] = zz + + contextualised = pd.merge( + frame_of_reference, track["tracks"], on="time", how="left" + ) + + for var in variables: + stats["mean"][var].append(contextualised[var]) + + for var in variables: + stacked = np.stack(stats["mean"][var]) + counts = np.count_nonzero(~np.isnan(stacked), axis=0) + + if stats["n_members"] is None: + stats["n_members"] = counts + elif not np.all(stats["n_members"] == counts): + raise ValueError( + "n_members is not the same for all variables but should be" + ) + + stats["variance"][var] = np.nanvar(stacked, axis=0) + stats["mean"][var] = np.nanmean(stacked, axis=0) + + stats = cartesian_to_spherical_track(stats, tru_track, frame_of_reference) + + return stats diff --git a/recipes/tc_tracking/plotting/plot_tracks_n_fields_notebook.py b/recipes/tc_tracking/plotting/plot_tracks_n_fields_notebook.py new file mode 100644 index 000000000..c7d9f1d08 --- /dev/null +++ b/recipes/tc_tracking/plotting/plot_tracks_n_fields_notebook.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% [markdown] +# # Plot Evolution of Storm Tracks and Atmospheric Fields +# +# This notebook demonstrates how to visualise tropical cyclone (TC) tracks +# overlaid on atmospheric field data. Use this to validate tracking algorithm +# outputs or create publication-ready animations. +# +# ### Workflow Overview +# 1. Configure paths and plotting parameters +# 2. Load and subset field data from NetCDF +# 3. Load and align track data from CSV files +# 4. Generate animated visualisation +# 5. (Optional) Save animation to file +# +# ### Data Prerequisites +# - NetCDF file containing atmospheric field data +# - Track CSV files generated by the TC tracking algorithm +# +# Both files can be produced by configuring a tc_hunt run with +# `store_type: "netcdf"` and cyclone tracking switched on, as e.g. done +# in `cfg/reproduce_helene.yaml`. +# +# ### Configuration +# +# **Path Settings** +# - `field_data` - path to NetCDF file (e.g. +# `/path/to/outputs_reproduce_helene/helene_2024-09-24T00.00.00_mems0000-0013.nc`) +# - `track_dir` - folder containing track CSVs (e.g. +# `/path/to/outputs_reproduce_helene/cyclone_tracks_te`) +# +# **Plotting Settings** +# - `variable` - field variable to plot (e.g., `'u10m'`, `'mslp'`). +# Use `'wind_speed'` to compute 10m wind magnitude from `u10m` and `v10m` +# - `ensemble_member` - which ensemble member to visualise +# - `region` - geographic region (`'global'`, `'north_atlantic'`, +# `'gulf_of_mexico'`) +# +# **Animation Settings** +# - `max_frames` - limit number of frames (set high to include full forecast) +# - `scale` - spatial coarsening factor (1 = full resolution, 2 = half, etc.) +# - `fps` - frames per second for animation playback + +# %% +field_data = "/path/to/outputs_reproduce_helene/reproduce_helene_2024-09-24T00.00.00_mems0000-0013.nc" +track_dir = "/path/to/outputs_reproduce_helene/cyclone_tracks_te" + +variable = "wind_speed" +ensemble_member = 3 +region = "gulf_of_mexico" + +max_frames = 99 # maximum number of frames to plot +scale = 1 +fps = 4 + +# Map region label to (lat_min, lat_max, lon_min, lon_max). Add new +# entries by appending another row. +REGIONS: dict[str, tuple[float, float, float, float]] = { + "global": (-90, 90, 0, 359.75), + "west_pacific": (9, 60, 100, 180), + "north_atlantic": (9, 65, 250, 360), + "gulf_of_mexico": (9, 45, 250, 310), + "north_indian": (9, 40, 50, 100), + "southern_pacific": (-40, 5, 140, 240), +} + +try: + lat_min, lat_max, lon_min, lon_max = REGIONS[region] +except KeyError as exc: + raise ValueError( + f"region {region!r} not yet implemented. " + "Feel free to add it by providing lat/lon coords of the bounding " + "box in the REGIONS dict." + ) from exc + +# %% [markdown] +# ### Step 1: Load Field Data +# +# - Read field data from NetCDF file +# - Sub-select region and variable; coarsen data if `scale > 1` for faster +# iteration +# - Compute min/max values for consistent colour map across all timesteps + +# %% +import os +import warnings + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.animation as animation +import matplotlib.pyplot as plt +import numpy as np +import xarray as xr +from cartopy.mpl.ticker import LatitudeFormatter, LongitudeFormatter +from loguru import logger + +# fall back to ``print`` when not running inside an IPython kernel so that +# ``display(...)`` calls work both as a plain script and after +# ``jupytext --to notebook`` conversion +try: + from IPython.display import display +except ImportError: + display = print + +base_name = field_data.split("/")[-1].split(".")[0] +out_dir = f"outputs_{base_name}" + +ds = xr.open_dataset(field_data) + +# subselect lat/lon box, coarsen data if scale factor > 1 +sub_ds = ds.sel( + lat=list(np.arange(lat_min, lat_max, scale * 0.25)), + lon=list(np.arange(lon_min, lon_max, scale * 0.25)), +) + +# extract variable and obtain global min/max values +if variable == "wind_speed": + sub_ds = np.sqrt(np.square(sub_ds.u10m) + np.square(sub_ds.v10m)) +else: + sub_ds = sub_ds[variable] + +display(ds) + +# %% [markdown] +# ### Step 2: Load Track Data +# +# This step loads and processes the cyclone track data to align with field +# timestamps. +# +# **Processing steps:** +# 1. **Select file** - filter by `ensemble_member` from filenames ignoring +# the random seed +# 2. **Separate tracks** - split into individual DataFrames for each track +# 3. **Align timestamps** - reindex each track to match field data timesteps +# +# **Why reindexing matters:** +# Tracks may not have positions at every timestep (e.g., storm +# genesis/dissipation). Reindexing fills gaps with NaN, enabling +# frame-by-frame plotting without index errors. + +# %% +from data_handling import extract_tracks_from_file + +track_dir = os.path.abspath(track_dir) + +# extract time steps of field data +time_stamps = ds.time.values + ds.lead_time.values + +# select track file of ensemble member +track_file = [ + f + for f in os.listdir(track_dir) + if f.endswith(".csv") + and int(f.split("_mem_")[-1].split("_seed_")[0]) == ensemble_member +][0] +track_file = os.path.join(track_dir, track_file) + +# extract tracks from prediction +tracks = extract_tracks_from_file(track_file) + +# separate individual tracks in prediction +n_tracks = tracks["track_id"].iloc[-1] + 1 +tracks = [tracks.loc[tracks["track_id"] == ii].copy() for ii in range(n_tracks)] + +# align track data with simulation time steps +for ii in range(n_tracks): + # extract the lines of tracks for which track['time'] is in time_stamps + tracks[ii] = tracks[ii][tracks[ii]["time"].isin(time_stamps)] + + # fill the 'time' column with the time_stamps values + tracks[ii] = tracks[ii].set_index("time").reindex(time_stamps).reset_index() + +# %% [markdown] +# ### Step 3: Create Animation +# +# Creates an interactive animation where you can click through timesteps. +# Each frame shows the field data with tracks drawn progressively up to +# that point. + +# %% +colour_map = "plasma" +projection = ccrs.PlateCarree() + +# suppress line warnings stemming from potential NANs in track data +warnings.filterwarnings("ignore", message="invalid value encountered in linestrings") + +# get index of ensemble member +ensemble_idx = np.argwhere(ds.ensemble.values == ensemble_member)[0, 0] + +# get min and max vals for colour map (read using named dims for clarity) +member_field = sub_ds.isel(ensemble=ensemble_idx) +min_val = float(np.min(member_field)) +max_val = float(np.max(member_field)) + + +# define plots +def make_figure() -> tuple[plt.Figure, plt.Axes]: + """Create a figure with a single PlateCarree subplot.""" + fig = plt.figure(figsize=(11, 7)) + ax = fig.add_subplot(1, 1, 1, projection=projection) + + lon_formatter = LongitudeFormatter(zero_direction_label=False) + lat_formatter = LatitudeFormatter() + ax.xaxis.set_major_formatter(lon_formatter) + ax.yaxis.set_major_formatter(lat_formatter) + + return fig, ax + + +# make animation +plt.rcParams["animation.html"] = "jshtml" +fig, ax = make_figure() + + +def make_frame(frame: int) -> plt.Artist: + """Render a single animation frame with field data and tracks.""" + logger.info( + f"processing frame {frame+1} of {min(max_frames, sub_ds.sizes['lead_time'])}" + ) + + # Clear previous plot objects + for artist in ax.get_children(): + if hasattr(artist, "get_array"): # This targets pcolormesh objects + artist.remove() + + plot_ds = sub_ds.isel(ensemble=ensemble_idx, time=0, lead_time=max(frame, 0)) + pc = ax.pcolormesh( + sub_ds.lon, + sub_ds.lat, + plot_ds, + transform=projection, + cmap=colour_map, + vmin=min_val, + vmax=max_val, + ) + + ax.add_feature(cfeature.COASTLINE, lw=0.5) + ax.add_feature(cfeature.RIVERS, lw=0.5) + for track in tracks: + ax.plot( + track["lon"][: frame + 1], + track["lat"][: frame + 1], + transform=ccrs.PlateCarree(), + color="lime", + linewidth=1.0, + alpha=1, + ) + + # Enforce the plotting region extent + ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) + + if frame == -1: + cbar = fig.colorbar(pc, extend="both", shrink=0.8, ax=ax) # noqa: F841 + + header = f"{base_name} {variable} lead time {frame*6}h" + ax.set_title(header, fontsize=14) + + return pc + + +def animate(frame: int) -> plt.Artist: + """Animation callback forwarding to ``make_frame``.""" + return make_frame(frame) + + +def first_frame() -> plt.Artist: + """Initialisation callback rendering frame -1 (colourbar only).""" + return make_frame(-1) + + +ani = animation.FuncAnimation( + fig, + animate, + min(max_frames, sub_ds.sizes["lead_time"]), + init_func=first_frame, + blit=False, + repeat=False, + interval=1000 / fps, +) +plt.close("all") +ani + +# %% [markdown] +# ### Step 4 (Optional): Save Animation +# +# Uncomment and run the cell below to save the animation as GIF in +# ``outputs_/``. + +# %% +# plt.close("all") +# +# # Recreate figure and animation +# cbar = None +# fig = plt.figure(figsize=(12, 8)) +# ax = fig.add_subplot(1, 1, 1, projection=projection) +# +# ani = animation.FuncAnimation( +# fig, +# animate, +# min(max_frames, sub_ds.sizes["lead_time"]), +# init_func=first_frame, +# blit=False, +# repeat=False, +# interval=1000 / fps, +# ) +# +# os.makedirs(out_dir, exist_ok=True) +# ani.save( +# os.path.join(out_dir, "tracks_n_fields_ani.gif"), +# fps=fps, +# savefig_kwargs={"bbox_inches": "tight", "pad_inches": 0.1}, +# ) +# plt.close(fig) diff --git a/recipes/tc_tracking/plotting/plotting_helpers.py b/recipes/tc_tracking/plotting/plotting_helpers.py new file mode 100644 index 000000000..e2f5432c9 --- /dev/null +++ b/recipes/tc_tracking/plotting/plotting_helpers.py @@ -0,0 +1,817 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Any + +import cartopy.crs as ccrs +import cartopy.feature as cfeature +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from data_handling import merge_tracks_by_time +from matplotlib.collections import LineCollection + +_DEFAULT_TIME_STEP = np.timedelta64(6, "h") +_DEFAULT_P_REF_PA = 101325 + + +def _make_suptitle( + case: str, + ic: np.datetime64 | None = None, + n_tracks: int | None = None, + n_members: int | None = None, +) -> str: + """Build a standardised plot suptitle from storm metadata.""" + title = case.split("_")[0].upper() + if ic is not None: + title += f"\n initialised on {ic}" + if n_tracks is not None and n_members is not None: + title += f"\n {n_tracks} tracks in {n_members} ensemble members" + return title + + +def _var_display_info(var: str) -> tuple[str, str, float]: + """Return ``(display_label, unit_string, scale_divisor)`` for a tracked variable.""" + info = { + "msl": ("msl", "hPa", 100), + "dist": ("distance", "km", 1000), + "wind_speed": ("maximum instantaneous wind speed", "m/s", 1), + } + return info.get(var, (var, "", 1)) + + +def add_some_gap( + lat_min: float, lat_max: float, lon_min: float, lon_max: float +) -> tuple[float, float, float, float]: + """Expand a lat/lon bounding box by 10 % on each side and correct extreme aspect ratios. + + Parameters + ---------- + lat_min, lat_max : float + Latitude bounds in degrees. + lon_min, lon_max : float + Longitude bounds in degrees. + + Returns + ------- + tuple[float, float, float, float] + ``(lat_min, lat_max, lon_min, lon_max)`` with padding applied. + """ + gap_fac = 0.1 + lat_gap = (lat_max - lat_min) * gap_fac + lon_gap = (lon_max - lon_min) * gap_fac + + lat_min, lat_max = lat_min - lat_gap, lat_max + lat_gap + lon_min, lon_max = lon_min - lon_gap, lon_max + lon_gap + + if lat_gap / lon_gap > 2: + d_lon = 0.5 * (lat_max - lat_min) + med_lon = 0.5 * (lon_min + lon_max) + lon_min, lon_max = med_lon - d_lon / 2, med_lon + d_lon / 2 + + elif lon_gap / lat_gap > 2: + d_lat = 0.5 * (lon_max - lon_min) + med_lat = 0.5 * (lat_min + lat_max) + lat_min, lat_max = med_lat - d_lat / 2, med_lat + d_lat / 2 + + return lat_min, lat_max, lon_min, lon_max + + +def get_central_coords(track: pd.DataFrame) -> tuple[float, float]: + """Return the median latitude and longitude of a track. + + Parameters + ---------- + track : pd.DataFrame + Track with ``lat`` and ``lon`` columns. + + Returns + ------- + tuple[float, float] + ``(lat_median, lon_median)`` + """ + lat_cen = track["lat"].median() + lon_cen = track["lon"].median() + + return lat_cen, lon_cen + + +def plot_spaghetti( + true_track: pd.DataFrame, + pred_tracks: list[dict[str, Any]], + ensemble_mean: dict[str, Any], + case: str, + n_members: int, + out_dir: str | None = None, + alpha: float = 0.2, + line_width: float = 2, + ic: np.datetime64 | list[np.datetime64] | None = None, +) -> None: + """Plot ensemble track trajectories (spaghetti plot) on a map. + + Parameters + ---------- + true_track : pd.DataFrame + Reference track (plotted in red). + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + ensemble_mean : dict[str, Any] + Ensemble-mean track with ``"lat"`` and ``"lon"`` arrays. + case : str + Storm identifier for the plot title. + n_members : int + Total number of ensemble members (including unmatched). + out_dir : str | None, optional + If provided, the figure is saved here. + alpha : float, optional + Transparency for ensemble member lines, by default 0.2 + line_width : float, optional + Line width for all tracks, by default 2 + ic : np.datetime64 | list[np.datetime64] | None, optional + If provided, only plot members whose ``"ic"`` is in *ic*. + """ + plt.close("all") + + lat_cen, lon_cen = get_central_coords(true_track) + + fig = plt.figure(figsize=(22, 10)) + fig.suptitle( + _make_suptitle(case, pred_tracks[0]["ic"], len(pred_tracks), n_members), + fontsize=16, + ) + + projection = ccrs.LambertAzimuthalEqualArea( + central_longitude=lon_cen, central_latitude=lat_cen + ) + ax = fig.add_subplot(1, 1, 1, projection=projection) + + ax.add_feature(cfeature.COASTLINE, lw=0.5) + ax.add_feature(cfeature.RIVERS, lw=0.5) + if case != "debbie_2017_southern_pacific": # cartopy issues with small islands + ax.add_feature(cfeature.OCEAN, facecolor="#b0c4de") + ax.add_feature(cfeature.LAND, facecolor="#c4b9a3") + + # Seed lat/lon bounds from the true track so the loop only needs to widen + # them; works even when ``ic`` filters out every predicted member. + lat_min, lat_max = true_track["lat"].min(), true_track["lat"].max() + lon_min, lon_max = true_track["lon"].min(), true_track["lon"].max() + + segments = [] + for _track in pred_tracks: + track = _track["tracks"] + if ic is not None and _track["ic"] not in ic: + continue + + lat_min, lat_max = min(lat_min, track["lat"].min()), max( + lat_max, track["lat"].max() + ) + lon_min, lon_max = min(lon_min, track["lon"].min()), max( + lon_max, track["lon"].max() + ) + + segments.append(np.column_stack([track["lon"].values, track["lat"].values])) + + if segments: + ax.add_collection( + LineCollection( + segments, + colors="black", + linewidths=line_width, + alpha=alpha, + transform=ccrs.PlateCarree(), + ) + ) + + ax.plot( + true_track["lon"], + true_track["lat"], + transform=ccrs.PlateCarree(), + color="red", + linewidth=line_width, + alpha=1.0, + ) + + ax.plot( + ensemble_mean["lon"], + ensemble_mean["lat"], + transform=ccrs.PlateCarree(), + color="lime", + linewidth=line_width, + alpha=1.0, + ) + + lat_min, lat_max, lon_min, lon_max = add_some_gap( + lat_min, lat_max, lon_min, lon_max + ) + + plt.tight_layout() + + ax.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) + ax.gridlines(draw_labels=False, dms=False, x_inline=False, y_inline=False) + + if out_dir: + fig.savefig(os.path.join(out_dir, f"{case}_tracks.png")) + + +def normalised_intensities( + track: pd.DataFrame, + tru_track: pd.DataFrame, + var: str, + p_ref: float = _DEFAULT_P_REF_PA, +) -> pd.DataFrame: + """Normalise a track variable relative to the reference track. + + For pressure (``msl``), the normalisation is + ``(pred - ref) / (p_ref - ref)``. For other variables it is + ``(pred - ref) / ref``. + + Parameters + ---------- + track : pd.DataFrame + Predicted or ensemble-mean track. + tru_track : pd.DataFrame + Reference track. + var : str + Variable name to normalise. + p_ref : float, optional + Reference pressure (Pa) used for the ``msl`` normalisation, by + default 101 325 Pa. + + Returns + ------- + pd.DataFrame + Merged frame with *var* replaced by its normalised values. + """ + merged_track = merge_tracks_by_time(track, tru_track) + + if var == "msl": + merged_track[var] = (merged_track[var] - merged_track[var + "_tru"]) / ( + p_ref - merged_track[var + "_tru"] + ) + else: + merged_track[var] = ( + merged_track[var] - merged_track[var + "_tru"] + ) / merged_track[var + "_tru"] + + return merged_track + + +def plot_relative_over_time( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + ensemble_mean: dict[str, Any], + case: str, + n_members: int, + ics: np.datetime64 | list[np.datetime64] | None = None, + out_dir: str | None = None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> None: + """Plot normalised intensity deviations from the reference track over time. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + ensemble_mean : dict[str, Any] + Ensemble statistics dict with ``"time"`` and ``"mean"`` keys. + case : str + Storm identifier for the plot title. + n_members : int + Total number of ensemble members. + ics : np.datetime64 | list[np.datetime64] | None, optional + If provided, only plot members whose ``"ic"`` is in *ics*. + out_dir : str | None, optional + If provided, the figure is saved here. + time_step : np.timedelta64, optional + Model time step, by default 6 h. + """ + fig, _ax = plt.subplots(2, 1, figsize=(11, 11), sharex=True) + fig.suptitle( + _make_suptitle(case, pred_tracks[0]["ic"], len(pred_tracks), n_members), + fontsize=16, + ) + + variables = ["msl", "wind_speed"] + labels = [ + "(msl - msl_ref)/(101325Pa - msl_ref)", + "max_wind/max_wind_ref - 1", + ] + + rel_steps = int( + ((tru_track["time"].max() - pred_tracks[0]["ic"]) / time_step + 1) * 0.75 + ) + + for ii in range(_ax.shape[0]): + ax = _ax[ii] + # ``+inf``/``-inf`` are the natural identity values for a running + # min/max and survive an empty inner loop without flipping the axis. + vmin, vmax = float("inf"), float("-inf") + for _track in pred_tracks: + track = _track["tracks"] + if ics is not None and _track["ic"] not in ics: + continue + + track = normalised_intensities(track, tru_track, variables[ii]) + + vmin = min(vmin, track[variables[ii]][:rel_steps].min()) + vmax = max(vmax, track[variables[ii]][:rel_steps].max()) + + ax.plot(track["time"], track[variables[ii]], color="black", alpha=0.1) + + ax.set_ylabel(labels[ii]) + ax.grid(True) + if np.isfinite(vmin) and np.isfinite(vmax): + ax.set_ylim(vmin, vmax) + + ax.plot( + tru_track["time"], + np.zeros(len(tru_track)), + color="orangered", + linewidth=2.5, + label="era5 comparison", + ) + + mean = pd.DataFrame( + { + "time": ensemble_mean["time"], + variables[ii]: ensemble_mean["mean"][variables[ii]], + } + ) + _track = normalised_intensities(mean, tru_track, variables[ii]) + ax.plot( + _track["time"], + _track[variables[ii]], + color="lime", + linewidth=2.5, + label="ensemble mean", + linestyle="--", + ) + + ax.legend() + + _ax[-1].set_xlabel("time [UTC]") + + plt.xlim( + pred_tracks[0]["ic"] - time_step, + tru_track["time"].max() + time_step, + ) + + if out_dir: + plt.savefig(os.path.join(out_dir, f"{case}_rel_intensities.png")) + + +def plot_over_time( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + ensemble_mean: dict[str, Any], + case: str, + n_members: int, + variables: list[str] | None = None, + labels: list[str] | None = None, + ics: np.datetime64 | list[np.datetime64] | None = None, + out_dir: str | None = None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> None: + """Plot absolute intensity and distance time series for all ensemble members. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + ensemble_mean : dict[str, Any] + Ensemble statistics dict. + case : str + Storm identifier for the plot title. + n_members : int + Total number of ensemble members. + variables : list[str] | None, optional + Variables to plot (one subplot each), by default + ``["msl", "wind_speed", "dist"]`` + labels : list[str] | None, optional + Y-axis labels corresponding to *variables*. Auto-derived from + :func:`_var_display_info` when *None*. + ics : np.datetime64 | list[np.datetime64] | None, optional + If provided, only plot members whose ``"ic"`` is in *ics*. + out_dir : str | None, optional + If provided, the figure is saved here. + time_step : np.timedelta64, optional + Model time step, by default 6 h. + """ + if variables is None: + variables = ["msl", "wind_speed", "dist"] + if labels is None: + labels = [ + f"{_var_display_info(v)[0]} [{_var_display_info(v)[1]}]" for v in variables + ] + + fig, _ax = plt.subplots(len(variables), 1, figsize=(11, 15), sharex=True) + fig.suptitle( + _make_suptitle(case, pred_tracks[0]["ic"], len(pred_tracks), n_members), + fontsize=16, + ) + + # Seed from the first predicted track and widen as the loop progresses. + first_times = pred_tracks[0]["tracks"]["time"] + t_min, t_max = first_times.min(), first_times.max() + + for ii in range(_ax.shape[0]): + _, _, scale = _var_display_info(variables[ii]) + + for _track in pred_tracks: + track = _track["tracks"] + if ics is not None and _track["ic"] not in ics: + continue + + _ax[ii].plot( + track["time"], track[variables[ii]] / scale, color="black", alpha=0.1 + ) + + t_min, t_max = min(t_min, track["time"].min()), max( + t_max, track["time"].max() + ) + + _ax[ii].set_xlim(t_min - time_step, t_max + time_step) + _ax[ii].set_ylabel(labels[ii]) + _ax[ii].grid(True) + + _ax[ii].plot( + tru_track["time"], + tru_track[variables[ii]] / scale, + color="orangered", + linewidth=2.5, + label="era5 comparison", + ) + + _ax[ii].plot( + ensemble_mean["time"], + ensemble_mean["mean"][variables[ii]] / scale, + color="lime", + linewidth=2.5, + label="ensemble mean", + linestyle="--", + ) + _ax[ii].legend() + + _ax[-1].set_xlabel("time [UTC]") + + if out_dir: + plt.savefig(os.path.join(out_dir, f"{case}_abs_intensities.png")) + + +def plot_ib_era5( + tru_track: pd.DataFrame, + case: str, + variables: list[str] | None = None, + out_dir: str | None = None, + p_ref: float = _DEFAULT_P_REF_PA, +) -> None: + """Plot ERA5-vs-IBTrACS intensity ratios on twin y-axes. + + Parameters + ---------- + tru_track : pd.DataFrame + Reference track containing both ERA5 and IBTrACS columns. + case : str + Storm identifier for the plot title. + variables : list[str] | None, optional + Variables to compare (``"msl"`` and/or ``"wind_speed"``), by + default ``["msl", "wind_speed"]`` + out_dir : str | None, optional + If provided, the figure is saved here. + p_ref : float, optional + Reference pressure (Pa) used for the ``msl`` ratio, by default + 101 325 Pa. + """ + if variables is None: + variables = ["msl", "wind_speed"] + + plt.close("all") + + fig, ax1 = plt.subplots(1, 1, figsize=(8, 5)) + fig.suptitle(_make_suptitle(case), fontsize=16) + + ax2 = ax1.twinx() + + if "msl" in variables: + ax1.plot( + tru_track["time"], + (p_ref - tru_track["msl"]) / (p_ref - tru_track["msl_ib"]), + "black", + ) + ax1.set_ylabel("(1013.25hPa-msl_era5)/(1013.25hPa-msl_ib)", color="black") + + if "wind_speed" in variables: + ax2.plot( + tru_track["time"], + tru_track["wind_speed"] / tru_track["wind_speed_ib"], + "orangered", + ) + ax2.set_ylabel("wind_speed_era5/wind_speed_ib", color="orangered") + + fig.tight_layout() + if out_dir: + plt.savefig(os.path.join(out_dir, f"{case}_ib_era5_wind_speed.png")) + + +def root_metrics( + err_dict: dict[str, dict[str, np.ndarray]], +) -> dict[str, dict[str, np.ndarray]]: + """Replace MSE/variance with RMSE/standard-deviation and drop member counts. + + Parameters + ---------- + err_dict : dict[str, dict[str, np.ndarray]] + Per-variable error metrics (modified in place). + + Returns + ------- + dict[str, dict[str, np.ndarray]] + Updated *err_dict*. + """ + for var in err_dict.keys(): + mse = err_dict[var].pop("mse") + err_dict[var]["rmse"] = np.sqrt(mse) + variance = err_dict[var].pop("variance") + err_dict[var]["standard_deviation"] = np.sqrt(variance) + err_dict[var].pop("n_members") + + return err_dict + + +def plot_errors_over_lead_time( + err_dict: dict[str, dict[str, np.ndarray]], + case: str, + ic: np.datetime64, + n_members: int, + n_tracks: int, + norm_dict: dict[str, float] | None = None, + unit_dict: dict[str, str] | None = None, + out_dir: str | None = None, + time_step: np.timedelta64 = _DEFAULT_TIME_STEP, +) -> None: + """Plot error metrics (RMSE, MAE, standard deviation) as a function of lead time. + + Parameters + ---------- + err_dict : dict[str, dict[str, np.ndarray]] + Per-variable error metrics. + case : str + Storm identifier for the plot title. + ic : np.datetime64 + Initial condition timestamp. + n_members : int + Total number of ensemble members. + n_tracks : int + Number of matched tracks. + norm_dict : dict[str, float] | None, optional + Normalisation divisors for display units. Auto-derived from + :func:`_var_display_info` when *None*. + unit_dict : dict[str, str] | None, optional + Display unit strings. Auto-derived when *None*. + out_dir : str | None, optional + If provided, the figure is saved here. + time_step : np.timedelta64, optional + Model time step used for the lead-time axis, by default 6 h. + """ + if "mse" in err_dict[list(err_dict.keys())[0]].keys(): + err_dict = root_metrics(err_dict) + + if norm_dict is None: + norm_dict = {v: _var_display_info(v)[2] for v in err_dict} + if unit_dict is None: + unit_dict = {v: _var_display_info(v)[1] for v in err_dict} + + variables = list(err_dict.keys()) + metrics = list(err_dict[variables[0]].keys()) + + for extreme in ["min", "max"]: + if extreme in metrics: + metrics.remove(extreme) + + lead_time = np.arange(err_dict[variables[0]][metrics[0]].shape[0]) * time_step + + fig, ax = plt.subplots( + len(variables), + len(metrics), + figsize=((len(metrics) + 1) * 2, (len(variables) + 1) * 2), + sharex=True, + ) + + for ivar, var in enumerate(err_dict.keys()): + for imet, metric in enumerate(metrics): + + ax[ivar, imet].plot(lead_time, err_dict[var][metric] / norm_dict[var]) + + if ivar == 0: + ax[ivar, imet].set_title(metric, fontsize=12, weight="bold") + + if imet == 0: + ax[ivar, imet].set_ylabel( + f"{var} [{unit_dict[var]}]", fontsize=12, weight="bold" + ) + + if ivar == len(variables) - 1: + ax[ivar, imet].set_xlabel("lead time [h]", fontsize=12) + + fig.suptitle(_make_suptitle(case, ic, n_tracks, n_members), fontsize=16) + + fig.tight_layout() + if out_dir: + plt.savefig(os.path.join(out_dir, f"{case}_error_metrics_over_lead_time.png")) + + +def extract_reference_extremes( + tru_track: pd.DataFrame, + pred_tracks: list[dict[str, Any]], + ens_mean: dict[str, Any], + variables: list[str], +) -> dict[str, dict[str, Any]]: + """Extract per-member extreme values and the corresponding reference extremes. + + Parameters + ---------- + tru_track : pd.DataFrame + Reference track. + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + ens_mean : dict[str, Any] + Ensemble statistics dict. + variables : list[str] + Variables to extract extremes for. + + Returns + ------- + dict[str, dict[str, Any]] + Per-variable dict with ``"pred"`` (array), ``"tru"`` (scalar), + and ``"ens_mean"`` (scalar). + """ + extreme_dict: dict[str, dict[str, Any]] = {} + for var in variables: + if var in ["wind_speed"]: + reduce_fn = np.nanmax + elif var in ["msl"]: + reduce_fn = np.nanmin + else: + continue + + extreme_dict[var] = { + "pred": np.zeros(len(pred_tracks)), + "tru": reduce_fn(tru_track[var]), + "ens_mean": reduce_fn(ens_mean["mean"][var]), + } + for ii, track in enumerate(pred_tracks): + extreme_dict[var]["pred"][ii] = reduce_fn(track["tracks"][var]) + + return extreme_dict + + +def add_stats_box( + ax: plt.Axes, + pred_var: np.ndarray, + tru_var: float, + var: str, + reduction: str, + unit: str, +) -> None: + """Add a text box with summary statistics below a histogram axis. + + Parameters + ---------- + ax : plt.Axes + Matplotlib axes to annotate. + pred_var : np.ndarray + Per-member extreme values. + tru_var : float + Reference extreme value. + var : str + Variable name. + reduction : str + ``"max"`` or ``"min"``. + unit : str + Display unit string. + """ + # For wind speed, "more intense" means larger values; for msl it means + # smaller values. Pick the comparator that actually counts members that + # are *more intense* than the reference. + if var == "wind_speed": + n_beyond_ref = int((pred_var > tru_var).sum()) + comp = "exceeding" + else: + n_beyond_ref = int((pred_var < tru_var).sum()) + comp = "below" + n_total = len(pred_var) + + stats = [ + ("era5 reference:", f"{tru_var:.1f} {unit}"), + ( + f"members {comp} ref:", + f"{n_beyond_ref} of {n_total} ({(n_beyond_ref/n_total)*100:.1f}%)", + ), + (f"max {reduction} {var}:", f"{pred_var.max():.1f} {unit}"), + (f"min {reduction} {var}:", f"{pred_var.min():.1f} {unit}"), + (f"avg {reduction} {var}:", f"{pred_var.mean():.1f} {unit}"), + (f"std {reduction} {var}:", f"{pred_var.std():.1f} {unit}"), + ] + + max_label_width = max(len(label) for label, _ in stats) + text = "\n".join([f"{label:<{max_label_width}} {value}" for label, value in stats]) + + ax.text( + 0.01, + -0.25, + text, + transform=ax.transAxes, + ha="left", + va="top", + bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray", alpha=0.8), + fontfamily="monospace", + ) + + +def plot_extreme_extremes_histograms( + pred_tracks: list[dict[str, Any]], + tru_track: pd.DataFrame, + ensemble_mean: dict[str, Any], + case: str, + variables: list[str] | None = None, + out_dir: str | None = None, + nbins: int = 12, +) -> None: + """Plot histograms of per-member extreme values with reference lines. + + Parameters + ---------- + pred_tracks : list[dict[str, Any]] + Matched prediction dicts. + tru_track : pd.DataFrame + Reference track. + ensemble_mean : dict[str, Any] + Ensemble statistics dict. + case : str + Storm identifier for the plot title. + variables : list[str] | None, optional + Variables to plot (one subplot each), by default + ``["wind_speed", "msl"]`` + out_dir : str | None, optional + If provided, the figure is saved here. + nbins : int, optional + Number of histogram bins, by default 12 + """ + if variables is None: + variables = ["wind_speed", "msl"] + + extreme_dict = extract_reference_extremes( + tru_track, pred_tracks, ensemble_mean, variables + ) + + fig, ax = plt.subplots( + 1, len(variables), figsize=(3 * (len(variables) + 1), 6), sharey=True + ) + fig.suptitle( + _make_suptitle(case, pred_tracks[0]["ic"]), + fontsize=16, + ) + ax[0].set_ylabel("count") + + for ii, var in enumerate(variables): + + reduction = "max" if var in ["wind_speed"] else "min" + _, unit, scale = _var_display_info(var) + + pred_var = extreme_dict[var]["pred"] / scale + tru_var = extreme_dict[var]["tru"] / scale + mean_var = extreme_dict[var]["ens_mean"] / scale + + ax[ii].hist(pred_var, bins=nbins) + ax[ii].axvline( + tru_var, color="orangered", linestyle="--", label="era5 reference" + ) + ax[ii].axvline(mean_var, color="lime", linestyle="--", label="ensemble mean") + + ax[ii].set_title(f"{reduction} {var} (x, t)") + ax[ii].set_xlabel(f"{var} [{unit}]") + ax[ii].legend() + + add_stats_box(ax[ii], pred_var, tru_var, var, reduction, unit) + + fig.tight_layout() + if out_dir: + fig.savefig(os.path.join(out_dir, f"{case}_histograms.png")) diff --git a/recipes/tc_tracking/plotting/tracks_slayground_notebook.py b/recipes/tc_tracking/plotting/tracks_slayground_notebook.py new file mode 100644 index 000000000..2754bd687 --- /dev/null +++ b/recipes/tc_tracking/plotting/tracks_slayground_notebook.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% [markdown] +# # Analyse Tropical Cyclone Track Ensembles +# +# This notebook demonstrates how to analyse and validate ensemble tropical +# cyclone (TC) track predictions. Use this to compare forecast tracks against +# observations +# ([IBTrACS](https://www.ncei.noaa.gov/products/international-best-track-archive)) +# and reanalysis-based reference tracks (ERA5). We explain the workflow using +# the example of +# [Hurricane Helene](https://en.wikipedia.org/wiki/Hurricane_Helene), but it +# can be easily configured to investigate other storms. +# +# ### Workflow Overview +# 1. Configure case and paths for predicted ensemble and reference track +# 2. Plot track trajectories (spaghetti plot) +# 3. Plot absolute intensities over time (wind speed, MSLP) +# 4. Plot relative intensities over time (normalised by reference) +# 5. Compare ERA5 reference against IBTrACS observations +# 6. Analyse extreme value statistics (histograms) +# 7. Compute error moments over lead time +# +# ### Data Prerequisites +# - Ensemble of predicted TC tracks (CSV files from tracking algorithm) +# - Reference track from ERA5 or IBTrACS observations (CSV file) +# +# Both can be produced by configuring a `tc_hunt.py` run with cyclone tracking +# enabled. The ensemble of predicted tracks can, for example, be produced with +# the config `cfg/helene.yaml`. Reference tracks can be extracted using +# `cfg/extract_era5.yaml`. +# +# ### Analysis Notes +# - Track positions are typically well-represented by the model +# - Intensity metrics (wind speed, pressure) show larger biases due to the +# coarse (~0.25°) resolution, which cannot fully resolve TC structure +# - IBTrACS provides observed best track data; ERA5 represents the best +# achievable reference for reanalysis-driven forecasts +# +# ### Running the notebook +# This script (and its sibling helpers) imports from `data_handling.py`, +# `plotting_helpers.py` and `analyse_n_plot.py` using bare module names, so +# the working directory must be `recipes/tc_tracking/plotting/`. From the +# recipe root: +# +# ```bash +# cd plotting +# jupytext --to notebook tracks_slayground_notebook.py +# jupyter notebook tracks_slayground.ipynb +# ``` +# +# ### Configuration +# - `case` - named storm to analyse. Can be expanded with additional named +# storms by following the pattern `{name}_{YYYY}_{basin}` +# - `pred_track_dir` - folder containing predicted track CSVs (if using data +# produced with `cfg/helene.yaml`, this would be +# `/path/to/outputs_helene/cyclone_tracks_te`) +# - `tru_track_dir` - folder containing reference track CSV file (if using +# data produced with `cfg/extract_era5.yaml`, this would be +# `/path/to/outputs_reference_tracks/`) +# - `out_dir` - path for storing plots +# - `time_step` - cadence of the predictions; defaults to 6 h to match the +# stock FCN3 / AIFS-ENS configurations + +# %% +import numpy as np +from analyse_n_plot import load_tracks +from data_handling import compute_averages_of_errors_over_lead_time +from plotting_helpers import ( + plot_errors_over_lead_time, + plot_extreme_extremes_histograms, + plot_ib_era5, + plot_over_time, + plot_relative_over_time, + plot_spaghetti, +) + +# case = 'amphan_2020_north_indian' +# case = 'beryl_2024_north_atlantic' +# case = 'debbie_2017_southern_pacific' +# case = 'dorian_2019_north_atlantic' +# case = 'harvey_2017_north_atlantic' +case = "hato_2017_west_pacific" +# case = 'helene_2024_north_atlantic' +# case = 'ian_2022_north_atlantic' +# case = 'iota_2020_north_atlantic' +# case = 'irma_2017_north_atlantic' +# case = 'lan_2017_west_pacific' +# case = 'lee_2023_north_atlantic' +# case = 'lorenzo_2019_north_atlantic' +# case = 'maria_2017_north_atlantic' +# case = 'mawar_2023_west_pacific' +# case = 'michael_2018_north_atlantic' +# case = 'milton_2024_north_atlantic' +# case = 'ophelia_2017_north_atlantic' +# case = 'yagi_2024_west_pacific' +# case = 'erin_2025_north_atlantic' + +pred_track_dir = "/path/to/outputs_hato/cyclone_tracks_te" +tru_track_dir = "/path/to/outputs_reference_tracks" +out_dir = "./plots" +time_step = np.timedelta64(6, "h") + +tru_track, pred_tracks, ens_mean, n_members, out_dir = load_tracks( + case=case, + pred_track_dir=pred_track_dir, + tru_track_dir=tru_track_dir, + out_dir=out_dir, + time_step=time_step, +) + +# %% [markdown] +# ### Spaghetti Plot +# +# - **Ensemble members** are shown in grey +# - **Ensemble mean** is displayed in green +# - **ERA5 reference** is shown in red + +# %% +plot_spaghetti( + true_track=tru_track, + pred_tracks=pred_tracks, + ensemble_mean=ens_mean["mean"], + case=case, + n_members=n_members, + out_dir=out_dir, +) + +# %% [markdown] +# ### Plot Absolute Intensities and Track Distance Over Time +# +# This section examines the temporal evolution of cyclone intensities and the +# distance from the reference track. Intensities are represented by minimum +# sea level pressure and maximum wind speed. The reference track should mostly +# fall within the ensemble spread for intensity predictions. However, some +# storms exhibit phenomena such as rapid intensification that cannot be +# adequately captured by models on quarter-degree resolution. In such cases, +# the reference intensity may lie outside the ensemble spread, indicating +# model limitations in resolving fine-scale processes. + +# %% +plot_over_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + n_members=n_members, + out_dir=out_dir, + time_step=time_step, +) + +# %% [markdown] +# ### Plot Relative Intensities Over Time +# +# This cell shows the same intensity metrics as in the previous cell, this +# time normalised by the reference. Note that for the pressure field we +# normalise the deviation from normal pressure. This normalisation helps to +# identify systematic biases in the ensemble predictions and highlights +# periods where the model over- or underestimates cyclone intensity relative +# to observations. + +# %% +plot_relative_over_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + n_members=n_members, + out_dir=out_dir, + time_step=time_step, +) + +# %% [markdown] +# ### Plot ERA5 Against IBTrACS Variables +# +# This cell compares the intensities reached in the ERA5 reanalysis data +# against those obtained from IBTrACS observations. The deviation between +# both datasets is usually larger the more intense the storm becomes. Note +# that there are two separate y-axes for the different intensity metrics +# (pressure and wind speed). + +# %% +plot_ib_era5( + tru_track=tru_track, + case=case, + variables=["msl", "wind_speed"], + out_dir=out_dir, +) + +# %% [markdown] +# ### Extreme Values Over Lifetime of Storm +# +# This cell computes the maximum intensity reached along each track +# throughout the storm's lifetime and displays the distribution across +# ensemble members as histograms. For comparison, the extreme values from the +# reference track are shown as vertical lines. + +# %% +plot_extreme_extremes_histograms( + pred_tracks=pred_tracks, + tru_track=tru_track, + ensemble_mean=ens_mean, + case=case, + out_dir=out_dir, +) + +# %% [markdown] +# ### Statistics +# +# This cell computes error metrics as a function of lead time across all +# ensemble members. The following statistics are calculated: mean absolute +# error, root mean square error, and standard deviation for wind speed, +# pressure intensity, and track distance. + +# %% +variables = ["wind_speed", "msl", "dist"] + +err_dict, _ = compute_averages_of_errors_over_lead_time( + pred_tracks=pred_tracks, + tru_track=tru_track, + variables=variables, +) + +plot_errors_over_lead_time( + err_dict=err_dict, + case=case, + ic=pred_tracks[0]["ic"], + n_members=n_members, + n_tracks=len(pred_tracks), + out_dir=out_dir, + time_step=time_step, +) diff --git a/recipes/tc_tracking/src/tc_hunt_utils.py b/recipes/tc_tracking/src/tc_hunt_utils.py index b80b6732d..9a0858263 100644 --- a/recipes/tc_tracking/src/tc_hunt_utils.py +++ b/recipes/tc_tracking/src/tc_hunt_utils.py @@ -27,6 +27,8 @@ from earth2studio.utils.time import to_time_array +EARTH_RADIUS_M = 6371000 + def set_initial_times(cfg: DictConfig) -> np.ndarray: """Build array of initial conditions. @@ -342,28 +344,34 @@ def __call__( def great_circle_distance( - lat1: float, lon1: float, lat2: float, lon2: float, radius: float = 6371000 -) -> float: + lat1: float | np.ndarray, + lon1: float | np.ndarray, + lat2: float | np.ndarray, + lon2: float | np.ndarray, + radius: float = EARTH_RADIUS_M, +) -> float | np.ndarray: """Compute the great-circle distance between two points on a sphere. - Uses the Haversine formula on the sphere, the radius of which is - defautlting to Earth's mean radius of 6371 km. + Uses the Haversine formula on the sphere, the radius of which defaults + to Earth's mean radius of 6371 km. Parameters ---------- - lat1 : float - Latitude of the first point in degrees. - lon1 : float - Longitude of the first point in degrees. - lat2 : float - Latitude of the second point in degrees. - lon2 : float - Longitude of the second point in degrees. + lat1 : float or np.ndarray + Latitude(s) of the first point in degrees. + lon1 : float or np.ndarray + Longitude(s) of the first point in degrees. + lat2 : float or np.ndarray + Latitude(s) of the second point in degrees. + lon2 : float or np.ndarray + Longitude(s) of the second point in degrees. + radius : float, optional + Sphere radius in metres, by default ``EARTH_RADIUS_M`` (6 371 km). Returns ------- - float - Distance in metres. + float or np.ndarray + Distance(s) in metres. """ lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2]) dlon = lon2 - lon1