Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
79aba95
split WeatherGenReader functionality to allow reading only JSON
Dec 4, 2025
7131983
informative error when metrics are not there
Dec 4, 2025
4831b40
Merge branch 'develop' into json_reader
s6sebusc Jan 7, 2026
45dfaf2
restore JSONreader after rebase
Jan 7, 2026
28db658
JSONreader mostly restored
Jan 7, 2026
0a79833
Merge branch 'develop' into json_reader
s6sebusc Jan 7, 2026
8fd16e9
MLFlow logging independent of JSON/zarr
Jan 7, 2026
305c63c
Merge branch 'develop' into json_reader
iluise Jan 8, 2026
2863b2f
Merge branch 'develop' into json_reader
s6sebusc Jan 8, 2026
7cbcdba
linting, properly cheking fsteps, ens, samples in JSONreader
Jan 9, 2026
c08f0cc
Merge branch 'develop' into json_reader
s6sebusc Jan 12, 2026
c06a653
Merge branch 'develop' into json_reader
s6sebusc Jan 13, 2026
90bfe4b
tiny change to restore the MergeReader
Jan 13, 2026
816e8cd
lint
iluise Jan 14, 2026
8645bd9
adding upstream changes to zarr file handling
Jan 19, 2026
6fe6d49
enabling JSONreader to skip plots and missing scores gracefully
Jan 19, 2026
969ac63
required reformatting
Jan 19, 2026
9e009cd
move skipping of metrics to the reader class
Jan 20, 2026
f7fd406
slighly more explicit formulations
Jan 20, 2026
aca608a
Merge branch 'develop' into json_reader
s6sebusc Jan 20, 2026
2c88a5e
Merge remote-tracking branch 'upstream/develop' into json_reader
Feb 3, 2026
19b51c6
first attempt passing parameters to individual metrics
Feb 3, 2026
5827a23
Merge branch 'develop' into metric_parameters
s6sebusc Feb 3, 2026
f221261
merged parameters into the metrics entry of the config
Feb 18, 2026
35893d6
Merge branch 'develop' into metric_parameters
s6sebusc Feb 18, 2026
2308c0c
updated type hints, trying to restore plot_maps
Feb 18, 2026
305dde5
Merge branch 'develop' into metric_parameters
s6sebusc Feb 18, 2026
ee8bcc6
Merge branch 'develop' into metric_parameters
s6sebusc Feb 19, 2026
91c9489
restored mergereader, bugfixes, lintig
Feb 19, 2026
45414e3
Merge branch 'develop' into metric_parameters
s6sebusc Feb 20, 2026
23d053b
Merge branch 'develop' into metric_parameters
iluise Feb 25, 2026
894239e
metric_list_to_json style change, restored line plots
Feb 25, 2026
aed3c2a
Merge branch 'develop' into metric_parameters
iluise Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion config/evaluate/eval_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,12 @@ run_ids :
forecast_step: "all"
sample: "all"


#mertics with parameters example:
###############################

#Example of syntax to pass parameters to individual metrics
evaluation:
metrics:
- fbi:
thresh: 280
- rmse
38 changes: 23 additions & 15 deletions packages/evaluate/src/weathergen/evaluate/io/wegen_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def get_channels(self, stream: str) -> list[str]:
return all_channels

