Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions earth2studio/models/px/fengwu.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ def _forward(
x: torch.Tensor,
ort_session: InferenceSession,
) -> torch.Tensor:
x = (x - self.center) / self.scale # Normalize
x = x.view(x.shape[0], -1, 721, 1440) # Concat time-steps

if self.device.type == "cpu":
# ONNXRuntime IO binding expects device-backed buffers; use the
# regular NumPy path for CPU execution.
output_np = ort_session.run(
["output"],
{"input": x.detach().cpu().contiguous().numpy().astype(np.float32)},
)[0]
output_tensor = torch.from_numpy(output_np).to(self.device)
return self.scale * output_tensor[:, :69].unsqueeze(1) + self.center

# Ref https://onnxruntime.ai/docs/api/python/api_summary.html
binding = ort_session.io_binding()
Expand Down Expand Up @@ -301,8 +313,6 @@ def bind_output(name: str, like: torch.Tensor) -> torch.Tensor:
)
return out

x = (x - self.center) / self.scale # Normalize
x = x.view(x.shape[0], -1, 721, 1440) # Concat time-steps
# Forward pass, fengwu onnx supports batched
bind_input("input", x)
output = bind_output("output", like=x)
Expand Down
14 changes: 14 additions & 0 deletions recipes/eval/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,20 @@ can switch on the command line:
python main.py model=fcn3
```

Bundled standard forecast model configs include `ace2`, `aifs`, `aifsens`,
`atlas`, `aurora`, `cbottle_video`, `dlwp`, `fcn`, `fcn3`, `fengwu`,
`gencast_mini`, `graphcast_operational`, `graphcast_small`, `pangu3`,
`pangu6`, `pangu24`, and `sfno`. Specialized configs are also provided for
`dlesym` and `stormscope_goes_mrms`, which use custom pipelines.

Some model families require optional dependencies beyond the base eval
environment. Add the matching Earth2Studio extra to the recipe's
`earth2studio[...]` dependency before syncing; the recipe does not install
every model backend by default. Most model configs use the same name as
their extra, for example `aifs`, `atlas`, `aurora`, `fcn`, `fengwu`,
`graphcast`, `pangu`, or `sfno`; `cbottle_video` uses `cbottle`, and
`gencast_mini` uses `gencast`.

### Ensemble runs

Set `ensemble_size > 1` and provide a perturbation config:
Expand Down
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/ace2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# ACE2-ERA5 prognostic model
# 6-hour time step, 1-degree lat/lon grid, with external forcing data.
architecture: earth2studio.models.px.ace2.ACE2ERA5
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/aifs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# AIFS single deterministic prognostic model
# 6-hour time step, native octahedral grid internally with lat/lon I/O wrapper.
architecture: earth2studio.models.px.aifs.AIFS
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/aifsens.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# AIFS ensemble prognostic model
# 6-hour time step, native model grid internally with lat/lon I/O wrapper.
architecture: earth2studio.models.px.aifsens.AIFSENS
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/atlas.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Atlas prognostic model
# 6-hour time step, 721x1440 lat/lon grid, ERA5 variables.
architecture: earth2studio.models.px.atlas.Atlas
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/aurora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Aurora prognostic model
# 6-hour time step, 720x1440 lat/lon grid, two-time-level input.
architecture: earth2studio.models.px.aurora.Aurora
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
7 changes: 7 additions & 0 deletions recipes/eval/cfg/model/cbottle_video.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# cBottle video prognostic model
# Generative global climate/video model with lat/lon I/O by default.
architecture: earth2studio.models.px.cbottle_video.CBottleVideo
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
load_args:
lat_lon: true
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/fcn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# FCN / FourCastNet prognostic model
# 6-hour time step, 720x1440 lat/lon grid, 26 prognostic variables.
architecture: earth2studio.models.px.fcn.FCN
# Omit package_path to use the default NGC checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/fengwu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# FengWu prognostic model
# 6-hour time step, 721x1440 lat/lon grid, ONNX runtime backend.
architecture: earth2studio.models.px.fengwu.FengWu
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
7 changes: 7 additions & 0 deletions recipes/eval/cfg/model/gencast_mini.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# GenCast Mini prognostic model
# 12-hour time step, 1-degree lat/lon grid, stochastic JAX backend.
architecture: earth2studio.models.px.gencast_mini.GenCastMini
# Omit package_path to use the default Google Cloud checkpoint.
# package_path: /path/to/custom/checkpoint
load_args:
jit_compile: false
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/graphcast_operational.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# GraphCast operational prognostic model
# 6-hour time step, 721x1440 lat/lon grid, JAX backend.
architecture: earth2studio.models.px.graphcast_operational.GraphCastOperational
# Omit package_path to use the default Google Cloud checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/graphcast_small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# GraphCast Small prognostic model
# 6-hour time step, 181x360 lat/lon grid, JAX backend.
architecture: earth2studio.models.px.graphcast_small.GraphCastSmall
# Omit package_path to use the default Google Cloud checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/pangu24.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Pangu-Weather 24-hour prognostic model
# 24-hour time step, 721x1440 lat/lon grid, ONNX runtime backend.
architecture: earth2studio.models.px.pangu.Pangu24
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/pangu3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Pangu-Weather 3-hour prognostic model
# 3-hour time step, 721x1440 lat/lon grid, ONNX runtime backend.
architecture: earth2studio.models.px.pangu.Pangu3
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/pangu6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Pangu-Weather 6-hour prognostic model
# 6-hour time step, 721x1440 lat/lon grid, ONNX runtime backend.
architecture: earth2studio.models.px.pangu.Pangu6
# Omit package_path to use the default HuggingFace checkpoint.
# package_path: /path/to/custom/checkpoint
5 changes: 5 additions & 0 deletions recipes/eval/cfg/model/sfno.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SFNO prognostic model
# 6-hour time step, 721x1440 lat/lon grid, 73 prognostic variables.
architecture: earth2studio.models.px.sfno.SFNO
# Omit package_path to use the default NGC checkpoint.
# package_path: /path/to/custom/checkpoint
2 changes: 1 addition & 1 deletion recipes/eval/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from loguru import logger
from omegaconf import DictConfig

from earth2studio.data import DataSource
from earth2studio.data.base import DataSource
from earth2studio.utils.type import TimeArray, VariableArray


Expand Down
9 changes: 8 additions & 1 deletion recipes/eval/src/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
import torch
from loguru import logger
from physicsnemo.distributed import DistributedManager
from physicsnemo.distributed.manager import (
PhysicsNeMoUninitializedDistributedManagerWarning,
)

T = TypeVar("T")

Expand All @@ -49,7 +52,11 @@ def run_on_rank0_first(func: Callable[..., T], *args: Any, **kwargs: Any) -> T:
T
Return value of *func*.
"""
dist = DistributedManager()
try:
dist = DistributedManager()
except PhysicsNeMoUninitializedDistributedManagerWarning:
DistributedManager.initialize()
dist = DistributedManager()

if not dist.distributed:
return func(*args, **kwargs)
Expand Down
18 changes: 13 additions & 5 deletions recipes/eval/src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,25 @@

from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING, Any

import hydra
from loguru import logger
from omegaconf import DictConfig

from earth2studio.models.dx import DiagnosticModel
from earth2studio.models.px import PrognosticModel
from earth2studio.models.px.base import PrognosticModel

from .distributed import run_on_rank0_first

if TYPE_CHECKING:
from earth2studio.models.dx.base import DiagnosticModel


def _instantiate_load_args(cfg: DictConfig) -> dict[str, Any]:
if "load_args" not in cfg:
return {}
return dict(hydra.utils.instantiate(cfg.load_args))


def load_prognostic(cfg: DictConfig) -> PrognosticModel:
"""Load a prognostic model from the Hydra config.
Expand Down Expand Up @@ -60,7 +68,7 @@ def load_prognostic(cfg: DictConfig) -> PrognosticModel:
else:
pkg = run_on_rank0_first(cls.load_default_package)

load_kwargs: dict[str, Any] = dict(model_cfg.get("load_args", {}))
load_kwargs = _instantiate_load_args(model_cfg)
model: PrognosticModel = cls.load_model(package=pkg, **load_kwargs)

logger.success(f"Loaded prognostic model: {cls.__name__}")
Expand Down Expand Up @@ -94,7 +102,7 @@ def load_diagnostics(cfg: DictConfig) -> list[DiagnosticModel]:
elif "architecture" in dx_cfg:
cls = hydra.utils.get_class(dx_cfg.architecture)
pkg = run_on_rank0_first(cls.load_default_package)
load_kwargs: dict[str, Any] = dict(dx_cfg.get("load_args", {}))
load_kwargs = _instantiate_load_args(dx_cfg)
dx = cls.load_model(package=pkg, **load_kwargs)
else:
raise ValueError(
Expand Down
8 changes: 5 additions & 3 deletions recipes/eval/src/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np
import torch
Expand All @@ -32,12 +32,14 @@
from physicsnemo.distributed import DistributedManager

from earth2studio.io import ZarrBackend
from earth2studio.models.dx import DiagnosticModel
from earth2studio.models.px import PrognosticModel
from earth2studio.models.px.base import PrognosticModel
from earth2studio.utils.coords import CoordSystem, handshake_coords, split_coords

from .distributed import run_on_rank0_first

if TYPE_CHECKING:
from earth2studio.models.dx.base import DiagnosticModel

_NON_SPATIAL_DIMS = frozenset({"batch", "time", "lead_time", "variable", "ensemble"})


Expand Down
2 changes: 1 addition & 1 deletion recipes/eval/src/pipelines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from physicsnemo.distributed import DistributedManager
from tqdm import tqdm

from earth2studio.data import DataSource
from earth2studio.data.base import DataSource
from earth2studio.utils.coords import CoordSystem, map_coords
from src.data import CompositeSource, PredownloadedSource
from src.output import OutputManager, build_output_coords
Expand Down
13 changes: 7 additions & 6 deletions recipes/eval/src/pipelines/dlesym.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from physicsnemo.distributed import DistributedManager
from tqdm import tqdm

from earth2studio.data import DataSource, fetch_data
from earth2studio.models.px.dlesym import DLESyM
from earth2studio.data.base import DataSource
from earth2studio.data.utils import fetch_data
from earth2studio.utils.coords import CoordSystem, map_coords

from ..models import load_prognostic
Expand Down Expand Up @@ -207,14 +207,15 @@ def _mask_invalid_ocean(
if not self._ocean_variables:
return x_step

if not isinstance(self.prognostic, DLESyM):
retrieve_valid_ocean_outputs = getattr(
self.prognostic, "retrieve_valid_ocean_outputs", None
)
if retrieve_valid_ocean_outputs is None:
raise ValueError(
"DLESyMPipeline expects the loaded prognostic to be a DLESyM model; "
f"Got: {type(self.prognostic).__name__}"
)
_, valid_coords = self.prognostic.retrieve_valid_ocean_outputs(
x_step, coords_step
)
_, valid_coords = retrieve_valid_ocean_outputs(x_step, coords_step)
valid_lt = set(valid_coords["lead_time"].tolist())
all_lt = list(coords_step["lead_time"].tolist())
all_vars = list(coords_step["variable"])
Expand Down
15 changes: 11 additions & 4 deletions recipes/eval/src/pipelines/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from collections import OrderedDict
from collections.abc import Iterator
from typing import TYPE_CHECKING

import hydra
import numpy as np
Expand All @@ -29,9 +30,9 @@
from physicsnemo.distributed import DistributedManager
from tqdm import tqdm

from earth2studio.data import DataSource, fetch_data
from earth2studio.models.dx import DiagnosticModel
from earth2studio.models.px import PrognosticModel
from earth2studio.data.base import DataSource
from earth2studio.data.utils import fetch_data
from earth2studio.models.px.base import PrognosticModel
from earth2studio.perturbation import Perturbation
from earth2studio.utils.coords import CoordSystem, cat_coords, map_coords

Expand All @@ -40,6 +41,9 @@
from ..work import WorkItem
from .base import Pipeline, PredownloadStore

if TYPE_CHECKING:
from earth2studio.models.dx.base import DiagnosticModel


def _align_to_grid(
x: torch.Tensor,
Expand Down Expand Up @@ -80,7 +84,10 @@ def _align_to_grid(
new_coords = OrderedDict(coords)
new_coords["lat"] = np.asarray(tgt_lat)
new_coords["lon"] = np.asarray(tgt_lon)
return torch.from_numpy(np.asarray(da.values)).to(x.device), new_coords
return (
torch.from_numpy(np.asarray(da.values)).to(device=x.device, dtype=x.dtype),
new_coords,
)


class ForecastPipeline(Pipeline):
Expand Down
3 changes: 2 additions & 1 deletion recipes/eval/src/pipelines/stormscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
from physicsnemo.distributed import DistributedManager
from tqdm import tqdm

from earth2studio.data import DataSource, fetch_data
from earth2studio.data.base import DataSource
from earth2studio.data.utils import fetch_data
from earth2studio.utils.coords import CoordSystem, cat_coords

from ..data import (
Expand Down
2 changes: 1 addition & 1 deletion recipes/eval/src/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import xarray as xr
from numpy.typing import ArrayLike

from earth2studio.data import DataSource
from earth2studio.data.base import DataSource
from earth2studio.utils.coords import CoordSystem
from earth2studio.utils.interp import NearestNeighborInterpolator
from earth2studio.utils.type import TimeArray, VariableArray
Expand Down
2 changes: 1 addition & 1 deletion recipes/eval/src/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from physicsnemo.distributed import DistributedManager
from tqdm import tqdm

from earth2studio.data import DataSource
from earth2studio.data.base import DataSource
from earth2studio.statistics.weights import lat_weight
from earth2studio.utils.coords import CoordSystem

Expand Down
Loading