From ef2deea7cea7d274cfbaecfa5a91acf05641e6f5 Mon Sep 17 00:00:00 2001 From: Oliver Hennigh Date: Thu, 14 May 2026 14:34:43 -0700 Subject: [PATCH] Add eval prognostic model configs --- earth2studio/models/px/fengwu.py | 14 +++- recipes/eval/README.md | 14 ++++ recipes/eval/cfg/model/ace2.yaml | 5 ++ recipes/eval/cfg/model/aifs.yaml | 5 ++ recipes/eval/cfg/model/aifsens.yaml | 5 ++ recipes/eval/cfg/model/atlas.yaml | 5 ++ recipes/eval/cfg/model/aurora.yaml | 5 ++ recipes/eval/cfg/model/cbottle_video.yaml | 7 ++ recipes/eval/cfg/model/fcn.yaml | 5 ++ recipes/eval/cfg/model/fengwu.yaml | 5 ++ recipes/eval/cfg/model/gencast_mini.yaml | 7 ++ .../eval/cfg/model/graphcast_operational.yaml | 5 ++ recipes/eval/cfg/model/graphcast_small.yaml | 5 ++ recipes/eval/cfg/model/pangu24.yaml | 5 ++ recipes/eval/cfg/model/pangu3.yaml | 5 ++ recipes/eval/cfg/model/pangu6.yaml | 5 ++ recipes/eval/cfg/model/sfno.yaml | 5 ++ recipes/eval/src/data.py | 2 +- recipes/eval/src/distributed.py | 9 ++- recipes/eval/src/models.py | 18 +++-- recipes/eval/src/output.py | 8 +- recipes/eval/src/pipelines/base.py | 2 +- recipes/eval/src/pipelines/dlesym.py | 13 ++-- recipes/eval/src/pipelines/forecast.py | 15 +++- recipes/eval/src/pipelines/stormscope.py | 3 +- recipes/eval/src/regrid.py | 2 +- recipes/eval/src/scoring.py | 2 +- recipes/eval/test/test_models.py | 74 +++++++++++++++++-- test/models/px/test_fengwu.py | 2 +- 29 files changed, 225 insertions(+), 32 deletions(-) create mode 100644 recipes/eval/cfg/model/ace2.yaml create mode 100644 recipes/eval/cfg/model/aifs.yaml create mode 100644 recipes/eval/cfg/model/aifsens.yaml create mode 100644 recipes/eval/cfg/model/atlas.yaml create mode 100644 recipes/eval/cfg/model/aurora.yaml create mode 100644 recipes/eval/cfg/model/cbottle_video.yaml create mode 100644 recipes/eval/cfg/model/fcn.yaml create mode 100644 recipes/eval/cfg/model/fengwu.yaml create mode 100644 recipes/eval/cfg/model/gencast_mini.yaml create mode 100644 recipes/eval/cfg/model/graphcast_operational.yaml create mode 100644 recipes/eval/cfg/model/graphcast_small.yaml create mode 100644 recipes/eval/cfg/model/pangu24.yaml create mode 100644 recipes/eval/cfg/model/pangu3.yaml create mode 100644 recipes/eval/cfg/model/pangu6.yaml create mode 100644 recipes/eval/cfg/model/sfno.yaml diff --git a/earth2studio/models/px/fengwu.py b/earth2studio/models/px/fengwu.py index 318e2711c..a2905205b 100644 --- a/earth2studio/models/px/fengwu.py +++ b/earth2studio/models/px/fengwu.py @@ -274,6 +274,18 @@ def _forward( x: torch.Tensor, ort_session: InferenceSession, ) -> torch.Tensor: + x = (x - self.center) / self.scale # Normalize + x = x.view(x.shape[0], -1, 721, 1440) # Concat time-steps + + if self.device.type == "cpu": + # ONNXRuntime IO binding expects device-backed buffers; use the + # regular NumPy path for CPU execution. + output_np = ort_session.run( + ["output"], + {"input": x.detach().cpu().contiguous().numpy().astype(np.float32)}, + )[0] + output_tensor = torch.from_numpy(output_np).to(self.device) + return self.scale * output_tensor[:, :69].unsqueeze(1) + self.center # Ref https://onnxruntime.ai/docs/api/python/api_summary.html binding = ort_session.io_binding() @@ -301,8 +313,6 @@ def bind_output(name: str, like: torch.Tensor) -> torch.Tensor: ) return out - x = (x - self.center) / self.scale # Normalize - x = x.view(x.shape[0], -1, 721, 1440) # Concat time-steps # Forward pass, fengwu onnx supports batched bind_input("input", x) output = bind_output("output", like=x) diff --git a/recipes/eval/README.md b/recipes/eval/README.md index b324241ff..94913f06e 100644 --- a/recipes/eval/README.md +++ b/recipes/eval/README.md @@ -381,6 +381,20 @@ can switch on the command line: python main.py model=fcn3 ``` +Bundled standard forecast model configs include `ace2`, `aifs`, `aifsens`, +`atlas`, `aurora`, `cbottle_video`, `dlwp`, `fcn`, `fcn3`, `fengwu`, +`gencast_mini`, `graphcast_operational`, `graphcast_small`, `pangu3`, +`pangu6`, `pangu24`, and `sfno`. Specialized configs are also provided for +`dlesym` and `stormscope_goes_mrms`, which use custom pipelines. + +Some model families require optional dependencies beyond the base eval +environment. Add the matching Earth2Studio extra to the recipe's +`earth2studio[...]` dependency before syncing; the recipe does not install +every model backend by default. Most model configs use the same name as +their extra, for example `aifs`, `atlas`, `aurora`, `fcn`, `fengwu`, +`graphcast`, `pangu`, or `sfno`; `cbottle_video` uses `cbottle`, and +`gencast_mini` uses `gencast`. + ### Ensemble runs Set `ensemble_size > 1` and provide a perturbation config: diff --git a/recipes/eval/cfg/model/ace2.yaml b/recipes/eval/cfg/model/ace2.yaml new file mode 100644 index 000000000..969895be1 --- /dev/null +++ b/recipes/eval/cfg/model/ace2.yaml @@ -0,0 +1,5 @@ +# ACE2-ERA5 prognostic model +# 6-hour time step, 1-degree lat/lon grid, with external forcing data. +architecture: earth2studio.models.px.ace2.ACE2ERA5 +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/aifs.yaml b/recipes/eval/cfg/model/aifs.yaml new file mode 100644 index 000000000..119ed5290 --- /dev/null +++ b/recipes/eval/cfg/model/aifs.yaml @@ -0,0 +1,5 @@ +# AIFS single deterministic prognostic model +# 6-hour time step, native octahedral grid internally with lat/lon I/O wrapper. +architecture: earth2studio.models.px.aifs.AIFS +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/aifsens.yaml b/recipes/eval/cfg/model/aifsens.yaml new file mode 100644 index 000000000..a1d23ec5e --- /dev/null +++ b/recipes/eval/cfg/model/aifsens.yaml @@ -0,0 +1,5 @@ +# AIFS ensemble prognostic model +# 6-hour time step, native model grid internally with lat/lon I/O wrapper. +architecture: earth2studio.models.px.aifsens.AIFSENS +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/atlas.yaml b/recipes/eval/cfg/model/atlas.yaml new file mode 100644 index 000000000..4ca1df789 --- /dev/null +++ b/recipes/eval/cfg/model/atlas.yaml @@ -0,0 +1,5 @@ +# Atlas prognostic model +# 6-hour time step, 721x1440 lat/lon grid, ERA5 variables. +architecture: earth2studio.models.px.atlas.Atlas +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/aurora.yaml b/recipes/eval/cfg/model/aurora.yaml new file mode 100644 index 000000000..2def0fe8c --- /dev/null +++ b/recipes/eval/cfg/model/aurora.yaml @@ -0,0 +1,5 @@ +# Aurora prognostic model +# 6-hour time step, 720x1440 lat/lon grid, two-time-level input. +architecture: earth2studio.models.px.aurora.Aurora +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/cbottle_video.yaml b/recipes/eval/cfg/model/cbottle_video.yaml new file mode 100644 index 000000000..79eefd120 --- /dev/null +++ b/recipes/eval/cfg/model/cbottle_video.yaml @@ -0,0 +1,7 @@ +# cBottle video prognostic model +# Generative global climate/video model with lat/lon I/O by default. +architecture: earth2studio.models.px.cbottle_video.CBottleVideo +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint +load_args: + lat_lon: true diff --git a/recipes/eval/cfg/model/fcn.yaml b/recipes/eval/cfg/model/fcn.yaml new file mode 100644 index 000000000..de00102ff --- /dev/null +++ b/recipes/eval/cfg/model/fcn.yaml @@ -0,0 +1,5 @@ +# FCN / FourCastNet prognostic model +# 6-hour time step, 720x1440 lat/lon grid, 26 prognostic variables. +architecture: earth2studio.models.px.fcn.FCN +# Omit package_path to use the default NGC checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/fengwu.yaml b/recipes/eval/cfg/model/fengwu.yaml new file mode 100644 index 000000000..85ed74ffd --- /dev/null +++ b/recipes/eval/cfg/model/fengwu.yaml @@ -0,0 +1,5 @@ +# FengWu prognostic model +# 6-hour time step, 721x1440 lat/lon grid, ONNX runtime backend. +architecture: earth2studio.models.px.fengwu.FengWu +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/gencast_mini.yaml b/recipes/eval/cfg/model/gencast_mini.yaml new file mode 100644 index 000000000..ffb3258aa --- /dev/null +++ b/recipes/eval/cfg/model/gencast_mini.yaml @@ -0,0 +1,7 @@ +# GenCast Mini prognostic model +# 12-hour time step, 1-degree lat/lon grid, stochastic JAX backend. +architecture: earth2studio.models.px.gencast_mini.GenCastMini +# Omit package_path to use the default Google Cloud checkpoint. +# package_path: /path/to/custom/checkpoint +load_args: + jit_compile: false diff --git a/recipes/eval/cfg/model/graphcast_operational.yaml b/recipes/eval/cfg/model/graphcast_operational.yaml new file mode 100644 index 000000000..dadded416 --- /dev/null +++ b/recipes/eval/cfg/model/graphcast_operational.yaml @@ -0,0 +1,5 @@ +# GraphCast operational prognostic model +# 6-hour time step, 721x1440 lat/lon grid, JAX backend. +architecture: earth2studio.models.px.graphcast_operational.GraphCastOperational +# Omit package_path to use the default Google Cloud checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/graphcast_small.yaml b/recipes/eval/cfg/model/graphcast_small.yaml new file mode 100644 index 000000000..d77c70455 --- /dev/null +++ b/recipes/eval/cfg/model/graphcast_small.yaml @@ -0,0 +1,5 @@ +# GraphCast Small prognostic model +# 6-hour time step, 181x360 lat/lon grid, JAX backend. +architecture: earth2studio.models.px.graphcast_small.GraphCastSmall +# Omit package_path to use the default Google Cloud checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/pangu24.yaml b/recipes/eval/cfg/model/pangu24.yaml new file mode 100644 index 000000000..68f2328fc --- /dev/null +++ b/recipes/eval/cfg/model/pangu24.yaml @@ -0,0 +1,5 @@ +# Pangu-Weather 24-hour prognostic model +# 24-hour time step, 721x1440 lat/lon grid, ONNX runtime backend. +architecture: earth2studio.models.px.pangu.Pangu24 +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/pangu3.yaml b/recipes/eval/cfg/model/pangu3.yaml new file mode 100644 index 000000000..bd94781ee --- /dev/null +++ b/recipes/eval/cfg/model/pangu3.yaml @@ -0,0 +1,5 @@ +# Pangu-Weather 3-hour prognostic model +# 3-hour time step, 721x1440 lat/lon grid, ONNX runtime backend. +architecture: earth2studio.models.px.pangu.Pangu3 +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/pangu6.yaml b/recipes/eval/cfg/model/pangu6.yaml new file mode 100644 index 000000000..ff0de88a4 --- /dev/null +++ b/recipes/eval/cfg/model/pangu6.yaml @@ -0,0 +1,5 @@ +# Pangu-Weather 6-hour prognostic model +# 6-hour time step, 721x1440 lat/lon grid, ONNX runtime backend. +architecture: earth2studio.models.px.pangu.Pangu6 +# Omit package_path to use the default HuggingFace checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/cfg/model/sfno.yaml b/recipes/eval/cfg/model/sfno.yaml new file mode 100644 index 000000000..86d5b7cf1 --- /dev/null +++ b/recipes/eval/cfg/model/sfno.yaml @@ -0,0 +1,5 @@ +# SFNO prognostic model +# 6-hour time step, 721x1440 lat/lon grid, 73 prognostic variables. +architecture: earth2studio.models.px.sfno.SFNO +# Omit package_path to use the default NGC checkpoint. +# package_path: /path/to/custom/checkpoint diff --git a/recipes/eval/src/data.py b/recipes/eval/src/data.py index 1a1877cc4..c32c6af44 100644 --- a/recipes/eval/src/data.py +++ b/recipes/eval/src/data.py @@ -30,7 +30,7 @@ from loguru import logger from omegaconf import DictConfig -from earth2studio.data import DataSource +from earth2studio.data.base import DataSource from earth2studio.utils.type import TimeArray, VariableArray diff --git a/recipes/eval/src/distributed.py b/recipes/eval/src/distributed.py index 0b8581e51..f4a67000a 100644 --- a/recipes/eval/src/distributed.py +++ b/recipes/eval/src/distributed.py @@ -24,6 +24,9 @@ import torch from loguru import logger from physicsnemo.distributed import DistributedManager +from physicsnemo.distributed.manager import ( + PhysicsNeMoUninitializedDistributedManagerWarning, +) T = TypeVar("T") @@ -49,7 +52,11 @@ def run_on_rank0_first(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: T Return value of *func*. """ - dist = DistributedManager() + try: + dist = DistributedManager() + except PhysicsNeMoUninitializedDistributedManagerWarning: + DistributedManager.initialize() + dist = DistributedManager() if not dist.distributed: return func(*args, **kwargs) diff --git a/recipes/eval/src/models.py b/recipes/eval/src/models.py index be8fbf873..978816dfe 100644 --- a/recipes/eval/src/models.py +++ b/recipes/eval/src/models.py @@ -16,17 +16,25 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any import hydra from loguru import logger from omegaconf import DictConfig -from earth2studio.models.dx import DiagnosticModel -from earth2studio.models.px import PrognosticModel +from earth2studio.models.px.base import PrognosticModel from .distributed import run_on_rank0_first +if TYPE_CHECKING: + from earth2studio.models.dx.base import DiagnosticModel + + +def _instantiate_load_args(cfg: DictConfig) -> dict[str, Any]: + if "load_args" not in cfg: + return {} + return dict(hydra.utils.instantiate(cfg.load_args)) + def load_prognostic(cfg: DictConfig) -> PrognosticModel: """Load a prognostic model from the Hydra config. @@ -60,7 +68,7 @@ def load_prognostic(cfg: DictConfig) -> PrognosticModel: else: pkg = run_on_rank0_first(cls.load_default_package) - load_kwargs: dict[str, Any] = dict(model_cfg.get("load_args", {})) + load_kwargs = _instantiate_load_args(model_cfg) model: PrognosticModel = cls.load_model(package=pkg, **load_kwargs) logger.success(f"Loaded prognostic model: {cls.__name__}") @@ -94,7 +102,7 @@ def load_diagnostics(cfg: DictConfig) -> list[DiagnosticModel]: elif "architecture" in dx_cfg: cls = hydra.utils.get_class(dx_cfg.architecture) pkg = run_on_rank0_first(cls.load_default_package) - load_kwargs: dict[str, Any] = dict(dx_cfg.get("load_args", {})) + load_kwargs = _instantiate_load_args(dx_cfg) dx = cls.load_model(package=pkg, **load_kwargs) else: raise ValueError( diff --git a/recipes/eval/src/output.py b/recipes/eval/src/output.py index 9731b5fd5..20eb09b04 100644 --- a/recipes/eval/src/output.py +++ b/recipes/eval/src/output.py @@ -22,7 +22,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from pathlib import Path from types import TracebackType -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import torch @@ -32,12 +32,14 @@ from physicsnemo.distributed import DistributedManager from earth2studio.io import ZarrBackend -from earth2studio.models.dx import DiagnosticModel -from earth2studio.models.px import PrognosticModel +from earth2studio.models.px.base import PrognosticModel from earth2studio.utils.coords import CoordSystem, handshake_coords, split_coords from .distributed import run_on_rank0_first +if TYPE_CHECKING: + from earth2studio.models.dx.base import DiagnosticModel + _NON_SPATIAL_DIMS = frozenset({"batch", "time", "lead_time", "variable", "ensemble"}) diff --git a/recipes/eval/src/pipelines/base.py b/recipes/eval/src/pipelines/base.py index 2d131e58d..ffe96a786 100644 --- a/recipes/eval/src/pipelines/base.py +++ b/recipes/eval/src/pipelines/base.py @@ -51,7 +51,7 @@ from physicsnemo.distributed import DistributedManager from tqdm import tqdm -from earth2studio.data import DataSource +from earth2studio.data.base import DataSource from earth2studio.utils.coords import CoordSystem, map_coords from src.data import CompositeSource, PredownloadedSource from src.output import OutputManager, build_output_coords diff --git a/recipes/eval/src/pipelines/dlesym.py b/recipes/eval/src/pipelines/dlesym.py index e121b24dc..6de662ca8 100644 --- a/recipes/eval/src/pipelines/dlesym.py +++ b/recipes/eval/src/pipelines/dlesym.py @@ -26,8 +26,8 @@ from physicsnemo.distributed import DistributedManager from tqdm import tqdm -from earth2studio.data import DataSource, fetch_data -from earth2studio.models.px.dlesym import DLESyM +from earth2studio.data.base import DataSource +from earth2studio.data.utils import fetch_data from earth2studio.utils.coords import CoordSystem, map_coords from ..models import load_prognostic @@ -207,14 +207,15 @@ def _mask_invalid_ocean( if not self._ocean_variables: return x_step - if not isinstance(self.prognostic, DLESyM): + retrieve_valid_ocean_outputs = getattr( + self.prognostic, "retrieve_valid_ocean_outputs", None + ) + if retrieve_valid_ocean_outputs is None: raise ValueError( "DLESyMPipeline expects the loaded prognostic to be a DLESyM model; " f"Got: {type(self.prognostic).__name__}" ) - _, valid_coords = self.prognostic.retrieve_valid_ocean_outputs( - x_step, coords_step - ) + _, valid_coords = retrieve_valid_ocean_outputs(x_step, coords_step) valid_lt = set(valid_coords["lead_time"].tolist()) all_lt = list(coords_step["lead_time"].tolist()) all_vars = list(coords_step["variable"]) diff --git a/recipes/eval/src/pipelines/forecast.py b/recipes/eval/src/pipelines/forecast.py index 7cac66bed..7465a25cc 100644 --- a/recipes/eval/src/pipelines/forecast.py +++ b/recipes/eval/src/pipelines/forecast.py @@ -20,6 +20,7 @@ from collections import OrderedDict from collections.abc import Iterator +from typing import TYPE_CHECKING import hydra import numpy as np @@ -29,9 +30,9 @@ from physicsnemo.distributed import DistributedManager from tqdm import tqdm -from earth2studio.data import DataSource, fetch_data -from earth2studio.models.dx import DiagnosticModel -from earth2studio.models.px import PrognosticModel +from earth2studio.data.base import DataSource +from earth2studio.data.utils import fetch_data +from earth2studio.models.px.base import PrognosticModel from earth2studio.perturbation import Perturbation from earth2studio.utils.coords import CoordSystem, cat_coords, map_coords @@ -40,6 +41,9 @@ from ..work import WorkItem from .base import Pipeline, PredownloadStore +if TYPE_CHECKING: + from earth2studio.models.dx.base import DiagnosticModel + def _align_to_grid( x: torch.Tensor, @@ -80,7 +84,10 @@ def _align_to_grid( new_coords = OrderedDict(coords) new_coords["lat"] = np.asarray(tgt_lat) new_coords["lon"] = np.asarray(tgt_lon) - return torch.from_numpy(np.asarray(da.values)).to(x.device), new_coords + return ( + torch.from_numpy(np.asarray(da.values)).to(device=x.device, dtype=x.dtype), + new_coords, + ) class ForecastPipeline(Pipeline): diff --git a/recipes/eval/src/pipelines/stormscope.py b/recipes/eval/src/pipelines/stormscope.py index f9de9c225..410c28a18 100644 --- a/recipes/eval/src/pipelines/stormscope.py +++ b/recipes/eval/src/pipelines/stormscope.py @@ -33,7 +33,8 @@ from physicsnemo.distributed import DistributedManager from tqdm import tqdm -from earth2studio.data import DataSource, fetch_data +from earth2studio.data.base import DataSource +from earth2studio.data.utils import fetch_data from earth2studio.utils.coords import CoordSystem, cat_coords from ..data import ( diff --git a/recipes/eval/src/regrid.py b/recipes/eval/src/regrid.py index 1ed1b97a5..68b55dbc3 100644 --- a/recipes/eval/src/regrid.py +++ b/recipes/eval/src/regrid.py @@ -47,7 +47,7 @@ import xarray as xr from numpy.typing import ArrayLike -from earth2studio.data import DataSource +from earth2studio.data.base import DataSource from earth2studio.utils.coords import CoordSystem from earth2studio.utils.interp import NearestNeighborInterpolator from earth2studio.utils.type import TimeArray, VariableArray diff --git a/recipes/eval/src/scoring.py b/recipes/eval/src/scoring.py index b538027ca..0771b9aa3 100644 --- a/recipes/eval/src/scoring.py +++ b/recipes/eval/src/scoring.py @@ -40,7 +40,7 @@ from physicsnemo.distributed import DistributedManager from tqdm import tqdm -from earth2studio.data import DataSource +from earth2studio.data.base import DataSource from earth2studio.statistics.weights import lat_weight from earth2studio.utils.coords import CoordSystem diff --git a/recipes/eval/test/test_models.py b/recipes/eval/test/test_models.py index 6d792c7db..706edfd08 100644 --- a/recipes/eval/test/test_models.py +++ b/recipes/eval/test/test_models.py @@ -17,15 +17,17 @@ from __future__ import annotations from collections import OrderedDict +from pathlib import Path from unittest.mock import MagicMock, patch import numpy as np import pytest +from hydra.utils import get_class from omegaconf import OmegaConf from src.models import load_diagnostics, load_prognostic -from earth2studio.models.dx import Identity -from earth2studio.models.px import Persistence +from earth2studio.models.dx.identity import Identity +from earth2studio.models.px.persistence import Persistence _RANK0_PATH = "src.models.run_on_rank0_first" @@ -34,6 +36,11 @@ VARIABLES = ["t2m", "z500"] +class FakeLoadArg: + def __init__(self, value: str): + self.value = value + + def _passthrough(fn, *a, **kw): return fn(*a, **kw) @@ -105,6 +112,30 @@ def test_load_args_forwarded(self): assert kwargs["pretrained"] is True assert kwargs["precision"] == "fp16" + def test_nested_load_args_instantiated(self): + fake_cls, _ = _make_fake_prognostic_cls() + cfg = OmegaConf.create( + { + "model": { + "architecture": "some.module.FakePx", + "load_args": { + "helper": { + "_target_": "test.test_models.FakeLoadArg", + "value": "ready", + } + }, + } + } + ) + + with patch("src.models.hydra.utils.get_class", return_value=fake_cls): + with patch(_RANK0_PATH, side_effect=_passthrough): + load_prognostic(cfg) + + _, kwargs = fake_cls.load_model.call_args + assert isinstance(kwargs["helper"], FakeLoadArg) + assert kwargs["helper"].value == "ready" + class TestLoadDiagnostics: def test_no_diagnostics_returns_empty(self): @@ -117,7 +148,7 @@ def test_target_style_instantiation(self): { "diagnostics": { "identity": { - "_target_": "earth2studio.models.dx.Identity", + "_target_": "earth2studio.models.dx.identity.Identity", } } } @@ -158,11 +189,44 @@ def test_multiple_diagnostics(self): cfg = OmegaConf.create( { "diagnostics": { - "dx1": {"_target_": "earth2studio.models.dx.Identity"}, - "dx2": {"_target_": "earth2studio.models.dx.Identity"}, + "dx1": {"_target_": "earth2studio.models.dx.identity.Identity"}, + "dx2": {"_target_": "earth2studio.models.dx.identity.Identity"}, } } ) result = load_diagnostics(cfg) assert len(result) == 2 assert all(isinstance(d, Identity) for d in result) + + +class TestModelConfigCatalog: + @pytest.mark.parametrize( + "name, class_name", + [ + ("ace2", "ACE2ERA5"), + ("aifs", "AIFS"), + ("aifsens", "AIFSENS"), + ("atlas", "Atlas"), + ("aurora", "Aurora"), + ("cbottle_video", "CBottleVideo"), + ("dlesym", "DLESyMLatLon"), + ("dlwp", "DLWP"), + ("fcn", "FCN"), + ("fcn3", "FCN3"), + ("fengwu", "FengWu"), + ("gencast_mini", "GenCastMini"), + ("graphcast_operational", "GraphCastOperational"), + ("graphcast_small", "GraphCastSmall"), + ("pangu3", "Pangu3"), + ("pangu6", "Pangu6"), + ("pangu24", "Pangu24"), + ("sfno", "SFNO"), + ], + ) + def test_model_config_resolves_architecture(self, name: str, class_name: str): + cfg_path = Path(__file__).parents[1] / "cfg" / "model" / f"{name}.yaml" + cfg = OmegaConf.load(cfg_path) + + cls = get_class(cfg.architecture) + + assert cls.__name__ == class_name diff --git a/test/models/px/test_fengwu.py b/test/models/px/test_fengwu.py index 699c61a65..f0932a952 100644 --- a/test/models/px/test_fengwu.py +++ b/test/models/px/test_fengwu.py @@ -57,7 +57,7 @@ def fengwu_test_package(tmp_path_factory): opset_version=17, input_names=["input"], output_names=["output"], - dynamic_shapes=({0: "batch_size"},), + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, ) # Create fake normalization files np.save(tmp_path / "global_means.npy", np.zeros(69))