def load_scores(
self, stream: str, regions: list[str], metrics: list[str]
self, stream: str, regions: list[str], metrics: dict[str, object]
) -> xr.DataArray | None:
"""
Load multiple pre-computed scores for a given run, stream and metric and epoch.
Expand Down Expand Up @@ -180,8 +180,8 @@ def load_scores(
local_scores = {}
missing_metrics = {}
for region in regions:
for metric in metrics:
score = self.load_single_score(stream, region, metric)
for metric, parameters in metrics.items():
score = self.load_single_score(stream, region, metric, parameters)
if score is not None:
available_data = self.check_availability(stream, score, mode="evaluation")
if available_data.score_availability:
Expand All @@ -196,26 +196,34 @@ def load_scores(
continue

# all other cases: recompute scores
missing_metrics.setdefault(region, []).append(metric)
missing_metrics.setdefault(region, {}).update({metric: parameters})
continue
recomputable_missing_metrics = self.get_recomputable_metrics(missing_metrics)
return local_scores, recomputable_missing_metrics

def load_single_score(self, stream: str, region: str, metric: str) -> xr.DataArray | None:
def load_single_score(
self, stream: str, region: str, metric: str, parameters: dict | None = None
) -> xr.DataArray | None:
"""
Load a single pre-computed score for a given run, stream and metric
"""
if parameters is None:
parameters = {}
score_path = (
Path(self.metrics_dir)
/ f"{self.run_id}_{stream}_{region}_{metric}_chkpt{self.mini_epoch:05d}.json"
)
_logger.debug(f"Looking for: {score_path}")
score = None
if score_path.exists():
with open(score_path) as f:
data_dict = json.load(f)
score = xr.DataArray.from_dict(data_dict)
else:
score = None
if "scores" not in data_dict:
data_dict = {"scores": [data_dict]}
for score_version in data_dict["scores"]:
if score_version["attrs"] == parameters:
score = xr.DataArray.from_dict(score_version)
break
return score

def get_recomputable_metrics(self, metrics):
Expand Down Expand Up @@ -253,7 +261,7 @@ def __init__(
run_id: str,
private_paths: dict | None = None,
regions: list[str] | None = None,
metrics: list[str] | None = None,
metrics: dict[str, object] | None = None,
):
super().__init__(eval_cfg, run_id, private_paths)
# goes looking for the coordinates available for all streams, regions, metrics
Expand All @@ -265,8 +273,8 @@ def __init__(
} # remember who had which coords, so we can warn about it later.
for stream in streams:
for region in regions:
for metric in metrics:
score = self.load_single_score(stream, region, metric)
for metric, parameters in metrics.items():
score = self.load_single_score(stream, region, metric, parameters)
if score is not None:
for name in coord_names:
vals = set(score[name].values)
Expand Down Expand Up @@ -321,8 +329,8 @@ def __init__(self, eval_cfg: dict, run_id: str, private_paths: dict | None = Non
):
self.fname_zarr = fname_zarr
else:
_logger.error(f"Zarr file {self.fname_zarr} does not exist.")
raise FileNotFoundError(f"Zarr file {self.fname_zarr} does not exist")
_logger.error(f"Zarr file {fname_zarr} does not exist.")
raise FileNotFoundError(f"Zarr file {fname_zarr} does not exist")

def get_data(
self,
Expand Down Expand Up @@ -875,9 +883,9 @@ def load_scores(
if isinstance(self.readers[0], WeatherGenZarrReader):
# TODO: implement this properly. Not it is skipping loading scores
for region in regions:
for metric in metrics:
for metric, parameters in metrics.items():
# all other cases: recompute scores
missing_metrics.setdefault(region, []).append(metric)
missing_metrics.setdefault(region, {}).update({metric: parameters})
else: # JsonReader
# deep merge dicts
for reader in self.readers:
Expand Down
20 changes: 8 additions & 12 deletions packages/evaluate/src/weathergen/evaluate/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Third-party
import mlflow
from mlflow.client import MlflowClient
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig, OmegaConf, open_dict

# Local application / package
from weathergen.common.config import _REPO_ROOT
Expand All @@ -39,6 +39,7 @@
calc_scores_per_stream,
merge,
metric_list_to_json,
parse_metric_params,
plot_data,
plot_summary,
triple_nested_dict,
Expand Down Expand Up @@ -153,6 +154,8 @@ def evaluate_from_args(argl: list[str], log_queue: mp.Queue) -> None:
_logger.info(f"MLFlow client set up: {mlflow_client}")

cf = OmegaConf.load(config)
with open_dict(cf):
cf.evaluation.metrics = parse_metric_params(cf.evaluation.metrics)
assert isinstance(cf, DictConfig)
evaluate_from_config(cf, mlflow_client, log_queue)

Expand All @@ -163,7 +166,7 @@ def get_reader(
run_id: str,
private_paths: dict[str, str],
region: str | None = None,
metric: str | None = None,
metric: dict[str, object] | None = None,
):
if reader_type == "zarr":
reader = WeatherGenZarrReader(run, run_id, private_paths)
Expand Down Expand Up @@ -195,7 +198,7 @@ def _process_stream(
private_paths: dict[str, str],
global_plotting_opts: dict[str, object],
regions: list[str],
metrics: list[str],
metrics: dict[str, object],
plot_score_maps: bool,
) -> tuple[str, str, dict[str, dict[str, dict[str, float]]]]:
"""
Expand All @@ -217,7 +220,7 @@ def _process_stream(
regions:
List of regions to be processed.
metrics:
List of metrics to be processed.
Dict of metrics to be processed and their parameters.
plot_score_maps:
Bool to define if the score maps need to be plotted or not.
"""
Expand All @@ -237,11 +240,7 @@ def _process_stream(
if not stream_dict.get("evaluation"):
return run_id, stream, {}

stream_loaded_scores, recomputable_metrics = reader.load_scores(
stream,
regions,
metrics,
)
stream_loaded_scores, recomputable_metrics = reader.load_scores(stream, regions, metrics)
scores_dict = stream_loaded_scores

if recomputable_metrics or (plot_score_maps and type_ == "zarr"):
Expand Down Expand Up @@ -312,9 +311,6 @@ def evaluate_from_config(
if "streams" not in run:
run["streams"] = default_streams

regions = cfg.evaluation.regions
metrics = cfg.evaluation.metrics

reader = get_reader(type_, run, run_id, private_paths, regions, metrics)

for stream in reader.streams:
Expand Down
13 changes: 9 additions & 4 deletions packages/evaluate/src/weathergen/evaluate/scores/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def get_score(
group_by_coord: str | None = None,
ens_dim: str = "ens",
compute: bool = False,
parameters: dict | None = None,
**kwargs,
) -> xr.DataArray:
"""
Expand Down Expand Up @@ -137,9 +138,10 @@ def get_score(
xr.DataArray
Calculated score as an xarray DataArray.
"""
if parameters is None:
parameters = {}
sc = Scores(agg_dims=agg_dims, ens_dim=ens_dim)

score_data = sc.get_score(data, score_name, group_by_coord, **kwargs)
score_data = sc.get_score(data, score_name, group_by_coord, parameters=parameters, **kwargs)
if compute:
# If compute is True, compute the score immediately
return score_data.compute()
Expand Down Expand Up @@ -208,6 +210,7 @@ def get_score(
score_name: str,
group_by_coord: str | None = None,
compute: bool = False,
parameters: dict | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -245,6 +248,8 @@ def get_score(
Calculated score as an xarray DataArray.

"""
if parameters is None:
parameters = {}
if score_name in self.det_metrics_dict.keys():
f = self.det_metrics_dict[score_name]
_logger.debug(f"Using deterministic metric: {score_name}")
Expand Down Expand Up @@ -316,14 +321,14 @@ def get_score(
group_slice = {
k: (v[name] if v is not None else v) for k, v in grouped_args.items()
}
res = f(**group_slice)
res = f(**group_slice, **parameters)
# Add coordinate for concatenation
res = res.expand_dims({group_by_coord: [name]})
results.append(res)
result = xr.concat(results, dim=group_by_coord)
else:
# No grouping: just call the function
result = f(**args)
result = f(**args, **parameters)

if compute:
return result.compute()
Expand Down
78 changes: 63 additions & 15 deletions packages/evaluate/src/weathergen/evaluate/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def calc_scores_per_stream(
f"RUN {reader.run_id} - {stream}: Calculating scores for region {region}"
f" and metrics {metrics}..."
)

metric_stream = xr.DataArray(
np.full(
(len(samples), len(fsteps), len(channels), len(metrics), len(ensemble)),
Expand All @@ -141,7 +140,7 @@ def calc_scores_per_stream(
"sample": samples,
"forecast_step": fsteps,
"channel": channels,
"metric": metrics,
"metric": list(metrics.keys()),
"ens": ensemble,
},
)
Expand Down Expand Up @@ -180,9 +179,13 @@ def calc_scores_per_stream(
valid_scores = []
valid_metric_names = []

for metric in metrics:
for metric, parameters in metrics.items():
score = get_score(
score_data, metric, agg_dims="ipoint", group_by_coord=group_by_coord
score_data,
metric,
agg_dims="ipoint",
group_by_coord=group_by_coord,
parameters=parameters,
)
if score is not None:
valid_scores.append(score)
Expand Down Expand Up @@ -288,8 +291,8 @@ def calc_scores_per_stream(
_logger.debug(f"all_metric_attrs keys: {list(all_metric_attrs.keys())}")

# Build local dictionary for this region
for metric in metrics:
metric_data = metric_stream.sel({"metric": metric})
for metric, parameters in metrics.items():
metric_data = metric_stream.sel({"metric": metric}).assign_attrs(parameters)
# Restore metric-specific attributes from all forecast steps
# Attributes are the same across forecast steps for a given metric
for (_stored_fstep, stored_metric), attrs in all_metric_attrs.items():
Expand All @@ -311,7 +314,7 @@ def _plot_score_maps_per_stream(
stream: str,
region: str,
score_data: VerifiedData,
metrics: list[str],
metrics: dict[str, object],
fstep: int,
) -> None:
"""Plot 2D score maps for all metrics and channels.
Expand Down Expand Up @@ -354,7 +357,7 @@ def _plot_score_maps_per_stream(
preds = score_data.prediction

plot_metrics = xr.concat(
[get_score(score_data, m, agg_dims="sample") for m in metrics],
[get_score(score_data, m, agg_dims="sample", parameters=p) for m, p in metrics.items()],
dim="metric",
coords="minimal",
combine_attrs="drop_conflicts",
Expand All @@ -363,7 +366,7 @@ def _plot_score_maps_per_stream(
plot_metrics = plot_metrics.assign_coords(
lat=preds.lat.reset_coords(drop=True),
lon=preds.lon.reset_coords(drop=True),
metric=metrics,
metric=list(metrics.keys()),
).compute()

if "ens" in preds.dims:
Expand Down Expand Up @@ -521,10 +524,7 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None:


def metric_list_to_json(
reader: Reader,
stream: str,
metrics_dict: list[xr.DataArray],
regions: list[str],
reader: Reader, stream: str, metrics_dict: list[xr.DataArray], regions: list[str]
):
"""
Write the evaluation results collected in a list of xarray DataArrays for the metrics
Expand Down Expand Up @@ -552,10 +552,34 @@ def metric_list_to_json(
reader.metrics_dir
/ f"{run_id}_{stream}_{region}_{metric}_chkpt{reader.mini_epoch:05d}.json"
)
metric_data_dict = metric_data.to_dict()

if save_path.exists():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these lines can be maybe refactored a bit for readability?

metric_dict = metric_data.to_dict()

if save_path.exists():
    _logger.info(f"{save_path} already present")

    with save_path.open("r") as f:
        data_dict = json.load(f)

    # Normalize structure
    scores = data_dict.get("scores")
    if scores is None:
        scores = [data_dict]
        data_dict = {"scores": scores}

    # Try to replace existing metric with same attrs
    for i, existing_score in enumerate(scores):
        if existing_score["attrs"] == metric_data.attrs:
            _logger.warning("Metric with same parameters found, replacing")
            scores[i] = metric_dict
            break
    else:
        scores.append(metric_dict)
        _logger.info(f"Appending results to {save_path}")

else:
    _logger.info(f"Saving results to new file {save_path}")
    data_dict = {"scores": [metric_dict]}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea that looks better, I made two small changes to your suggestion:

  • metric_dict sounds almost exactly like metrics_dict which was already there so I said metric_data_dict = metric_data.to_dict() instead.
  • I think scores = [data_dict]; data_dict = {"scores": scores} looks suspicious, like it might create infinite recursion (it doesn't really but I think it looks weird) so I slightly rephrased that.

_logger.info(f"{save_path} already present")

with save_path.open("r") as f:
data_dict = json.load(f)

# Normalize structure
if "scores" not in data_dict:
data_dict = {"scores": [data_dict]}
scores = data_dict.get("scores")

_logger.info(f"Saving results to {save_path}")
# Try to replace existing metric with same attrs
for i, existing_score in enumerate(scores):
if existing_score["attrs"] == metric_data.attrs:
_logger.warning("Metric with same parameters found, replacing")
scores[i] = metric_data_dict
break
else:
scores.append(metric_data_dict)
_logger.info(f"Appending results to {save_path}")

else:
_logger.info(f"Saving results to new file {save_path}")
data_dict = {"scores": [metric_data_dict]}
with open(save_path, "w") as f:
json.dump(metric_data.to_dict(), f, indent=4)
json.dump(data_dict, f, indent=4)

_logger.info(
f"Saved all results of inference run {reader.run_id} - mini_epoch {reader.mini_epoch:d} "
Expand Down Expand Up @@ -784,3 +808,27 @@ def merge(dst: dict, src: dict) -> dict:
else:
dst[k] = v
return dst


def parse_metric_params(metrics):
"""
Convert a mixed list of str and dict metrics into a dict where the metric
names are the keys and the values are dicts of parameters for that metric.
The config might read
metrics:
- fbi:
thresh: 280
- rmse
...
In python, metrics then looks like
[{'fbi':{'thresh':280}},'rmse']
This function converts it to
{'fbi':{'thresh':280}, 'rmse':{}}
"""
out = oc.DictConfig({})
for metric in metrics:
if isinstance(metric, str):
out = oc.OmegaConf.merge(out, {metric: {}})
else:
out = oc.OmegaConf.merge(out, metric)
return out
Loading