diff --git a/.claude/skills/create-assimilation-wrapper/SKILL.md b/.claude/skills/create-assimilation-wrapper/SKILL.md
new file mode 100644
index 000000000..fd2b7a30e
--- /dev/null
+++ b/.claude/skills/create-assimilation-wrapper/SKILL.md
@@ -0,0 +1,1187 @@
+---
+name: create-assimilation-wrapper
+description: >-
+ Create a new Earth2Studio data assimilation model (da) wrapper from a
+ reference inference script or repository. DA models ingest sparse
+ observations (DataFrames) and/or gridded state arrays (DataArrays) and
+ produce analysis output — they do NOT use @batch_func, @batch_coords,
+ or PrognosticMixin.
+argument-hint: URL or local path to reference inference script/repo (optional — will be asked if not provided)
+---
+
+# Create Assimilation Model Wrapper
+
+Create a new Earth2Studio data assimilation model wrapper by following every step below in order.
+Each confirmation gate marked by starting with:
+
+```markdown
+### **[CONFIRM —
]**
+```
+
+requires explicit user approval before proceeding.
+
+> **Environment note**: Use `uv run python` for all Python
+> commands. The project uses a `uv`-managed virtual environment
+> — never install packages globally or use bare `python`.
+
+---
+
+## Critical Differences from px/dx Models
+
+DA models are **not** tensor-in / tensor-out like px and dx models. The entire I/O
+contract is different. Review this table **before** writing any code:
+
+| Aspect | px / dx | da |
+| ------------------- | ------------------------------ | ---------------------------------------- |
+| **Primary input** | `Tensor` + `CoordSystem` | `pd.DataFrame` or `xr.DataArray` |
+| **Primary output** | `Tensor` + `CoordSystem` | `xr.DataArray` or `pd.DataFrame` |
+| **Coord schemas** | `CoordSystem` only | `FrameSchema` + `CoordSystem` |
+| **Coord returns** | Single dict | **Tuple** (even for single) |
+| **Batch handling** | `@batch_func` / `@batch_coords` | **Neither** — N/A for DataFrame I/O |
+| **PrognosticMixin** | Yes (px) / No (dx) | **No** |
+| **Time integration** | `create_iterator` yields | `create_generator` with **send** |
+| **Generator prime** | Yields initial condition | Yields `None` / state, `.send()` |
+| **Generator cleanup** | N/A | Must handle `GeneratorExit` |
+| **Init data** | No | Optional via `init_coords()` |
+| **Time metadata** | Coordinate dimension | `obs.attrs["request_time"]` |
+| **GPU data** | `Tensor` on device | **cupy** arrays, **cudf** DFs |
+| **Input validation** | `handshake_dim`/`handshake_coords` | `validate_observation_fields()` |
+| **Time filtering** | N/A | `filter_time_range()` |
+| **Tensor conversion** | Already tensors | `dfseries_to_torch()` |
+| **Device tracking** | `device_buffer` / `parameters()` | `device_buffer` + `@property` |
+
+---
+
+## Step 0 — Obtain Reference & Analyze Model
+
+### 0a. Obtain reference script / repository
+
+If `$ARGUMENTS` is provided, use it as the reference inference script or repository.
+
+- If it is a URL, use WebFetch to retrieve the content.
+- If it is a local file path, read it directly.
+
+If `$ARGUMENTS` is empty or not provided, ask the user:
+
+> Please provide a reference inference script or
+> repository URL/path that demonstrates how this model
+> runs inference. This will be used to understand the
+> model architecture, dependencies, input/output shapes,
+> observation schema, and variable mapping.
+
+Store the reference code content for use in subsequent steps.
+
+### 0b. Analyze reference model
+
+After obtaining the reference, analyze it for:
+
+- **Input types**: Does the model ingest `pd.DataFrame`
+ (sparse observations), `xr.DataArray` (gridded fields),
+ or both?
+- **Output types**: Does it produce `xr.DataArray`
+ (gridded analysis), `pd.DataFrame` (corrected
+ observations), or both?
+- **Stateful vs stateless**: Does the model maintain
+ internal state across time steps (e.g., background field),
+ or is each call independent? Stateless models return
+ `None` from `init_coords()`.
+- **`@torch.inference_mode()` safety**: Does the forward
+ pass require gradients (e.g., DPS guidance through a
+ denoiser)? If so, `@torch.inference_mode()` must be
+ omitted and the reason documented.
+- **Dependencies**: External packages required (physicsnemo, scipy, healpy, cudf, cupy, etc.)
+
+Present the analysis to the user:
+
+> **Model Analysis Summary**
+>
+> - **Input type(s):** [DataFrame / DataArray / both]
+> - **Output type(s):** [DataArray / DataFrame / both]
+> - **Stateful/Stateless:** [stateful — needs init data / stateless — no init data]
+> - **Inference mode safe:** [yes / no — reason]
+> - **Key dependencies:** [list]
+>
+> Evidence:
+>
+> - [list key indicators from reference code]
+
+### **[CONFIRM — Model Analysis]**
+
+Ask the user to confirm the analysis before proceeding.
+
+---
+
+## Step 1 — Examine Reference & Propose Dependencies
+
+### 1a. Analyze the reference code
+
+Examine the reference inference script/repo to identify:
+
+- **Python packages** required (e.g., `physicsnemo`, `scipy`, `healpy`, `cudf`, `cupy`, custom packages)
+- **Model architecture** (PyTorch module, ONNX, etc.)
+- **Observation schema** (DataFrame columns, variable names, coordinate dimensions)
+- **Output grid specification** (lat/lon resolution, projection, etc.)
+- **Checkpoint format** (`.pt`, `.onnx`, `.safetensors`, `.mdlus`, etc.)
+
+### 1b. Propose pyproject.toml dependency group
+
+Propose a new optional dependency group for `pyproject.toml`.
+The group name must follow the pattern `da-`:
+
+```toml
+# In [project.optional-dependencies] section of pyproject.toml
+da-model-name = [
+ "package1>=version",
+ "package2",
+]
+```
+
+Look at the existing groups in `pyproject.toml` for reference on naming and version pinning conventions.
+
+Also propose adding the new group to the `all` aggregate in the appropriate line (da models).
+
+Highlight `cudf` and `cupy` as optional GPU acceleration
+packages — these are not required but improve performance.
+
+### **[CONFIRM — Dependencies]**
+
+Present to the user:
+
+1. The proposed dependency group name (`da-`)
+2. The list of packages with versions
+3. Ask if the packages and group name look correct
+
+---
+
+## Step 2 — Add Dependencies to pyproject.toml
+
+After confirmation, edit `pyproject.toml`:
+
+1. Add the new optional dependency group in alphabetical order among the per-model extras
+2. Add the group to the `all` aggregate (in the da models line)
+
+---
+
+## Step 3 — Create Skeleton Class File
+
+### 3a. Determine class name and file name
+
+Based on the model name from the reference, propose:
+
+- **Class name**: PascalCase (e.g., `StormCastSDA`, `InterpEquirectangular`, `HealDA`)
+- **File name**: lowercase with underscores (e.g., `sda_stormcast.py`, `interp.py`, `healda.py`)
+- **File path**: `earth2studio/models/da/.py`
+
+### 3b. Write skeleton with pseudocode
+
+Create the file with the full structure but pseudocode implementations.
+Every `.py` file in `earth2studio/` **must** start with this license
+header:
+
+```python
+# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+```
+
+Dual inheritance: `torch.nn.Module + AutoModelMixin` (NO `PrognosticMixin`).
+Class-level `@check_optional_dependencies()` decorator.
+
+### Canonical method ordering for DA models
+
+Methods in the class **must** appear in this order (11 method slots):
+
+1. `__init__` — register `device_buffer`, store model params, normalize time tolerance
+2. `device` property — `return self.device_buffer.device`
+3. `init_coords` — `None` for stateless, tuple of `CoordSystem`/`FrameSchema` for stateful
+4. `input_coords` — tuple of `FrameSchema` (for DataFrame) or `CoordSystem` (for DataArray)
+5. `output_coords` — accept `input_coords` tuple + `request_time` kwarg, return tuple
+6. `load_default_package` — classmethod returning default `Package`
+7. `load_model` — classmethod with `@check_optional_dependencies()`
+8. `to` — device management, return `AssimilationModel`
+9. Private/support methods (e.g., `_interpolate`, `_forward`, spatial lookups)
+10. `__call__` — stateless forward, accept `*args: pd.DataFrame | xr.DataArray | None`
+11. `create_generator` — bidirectional generator with send protocol
+
+### Complete skeleton template
+
+```python
+from __future__ import annotations
+
+from collections import OrderedDict
+from collections.abc import Generator
+from typing import Any
+
+import numpy as np
+import pandas as pd
+import torch
+import xarray as xr
+from loguru import logger
+
+from earth2studio.models.auto import AutoModelMixin, Package
+from earth2studio.models.da.base import AssimilationModel
+from earth2studio.models.da.utils import (
+ dfseries_to_torch,
+ filter_time_range,
+ validate_observation_fields,
+)
+from earth2studio.utils.imports import (
+ OptionalDependencyFailure,
+ check_optional_dependencies,
+)
+from earth2studio.utils.time import normalize_time_tolerance
+from earth2studio.utils.type import CoordSystem, FrameSchema, TimeTolerance
+
+try:
+ import cupy as cp
+except ImportError:
+ cp = None # type: ignore[assignment]
+
+try:
+ import cudf
+except ImportError:
+ cudf = None # type: ignore[assignment, misc]
+
+try:
+ from some_package import CoreModel
+except ImportError:
+ OptionalDependencyFailure("da-mymodel")
+ CoreModel = None
+
+
+@check_optional_dependencies()
+class MyDAModel(torch.nn.Module, AutoModelMixin):
+ """One-line description of the DA model.
+
+ Extended description of the model, its source, observation types it
+ handles, and the analysis output it produces.
+
+ Parameters
+ ----------
+ model : torch.nn.Module
+ Core neural network or inference model
+ time_tolerance : TimeTolerance, optional
+ Observation time tolerance window, by default np.timedelta64(3, "h")
+
+ Note
+ ----
+ For more information see:
+
+ Badges
+ ------
+ region:global class:da product:atmos product:insitu
+ """
+
+ OUTPUT_VARIABLES = ["u10m", "v10m", "t2m"]
+
+ # 1. Constructor
+ def __init__(
+ self,
+ model: torch.nn.Module,
+ time_tolerance: TimeTolerance = np.timedelta64(3, "h"),
+ ) -> None:
+ super().__init__()
+ self._model = model
+ self._tolerance = normalize_time_tolerance(time_tolerance)
+ self.register_buffer("device_buffer", torch.empty(0))
+ # TODO: Register normalization buffers, static fields, etc.
+
+ # 2. Device property
+ @property
+ def device(self) -> torch.device:
+ """Model device."""
+ return self.device_buffer.device
+
+ # 3. Init coords
+ def init_coords(self) -> None:
+ """Initialization coordinate system.
+
+ Returns None for stateless models. Override to return a tuple of
+ CoordSystem / FrameSchema for stateful models that require initial
+ state data.
+
+ Returns
+ -------
+ None
+ No initialization data required (stateless model)
+ """
+ return None
+
+ # For stateful models, return a tuple instead:
+ # return (
+ # CoordSystem(OrderedDict({
+ # "time": np.empty(0),
+ # "lead_time": np.array([np.timedelta64(0, "h")]),
+ # "variable": np.array(self.OUTPUT_VARIABLES),
+ # "lat": self._lat,
+ # "lon": self._lon,
+ # })),
+ # )
+
+ # 4. Input coords
+ def input_coords(self) -> tuple[FrameSchema]:
+ """Input coordinate system specifying required DataFrame fields.
+
+ Returns
+ -------
+ tuple[FrameSchema]
+ Tuple containing FrameSchema with field names as keys.
+ Use np.empty(0, dtype=...) for dynamic dimensions and
+ np.array([...]) for enumerated allowed values.
+ """
+ return (
+ FrameSchema({
+ "time": np.empty(0, dtype="datetime64[ns]"),
+ "lat": np.empty(0, dtype=np.float32),
+ "lon": np.empty(0, dtype=np.float32),
+ "observation": np.empty(0, dtype=np.float32),
+ "variable": np.array(self.OUTPUT_VARIABLES, dtype=str),
+ }),
+ )
+
+ # 5. Output coords
+ def output_coords(
+ self,
+ input_coords: tuple[FrameSchema | CoordSystem, ...],
+ request_time: np.ndarray | None = None,
+ **kwargs: Any,
+ ) -> tuple[CoordSystem]:
+ """Output coordinate system.
+
+ Parameters
+ ----------
+ input_coords : tuple[FrameSchema | CoordSystem, ...]
+ Input coordinate system(s)
+ request_time : np.ndarray | None, optional
+ Analysis valid time(s)
+
+ Returns
+ -------
+ tuple[CoordSystem]
+ Output coordinate system(s)
+ """
+ if request_time is None:
+ request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]")
+
+ # Extract variables from first input coord system
+ if len(input_coords) > 0 and "variable" in input_coords[0]:
+ variables = input_coords[0]["variable"]
+ else:
+ variables = np.array(self.OUTPUT_VARIABLES, dtype=str)
+
+ return (
+ CoordSystem(OrderedDict({
+ "time": request_time,
+ "variable": variables,
+ "lat": np.linspace(90, -90, 181),
+ "lon": np.linspace(0, 360, 360, endpoint=False),
+ })),
+ )
+
+ # 6. Default package
+ @classmethod
+ def load_default_package(cls) -> Package:
+ """Default pre-trained model package.
+
+ Returns
+ -------
+ Package
+ Model package with default checkpoint location
+ """
+ return Package(
+ "hf://nvidia/my-da-model@",
+ cache_options={"same_names": True},
+ )
+
+ # 7. Load model
+ @classmethod
+ @check_optional_dependencies()
+ def load_model(
+ cls,
+ package: Package,
+ time_tolerance: TimeTolerance = np.timedelta64(3, "h"),
+ ) -> AssimilationModel:
+ """Load assimilation model from package.
+
+ Parameters
+ ----------
+ package : Package
+ Package containing model checkpoint and statistics
+ time_tolerance : TimeTolerance, optional
+ Observation time tolerance window
+
+ Returns
+ -------
+ AssimilationModel
+ Loaded assimilation model
+ """
+ # TODO: Load model from package
+ model = CoreModel.from_checkpoint(package.resolve("model.mdlus"))
+ model.eval()
+
+ # Load normalization stats, static fields, etc.
+ stats = np.load(package.resolve("stats.npy"))
+
+ return cls(model=model, time_tolerance=time_tolerance)
+
+ # 8. Device management
+ def to(self, device: torch.device | str) -> AssimilationModel:
+ """Move model to device.
+
+ Parameters
+ ----------
+ device : torch.device | str
+ Target device
+
+ Returns
+ -------
+ AssimilationModel
+ Model on target device
+ """
+ super().to(device)
+ return self
+
+ # 9. Private/support methods
+ def _forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """Internal forward pass.
+
+ Parameters
+ ----------
+ inputs : torch.Tensor
+ Preprocessed input tensor on device
+
+ Returns
+ -------
+ torch.Tensor
+ Raw model output tensor
+ """
+ # TODO: normalize -> model -> denormalize
+ return self._model(inputs)
+
+ # 10. Stateless forward pass
+ @torch.inference_mode()
+ def __call__(self, obs: pd.DataFrame | None = None) -> xr.DataArray:
+ """Run single-step assimilation.
+
+ Parameters
+ ----------
+ obs : pd.DataFrame | None, optional
+ Observation DataFrame with required columns and
+ 'request_time' in attrs
+
+ Returns
+ -------
+ xr.DataArray
+ Analysis output on the same device as the model
+ """
+ if obs is None:
+ raise ValueError("obs must be provided")
+
+ # Validate required columns
+ input_coords = self.input_coords()
+ validate_observation_fields(obs, required_fields=list(input_coords[0].keys()))
+
+ # Extract request_time from DataFrame attrs
+ request_time = obs.attrs.get("request_time")
+ if request_time is None:
+ raise ValueError(
+ "Observation DataFrame must have 'request_time' in attrs. "
+ "This is typically set by earth2studio data sources."
+ )
+
+ # Get output coordinate system
+ (output_coords,) = self.output_coords(
+ input_coords, request_time=request_time
+ )
+
+ # Filter observations within time tolerance window
+ filtered_obs = filter_time_range(
+ obs,
+ request_time=request_time,
+ tolerance=self._tolerance,
+ time_column="time",
+ )
+
+ # Convert DataFrame columns to torch tensors (zero-copy for cudf)
+ obs_values = dfseries_to_torch(
+ filtered_obs["observation"],
+ dtype=torch.float32,
+ device=self.device,
+ )
+ obs_lat = dfseries_to_torch(
+ filtered_obs["lat"],
+ dtype=torch.float32,
+ device=self.device,
+ )
+ obs_lon = dfseries_to_torch(
+ filtered_obs["lon"],
+ dtype=torch.float32,
+ device=self.device,
+ )
+
+ # TODO: Run model forward pass
+ prediction = self._forward(obs_values)
+
+ # Build output xr.DataArray with cupy (GPU) or numpy (CPU)
+ if self.device.type == "cuda" and cp is not None:
+ data = cp.asarray(prediction)
+ else:
+ data = prediction.cpu().numpy()
+
+ return xr.DataArray(
+ data=data,
+ dims=list(output_coords.keys()),
+ coords=output_coords,
+ )
+
+ # 11. Stateful generator
+ def create_generator(
+ self,
+ ) -> Generator[xr.DataArray, pd.DataFrame | None, None]:
+ """Creates a generator for sequential data assimilation.
+
+ Yields the current analysis state and receives the next observations
+ via generator.send(). Prime the generator with next(gen) before
+ sending observations.
+
+ Yields
+ ------
+ xr.DataArray
+ Analysis at each step
+
+ Receives
+ --------
+ pd.DataFrame | None
+ Observations for the next step. Pass None for steps with no data.
+
+ Example
+ -------
+ >>> gen = model.create_generator()
+ >>> next(gen) # prime generator (yields None)
+ >>> result = gen.send(obs_df) # step 1 with observations
+ >>> result = gen.send(None) # step 2 without observations
+ >>> gen.close() # clean up
+ """
+ # Prime generator — yield None for stateless models
+ observations = yield None # type: ignore[misc]
+ try:
+ while True:
+ result = self.__call__(observations)
+ observations = yield result
+ except GeneratorExit:
+ logger.debug("MyDAModel generator clean up complete.")
+```
+
+For **stateful models** that maintain internal state, the `create_generator` pattern is:
+
+```python
+def create_generator(
+ self,
+ x: xr.DataArray, # initial state
+) -> Generator[xr.DataArray, pd.DataFrame | None, None]:
+ """Creates a stateful generator for sequential data assimilation.
+
+ Parameters
+ ----------
+ x : xr.DataArray
+ Initial state DataArray
+
+ Yields
+ ------
+ xr.DataArray
+ Analysis at each step
+
+ Receives
+ --------
+ pd.DataFrame | None
+ Observations for the next step
+
+ Example
+ -------
+ >>> gen = model.create_generator(x0)
+ >>> state = next(gen) # yields initial state
+ >>> state = gen.send(obs_df) # step 1 with observations
+ >>> state = gen.send(None) # step 2 without observations
+ >>> gen.close() # clean up
+ """
+ # Yield initial state to prime the generator
+ obs = yield x
+
+ try:
+ while True:
+ result = self.__call__(x, obs)
+ x = result # update internal state
+ obs = yield result
+ except GeneratorExit:
+ logger.debug("MyDAModel generator clean up complete.")
+```
+
+### **[CONFIRM — Skeleton]**
+
+Present to the user:
+
+1. The proposed class name
+2. The proposed file name and path
+3. The canonical method ordering (11 methods)
+4. Whether stateful or stateless generator pattern applies
+5. Ask if these are acceptable
+
+---
+
+## Step 4 — Implement Coordinate System
+
+### 4a. Map variables to E2STUDIO_VOCAB
+
+Read `earth2studio/lexicon/base.py` and verify every variable the
+model uses exists in `E2STUDIO_VOCAB`. The vocab contains 282
+entries including:
+
+| Category | Examples |
+| --------------- | ------------------------------------------- |
+| Surface wind | `u10m`, `v10m`, `ws10m`, `u100m`, `v100m` |
+| Surface temp | `t2m`, `d2m`, `sst`, `skt` |
+| Humidity | `r2m`, `q2m`, `tcwv` |
+| Pressure | `sp`, `msl` |
+| Precipitation | `tp`, `lsp`, `cp`, `tp06` |
+| Pressure-level | `u50`-`u1000`, `v50`-`v1000`, `z50`-`z1000` |
+| Cloud/radiation | `tcc`, `rlut`, `rsut` |
+
+If a variable in the reference model does NOT exist in
+`E2STUDIO_VOCAB`, flag it to the user and discuss whether to
+map it to an existing vocab entry or propose adding a new one.
+
+### 4b. Implement init_coords
+
+**Stateless models** — return `None`:
+
+```python
+def init_coords(self) -> None:
+ """No initialization data required."""
+ return None
+```
+
+**Stateful models** — return a tuple of `CoordSystem` and/or `FrameSchema`:
+
+```python
+def init_coords(self) -> tuple[CoordSystem]:
+ """Initialization coordinate system for the background state.
+
+ Returns
+ -------
+ tuple[CoordSystem]
+ Tuple containing one CoordSystem for the initial gridded state
+ """
+ return (
+ CoordSystem(OrderedDict({
+ "time": np.empty(0),
+ "lead_time": np.array([np.timedelta64(0, "h")]),
+ "variable": np.array(self.OUTPUT_VARIABLES),
+ "lat": self._lat,
+ "lon": self._lon,
+ })),
+ )
+```
+
+### 4c. Implement input_coords
+
+Return a **tuple** of `FrameSchema | CoordSystem` — one entry per
+positional argument that `__call__` accepts.
+
+For DataFrame observations, use `FrameSchema`:
+
+```python
+def input_coords(self) -> tuple[FrameSchema]:
+ """Input coordinate system specifying required DataFrame fields.
+
+ Returns
+ -------
+ tuple[FrameSchema]
+ Tuple containing FrameSchema for observation DataFrame
+ """
+ return (
+ FrameSchema({
+ "time": np.empty(0, dtype="datetime64[ns]"),
+ "lat": np.empty(0, dtype=np.float32),
+ "lon": np.empty(0, dtype=np.float32),
+ "observation": np.empty(0, dtype=np.float32),
+ "variable": np.array(self.OUTPUT_VARIABLES, dtype=str),
+ }),
+ )
+```
+
+For models accepting multiple inputs (e.g., conventional + satellite observations):
+
+```python
+def input_coords(self) -> tuple[FrameSchema, FrameSchema]:
+ conv_schema = FrameSchema({...})
+ sat_schema = FrameSchema({...})
+ return (conv_schema, sat_schema)
+```
+
+**Rules:**
+
+- Use `np.empty(0, dtype=...)` for unbounded/dynamic dimensions (individual observations)
+- Use `np.array([...])` for enumerated allowed values (variable names)
+- Always return a **tuple**, even for a single input
+
+### 4d. Implement output_coords
+
+Accept the input coordinate tuple plus `request_time` kwargs.
+Return a tuple of output `CoordSystem | FrameSchema`:
+
+```python
+def output_coords(
+ self,
+ input_coords: tuple[FrameSchema | CoordSystem, ...],
+ request_time: np.ndarray | None = None,
+ **kwargs: Any,
+) -> tuple[CoordSystem]:
+ """Output coordinate system.
+
+ Parameters
+ ----------
+ input_coords : tuple[FrameSchema | CoordSystem, ...]
+ Input coordinate system(s)
+ request_time : np.ndarray | None, optional
+ Analysis valid time(s)
+
+ Returns
+ -------
+ tuple[CoordSystem]
+ Output coordinate system(s)
+ """
+ if request_time is None:
+ request_time = np.array([np.datetime64("NaT")], dtype="datetime64[ns]")
+
+ # Extract variables from first input coord system
+ if len(input_coords) > 0 and "variable" in input_coords[0]:
+ variables = input_coords[0]["variable"]
+ else:
+ variables = np.array(self.OUTPUT_VARIABLES, dtype=str)
+
+ return (
+ CoordSystem(OrderedDict({
+ "time": request_time,
+ "variable": variables,
+ "lat": self._lat,
+ "lon": self._lon,
+ })),
+ )
+```
+
+**Key points:**
+
+- Validate `CoordSystem` inputs using `handshake_dim` / `handshake_size` / `handshake_coords`
+- Validate `FrameSchema` inputs using `validate_observation_fields()`
+- Always return a **tuple**, even for a single output
+
+### **[CONFIRM — Coordinates]**
+
+Present to the user:
+
+1. The input and output variable lists and any mapping issues with `E2STUDIO_VOCAB`
+2. Whether `init_coords` returns `None` (stateless) or a tuple (stateful)
+3. The `FrameSchema` column names and types for observation inputs
+4. The output grid specification (lat/lon dimensions)
+
+---
+
+## Step 5 — Implement Forward Pass
+
+### 5a. Implement `__call__`
+
+The stateless forward pass accepts typed arguments matching
+`input_coords`, runs inference, and returns a tuple of
+`xr.DataArray | pd.DataFrame`.
+
+```python
+@torch.inference_mode()
+def __call__(self, obs: pd.DataFrame | None = None) -> xr.DataArray:
+ """Run single-step assimilation.
+
+ Parameters
+ ----------
+ obs : pd.DataFrame | None, optional
+ Observation DataFrame
+
+ Returns
+ -------
+ xr.DataArray
+ Analysis output on the same device as the model
+ """
+ if obs is None:
+ raise ValueError("obs must be provided")
+
+ # 1. Validate required columns
+ input_coords = self.input_coords()
+ validate_observation_fields(obs, required_fields=list(input_coords[0].keys()))
+
+ # 2. Extract request_time from DataFrame attrs
+ request_time = obs.attrs.get("request_time")
+ if request_time is None:
+ raise ValueError(
+ "Observation DataFrame must have 'request_time' in attrs. "
+ "This is typically set by earth2studio data sources."
+ )
+
+ # 3. Get output coordinate system
+ (output_coords,) = self.output_coords(input_coords, request_time=request_time)
+
+ # 4. Filter observations within time tolerance window
+ filtered_obs = filter_time_range(
+ obs,
+ request_time=request_time,
+ tolerance=self._tolerance,
+ time_column="time",
+ )
+
+ # 5. Convert DataFrame columns to torch tensors (zero-copy for cudf)
+ obs_values = dfseries_to_torch(
+ filtered_obs["observation"], dtype=torch.float32, device=self.device,
+ )
+ obs_lat = dfseries_to_torch(
+ filtered_obs["lat"], dtype=torch.float32, device=self.device,
+ )
+ obs_lon = dfseries_to_torch(
+ filtered_obs["lon"], dtype=torch.float32, device=self.device,
+ )
+
+ # 6. Run model forward pass
+ prediction = self._forward(obs_values)
+
+ # 7. Build output xr.DataArray with cupy (GPU) or numpy (CPU)
+ if self.device.type == "cuda" and cp is not None:
+ data = cp.asarray(prediction)
+ else:
+ data = prediction.cpu().numpy()
+
+ return xr.DataArray(
+ data=data,
+ dims=list(output_coords.keys()),
+ coords=output_coords,
+ )
+```
+
+**Key points:**
+
+- Use `@torch.inference_mode()` on `__call__` or on the internal `_forward` method — place it
+ on `_forward` if preprocessing outside `_forward` needs gradients. Omit entirely if the
+ forward pass requires gradients (e.g., DPS guidance); document the reason if omitted
+- Expect `request_time` in `obs.attrs` — validate early
+- Use `validate_observation_fields()` to check required DataFrame columns
+- Use `filter_time_range()` for time-window filtering
+- Use `dfseries_to_torch()` for zero-copy cudf→torch conversion
+- Output data must live on the same device as the model (cupy for GPU, numpy for CPU)
+- Do **NOT** use `@batch_func()` — this is a px/dx convention only
+
+### 5b. Implement `create_generator`
+
+**Stateless pattern** (no initial state, delegates to `__call__`):
+
+```python
+def create_generator(
+ self,
+) -> Generator[xr.DataArray, pd.DataFrame | None, None]:
+ """Creates a generator for sequential data assimilation.
+
+ Yields
+ ------
+ xr.DataArray
+ Analysis at each step
+
+ Receives
+ --------
+ pd.DataFrame | None
+ Observations for the next step
+ """
+ # Prime generator — yield None for stateless models
+ observations = yield None # type: ignore[misc]
+ try:
+ while True:
+ result = self.__call__(observations)
+ observations = yield result
+ except GeneratorExit:
+ logger.debug("MyDAModel generator clean up complete.")
+```
+
+**Stateful pattern** (maintains background state across steps):
+
+```python
+def create_generator(
+ self,
+ x: xr.DataArray, # initial state
+) -> Generator[xr.DataArray, pd.DataFrame | None, None]:
+ """Creates a stateful generator for sequential data assimilation.
+
+ Parameters
+ ----------
+ x : xr.DataArray
+ Initial state DataArray
+
+ Yields
+ ------
+ xr.DataArray
+ Analysis at each step
+
+ Receives
+ --------
+ pd.DataFrame | None
+ Observations for the next step
+ """
+ # Yield initial state to prime the generator
+ obs = yield x
+
+ try:
+ while True:
+ result = self.__call__(x, obs)
+ x = result # update internal state
+ obs = yield result
+ except GeneratorExit:
+ logger.debug("MyDAModel generator clean up complete.")
+```
+
+**Key points:**
+
+- Always yield first to prime the generator (`yield None` for stateless, `yield x` for stateful)
+- Receive observations via `.send()`: `observations = yield result`
+- **Always** handle `GeneratorExit` for clean-up logic (e.g., releasing GPU resources)
+- Do **NOT** use `create_iterator` — that is a px convention only
+
+### **[CONFIRM — Forward Pass]**
+
+Show the user the implementation for `__call__` and `create_generator`. Ask:
+
+1. Does the computation logic look correct?
+2. Is `@torch.inference_mode()` safe, or does the model need gradient flow?
+3. Are there any special considerations (multiple observation types, custom preprocessing)?
+
+---
+
+## Step 6 — Implement Model Loading
+
+### 6a. Implement load_default_package
+
+```python
+@classmethod
+def load_default_package(cls) -> Package:
+ """Default pre-trained model package.
+
+ Returns
+ -------
+ Package
+ Model package with default checkpoint location
+ """
+ return Package(
+ "hf://nvidia/my-da-model@",
+ cache_options={"same_names": True},
+ )
+```
+
+### 6b. Implement load_model
+
+```python
+@classmethod
+@check_optional_dependencies()
+def load_model(
+ cls,
+ package: Package,
+ time_tolerance: TimeTolerance = np.timedelta64(3, "h"),
+) -> AssimilationModel:
+ """Load assimilation model from package.
+
+ Parameters
+ ----------
+ package : Package
+ Package containing model checkpoint and statistics
+ time_tolerance : TimeTolerance, optional
+ Observation time tolerance window
+
+ Returns
+ -------
+ AssimilationModel
+ Loaded assimilation model
+ """
+ model = CoreModel.from_checkpoint(package.resolve("model.mdlus"))
+ model.eval()
+
+ # Load normalization stats, static fields, etc.
+ stats = np.load(package.resolve("stats.npy"))
+
+ return cls(model=model, time_tolerance=time_tolerance)
+```
+
+**Key patterns:**
+
+- Decorate `load_model` with `@check_optional_dependencies()`
+- Use `package.resolve("filename")` to get cached file paths
+- Call `.eval()` on loaded neural network modules
+- Only expose essential parameters — do **not** over-populate the API
+
+### 6c. Implement .to()
+
+> **Note:** When the wrapper inherits from `torch.nn.Module`,
+> `super().to(device)` already handles moving all registered
+> parameters, buffers, and sub-modules. A custom `to()`
+> override is only needed when there is non-PyTorch state to
+> manage (e.g., ONNX Runtime sessions, JAX device placement).
+
+```python
+def to(self, device: torch.device | str) -> AssimilationModel:
+ """Move model to device.
+
+ Parameters
+ ----------
+ device : torch.device | str
+ Target device
+
+ Returns
+ -------
+ AssimilationModel
+ Model on target device
+ """
+ super().to(device)
+ return self
+```
+
+### **[CONFIRM — Model Loading]**
+
+Present to the user:
+
+1. The checkpoint URL/path for `load_default_package`
+2. The checkpoint file names and loading logic
+3. Whether there are multiple checkpoint files
+4. The `.to()` implementation
+
+---
+
+## Step 7 — Register the Model
+
+### 7a. Add to `__init__.py`
+
+Edit `earth2studio/models/da/__init__.py`:
+
+- Add import in alphabetical order:
+ `from earth2studio.models.da. import `
+- If an `__all__` list exists, add `` to it in alphabetical order
+
+### 7b. Verify pyproject.toml
+
+Confirm the dependency group was added in Step 2 and is included in the `all` aggregate.
+
+---
+
+## Step 8 — Verify Style, Documentation, Format & Lint
+
+Before testing, verify the wrapper passes all code quality checks.
+
+### 8a. Run formatting
+
+```bash
+make format
+```
+
+This runs `black` on the codebase. Fix any formatting issues in
+the new wrapper file.
+
+### 8b. Run linting
+
+```bash
+make lint
+```
+
+This runs `ruff` and `mypy`. Common issues to watch for:
+
+- Missing type annotations on public functions
+- Unused imports
+- Import ordering issues
+- Type errors from incorrect return types or missing annotations
+
+Fix all errors before proceeding.
+
+### 8c. Check license headers
+
+```bash
+make license
+```
+
+Verify that the wrapper file
+(`earth2studio/models/da/.py`) has the correct SPDX
+Apache-2.0 license header (2024-2026 copyright years).
+
+### 8d. Verify documentation
+
+Check that:
+
+- The class docstring follows NumPy-style formatting with
+ `Parameters`, `Note`, `Badges` sections
+- All public methods have complete docstrings with
+ `Parameters`, `Returns`, `Raises` sections as applicable
+- Type hints are present on all public method signatures
+- The model is added to `docs/modules/models.rst` in the
+ `earth2studio.models.da` section (alphabetical order)
+
+If any checks fail, fix the issues and re-run until all pass cleanly.
+
+---
+
+## Reminders
+
+- **DO** return tuples from `input_coords` and `output_coords`,
+ even for single inputs/outputs
+- **DO** use `FrameSchema` for tabular DataFrame inputs and
+ `CoordSystem` for gridded outputs
+- **DO** validate `request_time` from `obs.attrs` — it is set
+ by earth2studio data sources
+- **DO** use `validate_observation_fields()` to check required
+ DataFrame columns early
+- **DO** use `filter_time_range()` for time-window filtering
+ of observations
+- **DO** use `dfseries_to_torch()` for zero-copy cudf to torch
+ column conversion
+- **DO** prime `create_generator` with `yield None` (stateless)
+ or `yield initial_state` (stateful) before the loop
+- **DO** handle `GeneratorExit` in `create_generator` for
+ clean-up
+- **DO** register `device_buffer` and expose a `device` property
+- **DO** return cupy arrays on GPU, numpy arrays on CPU for
+ `xr.DataArray` output
+- **DO** use `loguru.logger` for logging, never `print()`,
+ inside `earth2studio/`
+- **DO** ensure all public functions have full type hints
+- **DO** run formatting (`make format`) and linting
+ (`make lint`) before finalizing
+- **DO** use `@torch.inference_mode()` on `__call__` for
+ inference-only models
+- **DO** set `eval()` on loaded NN models in `load_model`
+- **DO** add the model to `docs/modules/models.rst` in the
+ `earth2studio.models.da` section (alphabetical order)
+- **DO** use `uv run python` for all Python commands (never
+ bare `python`)
+- **DO NOT** use `@batch_func()` or `@batch_coords()` — these
+ are px/dx conventions only
+- **DO NOT** use `PrognosticMixin` — DA models do not need
+ iterator hooks
+- **DO NOT** use `create_iterator` — DA uses
+ `create_generator` with send protocol
+- **DO NOT** assume tensor inputs — inputs are
+ DataFrames/DataArrays
+- **DO NOT** forget cudf/cupy optional import pattern
+- **DO NOT** use `@torch.inference_mode()` if the forward pass
+ requires gradients (e.g., DPS guidance); document the reason
+ if omitted
+- **DO NOT** attempt to make a general base class with intent
+ to reuse the wrapper across models
+- **DO NOT** over-populate the `load_model()` API — only
+ expose essential parameters
diff --git a/.claude/skills/create-diagnostic-wrapper/SKILL.md b/.claude/skills/create-diagnostic-wrapper/SKILL.md
new file mode 100644
index 000000000..9df944198
--- /dev/null
+++ b/.claude/skills/create-diagnostic-wrapper/SKILL.md
@@ -0,0 +1,1440 @@
+---
+name: create-diagnostic-wrapper
+description: Create a new Earth2Studio diagnostic model (dx) wrapper from a reference inference script or repository. Handles both NN-based models (with AutoModelMixin and Package loading) and physics-based calculators (pure torch computations).
+argument-hint: URL or local path to reference inference script/repo (optional — will be asked if not provided)
+---
+
+# Create Diagnostic Model Wrapper
+
+Create a new Earth2Studio diagnostic model wrapper by following every step below in order.
+Each confirmation gate marked by starting with:
+
+```markdown
+### **[CONFIRM — ]**
+```
+
+requires explicit user approval before proceeding.
+
+> **Environment note**: Use `uv run python` for all Python
+> commands. The project uses a `uv`-managed virtual environment
+> — never install packages globally or use bare `python`.
+
+---
+
+## Step 0 — Obtain Reference & Detect Model Type
+
+### 0a. Obtain reference script / repository
+
+If `$ARGUMENTS` is provided, use it as the reference inference script or repository.
+
+- If it is a URL, use WebFetch to retrieve the content.
+- If it is a local file path, read it directly.
+
+If `$ARGUMENTS` is empty or not provided, ask the user:
+
+> Please provide a reference inference script or
+> repository URL/path that demonstrates how this model
+> runs inference. This will be used to understand the
+> model architecture, dependencies, input/output shapes,
+> and variable mapping.
+
+Store the reference code content for use in subsequent steps.
+
+### 0b. Detect model type — NN-based vs physics-based
+
+After obtaining the reference, classify the model into one
+of two categories:
+
+**NN-based indicators** (checkpoint-backed neural network):
+
+- Checkpoint loading (`torch.load`, `.from_pretrained`, ONNX runtime)
+- Trained weights, normalization parameters (mean/std tensors)
+- Model architecture classes (e.g., `torch.nn.Module` subclass with learned layers)
+- Imports like `onnxruntime`, `timm`, `einops`, `physicsnemo`
+- `.pt`, `.onnx`, `.safetensors`, `.mdlus` checkpoint files
+
+**Physics-based indicators** (analytical / derived calculator):
+
+- Analytical formulas (e.g., `sqrt(u² + v²)`, Clausius–Clapeyron)
+- No trained weights or checkpoint files
+- Parameterized by physical constants, pressure levels, or thresholds
+- Pure `torch` math operations
+- No model architecture classes
+
+Present the detection result to the user:
+
+> **Model type detected: [NN-based / Physics-based]**
+>
+> Evidence:
+>
+> - [list key indicators found]
+>
+> This determines branching in Steps 2, 3, and 6.
+> NN-based models require dependency management and
+> checkpoint loading. Physics-based models skip those
+> steps.
+
+Store the model type flag (`nn` or `physics`) for
+conditional branching in Steps 2, 3, and 6.
+
+### **[CONFIRM — Model Type]**
+
+Ask the user to confirm the detected model type before
+proceeding.
+
+---
+
+## Step 1 — Examine Reference & Propose Dependencies
+
+### 1a. Analyze the reference code
+
+Examine the reference inference script/repo to identify:
+
+- **Python packages** required (e.g., `torch`, `onnxruntime`, `einops`, `timm`, custom packages)
+- **Model architecture** (PyTorch, ONNX, JAX, etc.) — NN-based only
+- **Input/output tensor shapes** and variable names
+- **Spatial resolution** (lat/lon grid dimensions and spacing, or flexible for physics-based)
+- **For NN-based**: Checkpoint format (`.pt`, `.onnx`, `.safetensors`, etc.)
+- **For physics-based**: Input variable patterns (e.g., `[u{level}, v{level}]`)
+ and output variable formulas (e.g., `ws = sqrt(u² + v²)`)
+
+### 1b. Propose pyproject.toml dependency group
+
+**NN-based models only:**
+
+Propose a new optional dependency group for `pyproject.toml`. Follow the existing pattern:
+
+```toml
+# In [project.optional-dependencies] section of pyproject.toml
+model-name = [
+ "package1>=version",
+ "package2",
+]
+```
+
+The group name should be lowercase-hyphenated (e.g., `precip-afno`, `windgust-afno`, `climatenet`).
+
+Look at the existing groups in `pyproject.toml`
+(lines ~59-257) for reference on naming and version
+pinning conventions.
+
+**Also propose adding the new group to the `all` aggregate** in the appropriate line (dx models line).
+
+**Physics-based models:**
+
+State: "No external dependencies needed — physics-based model uses
+only `torch` and `numpy` (already core dependencies)."
+
+### **[CONFIRM — Dependencies]**
+
+Present to the user:
+
+1. The proposed dependency group name (NN-based) or
+ "No dependencies" (physics-based)
+2. The list of packages with versions (NN-based only)
+3. Ask if the packages and group name look correct
+
+---
+
+## Step 2 — Add Dependencies to pyproject.toml
+
+**NN-based models:**
+
+After confirmation, edit `pyproject.toml`:
+
+1. Add the new optional dependency group in alphabetical order
+ among the per-model extras
+2. Add the group to the `all` aggregate (in the dx models line)
+
+**Physics-based models:**
+
+Skip — no external dependencies for physics-based model.
+State explicitly: "Step 2 skipped — physics-based model
+has no external dependencies."
+
+---
+
+## Step 3 — Create Skeleton Class File
+
+### 3a. Determine class name and file name
+
+Based on the model name from the reference, propose:
+
+- **Class name**: PascalCase (e.g., `PrecipitationAFNO`, `DerivedWS`, `WindgustAFNO`)
+- **File name**: lowercase with underscores
+ (e.g., `precipitation_afno.py`, `derived.py`, `wind_gust.py`)
+- **File path**: `earth2studio/models/dx/.py`
+
+### 3b. Write skeleton with pseudocode
+
+Create the file with the full structure but pseudocode
+implementations. Every `.py` file in `earth2studio/`
+**must** start with this license header:
+
+```python
+# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES.
+# SPDX-FileCopyrightText: All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+```
+
+Use one of the two skeleton templates depending on the
+model type detected in Step 0.
+
+#### NN-based skeleton (dual inheritance: `torch.nn.Module, AutoModelMixin`)
+
+```python
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+from earth2studio.models.auto import AutoModelMixin, Package
+from earth2studio.models.batch import batch_coords, batch_func
+from earth2studio.models.dx.base import DiagnosticModel
+from earth2studio.utils import handshake_coords, handshake_dim
+from earth2studio.utils.imports import (
+ OptionalDependencyFailure,
+ check_optional_dependencies,
+)
+from earth2studio.utils.type import CoordSystem
+
+# Optional dependency imports (try/except pattern)
+try:
+ import optional_package
+except ImportError:
+ OptionalDependencyFailure("model-name")
+ optional_package = None
+
+VARIABLES = [...] # List of input variable names from E2STUDIO_VOCAB
+
+@check_optional_dependencies()
+class ModelName(torch.nn.Module, AutoModelMixin):
+ """One-line description.
+
+ Extended description of the model, its source,
+ and any relevant details.
+
+ Parameters
+ ----------
+ core_model : torch.nn.Module
+ Core neural network model
+ ...additional params...
+
+ Note
+ ----
+ For more information see:
+ """
+
+ # 1. Constructor
+ def __init__(self, core_model: torch.nn.Module, ...):
+ super().__init__()
+ self.core_model = core_model
+ self.register_buffer(
+ "device_buffer", torch.empty(0)
+ )
+ # TODO: Register normalization buffers (center, scale, etc.)
+
+ # 2. Input coordinates
+ def input_coords(self) -> CoordSystem:
+ """Input coordinate system of diagnostic model.
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ # TODO: Define input coordinates
+ pass
+
+ # 3. Output coordinates
+ @batch_coords()
+ def output_coords(
+ self, input_coords: CoordSystem,
+ ) -> CoordSystem:
+ """Output coordinate system of diagnostic model.
+
+ Parameters
+ ----------
+ input_coords : CoordSystem
+ Input coordinate system to transform
+
+ Returns
+ -------
+ CoordSystem
+ Output coordinate system
+
+ Raises
+ ------
+ ValueError
+ If input coordinates are invalid
+ """
+ # TODO: Validate and transform coordinates
+ pass
+
+ # 4. Forward pass
+ @torch.inference_mode()
+ @batch_func()
+ def __call__(
+ self,
+ x: torch.Tensor,
+ coords: CoordSystem,
+ ) -> tuple[torch.Tensor, CoordSystem]:
+ """Forward pass of diagnostic model.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor
+ coords : CoordSystem
+ Input coordinate system
+
+ Returns
+ -------
+ tuple[torch.Tensor, CoordSystem]
+ Output tensor and coordinates
+ """
+ # TODO: Validate, forward, return
+ pass
+
+ # 5. Private forward computation (optional)
+ def _forward(
+ self, x: torch.Tensor, coords: CoordSystem,
+ ) -> torch.Tensor:
+ """Internal forward pass.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor on device
+ coords : CoordSystem
+ Input coordinate system
+
+ Returns
+ -------
+ torch.Tensor
+ Output tensor
+ """
+ # TODO: normalize -> core_model -> denormalize
+ pass
+
+ # 6. Device management
+ def to(
+ self, device: torch.device | str,
+ ) -> DiagnosticModel:
+ """Move model to device.
+
+ Parameters
+ ----------
+ device : torch.device | str
+ Target device
+
+ Returns
+ -------
+ DiagnosticModel
+ Model on target device
+ """
+ # TODO: Device management
+ pass
+
+ # 7. Default package location
+ @classmethod
+ def load_default_package(cls) -> Package:
+ """Default pre-trained model package.
+
+ Returns
+ -------
+ Package
+ Model package
+ """
+ # TODO: Default checkpoint location
+ pass
+
+ # 8. Load model from package
+ @classmethod
+ @check_optional_dependencies()
+ def load_model(
+ cls, package: Package,
+ ) -> DiagnosticModel:
+ """Load diagnostic model from package.
+
+ Parameters
+ ----------
+ package : Package
+ Model package with checkpoint files
+
+ Returns
+ -------
+ DiagnosticModel
+ Loaded model instance
+ """
+ # TODO: Load model from package
+ pass
+```
+
+#### Physics-based skeleton (`torch.nn.Module` only)
+
+```python
+from collections import OrderedDict
+
+import numpy as np
+import torch
+
+from earth2studio.models.batch import batch_coords, batch_func
+from earth2studio.models.dx.base import DiagnosticModel
+from earth2studio.utils import handshake_coords, handshake_dim
+from earth2studio.utils.type import CoordSystem
+
+
+class ModelName(torch.nn.Module):
+ """One-line description.
+
+ Extended description of the calculation, formula,
+ and any physical basis.
+
+ Parameters
+ ----------
+ levels : list[int | str], optional
+ Pressure / height levels to compute for
+ ...additional params...
+
+ Note
+ ----
+ For more information see:
+ """
+
+ # 1. Constructor
+ def __init__(
+ self, levels: list[int | str] = [100],
+ ) -> None:
+ super().__init__()
+ self.levels = levels
+ self.in_variables = [...] # Built from levels
+ self.out_variables = [...] # Built from levels
+
+ # 2. Input coordinates
+ def input_coords(self) -> CoordSystem:
+ """Input coordinate system of diagnostic model.
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ # TODO: Define input coordinates
+ pass
+
+ # 3. Output coordinates
+ @batch_coords()
+ def output_coords(
+ self, input_coords: CoordSystem,
+ ) -> CoordSystem:
+ """Output coordinate system of diagnostic model.
+
+ Parameters
+ ----------
+ input_coords : CoordSystem
+ Input coordinate system to transform
+
+ Returns
+ -------
+ CoordSystem
+ Output coordinate system
+
+ Raises
+ ------
+ ValueError
+ If input coordinates are invalid
+ """
+ # TODO: Validate and transform coordinates
+ pass
+
+ # 4. Forward pass
+ @torch.inference_mode()
+ @batch_func()
+ def __call__(
+ self,
+ x: torch.Tensor,
+ coords: CoordSystem,
+ ) -> tuple[torch.Tensor, CoordSystem]:
+ """Forward pass of diagnostic model.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor
+ coords : CoordSystem
+ Input coordinate system
+
+ Returns
+ -------
+ tuple[torch.Tensor, CoordSystem]
+ Output tensor and coordinates
+ """
+ # TODO: Physics computation
+ pass
+
+ # 5. Device management
+ def to(
+ self, device: torch.device | str,
+ ) -> DiagnosticModel:
+ """Move model to device.
+
+ Parameters
+ ----------
+ device : torch.device | str
+ Target device
+
+ Returns
+ -------
+ DiagnosticModel
+ Model on target device
+ """
+ super().to(device)
+ return self
+```
+
+### Canonical method ordering for diagnostic models
+
+Methods in the class **must** appear in this order:
+
+1. `__init__` — constructor
+2. `input_coords` — input coordinate system
+3. `output_coords` — output coordinate system (decorated `@batch_coords()`)
+4. `__call__` — forward pass (decorated `@torch.inference_mode()`, `@batch_func()`)
+5. `_forward` — private computation method (optional helper)
+6. `to` — device management
+7. `load_default_package` — classmethod returning default `Package` (NN-based only)
+8. `load_model` — classmethod loading model from package (NN-based only)
+
+### Key differences from prognostic models
+
+Be aware of these critical differences from the
+`create-prognostic-wrapper` skill:
+
+- **No `PrognosticMixin`** — diagnostic models do NOT
+ inherit from `PrognosticMixin` (no iterator hooks needed)
+- **No `create_iterator`** or `_default_generator` —
+ diagnostic models do NOT perform time integration
+- **Typically no `lead_time`** — most diagnostic models do
+ not have a `lead_time` dimension. However, some models
+ that consume forecast output at specific lead times (e.g.,
+ solar radiation, wind gust) may include `lead_time` in
+ their coordinate system. Only add it if the model
+ genuinely needs temporal context
+- **Match `handshake_dim` indices to coordinate position** —
+ use the actual position of each dimension in the input
+ `CoordSystem` OrderedDict. For simple models with
+ `(batch, variable, lat, lon)` the indices are `1, 2, 3`.
+ For models with `(batch, time, variable, lat, lon)` use
+ `2, 3, 4`. You may also use negative indices (`-3, -2, -1`)
+ if the model must accept flexible leading dimensions.
+ Check existing models in `earth2studio/models/dx/` for
+ the predominant convention used in the codebase.
+
+### **[CONFIRM — Skeleton]**
+
+Present to the user:
+
+1. The proposed class name
+2. The proposed file name and path
+3. The detected model type (NN-based or physics-based)
+4. Ask if these are acceptable
+
+---
+
+## Step 4 — Implement Coordinate System
+
+### 4a. Map variables to E2STUDIO_VOCAB
+
+Read `earth2studio/lexicon/base.py` and verify every
+variable the model uses exists in `E2STUDIO_VOCAB`.
+The vocab contains 282 entries including:
+
+| Category | Examples |
+|---|---|
+| Surface wind | `u10m`, `v10m`, `ws10m`, `u100m`, `v100m` |
+| Surface temp | `t2m`, `d2m`, `sst`, `skt` |
+| Humidity | `r2m`, `q2m`, `tcwv` |
+| Pressure | `sp`, `msl` |
+| Precipitation | `tp`, `lsp`, `cp`, `tp06` |
+| Pressure-level | `u50`-`u1000`, `v50`-`v1000`, `z50`-`z1000` |
+| Cloud/radiation | `tcc`, `rlut`, `rsut` |
+
+Pressure levels available: 50, 100, 150, 200, 250,
+300, 400, 500, 600, 700, 850, 925, 1000.
+
+If a variable in the reference model does NOT exist
+in `E2STUDIO_VOCAB`, flag it to the user and discuss
+whether to:
+
+- Map it to an existing vocab entry
+- Propose adding a new vocab entry (separate step)
+
+### 4b. Implement input_coords
+
+**NN-based** (fixed grid):
+
+```python
+def input_coords(self) -> CoordSystem:
+ """Input coordinate system of diagnostic model.
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ return OrderedDict({
+ "batch": np.empty(0), # MUST be first, MUST be np.empty(0)
+ "time": np.empty(0), # Dynamic time dimension (optional)
+ "variable": np.array(VARIABLES),
+ "lat": np.linspace(90, -90, num_lat, endpoint=...), # From reference
+ "lon": np.linspace(0, 360, num_lon, endpoint=False), # From reference
+ })
+```
+
+**Physics-based** (flexible grid):
+
+```python
+def input_coords(self) -> CoordSystem:
+ """Input coordinate system of diagnostic model.
+
+ Returns
+ -------
+ CoordSystem
+ Coordinate system dictionary
+ """
+ return OrderedDict({
+ "batch": np.empty(0), # MUST be first, MUST be np.empty(0)
+ "variable": np.array(self.in_variables),
+ "lat": np.empty(0), # Flexible — accepts any grid
+ "lon": np.empty(0), # Flexible — accepts any grid
+ })
+```
+
+**Rules:**
+
+- `batch` is always first with `np.empty(0)`
+- `time` is `np.empty(0)` (dynamic) — include only if
+ the model needs time information (e.g., for solar
+ zenith angle calculations)
+- **Typically no `lead_time`** — most diagnostic models do
+ not include `lead_time`. Add it only if the model needs
+ temporal context (e.g., solar radiation models that need
+ time-of-day, or models that consume forecast output at
+ specific lead times)
+- `lat` typically goes from 90 to -90 (north to south)
+ for NN-based; use `np.empty(0)` for physics-based
+ (flexible grid)
+- `lon` typically goes from 0 to 360 for NN-based;
+ use `np.empty(0)` for physics-based (flexible grid)
+
+### 4c. Implement output_coords
+
+```python
+@batch_coords()
+def output_coords(self, input_coords: CoordSystem) -> CoordSystem:
+ """Output coordinate system of diagnostic model.
+
+ Parameters
+ ----------
+ input_coords : CoordSystem
+ Input coordinates to validate and transform
+
+ Returns
+ -------
+ CoordSystem
+ Output coordinates with updated variable list
+
+ Raises
+ ------
+ ValueError
+ If input coordinates are invalid
+ """
+ target_input_coords = self.input_coords()
+
+ # Validate dimensions exist at correct indices
+ # Use the positional index matching each dim's position in input_coords
+ # For (batch, variable, lat, lon) → 1, 2, 3
+ # For (batch, time, variable, lat, lon) → 2, 3, 4
+ # Negative indices (-3, -2, -1) also work for flexible leading dims
+ handshake_dim(input_coords, "variable", 1)
+ handshake_dim(input_coords, "lat", 2)
+ handshake_dim(input_coords, "lon", 3)
+
+ # Validate coordinate values match
+ handshake_coords(input_coords, target_input_coords, "variable")
+ handshake_coords(input_coords, target_input_coords, "lat")
+ handshake_coords(input_coords, target_input_coords, "lon")
+
+ output_coords = input_coords.copy()
+ output_coords["variable"] = np.array(OUTPUT_VARIABLES)
+ return output_coords
+```
+
+**Key points:**
+
+- Use `@batch_coords()` decorator
+- Use `handshake_dim` with indices matching each
+ dimension's position in the `CoordSystem` OrderedDict.
+ Use positive indices (e.g., `1, 2, 3`) or negative
+ indices (e.g., `-3, -2, -1`) — match the convention
+ of the most similar existing model in the codebase.
+- Validate coordinate values with `handshake_coords`
+- Output typically changes the `variable` array (different
+ output variables from input variables)
+- Unlike prognostic models, there is typically no
+ `lead_time` to increment
+
+### **[CONFIRM — Coordinates]**
+
+Present to the user:
+
+1. The input and output variable lists and any mapping
+ issues with `E2STUDIO_VOCAB`
+2. The spatial dimensions (lat/lon grid size and spacing,
+ or "flexible" for physics-based)
+3. Whether `lead_time` is needed (and why — most dx
+ models omit it)
+4. Whether `time` is included (and why, if so)
+
+---
+
+## Step 5 — Implement Forward Pass
+
+### 5a. Implement `__call__`
+
+```python
+@torch.inference_mode()
+@batch_func()
+def __call__(
+ self,
+ x: torch.Tensor,
+ coords: CoordSystem,
+) -> tuple[torch.Tensor, CoordSystem]:
+ """Forward pass of diagnostic model.
+
+ Parameters
+ ----------
+ x : torch.Tensor
+ Input tensor
+ coords : CoordSystem
+ Input coordinate system
+
+ Returns
+ -------
+ tuple[torch.Tensor, CoordSystem]
+ Output tensor and coordinates
+ """
+ # Get output coordinates (this also validates input coords)
+ output_coords = self.output_coords(coords)
+
+ # Move to device
+ device = self.device_buffer.device # NN-based
+ # or: device = next(self.parameters()).device
+ x = x.to(device)
+
+ # Run forward pass
+ out = self._forward(x, coords)
+
+ return out, output_coords
+```
+
+Key implementation notes:
+
+- The `@batch_func()` decorator handles the batch
+ dimension — inside `__call__`, `x` has shape
+ `(batch, ..., variable, lat, lon)` where `...` may
+ include `time` and/or `lead_time` depending on the
+ upstream pipeline
+- Coordinate validation is handled inside `output_coords()`
+ — no need to repeat `handshake_dim`/`handshake_coords`
+ in `__call__` (follow the pattern of existing dx models)
+- All tensor operations should happen on GPU when
+ possible
+- **No `create_iterator`**, no hooks, no yielding —
+ diagnostic models perform a single forward pass only
+
+### 5b. Implement `_forward`
+
+**NN-based** (normalize → model → denormalize):
+
+```python
+def _forward(
+ self, x: torch.Tensor, coords: CoordSystem,
+) -> torch.Tensor:
+ """Internal forward pass."""
+ # Normalize input
+ x = (x - self.center) / self.scale
+
+ # Reshape for core model if needed
+ # TODO: Reshape from (..., variable, lat, lon) to model format
+
+ # Core model forward
+ out = self.core_model(x)
+
+ # Reshape output back
+ # TODO: Reshape from model format to (..., out_variable, lat, lon)
+
+ # Denormalize output (if needed)
+ out = out * self.output_scale + self.output_center
+
+ return out
+```
+
+**Physics-based** (direct torch math):
+
+```python
+@torch.inference_mode()
+@batch_func()
+def __call__(
+ self,
+ x: torch.Tensor,
+ coords: CoordSystem,
+) -> tuple[torch.Tensor, CoordSystem]:
+ """Forward pass of diagnostic."""
+ output_coords = self.output_coords(coords)
+
+ # Example: wind speed from u, v components
+ u = x[..., ::2, :, :] # Every other variable starting from 0
+ v = x[..., 1::2, :, :] # Every other variable starting from 1
+ out = torch.sqrt(u**2 + v**2)
+
+ return out, output_coords
+```
+
+For physics-based models, the computation typically
+goes directly in `__call__` rather than a separate
+`_forward` method.
+
+### **[CONFIRM — Forward Pass]**
+
+Show the user the pseudocode for `__call__`
+(especially the reshape logic for NN-based, or the
+formula for physics-based). Ask:
+
+1. Does the computation logic look correct for this
+ model?
+2. Are there any special considerations (e.g., multiple
+ ONNX models, clamping, special tensor layouts)?
+
+---
+
+## Step 6 — Implement Model Loading (NN-based only)
+
+**Physics-based models:** Skip this step entirely.
+State explicitly: "Step 6 skipped — physics-based models
+have no checkpoints to load."
+
+### 6a. Implement load_default_package
+
+```python
+@classmethod
+def load_default_package(cls) -> Package:
+ """Default pre-trained model package on .
+
+ Returns
+ -------
+ Package
+ Model package
+ """
+ return Package(
+ "ngc://models/org/model@version", # or hf://, s3://, local path
+ cache_options={
+ "cache_storage": Package.default_cache("model_name"),
+ "same_names": True,
+ },
+ )
+```
+
+### 6b. Implement load_model
+
+```python
+@classmethod
+@check_optional_dependencies()
+def load_model(
+ cls,
+ package: Package,
+) -> DiagnosticModel:
+ """Load diagnostic model from package.
+
+ Parameters
+ ----------
+ package : Package
+ Model package with checkpoint files
+
+ Returns
+ -------
+ DiagnosticModel
+ Loaded model instance
+ """
+ # Resolve checkpoint files
+ checkpoint_path = package.resolve("model.pt")
+
+ # Load model
+ core_model = torch.load(
+ checkpoint_path, map_location="cpu", weights_only=False,
+ )
+ core_model.eval()
+ core_model.requires_grad_(False)
+
+ # Load any additional data files (normalization, masks, etc.)
+ # center = torch.Tensor(np.load(package.resolve("center.npy")))
+ # scale = torch.Tensor(np.load(package.resolve("scale.npy")))
+
+ return cls(core_model)
+```
+
+**Key patterns:**
+
+- Use `package.resolve("filename")` to get cached
+ file paths
+- Load with `map_location="cpu"` then let user call
+ `.to(device)`
+- Set model to `eval()` mode and `requires_grad_(False)`
+- Do NOT over-populate `load_model()` API — only
+ expose essential parameters
+- Use `@check_optional_dependencies()` if the model
+ has optional deps
+
+### 6c. Implement .to()
+
+> **Note:** When the wrapper inherits from `torch.nn.Module`,
+> `super().to(device)` already handles moving all registered
+> parameters, buffers, and sub-modules. A custom `to()`
+> override is only needed when there is non-PyTorch state to
+> manage (e.g., ONNX Runtime sessions that must be destroyed
+> and recreated on a new device, or JAX device placement).
+> If `super().to(device)` is sufficient, you can omit the
+> override entirely.
+
+```python
+def to(self, device: torch.device | str) -> DiagnosticModel:
+ """Move model to device.
+
+ Parameters
+ ----------
+ device : torch.device | str
+ Target device
+
+ Returns
+ -------
+ DiagnosticModel
+ Model on target device
+ """
+ super().to(device)
+ # If using ONNX Runtime, destroy and recreate session on new device
+ # If using PyTorch, super().to(device) handles it
+ return self
+```
+
+### **[CONFIRM — Model Loading]**
+
+Present to the user:
+
+1. The checkpoint URL/path for `load_default_package`
+2. The checkpoint file names and loading logic
+3. Whether there are multiple checkpoint files
+4. The `.to()` implementation (especially if ONNX or
+ non-PyTorch backend)
+
+---
+
+## Step 7 — Register the Model
+
+### 7a. Add to `__init__.py`
+
+Edit `earth2studio/models/dx/__init__.py`:
+
+- Add import in alphabetical order:
+ `from earth2studio.models.dx. import `
+- Add `` to the `__all__` list in alphabetical order
+
+### 7b. Verify pyproject.toml (NN-based only)
+
+Confirm the dependency group was added in Step 2 and is included in the `all` aggregate.
+
+For physics-based models, no pyproject.toml verification is needed.
+
+---
+
+## Step 8 — Verify Style, Documentation, Format & Lint
+
+Before testing, verify the wrapper passes all code quality checks.
+
+### 8a. Run formatting
+
+```bash
+make format
+```
+
+This runs `black` on the codebase. Fix any formatting issues in the new wrapper file and test file.
+
+### 8b. Run linting
+
+```bash
+make lint
+```
+
+This runs `ruff` and `mypy`. Common issues to watch
+for:
+
+- Missing type annotations on public functions
+- Unused imports
+- Import ordering issues
+- Type errors from incorrect return types or missing
+ `CoordSystem` annotations
+
+Fix all errors before proceeding.
+
+### 8c. Check license headers
+
+```bash
+make license
+```
+
+Verify that both the wrapper file
+(`earth2studio/models/dx/.py`) and the test
+file (`test/models/dx/test_.py`) have the
+correct SPDX Apache-2.0 license header.
+
+### 8d. Verify documentation
+
+Check that:
+
+- The class docstring follows NumPy-style formatting
+ with `Parameters`, `Note`, etc.
+- All public methods (`__call__`, `input_coords`,
+ `output_coords`, `to`, and for NN-based:
+ `load_default_package`, `load_model`) have complete
+ docstrings with `Parameters`, `Returns`, `Raises`
+ sections as applicable
+- Type hints are present on all public method
+ signatures
+
+If any checks fail, fix the issues and re-run until all pass cleanly.
+
+---
+
+## Step 9 — Test Forward Pass with Random Data
+
+Write and run a quick smoke test script.
+
+**Note:** Diagnostic models do NOT have `create_iterator` — only test `__call__`.
+
+```python
+import torch
+import numpy as np
+from earth2studio.models.dx import ModelName
+
+# Load model (or construct with dummy weights for testing)
+model = ModelName(...) # Use dummy/test weights if real ones aren't available
+model = model.to("cuda" if torch.cuda.is_available() else "cpu")
+
+# Get input coords
+input_coords = model.input_coords()
+
+# Create random input tensor
+shape = tuple(max(len(v), 1) for v in input_coords.values())
+x = torch.randn(shape)
+
+# Test __call__
+output, out_coords = model(x, input_coords)
+print(f"Input shape: {x.shape}")
+print(f"Output shape: {output.shape}")
+print(f"Input variables: {input_coords['variable']}")
+print(f"Output variables: {out_coords['variable']}")
+```
+
+Report results to the user. There is no `create_iterator`
+to test — diagnostic models perform a single forward pass.
+
+---
+
+## Step 10 — Test Data Fetch with Random Source
+
+Test that the model's coordinate system works with Earth2Studio's data pipeline:
+
+```python
+import numpy as np
+from earth2studio.data import Random, fetch_data
+from earth2studio.models.dx import ModelName
+
+model = ModelName(...)
+
+# Create time array
+time = np.array([np.datetime64("2024-01-01T00:00")])
+
+# Fetch data using model's input coords
+input_coords = model.input_coords()
+input_coords["time"] = time
+ds = Random(input_coords)
+x, coords = fetch_data(ds, time, input_coords["variable"])
+
+print(f"Fetched data shape: {x.shape}")
+print(f"Variables: {input_coords['variable']}")
+```
+
+Report results to the user. This validates the coordinate system is compatible with the data pipeline.
+
+---
+
+## Step 11 — Write Pytest Unit Tests
+
+Create `test/models/dx/test_.py` following the existing test patterns.
+
+### 11a. Test file structure
+
+```python
+# License header (same SPDX Apache-2.0 header as above)
+
+from collections import OrderedDict
+
+import numpy as np
+import pytest
+import torch
+
+from earth2studio.data import Random, fetch_data
+from earth2studio.models.auto import Package
+from earth2studio.models.dx import ModelName
+from earth2studio.utils import handshake_dim
+
+
+class PhooModelName(torch.nn.Module):
+ """Dummy model that performs a simple deterministic operation."""
+
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x[..., :NUM_OUTPUT_VARS, :, :] # Simple slice for testing
+
+
+@pytest.fixture(scope="class")
+def test_package(tmp_path_factory):
+ """Create a dummy model package for testing."""
+ tmp_path = tmp_path_factory.mktemp("data")
+ # Export dummy model to checkpoint format
+ model = PhooModelName()
+ torch.save(model, tmp_path / "model.pt")
+ # Save any additional files (normalization, etc.)
+ # np.save(tmp_path / "center.npy", np.zeros(NUM_VARS))
+ return Package(str(tmp_path))
+
+
+class TestModelNameMock:
+ @pytest.mark.parametrize(
+ "time",
+ [
+ np.array([
+ np.datetime64("1999-10-11T12:00"),
+ np.datetime64("2001-06-04T00:00"),
+ ]),
+ ],
+ )
+ @pytest.mark.parametrize(
+ "device",
+ [
+ "cpu",
+ pytest.param(
+ "cuda:0",
+ marks=pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="No GPU"
+ ),
+ ),
+ ],
+ )
+ def test_model_call(self, test_package, time, device):
+ """Test single forward pass."""
+ # NN-based: use load_model
+ model = ModelName.load_model(test_package)
+ # Physics-based: construct directly instead
+ # model = ModelName(levels=[1000])
+ model = model.to(device)
+
+ # Fetch input data
+ dc = model.input_coords()
+ dc["time"] = time
+ ds = Random(dc)
+ x, coords = fetch_data(ds, time, dc["variable"], device=device)
+
+ # Run forward
+ out, out_coords = model(x, coords)
+
+ # Validate output
+ assert isinstance(out_coords, OrderedDict)
+ handshake_dim(out_coords, "variable", 1) # Adjust index to match model's coord ordering
+ # Additional model-specific assertions
+ # e.g., assert output variable count is correct
+
+ @pytest.mark.parametrize(
+ "coords",
+ [
+ OrderedDict({
+ "batch": np.empty(0),
+ "variable": np.array(["wrong_var"]),
+ "lat": np.linspace(90, -90, 10),
+ "lon": np.linspace(0, 360, 20),
+ }),
+ ],
+ )
+ def test_model_exceptions(self, test_package, coords):
+ """Test model raises on invalid coordinates."""
+ # NN-based: use load_model
+ model = ModelName.load_model(test_package)
+ # Physics-based: construct directly instead
+ # model = ModelName(levels=[1000])
+ x = torch.randn(
+ 1,
+ len(coords["variable"]),
+ len(coords["lat"]),
+ len(coords["lon"]),
+ )
+ with pytest.raises((KeyError, ValueError)):
+ model(x, coords)
+```
+
+**Note:** There is NO `test_model_iter` — diagnostic
+models do not have `create_iterator`.
+
+**Physics-based models — add `test_physics_correctness`:**
+
+```python
+def test_physics_correctness(self):
+ """Verify physics formula against known values."""
+ model = ModelName(levels=[1000])
+
+ # Create known input values
+ # e.g., for wind speed: u=3, v=4 → ws=5
+ input_coords = model.input_coords()
+ x = torch.zeros(1, len(input_coords["variable"]), 5, 5)
+ x[:, 0, :, :] = 3.0 # u component
+ x[:, 1, :, :] = 4.0 # v component
+
+ out, _ = model(x, input_coords)
+
+ expected = torch.full_like(out, 5.0) # sqrt(3² + 4²) = 5
+ assert torch.allclose(out, expected, rtol=1e-5)
+```
+
+**NN-based models — add `@pytest.mark.package` integration test:**
+
+```python
+@pytest.mark.package
+def test_model_package():
+ """Integration test with real model weights."""
+ model = ModelName.from_pretrained()
+ input_coords = model.input_coords()
+ time = np.array([np.datetime64("2024-01-01T00:00")])
+ input_coords["time"] = time
+ ds = Random(input_coords)
+ x, coords = fetch_data(ds, time, input_coords["variable"])
+ model = model.to("cuda" if torch.cuda.is_available() else "cpu")
+ x = x.to(model.device_buffer.device)
+ out, out_coords = model(x, coords)
+ # Validate output shape and variables
+ assert out_coords["variable"].shape[0] > 0
+```
+
+Adapt the dummy model (`PhooModelName`) to match the
+actual model's input/output interface so the wrapper's
+reshaping logic is exercised.
+
+---
+
+## Step 11b — Run Tests
+
+### 11b-i. Run mock tests (no package flag)
+
+First, run the unit tests that use mocked / dummy models.
+These do **not** require downloading real checkpoints and
+should run quickly on any machine:
+
+```bash
+uv run python -m pytest test/models/dx/test_.py \
+ -m "not package" -v
+```
+
+All mock tests must pass before proceeding. Fix any
+failures and re-run until green.
+
+### 11b-ii. Run the package integration test (NN-based only)
+
+Once the mock tests pass, run the `@pytest.mark.package`
+test which exercises `from_pretrained()` with real model
+weights:
+
+```bash
+uv run python -m pytest test/models/dx/test_.py \
+ -m "package" -v
+```
+
+### **[CONFIRM — Package Test]**
+
+Before executing the package test, warn the user:
+
+> The package / integration test will:
+>
+> - **Download the model checkpoint** (may be several GB)
+> to the local cache
+> - **Require GPU compute** for models that need CUDA
+> (the test will skip on CPU-only machines if
+> `torch.cuda.is_available()` is `False`)
+> - Take significantly longer than the mock tests
+>
+> Do you want to proceed with the package test?
+
+Only run the package test after the user confirms. Report
+the results back to the user. If the package test fails,
+debug and fix the wrapper or test, then re-run.
+
+For physics-based models, there is no package test to run.
+State: "No package test — physics-based model has no
+checkpoints."
+
+---
+
+## Step 12 — Provide Side-by-Side Comparison Scripts
+
+Present two scripts to the user:
+
+### Reference script (without Earth2Studio)
+
+Reconstruct a minimal inference script based on the original reference code:
+
+**NN-based:**
+
+```python
+# Reference inference (no Earth2Studio)
+import torch
+# ... original model imports ...
+
+# Load model
+model = OriginalModel.from_pretrained("path/to/checkpoint")
+model.eval().cuda()
+
+# Prepare input
+input_data = ... # Load/prepare input data per original repo instructions
+
+# Run inference
+with torch.no_grad():
+ output = model(input_data)
+
+print(f"Output shape: {output.shape}")
+```
+
+**Physics-based:**
+
+```python
+# Reference calculation (no Earth2Studio)
+import torch
+
+# Hand-calculated / reference values
+u = torch.tensor([3.0, 0.0, 1.0])
+v = torch.tensor([4.0, 5.0, 0.0])
+
+# Direct formula
+ws = torch.sqrt(u**2 + v**2)
+
+print(f"Wind speed: {ws}") # Expected: [5.0, 5.0, 1.0]
+```
+
+### Earth2Studio equivalent
+
+**NN-based:**
+
+```python
+# Earth2Studio inference
+import torch
+import numpy as np
+from earth2studio.models.dx import ModelName
+from earth2studio.data import Random, fetch_data
+
+# Load model
+model = ModelName.from_pretrained()
+model = model.to("cuda")
+
+# Prepare input via Earth2Studio data pipeline
+time = np.array([np.datetime64("2024-01-01T00:00")])
+input_coords = model.input_coords()
+input_coords["time"] = time
+ds = Random(input_coords) # Replace with real data source
+x, coords = fetch_data(ds, time, input_coords["variable"], device="cuda")
+
+# Single forward pass (no iterator — this is a diagnostic model)
+with torch.no_grad():
+ output, out_coords = model(x, coords)
+
+print(f"Output shape: {output.shape}")
+print(f"Output variables: {out_coords['variable']}")
+```
+
+**Physics-based:**
+
+```python
+# Earth2Studio inference
+import torch
+import numpy as np
+from earth2studio.models.dx import ModelName
+
+# Construct model (no checkpoint needed)
+model = ModelName(levels=[1000])
+
+# Prepare input
+input_coords = model.input_coords()
+x = torch.zeros(1, len(input_coords["variable"]), 5, 5)
+x[:, 0, :, :] = 3.0 # u component
+x[:, 1, :, :] = 4.0 # v component
+
+# Single forward pass
+output, out_coords = model(x, input_coords)
+
+print(f"Output shape: {output.shape}")
+print(f"Output variables: {out_coords['variable']}")
+print(f"Output values (should be 5.0): {output.mean():.1f}")
+```
+
+### **[CONFIRM — Comparison Scripts]**
+
+Ask the user to compare the two scripts and verify the
+Earth2Studio version is functionally equivalent to the
+reference.
+
+For physics-based models, also ask the user to verify
+the output matches hand-calculated expected values.
+
+---
+
+## Reminders
+
+- **DO NOT** make a general base class with intent to reuse the wrapper across models
+- **DO NOT** over-populate the `load_model()` API — only expose essential parameters
+- **DO NOT** add `lead_time` dimension unless the model genuinely needs temporal context
+ (e.g., solar radiation, wind gust models that depend on forecast lead time)
+- **DO NOT** add `create_iterator` or `_default_generator` — diagnostic models are single-pass
+- **DO NOT** inherit from `PrognosticMixin` — diagnostic models do not need iterator hooks
+- **DO** use `handshake_dim` indices matching each dimension's position in the
+ `CoordSystem` OrderedDict — check existing dx models for the predominant convention
+- **DO** add the model to `docs/modules/models.rst`
+ in the `earth2studio.models.dx` section
+ (alphabetical order)
+- **DO** use `loguru.logger` for logging, never `print()`, inside `earth2studio/`
+- **DO** ensure all public functions have full type hints
+- **DO** run formatting (`make format`) and linting (`make lint`) before finalizing
+- **DO** use `@torch.inference_mode()` on `__call__` for inference-only models
+- **DO** set `eval()` and `requires_grad_(False)` on loaded NN models
+- **DO** use `@batch_func()` on `__call__` and `@batch_coords()` on `output_coords`
+- **DO** validate coordinates with `handshake_coords()` and `handshake_dim()`
+- **DO** move tensors to device before operations: `x.to(device)`
+- **DO** use `uv run python` for all Python commands (never bare `python`)
diff --git a/.claude/skills/validate-assimilation-wrapper/SKILL.md b/.claude/skills/validate-assimilation-wrapper/SKILL.md
new file mode 100644
index 000000000..453430d68
--- /dev/null
+++ b/.claude/skills/validate-assimilation-wrapper/SKILL.md
@@ -0,0 +1,1552 @@
+---
+name: validate-assimilation-wrapper
+description: >-
+ Validate a newly created Earth2Studio data assimilation model wrapper by
+ writing unit tests (90% coverage required), performing reference comparison
+ with DataFrame/DataArray outputs, generating sanity-check plots, and
+ opening a PR with automated code review. Use after completing
+ create-assimilation-wrapper Steps 0-8.
+argument-hint: Name of the DA model class and test file (optional — will be inferred from recent changes if not provided)
+---
+
+# Validate Assimilation Model Wrapper
+
+Validate a newly created Earth2Studio data assimilation (DA) model
+wrapper by writing unit tests, performing reference comparison,
+generating sanity-check outputs, and opening a PR with automated code
+review. This skill picks up after the `create-assimilation-wrapper`
+skill completes implementation (Steps 0-8).
+
+> **Python Environment:** This project uses **uv** for dependency
+> management. Always use the local `.venv` virtual environment
+> (`source .venv/bin/activate` or prefix with `uv run python`) for all
+> Python commands — installing packages, running tests, executing
+> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of
+> `pip install`.
+
+Each confirmation gate marked by:
+
+```markdown
+### **[CONFIRM — ]**
+```
+
+requires **explicit user approval** before proceeding.
+
+---
+
+## Step 1 — Write Pytest Unit Tests
+
+Create a test file at `test/models/da/test_.py`. DA tests
+are fundamentally different from px/dx tests — inputs are DataFrames
+(not tensors), outputs are `xr.DataArray` (not tensor + CoordSystem
+tuples), and the generator uses the send protocol (not an iterator).
+
+### 1a. PhooModelName dummy class
+
+Create a lightweight dummy model that mimics the DA model under test.
+The dummy accepts `pd.DataFrame` input and returns `xr.DataArray`
+output. This is used for unit tests that do not require the real
+model checkpoint.
+
+```python
+class PhooModelName(torch.nn.Module):
+ """Dummy DA model for testing."""
+
+ VARIABLES = ["t2m", "u10m"]
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.register_buffer("device_buffer", torch.empty(0))
+ self._lat = np.linspace(25.0, 50.0, 11, dtype=np.float32)
+ self._lon = np.linspace(235.0, 295.0, 13, dtype=np.float32)
+
+ @property
+ def device(self) -> torch.device:
+ return self.device_buffer.device
+
+ def __call__(self, obs: pd.DataFrame) -> xr.DataArray:
+ request_time = obs.attrs["request_time"]
+ data = torch.randn(
+ len(request_time),
+ len(self.VARIABLES),
+ len(self._lat),
+ len(self._lon),
+ )
+ # Return as xr.DataArray (numpy for CPU)
+ return xr.DataArray(
+ data=data.numpy(),
+ dims=["time", "variable", "lat", "lon"],
+ coords={
+ "time": request_time,
+ "variable": np.array(self.VARIABLES, dtype=str),
+ "lat": self._lat,
+ "lon": self._lon,
+ },
+ )
+
+ def create_generator(self):
+ observations = yield None # Prime
+ try:
+ while True:
+ result = self.__call__(observations)
+ observations = yield result
+ except GeneratorExit:
+ pass
+
+ def to(self, device):
+ return super().to(device)
+```
+
+### 1b. Test fixtures
+
+Define these fixtures at the top of the test file:
+
+```python
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+import xarray as xr
+
+try:
+ import cudf
+except ImportError:
+ cudf = None
+
+try:
+ import cupy as cp
+except ImportError:
+ cp = None
+
+from earth2studio.models.da. import ModelName
+
+
+@pytest.fixture
+def sample_observations_pandas():
+ """Create sample pandas DataFrame observations for testing."""
+ time1 = np.datetime64("2024-01-01T12:00:00")
+ time2 = np.datetime64("2024-01-01T13:00:00")
+ return pd.DataFrame(
+ {
+ "time": [
+ time1, time1, time1, time1,
+ time2, time2, time2, time2,
+ ],
+ "lat": [30.0, 30.0, 40.0, 40.0, 30.0, 30.0, 40.0, 40.0],
+ "lon": [240.0, 250.0, 240.0, 250.0, 240.0, 250.0, 240.0, 250.0],
+ "observation": [10.0, 20.0, 25.0, 35.0, 30.0, 40.0, 35.0, 45.0],
+ "variable": [
+ "t2m", "t2m", "u10m", "u10m",
+ "t2m", "t2m", "u10m", "u10m",
+ ],
+ }
+ )
+
+
+@pytest.fixture
+def sample_observations_cudf():
+ """Create sample cudf DataFrame observations for testing."""
+ if cudf is None:
+ pytest.skip("cudf not available")
+ time1 = np.datetime64("2024-01-01T12:00:00")
+ time2 = np.datetime64("2024-01-01T13:00:00")
+ return cudf.DataFrame(
+ {
+ "time": [
+ time1, time1, time1, time1,
+ time2, time2, time2, time2,
+ ],
+ "lat": [30.0, 30.0, 40.0, 40.0, 30.0, 30.0, 40.0, 40.0],
+ "lon": [240.0, 250.0, 240.0, 250.0, 240.0, 250.0, 240.0, 250.0],
+ "observation": [10.0, 20.0, 30.0, 40.0, 30.0, 40.0, 50.0, 60.0],
+ "variable": [
+ "t2m", "t2m", "u10m", "u10m",
+ "t2m", "t2m", "u10m", "u10m",
+ ],
+ }
+ )
+
+
+@pytest.fixture
+def test_package(tmp_path):
+ """Create a dummy checkpoint directory for integration tests."""
+ # Adapt contents to match the model's load_model expectations
+ checkpoint_dir = tmp_path / "checkpoint"
+ checkpoint_dir.mkdir()
+ # Example: create a dummy weights file
+ torch.save({"state_dict": {}}, checkpoint_dir / "model.pt")
+ return checkpoint_dir
+```
+
+### 1c. Test methods
+
+Write the following test methods. Each test must be **complete and
+runnable** — no placeholder stubs.
+
+> **Note:** Every test file must start with the SPDX Apache-2.0
+> license header (see the create skill for the exact template).
+
+#### test_model_init
+
+```python
+def test_model_init():
+ """Test constructor parameter validation."""
+ model = ModelName(...) # Fill in required constructor args
+
+ # Verify model attributes are set correctly
+ assert hasattr(model, "device_buffer")
+ assert hasattr(model, "device")
+
+ # Test invalid constructor args raise errors
+ with pytest.raises((ValueError, TypeError)):
+ ModelName(invalid_param=...)
+```
+
+#### test_model_call
+
+```python
+@pytest.mark.parametrize(
+ "device",
+ [
+ "cpu",
+ pytest.param(
+ "cuda:0",
+ marks=pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="cuda missing"
+ ),
+ ),
+ ],
+)
+def test_model_call(sample_observations_pandas, device):
+ """Test stateless __call__ with pandas DataFrame input."""
+ model = ModelName(...).to(device)
+
+ request_time = np.array([np.datetime64("2024-01-01T12:00:00")])
+ sample_observations_pandas.attrs = {
+ "request_time": request_time,
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ da = model(sample_observations_pandas)
+
+ assert isinstance(da, xr.DataArray)
+ assert da.dims == ("time", "variable", "lat", "lon")
+ assert da.coords["time"].values[0] == request_time[0]
+
+ # Validate output shape matches model's coordinate system
+ # NOTE: n_variables varies by model - update this check to match the model's output
+ assert da.shape[0] == len(request_time)
+
+ # Validate coordinate values
+ assert "t2m" in da.coords["variable"].values # At least one expected variable
+
+ # Check device-specific return type
+ if device == "cuda:0" and torch.cuda.is_available():
+ if cp is not None:
+ assert isinstance(da.data, cp.ndarray)
+ assert not cp.all(cp.isnan(da.data))
+ else:
+ assert isinstance(da.data, np.ndarray)
+ assert not np.all(np.isnan(da.values))
+```
+
+#### test_model_call_cudf
+
+```python
+@pytest.mark.parametrize(
+ "device",
+ [
+ pytest.param(
+ "cuda:0",
+ marks=pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="cuda missing"
+ ),
+ ),
+ ],
+)
+def test_model_call_cudf(sample_observations_cudf, device):
+ """Test __call__ with cudf DataFrame input on GPU."""
+ if cudf is None:
+ pytest.skip("cudf not available")
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA not available")
+
+ model = ModelName(...).to(device)
+
+ request_time = np.array([np.datetime64("2024-01-01T12:00:00")])
+ sample_observations_cudf.attrs = {
+ "request_time": request_time,
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ da = model(sample_observations_cudf)
+
+ assert isinstance(da, xr.DataArray)
+ assert da.dims == ("time", "variable", "lat", "lon")
+ assert da.coords["time"].values[0] == request_time[0]
+
+ if cp is not None:
+ assert isinstance(da.data, cp.ndarray)
+ assert not cp.all(cp.isnan(da.data))
+ else:
+ assert isinstance(da.data, np.ndarray)
+ assert not np.all(np.isnan(da.values))
+```
+
+#### test_generator_protocol
+
+```python
+@pytest.mark.parametrize(
+ "device",
+ [
+ "cpu",
+ pytest.param(
+ "cuda:0",
+ marks=pytest.mark.skipif(
+ not torch.cuda.is_available(), reason="cuda missing"
+ ),
+ ),
+ ],
+)
+def test_generator_protocol(sample_observations_pandas, device):
+ """Test create_generator prime -> send -> close sequence."""
+ model = ModelName(...).to(device)
+
+ request_time = np.array([np.datetime64("2024-01-01T12:00:00")])
+ sample_observations_pandas.attrs = {
+ "request_time": request_time,
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ generator = model.create_generator()
+
+ # Prime the generator
+ result = generator.send(None)
+ assert result is None
+
+ # Send observations — first step
+ da = generator.send(sample_observations_pandas)
+ assert isinstance(da, xr.DataArray)
+ assert da.dims == ("time", "variable", "lat", "lon")
+ assert da.shape[0] == len(request_time)
+
+ # Send observations — second step
+ da2 = generator.send(sample_observations_pandas)
+ assert isinstance(da2, xr.DataArray)
+ assert da2.shape == da.shape
+
+ # Close generator
+ generator.close()
+```
+
+#### test_init_coords
+
+```python
+def test_init_coords():
+ """Test init_coords returns correct type for the model."""
+ model = ModelName(...)
+
+ result = model.init_coords()
+
+ # Choose ONE of the following patterns based on the model:
+
+ # For stateless models (no initialization data required):
+ assert result is None
+
+ # For stateful models (requires initialization data):
+ # assert isinstance(result, tuple)
+ # assert len(result) > 0
+ # for schema in result:
+ # assert isinstance(schema, (dict, OrderedDict))
+```
+
+#### test_input_coords
+
+```python
+def test_input_coords():
+ """Test input_coords returns tuple of FrameSchema."""
+ model = ModelName(...)
+
+ result = model.input_coords()
+
+ assert isinstance(result, tuple)
+ assert len(result) >= 1
+
+ # Each element should be a FrameSchema (OrderedDict)
+ for schema in result:
+ assert isinstance(schema, dict)
+ # DA observation schemas typically have these columns
+ assert "variable" in schema
+```
+
+#### test_output_coords
+
+```python
+def test_output_coords():
+ """Test output_coords returns valid tuple of CoordSystem."""
+ model = ModelName(...)
+
+ input_coords = model.input_coords()
+ request_time = np.array([np.datetime64("2024-01-01T12:00:00")])
+
+ result = model.output_coords(input_coords, request_time=request_time)
+
+ assert isinstance(result, tuple)
+ assert len(result) >= 1
+
+ # Each output should be a CoordSystem with expected dimensions
+ for coords in result:
+ assert isinstance(coords, dict)
+ assert "time" in coords
+ assert "variable" in coords
+```
+
+#### test_time_tolerance
+
+```python
+def test_time_tolerance():
+ """Test filter_time_range behavior with time tolerance."""
+ model = ModelName(...)
+
+ base_time = np.datetime64("2024-01-01T12:00:00")
+ time_within = base_time - np.timedelta64(30, "m")
+ time_outside = base_time + np.timedelta64(24, "h")
+
+ obs_df = pd.DataFrame(
+ {
+ "time": [time_within, time_outside],
+ "lat": [30.0, 40.0],
+ "lon": [240.0, 250.0],
+ "observation": [10.0, 20.0],
+ "variable": ["t2m", "t2m"],
+ }
+ )
+
+ request_time = np.array([base_time])
+ obs_df.attrs = {
+ "request_time": request_time,
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ da = model(obs_df)
+ assert isinstance(da, xr.DataArray)
+ # Verify output is valid — exact assertions depend on model behavior
+ assert da.shape[0] == len(request_time)
+```
+
+#### test_empty_dataframe
+
+```python
+def test_empty_dataframe():
+ """Test graceful handling of empty DataFrame."""
+ model = ModelName(...)
+
+ empty_df = pd.DataFrame(
+ {
+ "time": pd.Series([], dtype="datetime64[ns]"),
+ "lat": pd.Series([], dtype=np.float32),
+ "lon": pd.Series([], dtype=np.float32),
+ "observation": pd.Series([], dtype=np.float32),
+ "variable": pd.Series([], dtype=str),
+ }
+ )
+
+ request_time = np.array([np.datetime64("2024-01-01T12:00:00")])
+ empty_df.attrs = {
+ "request_time": request_time,
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ # Model should either return a DataArray (possibly with NaN) or raise cleanly
+ try:
+ da = model(empty_df)
+ assert isinstance(da, xr.DataArray)
+ except (ValueError, RuntimeError):
+ pass # Acceptable to raise on empty input
+```
+
+#### test_invalid_attrs
+
+```python
+def test_invalid_attrs():
+ """Test that missing request_time in attrs raises an error."""
+ model = ModelName(...)
+
+ obs_df = pd.DataFrame(
+ {
+ "time": [np.datetime64("2024-01-01T12:00:00")],
+ "lat": [30.0],
+ "lon": [240.0],
+ "observation": [10.0],
+ "variable": ["t2m"],
+ }
+ )
+ # Intentionally do NOT set obs_df.attrs with request_time
+
+ with pytest.raises((ValueError, KeyError, TypeError)):
+ model(obs_df)
+```
+
+#### test_validate_observation_fields
+
+```python
+def test_validate_observation_fields():
+ """Test that invalid DataFrame columns raise an error."""
+ model = ModelName(...)
+
+ bad_df = pd.DataFrame(
+ {
+ "wrong_column": [1.0],
+ "another_bad": [2.0],
+ }
+ )
+ bad_df.attrs = {
+ "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]),
+ }
+
+ with pytest.raises((ValueError, KeyError)):
+ model(bad_df)
+```
+
+#### test_model_exceptions
+
+```python
+def test_model_exceptions():
+ """Test model raises on invalid inputs."""
+ model = ModelName(...)
+
+ # Test with None input (if model requires non-None)
+ with pytest.raises((ValueError, TypeError)):
+ model(None)
+```
+
+#### Integration test (@pytest.mark.package)
+
+```python
+@pytest.mark.package
+def test_model_package(test_package):
+ """Integration test using model checkpoint.
+
+ This test requires the actual model package and is skipped by
+ default. Run with --slow to include.
+ """
+ from earth2studio.models.auto import Package
+
+ package = Package(str(test_package))
+ model = ModelName.load_model(package)
+
+ obs_df = pd.DataFrame(
+ {
+ "time": [np.datetime64("2024-01-01T12:00:00")],
+ "lat": [30.0],
+ "lon": [240.0],
+ "observation": [10.0],
+ "variable": ["t2m"],
+ }
+ )
+ obs_df.attrs = {
+ "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]),
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+
+ da = model(obs_df)
+ assert isinstance(da, xr.DataArray)
+```
+
+### **[CONFIRM — Package Test]**
+
+Before writing the `@pytest.mark.package` integration test, ask the
+user to confirm the model loading path and checkpoint structure:
+
+> The integration test needs to load the actual model checkpoint.
+>
+> 1. What is the checkpoint path or package URL?
+> 2. Does `load_model` require additional arguments beyond `package`?
+> 3. What test inputs should the integration test use?
+>
+> I'll write the `@pytest.mark.package` test based on your answers.
+
+**Do not proceed to Step 2 until the user confirms.**
+
+---
+
+## Step 2 — Run Tests & Achieve 90% Coverage
+
+### 2a. Run the new test file
+
+```bash
+uv run python -m pytest test/models/da/test_.py -v --timeout=60
+```
+
+All tests must pass. Fix failures and re-run until green.
+
+### 2b. Run coverage report with `--slow` tests
+
+Run the new test file **with coverage** and the `--slow` flag to
+include integration tests. The new DA model file must achieve **at
+least 90% line coverage**:
+
+```bash
+uv run python -m pytest test/models/da/test_.py -v \
+ --package --timeout=300 \
+ --cov=earth2studio/models/da/ \
+ --cov-report=term-missing \
+ --cov-fail-under=90
+```
+
+- `--package` enables integration tests marked with `@pytest.mark.package`
+ (the `--package` flag is configured in `conftest.py` to include package
+ tests that download real checkpoints and may require GPU)
+- `--cov=earth2studio/models/da/` scopes coverage to the
+ new model module only
+- `--cov-report=term-missing` shows which lines are not covered
+- `--cov-fail-under=90` fails the run if coverage is below 90%
+
+If coverage is below 90%, add additional tests or mock tests to cover
+the missing lines. Common DA-specific coverage gaps:
+
+- `GeneratorExit` cleanup path in `create_generator`
+- cudf code paths (skipped when cudf is unavailable)
+- Empty DataFrame handling branches
+- Time tolerance edge cases (observations right at boundary)
+- `obs.attrs` validation branches (missing `request_time`)
+- cupy vs numpy output paths (GPU vs CPU return types)
+- `filter_time_range` with no matching observations
+- `dfseries_to_torch` conversion branches
+
+Re-run until coverage is at or above 90%.
+
+### 2c. Run the full model test suite (optional but recommended)
+
+```bash
+make pytest TOX_ENV=test-models
+```
+
+Confirm no regressions in existing model tests.
+
+---
+
+## Step 3 — Reference Comparison & Sanity-Check
+
+This step validates the DA model wrapper produces correct output by
+comparing against the original reference implementation and generating
+visual sanity-check plots.
+
+### 3a. Create reference comparison script
+
+Create a **standalone Python script** in the repo root. This is for
+validation only and should **NOT** be committed to the repo.
+
+The script loads the reference model and the E2S wrapper side by side,
+runs both on identical input, and compares outputs with tolerance.
+
+**For DataArray output** (gridded analysis fields):
+
+```python
+"""Reference comparison for assimilation model.
+
+Compares the Earth2Studio wrapper output against the original reference
+implementation to verify numerical agreement.
+
+This script is for validation only — do NOT commit to the repo.
+"""
+import numpy as np
+import pandas as pd
+import torch
+
+# --- Reference model ---
+# TODO: Load original model per reference repo instructions
+# Uncomment and adapt the following lines:
+# ref_model = ...
+# ref_obs = pd.DataFrame({...})
+# ref_obs.attrs = {"request_time": ..., "request_lead_time": ...}
+# ref_output = ref_model(ref_obs)
+raise NotImplementedError(
+ "Fill in the reference model code above, then remove this line."
+)
+
+# --- Earth2Studio wrapper ---
+from earth2studio.models.da import ModelName
+
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Construct identical observation DataFrame
+obs_df = pd.DataFrame(
+ {
+ "time": [...],
+ "lat": [...],
+ "lon": [...],
+ "observation": [...],
+ "variable": [...],
+ }
+)
+obs_df.attrs = {
+ "request_time": np.array([...], dtype="datetime64[ns]"),
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+}
+
+e2s_output = model(obs_df)
+
+# --- Compare DataArray outputs ---
+ref_values = ref_output.values # numpy array
+e2s_values = e2s_output.values # numpy or cupy array
+if hasattr(e2s_values, "get"):
+ e2s_values = e2s_values.get() # cupy -> numpy
+
+max_abs_diff = np.abs(ref_values - e2s_values).max()
+max_rel_diff = (
+ np.abs(ref_values - e2s_values) / (np.abs(ref_values) + 1e-8)
+).max()
+correlation = np.corrcoef(ref_values.flatten(), e2s_values.flatten())[0, 1]
+
+print(f"Max absolute difference: {max_abs_diff:.2e}")
+print(f"Max relative difference: {max_rel_diff:.2e}")
+print(f"Correlation: {correlation:.8f}")
+
+assert np.allclose(ref_values, e2s_values, rtol=1e-4, atol=1e-5), \
+ f"Output mismatch! Max abs diff: {max_abs_diff:.2e}"
+
+# --- Compare generator outputs ---
+ref_gen = ref_model.create_generator()
+ref_gen.send(None)
+e2s_gen = model.create_generator()
+e2s_gen.send(None)
+
+for step_obs in [obs_df, obs_df]:
+ ref_step = ref_gen.send(step_obs)
+ e2s_step = e2s_gen.send(step_obs)
+ ref_step_vals = ref_step.values
+ e2s_step_vals = e2s_step.values
+ if hasattr(e2s_step_vals, "get"):
+ e2s_step_vals = e2s_step_vals.get()
+ step_diff = np.abs(ref_step_vals - e2s_step_vals).max()
+ print(f"Generator step max abs diff: {step_diff:.2e}")
+ assert np.allclose(ref_step_vals, e2s_step_vals, rtol=1e-4, atol=1e-5)
+
+ref_gen.close()
+e2s_gen.close()
+
+print("PASS: Reference comparison successful.")
+```
+
+**For DataFrame output** (if the model returns tabular data):
+
+```python
+# --- Compare DataFrame outputs ---
+assert len(ref_output) == len(e2s_output), \
+ f"Row count mismatch: ref={len(ref_output)}, e2s={len(e2s_output)}"
+
+for col in ["lat", "lon", "observation", "variable"]:
+ if col in ref_output.columns:
+ ref_vals = ref_output[col].values
+ e2s_vals = e2s_output[col].values
+ if np.issubdtype(ref_vals.dtype, np.floating):
+ max_diff = np.abs(ref_vals - e2s_vals).max()
+ print(f"Column '{col}' max diff: {max_diff:.2e}")
+ else:
+ assert np.array_equal(ref_vals, e2s_vals), \
+ f"Column '{col}' values differ"
+
+# Spatial coverage check
+ref_lat_range = (ref_output["lat"].min(), ref_output["lat"].max())
+e2s_lat_range = (e2s_output["lat"].min(), e2s_output["lat"].max())
+print(f"Lat range: ref={ref_lat_range}, e2s={e2s_lat_range}")
+
+ref_lon_range = (ref_output["lon"].min(), ref_output["lon"].max())
+e2s_lon_range = (e2s_output["lon"].min(), e2s_output["lon"].max())
+print(f"Lon range: ref={ref_lon_range}, e2s={e2s_lon_range}")
+```
+
+### 3b. Summarize model capabilities to user
+
+Before generating sanity-check plots, **present a summary table** to
+the user covering the model's capabilities:
+
+> **Model Summary for ``:**
+>
+> | Property | Value |
+> | ---------------------- | ------------------------------------ |
+> | **Model type** | Stateless / Stateful |
+> | **Input format** | DataFrame / DataArray / Mixed |
+> | **Input schema** | time, lat, lon, observation, variable |
+> | **Output format** | DataArray / DataFrame |
+> | **Output grid** | lat-lon N x M / HRRR / HealPix |
+> | **Output variables** | `var1`, `var2`, ... |
+> | **Time tolerance** | (default value) |
+> | **cudf/cupy support** | Yes / No |
+> | **Observation types** | Surface / Satellite / Radar / Mixed |
+> | **Checkpoint source** | NGC / HuggingFace / N/A |
+
+This summary helps the user verify the wrapper matches their
+expectations for the model.
+
+### 3c. Generate sanity-check plot scripts
+
+Create **standalone Python scripts** in the repo root. These are for
+PR reviewer reference only and should **NOT** be committed to the
+repo.
+
+#### Plot 1: Spatial assimilated output
+
+Contourf of gridded DataArray output from `__call__`:
+
+```python
+"""Sanity-check plot 1: Spatial assimilated output for .
+
+This script is for PR review only — do NOT commit to the repo.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+
+from earth2studio.models.da import ModelName
+
+# Load model
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Create observation DataFrame
+obs_df = pd.DataFrame(
+ {
+ "time": [np.datetime64("2024-01-01T12:00:00")] * 4,
+ "lat": [30.0, 35.0, 40.0, 45.0],
+ "lon": [240.0, 250.0, 260.0, 270.0],
+ "observation": [10.0, 15.0, 20.0, 25.0],
+ "variable": ["t2m", "t2m", "t2m", "t2m"],
+ }
+)
+obs_df.attrs = {
+ "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]),
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+}
+
+# Run forward pass
+da = model(obs_df)
+
+# Get coordinate arrays
+lat = da.coords["lat"].values
+lon = da.coords["lon"].values
+variables = da.coords["variable"].values
+
+# Plot contourf for each output variable
+n_vars = len(variables)
+fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5))
+if n_vars == 1:
+ axes = [axes]
+
+for ax, var in zip(axes, variables):
+ data_2d = da.sel(variable=var).isel(time=0).values
+ if hasattr(data_2d, "get"):
+ data_2d = data_2d.get() # cupy -> numpy
+ im = ax.contourf(lon, lat, data_2d, cmap="turbo", levels=20)
+ ax.set_title(f"{var}")
+ ax.set_xlabel("Longitude")
+ ax.set_ylabel("Latitude")
+ plt.colorbar(im, ax=ax, shrink=0.8)
+
+plt.suptitle(f" assimilated output", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check_da_spatial.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_da_spatial.png")
+```
+
+#### Plot 2: Observation overlay (unique to DA)
+
+Scatter of input DataFrame observations overlaid on assimilated grid
+output. This visualization is **specific to DA models** — it shows the
+sparse-to-dense mapping from observations to analysis field.
+
+```python
+"""Sanity-check plot 2: Observation overlay for .
+
+This script is for PR review only — do NOT commit to the repo.
+Shows input sparse observations overlaid on the assimilated gridded output.
+"""
+import cartopy.crs as ccrs
+import cartopy.feature as cfeature
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+
+from earth2studio.models.da import ModelName
+
+# Load model
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Create observation DataFrame
+obs_df = pd.DataFrame(
+ {
+ "time": [np.datetime64("2024-01-01T12:00:00")] * 8,
+ "lat": [30.0, 32.0, 35.0, 38.0, 40.0, 42.0, 45.0, 48.0],
+ "lon": [240.0, 245.0, 250.0, 255.0, 260.0, 265.0, 270.0, 275.0],
+ "observation": [10.0, 12.0, 15.0, 18.0, 20.0, 22.0, 25.0, 28.0],
+ "variable": ["t2m"] * 8,
+ }
+)
+obs_df.attrs = {
+ "request_time": np.array([np.datetime64("2024-01-01T12:00:00")]),
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+}
+
+# Run forward pass
+da = model(obs_df)
+
+# Get coordinate arrays
+lat = da.coords["lat"].values
+lon = da.coords["lon"].values
+variables = [v for v in da.coords["variable"].values if v in obs_df["variable"].unique()]
+
+n_vars = len(variables)
+fig, axes = plt.subplots(
+ 1, n_vars, figsize=(8 * n_vars, 6),
+ subplot_kw={"projection": ccrs.PlateCarree()},
+)
+if n_vars == 1:
+ axes = [axes]
+
+for ax, var in zip(axes, variables):
+ # Plot assimilated grid as contourf
+ data_2d = da.sel(variable=var).isel(time=0).values
+ if hasattr(data_2d, "get"):
+ data_2d = data_2d.get() # cupy -> numpy
+ ax.contourf(
+ lon, lat, data_2d,
+ cmap="turbo", alpha=0.5, levels=20,
+ transform=ccrs.PlateCarree(),
+ )
+
+ # Overlay input observations as scatter
+ obs_var = obs_df[obs_df["variable"] == var]
+ scatter = ax.scatter(
+ obs_var["lon"].values, obs_var["lat"].values,
+ c=obs_var["observation"].values,
+ cmap="turbo", edgecolors="k", s=40, zorder=5,
+ transform=ccrs.PlateCarree(),
+ )
+ plt.colorbar(scatter, ax=ax, shrink=0.7, label="Observation value")
+
+ ax.add_feature(cfeature.COASTLINE)
+ ax.add_feature(cfeature.BORDERS, linestyle=":")
+ ax.set_title(f"{var}: grid + observations")
+
+plt.suptitle(" — observation overlay", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check_da_overlay.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_da_overlay.png")
+```
+
+#### Plot 3: Generator sequence
+
+Multi-step assimilation evolution using the generator protocol:
+
+```python
+"""Sanity-check plot 3: Generator sequence for .
+
+This script is for PR review only — do NOT commit to the repo.
+Shows evolution of assimilated output across multiple generator steps.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+
+from earth2studio.models.da import ModelName
+
+# Load model
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Create observation sequence — one DataFrame per step
+base_time = np.datetime64("2024-01-01T12:00:00")
+n_steps = 4
+observation_sequence = []
+
+for i in range(n_steps):
+ step_time = base_time + np.timedelta64(i, "h")
+ obs_step = pd.DataFrame(
+ {
+ "time": [step_time] * 4,
+ "lat": [30.0, 35.0, 40.0, 45.0],
+ "lon": [240.0, 250.0, 260.0, 270.0],
+ "observation": [10.0 + i * 5, 15.0 + i * 5, 20.0 + i * 5, 25.0 + i * 5],
+ "variable": ["t2m"] * 4,
+ }
+ )
+ obs_step.attrs = {
+ "request_time": np.array([step_time]),
+ "request_lead_time": np.array([np.timedelta64(0, "h")]),
+ }
+ observation_sequence.append(obs_step)
+
+# Run generator
+gen = model.create_generator()
+gen.send(None) # Prime
+results = []
+for obs_step in observation_sequence:
+ result = gen.send(obs_step)
+ results.append(result)
+gen.close()
+
+# Plot sequence
+var = results[0].coords["variable"].values[0] # Plot first variable
+lat = results[0].coords["lat"].values
+lon = results[0].coords["lon"].values
+
+fig, axes = plt.subplots(1, len(results), figsize=(5 * len(results), 4))
+if len(results) == 1:
+ axes = [axes]
+
+for ax, (i, result) in zip(axes, enumerate(results)):
+ data_2d = result.sel(variable=var).isel(time=0).values
+ if hasattr(data_2d, "get"):
+ data_2d = data_2d.get() # cupy -> numpy
+ im = ax.contourf(lon, lat, data_2d, cmap="turbo", levels=20)
+ ax.set_title(f"Step {i}")
+ ax.set_xlabel("Longitude")
+ ax.set_ylabel("Latitude")
+ plt.colorbar(im, ax=ax, shrink=0.8)
+
+plt.suptitle(f" — generator sequence ({var})", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check_da_generator.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_da_generator.png")
+```
+
+### 3d. Create side-by-side comparison script
+
+Create a script that runs both the reference implementation and the
+E2S wrapper with identical inputs and produces a side-by-side plot:
+
+```python
+"""Side-by-side comparison: reference vs Earth2Studio for .
+
+This script is for validation only — do NOT commit to the repo.
+Fill in the TODO sections below before running.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+
+# TODO: Import and initialize both reference and E2S models
+# TODO: Prepare identical input observations (DataFrame)
+raise NotImplementedError(
+ "Fill in the reference and E2S model code above before running."
+)
+
+# ref_output = ... # Run reference model
+# e2s_output = ... # Run E2S wrapper
+
+fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+
+# Panel 1: Reference output
+# axes[0].contourf(lon, lat, ref_data, cmap="turbo")
+# axes[0].set_title("Reference")
+
+# Panel 2: E2S output
+# axes[1].contourf(lon, lat, e2s_data, cmap="turbo")
+# axes[1].set_title("Earth2Studio")
+
+# Panel 3: Difference
+# axes[2].contourf(lon, lat, ref_data - e2s_data, cmap="RdBu_r")
+# axes[2].set_title("Difference (Ref - E2S)")
+
+plt.suptitle(" — reference vs Earth2Studio")
+plt.tight_layout()
+plt.savefig("comparison_.png", dpi=150, bbox_inches="tight")
+print("Saved: comparison_.png")
+```
+
+### 3e. Run comparison and sanity-check scripts
+
+Execute the scripts:
+
+```bash
+uv run python reference_comparison_.py
+uv run python sanity_check_da_spatial.py
+uv run python sanity_check_da_overlay.py
+uv run python sanity_check_da_generator.py
+```
+
+Verify that:
+
+- The reference comparison passes (all assertions hold)
+- All sanity-check scripts run without errors
+- Output PNGs are generated
+- Metrics are printed (max abs diff, max rel diff, correlation)
+
+### 3f. **[CONFIRM — Sanity-Check & Comparison]**
+
+**You MUST ask the user to visually inspect the generated plot(s)
+before proceeding.** Do not skip this step even if the scripts ran
+without errors — a successful run does not guarantee the plots are
+correct (e.g., empty axes, wrong colorbar range, garbled data).
+
+Tell the user the absolute path to the generated image file(s) and
+the reference comparison metrics, then ask them to inspect:
+
+> The reference comparison and sanity-check scripts ran successfully.
+>
+> **Reference comparison metrics:**
+>
+> - Max absolute difference: ``
+> - Max relative difference: ``
+> - Correlation: ``
+>
+> **Sanity-check plots saved to:**
+>
+> 1. `/absolute/path/to/sanity_check_da_spatial.png` — gridded output
+> 2. `/absolute/path/to/sanity_check_da_overlay.png` — observations on grid
+> 3. `/absolute/path/to/sanity_check_da_generator.png` — generator sequence
+>
+> **Please open these images and confirm they look correct.** Check:
+>
+> 1. Data is visible on the axes (not blank/empty)
+> 2. Values are in physically reasonable ranges
+> 3. No obvious artifacts (all-NaN regions, garbled values)
+> 4. Spatial patterns look plausible (geographic features visible)
+> 5. Observation overlay: scatter points visible on top of grid
+> 6. Generator sequence: evolution across steps is coherent
+>
+> Do the plots look correct and do the reference comparison metrics
+> look acceptable?
+
+**Do not proceed to Step 4 until the user explicitly confirms.** If
+the user reports problems, debug and fix the issue, re-run the
+scripts, and ask the user to inspect again.
+
+---
+
+## Step 4 — Branch, Commit & Open PR
+
+### **[CONFIRM — Ready to Submit]**
+
+Before proceeding, confirm with the user:
+
+> All implementation and validation steps are complete:
+>
+> - DA model class implemented with correct method ordering
+> - `init_coords` returns correct type (None for stateless or tuple for stateful)
+> - `input_coords` returns tuple of FrameSchema
+> - `output_coords` returns tuple with `request_time`
+> - `create_generator` uses send protocol with GeneratorExit handling
+> - `validate_observation_fields` used for DataFrame inputs
+> - cupy/cudf optional import pattern present
+> - Registered in `earth2studio/models/da/__init__.py`
+> - Documentation added to `docs/modules/models.rst`
+> - Reference URLs included in class docstrings
+> - CHANGELOG.md updated
+> - Format, lint, and license checks pass
+> - Unit tests written and passing with >= 90% coverage
+> - Dependencies in pyproject.toml confirmed
+> - Reference comparison passes with acceptable tolerance
+> - Sanity-check plots generated and confirmed by user
+>
+> Ready to create a branch, commit, and prepare a PR?
+
+### 4a. Create branch and commit
+
+```bash
+git checkout -b feat/da-model-
+git add earth2studio/models/da/.py \
+ earth2studio/models/da/__init__.py \
+ test/models/da/test_.py \
+ pyproject.toml \
+ CHANGELOG.md \
+ docs/modules/models.rst
+git commit -m "feat: add assimilation model
+
+Add data assimilation model for .
+Includes unit tests and documentation."
+```
+
+Do **NOT** add the sanity-check scripts, comparison scripts, or
+their output images.
+
+### 4b. Identify the fork remote and push branch
+
+The working repository is typically a **fork** of
+`NVIDIA/earth2studio`. Before pushing, confirm which git remote
+points to the user's fork:
+
+```bash
+git remote -v
+```
+
+Ask the user:
+
+> Which git remote is your fork of `NVIDIA/earth2studio`?
+> (Usually `origin` — e.g., `git@github.com:/earth2studio.git`)
+
+Then push the feature branch to the **fork** remote:
+
+```bash
+git push -u feat/da-model-
+```
+
+### 4c. Open Pull Request (fork -> NVIDIA/earth2studio)
+
+> **Important:** PRs must be opened **from the fork** to the
+> **upstream source repository** `NVIDIA/earth2studio`. The branch
+> lives on the fork; the PR targets `main` on the upstream repo.
+
+Use `gh pr create` with explicit `--repo` and `--head` flags:
+
+```bash
+gh pr create \
+ --repo NVIDIA/earth2studio \
+ --base main \
+ --head :feat/da-model- \
+ --title "feat: add assimilation model" \
+ --body "..."
+```
+
+Where `` is the GitHub username that owns the fork.
+
+The PR body should follow this DA-model-specific template:
+
+````markdown
+## Description
+
+Add `` data assimilation model for .
+
+Closes # (if applicable)
+
+### Model details
+
+| Property | Value |
+|---|---|
+| **Model type** | Stateless / Stateful |
+| **Input format** | DataFrame / DataArray / Mixed |
+| **Output format** | DataArray / DataFrame |
+| **Observation schema** | time, lat, lon, observation, variable |
+| **Grid specification** | lat-lon / HRRR / HealPix / etc. |
+| **Time tolerance** | |
+| **cudf/cupy support** | Yes / No |
+| **Reference** | |
+
+### Dependencies added
+
+| Package | Version | License | License URL | Reason |
+|---|---|---|---|---|
+| `` | `>=X.Y` | | [link]() | |
+
+*(or "No new dependencies")*
+
+### Reference comparison
+
+- Max absolute difference:
+- Max relative difference:
+- Correlation:
+
+### Validation
+
+See sanity-check plots in PR comments below.
+
+## Checklist
+
+- [x] I am familiar with the [Contributing Guidelines][contrib].
+- [x] New or existing tests cover these changes.
+- [x] The documentation is up to date with these changes.
+- [x] The [CHANGELOG.md][changelog] is up to date with these changes.
+- [ ] An [issue][issues] is linked to this pull request.
+- [ ] Assess and address Greptile feedback (AI code review bot).
+
+[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md
+[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md
+[issues]: https://github.com/NVIDIA/earth2studio/issues
+````
+
+### 4d. Post sanity-check as PR comment
+
+After the PR is created, post the sanity-check visualization as a
+separate **PR comment** so it is immediately visible to reviewers.
+
+#### Image upload limitation
+
+**GitHub has no CLI or REST API for uploading images to PR comments.**
+The only way to embed an image is via the browser's drag-and-drop
+editor or by referencing an already-hosted URL.
+
+**Practical workflow:**
+
+1. Write the comment body to a temp file (avoids shell quoting issues
+ with heredocs containing backticks and markdown).
+2. Post the comment **without** the image — include the validation
+ table, reference comparison metrics, the full sanity-check script,
+ and a placeholder line.
+3. Tell the user to drag the image into the browser editor.
+
+```bash
+# 1. Write body to a temp file (use your editor tool, not heredoc)
+
+# 2. Post the comment
+gh api -X POST repos/NVIDIA/earth2studio/issues//comments \
+ -F "body=@/tmp/pr_comment_body.md" \
+ --jq '.html_url'
+```
+
+Do **not** waste time trying `curl` uploads, GraphQL file mutations,
+or the `uploads.github.com` asset endpoint — they do not work for
+issue/PR comment images.
+
+#### Comment content template
+
+```markdown
+## Sanity-Check Validation
+
+**Model:** `` —
+**Type:** Stateless / Stateful
+**Test environment:**
+
+### Reference Comparison
+
+| Metric | Value |
+|--------|-------|
+| Max absolute difference | |
+| Max relative difference | |
+| Correlation | |
+
+### Model Summary
+
+| Property | Value |
+|----------|-------|
+| Model type | Stateless / Stateful |
+| Input format | DataFrame / DataArray / Mixed |
+| Output format | DataArray / DataFrame |
+| Observation schema | time, lat, lon, observation, variable |
+| Output grid | lat-lon N x M / HRRR / HealPix |
+| Output variables | |
+| Time tolerance | |
+| cudf/cupy support | Yes / No |
+| Inference time | ~XX ms |
+
+**Key findings:**
+-
+-
+-
+
+> **TODO:** Attach sanity-check images by editing this comment in
+> the browser.
+
+
+Sanity-check scripts (click to expand)
+
+```python
+PASTE THE FULL WORKING SCRIPTS HERE — not truncated excerpts.
+The scripts must be copy-pasteable and produce the plots end-to-end.
+```
+
+
+```
+
+**Important:** Always paste the **complete, runnable** scripts — not
+shortened versions. Reviewers should be able to reproduce the plots
+by copying the scripts directly.
+
+#### Finalize
+
+After posting, inform the user of:
+
+1. The comment URL
+2. The local paths to the image files for manual attachment
+3. Instructions: *"Edit the comment in your browser and drag the
+ image files into the editor to embed them."*
+
+> **Note:** The sanity-check images and scripts are for PR review
+> purposes only — they must NOT be committed to the repository.
+
+---
+
+## Step 5 — Automated Code Review (Greptile)
+
+After the PR is created and pushed, an automated code review from
+**greptile-apps** (Greptile) will be posted as PR review comments.
+Wait for this review, then process the feedback.
+
+### 5a. Wait for Greptile review
+
+Poll for review comments from `greptile-apps[bot]` every 30 seconds
+for up to **5 minutes**. Time out gracefully if no review arrives:
+
+```bash
+# Poll loop — check every 30s, timeout after 5 minutes (10 attempts)
+for i in $(seq 1 10); do
+ REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null)
+ if [ -n "$REVIEW_ID" ]; then
+ echo "Greptile review found: $REVIEW_ID"
+ break
+ fi
+ echo "Attempt $i/10 — no review yet, waiting 30s..."
+ sleep 30
+done
+```
+
+If no review after 5 minutes, inform the user:
+
+> Greptile hasn't posted a review after 5 minutes. This can happen if
+> the review bot is busy or the PR hasn't triggered it. You can:
+>
+> 1. Ask me to check again later
+> 2. Skip this step and proceed without automated review
+> 3. Manually request a review from Greptile on the PR page
+
+### 5b. Pull and parse review comments
+
+Once the review is posted, fetch all comments:
+
+```bash
+# Get all review comments on the PR
+gh api repos/NVIDIA/earth2studio/pulls//comments \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") |
+ {path: .path, line: .diff_hunk, body: .body}'
+```
+
+Also fetch the top-level review body:
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body'
+```
+
+### 5c. Categorize and present to user
+
+Parse each comment and categorize it:
+
+| Category | Description | Default action |
+| --------------------- | --------------------------------- | -------------- |
+| **Bug / correctness** | Logic errors, wrong behavior | Fix |
+| **Style / convention** | Naming, formatting, patterns | Fix if valid |
+| **Performance** | Inefficiency, resource waste | Evaluate |
+| **Documentation** | Missing/wrong docs, docstrings | Fix |
+| **Suggestion** | Alternative approach, nice-to-have | User decides |
+| **False positive** | Incorrect or irrelevant feedback | Dismiss |
+
+### **[CONFIRM — Review Triage]**
+
+Present each comment to the user in a summary table:
+
+```markdown
+| # | File | Line | Category | Summary | Proposed Action |
+|---|------|------|----------|---------|-----------------|
+| 1 | .py | 142 | Bug | Missing null check | Fix: add guard |
+| 2 | .py | 305 | Style | Use f-string | Fix: convert |
+| 3 | .py | 45 | Suggestion | Add type alias | Skip: not needed |
+| ... | ... | ... | ... | ... | ... |
+```
+
+For each comment, briefly explain:
+
+- What Greptile flagged
+- Whether you agree or disagree (with reasoning)
+- Your proposed fix (or why to skip)
+
+Ask the user to confirm which comments to address. The user may:
+
+- Accept all proposed fixes
+- Select specific fixes
+- Override your recommendation on any comment
+- Add their own fixes
+
+### 5d. Implement fixes
+
+For each accepted fix:
+
+1. Make the code change
+2. Run `make format && make lint` after all fixes
+3. Run the relevant tests:
+
+ ```bash
+ uv run python -m pytest test/models/da/test_.py -v --timeout=60
+ ```
+
+4. Commit with a message like:
+
+ ```bash
+ git commit -m "fix: address code review feedback (Greptile)"
+ ```
+
+### 5e. Respond to review comments
+
+For each Greptile comment, post a reply on the PR:
+
+**For fixed comments:**
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//comments//replies \
+ -f body="Fixed in . "
+```
+
+**For dismissed comments (false positives / won't fix):**
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//comments//replies \
+ -f body="Won't fix — "
+```
+
+### 5f. Push and resolve
+
+```bash
+git push feat/da-model-
+```
+
+After pushing, resolve all addressed review threads if possible.
+
+Inform the user of the final state:
+
+- How many comments were fixed
+- How many were dismissed (with reasons)
+- Any remaining open threads
+
+---
+
+## Reminders
+
+- **DO** use the repo's local `uv` `.venv` to run Python with
+ `uv run python`
+- **DO NOT** commit sanity-check/comparison scripts or images to
+ the repo
+- **DO** use `loguru.logger` for logging, never `print()`, inside
+ `earth2studio/`
+- **DO** ensure all public functions have full type hints (mypy-clean)
+- **DO** maintain alphabetical order in `__init__.py` exports,
+ RST file entries, and CHANGELOG entries
+- **DO** return tuples from `input_coords` and `output_coords`
+- **DO** use `FrameSchema` for tabular inputs, `CoordSystem` for
+ gridded outputs
+- **DO** validate `request_time` from `obs.attrs`
+- **DO** use `validate_observation_fields`, `filter_time_range`,
+ `dfseries_to_torch` from `earth2studio.models.da.utils`
+- **DO** prime generator with `yield None` and handle `GeneratorExit`
+- **DO** return cupy arrays on GPU, numpy on CPU
+- **DO** register `device_buffer` and expose `device` property
+- **DO** follow the canonical DA method ordering:
+ `__init__`, `device` property, `init_coords`, `input_coords`,
+ `output_coords`, `load_default_package`, `load_model`, `to`,
+ private methods, `__call__`, `create_generator`
+- **DO** include reference URLs in class docstrings
+- **DO** always update CHANGELOG.md under the current unreleased
+ version
+- **DO** add the model to `docs/modules/models.rst` in the
+ `earth2studio.models.da` section (alphabetical order)
+- **DO NOT** use `@batch_func` or `@batch_coords` — these are
+ px/dx conventions only and do not apply to DA models
+- **DO NOT** use `PrognosticMixin` — DA models do not time-step
+- **DO NOT** use `create_iterator` — DA uses `create_generator`
+ with the send protocol
+- **DO NOT** assume tensor inputs — DA inputs are DataFrames and/or
+ DataArrays
+- **DO NOT** make a general base class with intent to reuse the
+ wrapper across models
+- **DO NOT** over-populate the `load_model()` API — only expose
+ essential parameters
+- **NEVER** commit, hardcode, or include API keys, secrets, tokens,
+ or credentials in source code, sample scripts, commit messages,
+ PR descriptions, or any file tracked by git
diff --git a/.claude/skills/validate-diagnostic-wrapper/SKILL.md b/.claude/skills/validate-diagnostic-wrapper/SKILL.md
new file mode 100644
index 000000000..1f0a6ade3
--- /dev/null
+++ b/.claude/skills/validate-diagnostic-wrapper/SKILL.md
@@ -0,0 +1,831 @@
+---
+name: validate-diagnostic-wrapper
+description: Validate a newly created Earth2Studio diagnostic model wrapper by running tests, performing reference comparison, generating sanity-check outputs, and opening a PR with automated code review. Use this skill after completing diagnostic model implementation (create-diagnostic-wrapper skill Steps 0-12).
+argument-hint: Name of the diagnostic model class and test file (optional — will be inferred from recent changes if not provided)
+---
+
+# Validate Diagnostic Model Wrapper
+
+Validate a newly created Earth2Studio diagnostic model (dx) wrapper by
+running tests, performing reference comparison, generating sanity-check
+outputs, and opening a PR with automated code review. This skill picks
+up after the `create-diagnostic-wrapper` skill completes implementation
+(Steps 0-12).
+
+> **Python Environment:** This project uses **uv** for dependency
+> management. Always use the local `.venv` virtual environment
+> (`source .venv/bin/activate` or prefix with `uv run python`) for all
+> Python commands — installing packages, running tests, executing
+> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of
+> `pip install`.
+
+Each confirmation gate marked by:
+
+```markdown
+### **[CONFIRM — ]**
+```
+
+requires **explicit user approval** before proceeding.
+
+---
+
+## Step 1 — Run Tests
+
+### 1a. Run the new test file
+
+```bash
+uv run python -m pytest test/models/dx/test_.py -v --timeout=60
+```
+
+All tests must pass. Fix failures and re-run until green.
+
+### 1b. Run coverage report with `--slow` tests
+
+Run the new test file **with coverage** and the `--slow` flag to
+include integration tests. The new diagnostic model file must achieve
+**at least 90% line coverage**:
+
+```bash
+uv run python -m pytest test/models/dx/test_.py -v \
+ --slow --timeout=300 \
+ --cov=earth2studio/models/dx/ \
+ --cov-report=term-missing \
+ --cov-fail-under=90
+```
+
+- `--slow` enables integration tests (marked `@pytest.mark.slow`)
+- `--cov=earth2studio/models/dx/` scopes coverage to the
+ new model module only
+- `--cov-report=term-missing` shows which lines are not covered
+- `--cov-fail-under=90` fails the run if coverage is below 90%
+
+If coverage is below 90%, add additional tests or mock tests to
+cover the missing lines. Common coverage gaps for dx models:
+
+- Error handling in `output_coords` (wrong variable names, wrong dims)
+- Device management paths (CPU vs CUDA)
+- Edge cases in physics calculations (zero inputs, extreme values)
+- `load_model` and `load_default_package` (NN-based models, needs mock)
+
+Re-run until coverage is at or above 90%.
+
+### 1c. Run the full model test suite (optional but recommended)
+
+```bash
+make pytest TOX_ENV=test-models
+```
+
+Confirm no regressions in existing model tests.
+
+---
+
+## Step 2 — Reference Comparison & Sanity-Check
+
+This step validates the diagnostic model wrapper produces correct
+output by comparing against the original reference implementation and
+generating visual sanity-check plots.
+
+### 2a. Create reference comparison script
+
+Create a **standalone Python script** in the repo root. This is for
+validation only and should **NOT** be committed to the repo.
+
+The script loads the reference model and the E2S wrapper side by side,
+runs both on identical input (same random seed or real data), and
+compares outputs with tolerance:
+
+```python
+"""Reference comparison for diagnostic model.
+
+Compares the Earth2Studio wrapper output against the original reference
+implementation to verify numerical agreement.
+
+This script is for validation only — do NOT commit to the repo.
+"""
+import torch
+import numpy as np
+
+# --- Reference model ---
+# TODO: Load original model per reference repo instructions
+# Uncomment and adapt the following lines:
+# ref_model = ...
+# ref_input = ...
+# ref_output = ref_model(ref_input)
+raise NotImplementedError(
+ "Fill in the reference model code above, then remove this line."
+)
+
+# --- Earth2Studio wrapper ---
+from earth2studio.models.dx import ModelName
+
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+input_coords = model.input_coords()
+# Construct input tensor matching the reference input
+# Use the same random seed or identical real data for both
+shape = tuple(max(len(v), 1) for v in input_coords.values())
+torch.manual_seed(42)
+x = torch.randn(shape, device=device)
+
+e2s_output, out_coords = model(x, input_coords)
+
+# --- Compare outputs ---
+# Ensure both tensors are on the same device and dtype
+ref_output = ref_output.to(e2s_output.device)
+
+max_abs_diff = (ref_output - e2s_output).abs().max().item()
+max_rel_diff = (
+ (ref_output - e2s_output).abs() / (ref_output.abs() + 1e-8)
+).max().item()
+correlation = torch.corrcoef(
+ torch.stack([ref_output.flatten(), e2s_output.flatten()])
+)[0, 1].item()
+
+print(f"Max absolute difference: {max_abs_diff:.2e}")
+print(f"Max relative difference: {max_rel_diff:.2e}")
+print(f"Correlation: {correlation:.8f}")
+
+assert torch.allclose(ref_output, e2s_output, rtol=1e-4, atol=1e-5), \
+ f"Output mismatch! Max abs diff: {max_abs_diff:.2e}"
+
+print("PASS: Reference comparison successful.")
+```
+
+**For physics-based models**, compare against hand-calculated expected
+values instead of a reference model:
+
+```python
+# Known test case: e.g., u=3, v=4 -> wind_speed=5
+u_input = torch.tensor([3.0])
+v_input = torch.tensor([4.0])
+expected_ws = torch.tensor([5.0])
+
+# Run through E2S wrapper
+# ...
+assert torch.allclose(e2s_output, expected_ws, atol=1e-6), \
+ f"Physics check failed: expected {expected_ws}, got {e2s_output}"
+```
+
+### 2b. Summarize model capabilities to user
+
+Before generating sanity-check plots, **present a summary table** to
+the user covering the model's capabilities:
+
+> **Model Summary for ``:**
+>
+> | Property | Value |
+> |---|---|
+> | **Input variables** | `var1`, `var2`, ... |
+> | **Output variables** | `out1`, `out2`, ... |
+> | **Spatial resolution** | X.XX deg x Y.YY deg (NxM) / Flexible |
+> | **Checkpoint size** | XX MB / N/A (physics-based) |
+> | **Checkpoint source** | NGC / HuggingFace / N/A |
+> | **Inference time** | ~XX ms per forward pass (on GPU/CPU) |
+
+This summary helps the user verify the wrapper matches their
+expectations for the model.
+
+### 2c. Generate sanity-check plot script
+
+Create a **standalone Python script** in the repo root. This is for
+PR reviewer reference only and should **NOT** be committed to the
+repo.
+
+Choose the appropriate template based on the model's output type:
+
+#### Spatial gridded outputs (e.g., precipitation, solar radiation)
+
+```python
+"""Sanity-check plot for diagnostic model.
+
+This script is for PR review only — do NOT commit to the repo.
+Run it to produce a quick visualization confirming the model works.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from earth2studio.models.dx import ModelName
+
+# Load model
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Create or fetch input data
+input_coords = model.input_coords()
+shape = tuple(max(len(v), 1) for v in input_coords.values())
+x = torch.randn(shape, device=device)
+
+# Run forward pass
+output, out_coords = model(x, input_coords)
+output_np = output.cpu().numpy()
+
+# Plot contourf for each output variable
+variables = list(out_coords["variable"])
+n_vars = len(variables)
+fig, axes = plt.subplots(1, n_vars, figsize=(6 * n_vars, 5))
+if n_vars == 1:
+ axes = [axes]
+
+for ax, i_var in zip(axes, range(n_vars)):
+ var = variables[i_var]
+ data_2d = output_np[0, 0, i_var, :, :] # batch=0, time=0
+ im = ax.contourf(data_2d, cmap="turbo", levels=20)
+ ax.set_title(f"{var}")
+ plt.colorbar(im, ax=ax, shrink=0.8)
+
+plt.suptitle(f" diagnostic output", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check_.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_.png")
+```
+
+#### Physics-based outputs (e.g., derived quantities like wind speed)
+
+```python
+"""Sanity-check for physics-based diagnostic.
+
+This script is for PR review only — do NOT commit to the repo.
+Validates exact physics results against known test cases.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from earth2studio.models.dx import ModelName
+
+model = ModelName(...)
+
+# Known test cases
+# e.g., u=3, v=4 -> wind_speed=5
+test_inputs = {
+ "u": [0.0, 3.0, -5.0, 10.0],
+ "v": [0.0, 4.0, 12.0, 0.0],
+}
+expected_outputs = {
+ "ws": [0.0, 5.0, 13.0, 10.0],
+}
+
+# Construct input tensor from test cases and run model
+input_coords = model.input_coords()
+n_cases = len(test_inputs["u"])
+n_vars = len(input_coords["variable"])
+# Build tensor: shape (1, n_vars, n_cases, 1) — batch=1, spatial=n_cases x 1
+x = torch.zeros(1, n_vars, n_cases, 1)
+# Map test inputs to the correct variable channels
+# Adapt these indices to match model.input_coords()["variable"]
+var_list = list(input_coords["variable"])
+x[:, var_list.index("u1000"), :, 0] = torch.tensor(test_inputs["u"])
+x[:, var_list.index("v1000"), :, 0] = torch.tensor(test_inputs["v"])
+
+# Update coords to match tensor shape
+test_coords = input_coords.copy()
+test_coords["lat"] = np.arange(n_cases, dtype=np.float64)
+test_coords["lon"] = np.array([0.0])
+
+output, out_coords = model(x, test_coords)
+
+# Validate physics
+for key, expected in expected_outputs.items():
+ actual = output[..., :, 0, 0].cpu().numpy().flatten()
+ expected_arr = np.array(expected)
+ print(f"{key}: expected={expected_arr}, actual={actual}")
+ assert np.allclose(actual, expected_arr, atol=1e-6), \
+ f"Physics validation failed for {key}"
+
+print("PASS: All physics test cases validated.")
+
+# Side-by-side plot: input vs output
+fig, axes = plt.subplots(1, 2, figsize=(12, 5))
+axes[0].bar(range(len(test_inputs["u"])), test_inputs["u"], label="u")
+axes[0].bar(range(len(test_inputs["v"])), test_inputs["v"],
+ alpha=0.7, label="v")
+axes[0].set_title("Input components")
+axes[0].legend()
+
+axes[1].bar(range(len(expected_outputs["ws"])), expected_outputs["ws"],
+ label="Expected", alpha=0.7)
+axes[1].bar(range(len(expected_outputs["ws"])),
+ output[..., :, 0, 0].cpu().numpy().flatten(),
+ label="E2S output", alpha=0.5)
+axes[1].set_title("Output: expected vs actual")
+axes[1].legend()
+
+plt.suptitle(" — physics validation")
+plt.tight_layout()
+plt.savefig("sanity_check_.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_.png")
+```
+
+#### Scalar/classification outputs (e.g., TC tracking, severity index)
+
+```python
+"""Sanity-check for scalar/classification diagnostic.
+
+This script is for PR review only — do NOT commit to the repo.
+Generates histogram and summary statistics of output values.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from earth2studio.models.dx import ModelName
+
+# Load model and run inference
+model = ModelName(...)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+input_coords = model.input_coords()
+shape = tuple(max(len(v), 1) for v in input_coords.values())
+x = torch.randn(shape, device=device)
+
+output, out_coords = model(x, input_coords)
+output_np = output.cpu().numpy().flatten()
+
+# Summary statistics
+print(f"Output shape: {output.shape}")
+print(f"Min: {output_np.min():.4f}")
+print(f"Max: {output_np.max():.4f}")
+print(f"Mean: {output_np.mean():.4f}")
+print(f"Std: {output_np.std():.4f}")
+
+# Histogram of output values
+fig, ax = plt.subplots(figsize=(8, 5))
+ax.hist(output_np, bins=50, edgecolor="black", alpha=0.7)
+ax.set_xlabel("Output value")
+ax.set_ylabel("Count")
+ax.set_title(f" output distribution (n={len(output_np)})")
+ax.axvline(output_np.mean(), color="red", linestyle="--", label=f"Mean: {output_np.mean():.3f}")
+ax.legend()
+
+plt.tight_layout()
+plt.savefig("sanity_check_.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_.png")
+```
+
+### 2d. Run comparison and sanity-check scripts
+
+Execute both scripts:
+
+```bash
+uv run python reference_comparison_.py
+uv run python sanity_check_.py
+```
+
+Verify that:
+
+- The reference comparison passes (all assertions hold)
+- The sanity-check script runs without errors
+- Output PNGs are generated
+- Metrics are printed (max abs diff, max rel diff, correlation)
+
+### 2e. **[CONFIRM — Sanity-Check & Comparison]**
+
+**You MUST ask the user to visually inspect the generated plot(s)
+before proceeding.** Do not skip this step even if the scripts ran
+without errors — a successful run does not guarantee the plots are
+correct (e.g., empty axes, wrong colorbar range, garbled data).
+
+Tell the user the absolute path to the generated image file(s) and
+the reference comparison metrics, then ask them to inspect:
+
+> The reference comparison and sanity-check scripts ran successfully.
+>
+> **Reference comparison metrics:**
+>
+> - Max absolute difference: ``
+> - Max relative difference: ``
+> - Correlation: ``
+>
+> **Sanity-check plot saved to:**
+> `/absolute/path/to/sanity_check_.png`
+>
+> **Please open this image and confirm it looks correct.** Check:
+>
+> 1. Data is visible on the axes (not blank/empty)
+> 2. Values are in physically reasonable ranges
+> 3. No obvious artifacts (all-NaN regions, garbled values)
+> 4. For spatial outputs: geographic patterns look plausible
+> 5. For physics outputs: results match expected analytical values
+>
+> Does the plot look correct and do the reference comparison metrics
+> look acceptable?
+
+**Do not proceed to Step 3 until the user explicitly confirms.** If
+the user reports problems, debug and fix the issue, re-run the
+scripts, and ask the user to inspect again.
+
+---
+
+## Step 3 — Branch, Commit & Open PR
+
+### **[CONFIRM — Ready to Submit]**
+
+Before proceeding, confirm with the user:
+
+> All implementation and validation steps are complete:
+>
+> - Diagnostic model class implemented with correct method ordering
+> - Coordinate system with proper `handshake_dim` indices
+> - Forward pass implemented (NN-based or physics-based)
+> - Model loading implemented (NN-based) or skipped (physics-based)
+> - Registered in `earth2studio/models/dx/__init__.py`
+> - Documentation added to `docs/modules/models.rst`
+> - Reference URLs included in class docstrings
+> - CHANGELOG.md updated
+> - Format, lint, and license checks pass
+> - Unit tests written and passing with >= 90% coverage
+> - Dependencies in pyproject.toml confirmed (NN-based)
+> - Reference comparison passes with acceptable tolerance
+> - Sanity-check plots generated and confirmed by user
+>
+> Ready to create a branch, commit, and prepare a PR?
+
+### 3a. Create branch and commit
+
+```bash
+git checkout -b feat/diagnostic-model-
+git add earth2studio/models/dx/.py \
+ earth2studio/models/dx/__init__.py \
+ test/models/dx/test_.py \
+ pyproject.toml \
+ CHANGELOG.md \
+ docs/modules/models.rst
+git commit -m "feat: add diagnostic model
+
+Add diagnostic model for .
+Includes unit tests and documentation."
+```
+
+Do **NOT** add the sanity-check scripts, comparison scripts, or
+their output images.
+
+### 3b. Identify the fork remote and push branch
+
+The working repository is typically a **fork** of
+`NVIDIA/earth2studio`. Before pushing, confirm which git remote
+points to the user's fork:
+
+```bash
+git remote -v
+```
+
+Ask the user:
+
+> Which git remote is your fork of `NVIDIA/earth2studio`?
+> (Usually `origin` — e.g., `git@github.com:/earth2studio.git`)
+
+Then push the feature branch to the **fork** remote:
+
+```bash
+git push -u feat/diagnostic-model-
+```
+
+### 3c. Open Pull Request (fork -> NVIDIA/earth2studio)
+
+> **Important:** PRs must be opened **from the fork** to the
+> **upstream source repository** `NVIDIA/earth2studio`. The branch
+> lives on the fork; the PR targets `main` on the upstream repo.
+
+Use `gh pr create` with explicit `--repo` and `--head` flags:
+
+```bash
+gh pr create \
+ --repo NVIDIA/earth2studio \
+ --base main \
+ --head :feat/diagnostic-model- \
+ --title "feat: add diagnostic model" \
+ --body "..."
+```
+
+Where `` is the GitHub username that owns the fork.
+
+The PR body should follow this diagnostic-model-specific template:
+
+````markdown
+## Description
+
+Add `` diagnostic model for .
+
+Closes # (if applicable)
+
+### Model details
+
+| Property | Value |
+|---|---|
+| **Model type** | NN-based / Physics-based |
+| **Architecture** | PyTorch / ONNX / Analytical |
+| **Input variables** | |
+| **Output variables** | |
+| **Spatial resolution** | X° x Y° (NxM) / Flexible |
+| **Checkpoint source** | NGC / HuggingFace / N/A |
+| **Reference** | |
+
+### Dependencies added
+
+| Package | Version | License | License URL | Reason |
+|---|---|---|---|---|
+| `` | `>=X.Y` | | [link]() | |
+
+*(or "No new dependencies — physics-based model")*
+
+### Reference comparison
+
+- Max absolute difference:
+- Max relative difference:
+- Correlation:
+
+### Validation
+
+See sanity-check plots in PR comments below.
+
+## Checklist
+
+- [x] I am familiar with the [Contributing Guidelines][contrib].
+- [x] New or existing tests cover these changes.
+- [x] The documentation is up to date with these changes.
+- [x] The [CHANGELOG.md][changelog] is up to date with these changes.
+- [ ] An [issue][issues] is linked to this pull request.
+- [ ] Assess and address Greptile feedback (AI code review bot).
+
+[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md
+[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md
+[issues]: https://github.com/NVIDIA/earth2studio/issues
+````
+
+### 3d. Post sanity-check as PR comment
+
+After the PR is created, post the sanity-check visualization as a
+separate **PR comment** so it is immediately visible to reviewers.
+
+#### Image upload limitation
+
+**GitHub has no CLI or REST API for uploading images to PR comments.**
+The only way to embed an image is via the browser's drag-and-drop
+editor or by referencing an already-hosted URL.
+
+**Practical workflow:**
+
+1. Write the comment body to a temp file (avoids shell quoting issues
+ with heredocs containing backticks and markdown).
+2. Post the comment **without** the image — include the validation
+ table, reference comparison metrics, the full sanity-check script,
+ and a placeholder line.
+3. Tell the user to drag the image into the browser editor.
+
+```bash
+# 1. Write body to a temp file (use your editor tool, not heredoc)
+
+# 2. Post the comment
+gh api -X POST repos/NVIDIA/earth2studio/issues//comments \
+ -F "body=@/tmp/pr_comment_body.md" \
+ --jq '.html_url'
+```
+
+Do **not** waste time trying `curl` uploads, GraphQL file mutations,
+or the `uploads.github.com` asset endpoint — they do not work for
+issue/PR comment images.
+
+#### Comment content template
+
+```markdown
+## Sanity-Check Validation
+
+**Model:** `` —
+**Type:** NN-based / Physics-based
+**Test environment:**
+
+### Reference Comparison
+
+| Metric | Value |
+|--------|-------|
+| Max absolute difference | |
+| Max relative difference | |
+| Correlation | |
+
+### Model Summary
+
+| Property | Value |
+|----------|-------|
+| Input variables | |
+| Output variables | |
+| Output shape | |
+| Spatial resolution | X° x Y° / Flexible |
+| Inference time | ~XX ms |
+
+**Key findings:**
+-
+-
+-
+
+> **TODO:** Attach sanity-check image by editing this comment in
+> the browser.
+
+
+Sanity-check script (click to expand)
+
+```python
+PASTE THE FULL WORKING SCRIPT HERE — not a truncated excerpt.
+The script must be copy-pasteable and produce the plot end-to-end.
+```
+
+
+```
+
+**Important:** Always paste the **complete, runnable** script — not
+a shortened version. Reviewers should be able to reproduce the plot
+by copying the script directly.
+
+#### Finalize
+
+After posting, inform the user of:
+
+1. The comment URL
+2. The local path to the image file for manual attachment
+3. Instructions: *"Edit the comment in your browser and drag the
+ image file into the editor to embed it."*
+
+> **Note:** The sanity-check image and script are for PR review
+> purposes only — they must NOT be committed to the repository.
+
+---
+
+## Step 4 — Automated Code Review (Greptile)
+
+After the PR is created and pushed, an automated code review from
+**greptile-apps** (Greptile) will be posted as PR review comments.
+Wait for this review, then process the feedback.
+
+### 4a. Wait for Greptile review
+
+Poll for review comments from `greptile-apps[bot]` every 30 seconds
+for up to **5 minutes**. Time out gracefully if no review arrives:
+
+```bash
+# Poll loop — check every 30s, timeout after 5 minutes (10 attempts)
+for i in $(seq 1 10); do
+ REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null)
+ if [ -n "$REVIEW_ID" ]; then
+ echo "Greptile review found: $REVIEW_ID"
+ break
+ fi
+ echo "Attempt $i/10 — no review yet, waiting 30s..."
+ sleep 30
+done
+```
+
+If no review after 5 minutes, inform the user:
+
+> Greptile hasn't posted a review after 5 minutes. This can happen if
+> the review bot is busy or the PR hasn't triggered it. You can:
+>
+> 1. Ask me to check again later
+> 2. Skip this step and proceed without automated review
+> 3. Manually request a review from Greptile on the PR page
+
+### 4b. Pull and parse review comments
+
+Once the review is posted, fetch all comments:
+
+```bash
+# Get all review comments on the PR
+gh api repos/NVIDIA/earth2studio/pulls//comments \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") |
+ {path: .path, line: .diff_hunk, body: .body}'
+```
+
+Also fetch the top-level review body:
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body'
+```
+
+### 4c. Categorize and present to user
+
+Parse each comment and categorize it:
+
+| Category | Description | Default action |
+|---|---|---|
+| **Bug / correctness** | Logic errors, wrong behavior | Fix |
+| **Style / convention** | Naming, formatting, patterns | Fix if valid |
+| **Performance** | Inefficiency, resource waste | Evaluate |
+| **Documentation** | Missing/wrong docs, docstrings | Fix |
+| **Suggestion** | Alternative approach, nice-to-have | User decides |
+| **False positive** | Incorrect or irrelevant feedback | Dismiss |
+
+### **[CONFIRM — Review Triage]**
+
+Present each comment to the user in a summary table:
+
+```markdown
+| # | File | Line | Category | Summary | Proposed Action |
+|---|------|------|----------|---------|-----------------|
+| 1 | .py | 142 | Bug | Missing null check | Fix: add guard |
+| 2 | .py | 305 | Style | Use f-string | Fix: convert |
+| 3 | .py | 45 | Suggestion | Add type alias | Skip: not needed |
+| ... | ... | ... | ... | ... | ... |
+```
+
+For each comment, briefly explain:
+
+- What Greptile flagged
+- Whether you agree or disagree (with reasoning)
+- Your proposed fix (or why to skip)
+
+Ask the user to confirm which comments to address. The user may:
+
+- Accept all proposed fixes
+- Select specific fixes
+- Override your recommendation on any comment
+- Add their own fixes
+
+### 4d. Implement fixes
+
+For each accepted fix:
+
+1. Make the code change
+2. Run `make format && make lint` after all fixes
+3. Run the relevant tests:
+
+ ```bash
+ uv run python -m pytest test/models/dx/test_.py -v --timeout=60
+ ```
+
+4. Commit with a message like:
+
+ ```bash
+ git commit -m "fix: address code review feedback (Greptile)"
+ ```
+
+### 4e. Respond to review comments
+
+For each Greptile comment, post a reply on the PR:
+
+**For fixed comments:**
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//comments//replies \
+ -f body="Fixed in . "
+```
+
+**For dismissed comments (false positives / won't fix):**
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//comments//replies \
+ -f body="Won't fix — "
+```
+
+### 4f. Push and resolve
+
+```bash
+git push feat/diagnostic-model-
+```
+
+After pushing, resolve all addressed review threads if possible.
+
+Inform the user of the final state:
+
+- How many comments were fixed
+- How many were dismissed (with reasons)
+- Any remaining open threads
+
+---
+
+## Reminders
+
+- **DO** use the repo's local `uv` `.venv` to run Python with
+ `uv run python`
+- **DO NOT** commit sanity-check/comparison scripts or images to
+ the repo
+- **DO** use `loguru.logger` for logging, never `print()`, inside
+ `earth2studio/`
+- **DO** ensure all public functions have full type hints (mypy-clean)
+- **DO** maintain alphabetical order in `__init__.py` exports,
+ RST file entries, and CHANGELOG entries
+- **DO** follow the canonical DX method ordering:
+ `__init__`, `input_coords`, `output_coords`, `__call__`,
+ `_forward`, `to`, `load_default_package`, `load_model`
+- **DO** use `handshake_dim` indices matching each dimension's position in the
+ `CoordSystem` OrderedDict — check existing dx models for the predominant convention
+- **DO** include reference URLs in class docstrings
+- **DO** always update CHANGELOG.md under the current unreleased
+ version
+- **DO** add the model to `docs/modules/models.rst` in the
+ `earth2studio.models.dx` section (alphabetical order)
+- **DO NOT** inherit from `PrognosticMixin` or add `create_iterator`
+ — diagnostic models are single-pass, not time-stepping
+- **DO NOT** add `lead_time` dimension unless the model genuinely needs temporal
+ context (e.g., solar radiation, wind gust models that depend on forecast lead time)
+- **DO NOT** make a general base class with intent to reuse the
+ wrapper across models
+- **DO NOT** over-populate the `load_model()` API — only expose
+ essential parameters
+- **NEVER** commit, hardcode, or include API keys, secrets, tokens,
+ or credentials in source code, sample scripts, commit messages,
+ PR descriptions, or any file tracked by git
diff --git a/.claude/skills/validate-prognostic-wrapper/SKILL.md b/.claude/skills/validate-prognostic-wrapper/SKILL.md
new file mode 100644
index 000000000..83c3456d7
--- /dev/null
+++ b/.claude/skills/validate-prognostic-wrapper/SKILL.md
@@ -0,0 +1,926 @@
+---
+name: validate-prognostic-wrapper
+description: Validate a newly created Earth2Studio prognostic model wrapper by running tests, performing reference comparison (single-step and multi-step), generating sanity-check plots, and opening a PR with automated code review. Use this skill after completing prognostic model implementation (create-prognostic-wrapper skill Steps 0-12).
+argument-hint: Name of the prognostic model class and test file (optional — will be inferred from recent changes if not provided)
+---
+
+# Validate Prognostic Model Wrapper
+
+Validate a newly created Earth2Studio prognostic model (px) wrapper by
+running tests, performing reference comparison (single-step and
+multi-step), generating sanity-check plots, and opening a PR with
+automated code review. This skill picks up after the
+`create-prognostic-wrapper` skill completes implementation
+(Steps 0-12).
+
+> **Python Environment:** This project uses **uv** for dependency
+> management. Always use the local `.venv` virtual environment
+> (`source .venv/bin/activate` or prefix with `uv run python`) for all
+> Python commands — installing packages, running tests, executing
+> scripts, etc. Use `uv add` / `uv pip install` / `uv lock` instead of
+> `pip install`.
+
+Each confirmation gate marked by:
+
+```markdown
+### **[CONFIRM — ]**
+```
+
+requires **explicit user approval** before proceeding.
+
+---
+
+## Step 1 — Run Tests
+
+### 1a. Run the new test file
+
+```bash
+uv run python -m pytest test/models/px/test_.py -v --timeout=60
+```
+
+All tests must pass. Fix failures and re-run until green.
+
+### 1b. Run coverage report with `--slow` tests
+
+Run the new test file **with coverage** and the `--slow` flag to
+include integration tests. The new prognostic model file must achieve
+**at least 90% line coverage**:
+
+```bash
+uv run python -m pytest test/models/px/test_.py -v \
+ --slow --timeout=300 \
+ --cov=earth2studio/models/px/ \
+ --cov-report=term-missing \
+ --cov-fail-under=90
+```
+
+- `--slow` enables integration tests (marked `@pytest.mark.slow`)
+- `--cov=earth2studio/models/px/` scopes coverage to the
+ new model module only
+- `--cov-report=term-missing` shows which lines are not covered
+- `--cov-fail-under=90` fails the run if coverage is below 90%
+
+If coverage is below 90%, add additional tests or mock tests to
+cover the missing lines. Common coverage gaps for px models:
+
+- Error handling in `output_coords` (wrong variable names, wrong dims)
+- Device management paths (CPU vs CUDA)
+- `create_iterator` edge cases (initial condition yield, hook calls)
+- `load_model` and `load_default_package` (needs mock or package test)
+- ONNX / non-PyTorch backend `.to()` logic
+
+Re-run until coverage is at or above 90%.
+
+### 1c. Run the full model test suite (optional but recommended)
+
+```bash
+make pytest TOX_ENV=test-models
+```
+
+Confirm no regressions in existing model tests.
+
+---
+
+## Step 2 — Reference Comparison & Sanity-Check
+
+This step validates the prognostic model wrapper produces correct
+output by comparing against the original reference implementation
+for both a single time step and a multi-step forecast, then
+generating visual sanity-check plots.
+
+### 2a. Create reference comparison script
+
+Create a **standalone Python script** in the repo root. This is for
+validation only and should **NOT** be committed to the repo.
+
+The script loads the reference model and the E2S wrapper side by side,
+runs both on identical input (same random seed or real data), and
+compares outputs with tolerance. For prognostic models, test **both**
+single-step (`__call__`) and multi-step (`create_iterator`):
+
+```python
+"""Reference comparison for prognostic model.
+
+Compares the Earth2Studio wrapper output against the original reference
+implementation to verify numerical agreement for both single-step and
+multi-step forecasts.
+
+This script is for validation only — do NOT commit to the repo.
+"""
+import torch
+import numpy as np
+
+# --- Reference model ---
+# TODO: Load original model per reference repo instructions
+# Uncomment and adapt the following lines:
+# ref_model = ...
+# ref_input = ...
+# ref_output_single = ref_model(ref_input) # single step
+# ref_outputs_multi = [ref_output_single]
+# current = ref_output_single
+# for step in range(N_STEPS):
+# current = ref_model(current)
+# ref_outputs_multi.append(current)
+raise NotImplementedError(
+ "Fill in the reference model code above, then remove this line."
+)
+
+# --- Earth2Studio wrapper ---
+from earth2studio.models.px import ModelName
+
+model = ModelName(...) # or ModelName.load_model(package)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+input_coords = model.input_coords()
+# Construct input tensor matching the reference input
+# Use the same random seed or identical real data for both
+shape = tuple(max(len(v), 1) for v in input_coords.values())
+torch.manual_seed(42)
+x = torch.randn(shape, device=device)
+
+# --- Single-step comparison ---
+e2s_output_single, out_coords = model(x, input_coords)
+
+ref_output_single = ref_output_single.to(e2s_output_single.device)
+max_abs_single = (ref_output_single - e2s_output_single).abs().max().item()
+max_rel_single = (
+ (ref_output_single - e2s_output_single).abs()
+ / (ref_output_single.abs() + 1e-8)
+).max().item()
+corr_single = torch.corrcoef(
+ torch.stack([
+ ref_output_single.flatten(),
+ e2s_output_single.flatten(),
+ ])
+)[0, 1].item()
+
+print("=== Single-step comparison ===")
+print(f"Max absolute difference: {max_abs_single:.2e}")
+print(f"Max relative difference: {max_rel_single:.2e}")
+print(f"Correlation: {corr_single:.8f}")
+
+assert torch.allclose(
+ ref_output_single, e2s_output_single, rtol=1e-4, atol=1e-5
+), f"Single-step mismatch! Max abs diff: {max_abs_single:.2e}"
+
+# --- Multi-step comparison ---
+N_STEPS = 5 # Adapt to model time step (e.g., 5 steps of 6h = 30h)
+iterator = model.create_iterator(x, input_coords)
+
+# Skip initial condition (step 0)
+step0_x, step0_coords = next(iterator)
+
+print(f"\n=== Multi-step comparison ({N_STEPS} steps) ===")
+for step_i in range(N_STEPS):
+ e2s_step, e2s_coords = next(iterator)
+ ref_step = ref_outputs_multi[step_i + 1].to(e2s_step.device)
+
+ max_abs = (ref_step - e2s_step).abs().max().item()
+ corr = torch.corrcoef(
+ torch.stack([ref_step.flatten(), e2s_step.flatten()])
+ )[0, 1].item()
+ lead = e2s_coords["lead_time"]
+
+ print(f"Step {step_i + 1} (lead_time={lead}): "
+ f"max_abs={max_abs:.2e}, corr={corr:.8f}")
+
+ assert torch.allclose(ref_step, e2s_step, rtol=1e-3, atol=1e-4), \
+ f"Multi-step mismatch at step {step_i + 1}! Max abs: {max_abs:.2e}"
+
+print("\nPASS: Reference comparison successful (single + multi-step).")
+```
+
+### 2b. Summarize model capabilities to user
+
+Before generating sanity-check plots, **present a summary table** to
+the user covering the model's capabilities:
+
+> **Model Summary for ``:**
+>
+> | Property | Value |
+> |---|---|
+> | **Input variables** | `var1`, `var2`, ... (N total) |
+> | **Output variables** | `out1`, `out2`, ... (M total) |
+> | **Time step** | Xh (e.g., 6h, 24h) |
+> | **Spatial resolution** | X.XX deg x Y.YY deg (NxM) |
+> | **History required** | None / Xh (e.g., -6h, 0h) |
+> | **Checkpoint size** | XX MB |
+> | **Checkpoint source** | NGC / HuggingFace / S3 |
+> | **Inference time** | ~XX ms per step (on GPU/CPU) |
+
+This summary helps the user verify the wrapper matches their
+expectations for the model.
+
+### 2c. Generate sanity-check plot script
+
+Create a **standalone Python script** in the repo root. This is for
+PR reviewer reference only and should **NOT** be committed to the
+repo.
+
+Choose the appropriate template based on the model's output type:
+
+#### Spatial forecast evolution (most common — gridded weather models)
+
+```python
+"""Sanity-check plot for prognostic model.
+
+This script is for PR review only — do NOT commit to the repo.
+Runs a multi-step forecast and visualizes the evolution of key
+variables over lead time.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from earth2studio.data import Random, fetch_data
+from earth2studio.models.px import ModelName
+
+# Load model
+model = ModelName.from_pretrained()
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+# Prepare input
+time = np.array([np.datetime64("2024-01-01T00:00")])
+input_coords = model.input_coords()
+input_coords["time"] = time
+ds = Random(input_coords)
+x, coords = fetch_data(ds, time, input_coords["variable"], device=device)
+
+# Run multi-step forecast
+N_STEPS = 5
+iterator = model.create_iterator(x, coords)
+
+# Collect outputs
+steps = []
+for i, (step_x, step_coords) in enumerate(iterator):
+ steps.append((step_x.cpu().numpy(), dict(step_coords)))
+ if i >= N_STEPS:
+ break
+
+# Pick 2-3 representative variables
+var_list = list(steps[0][1]["variable"])
+plot_vars = var_list[:3] # First 3 variables, or pick specific ones
+n_vars = len(plot_vars)
+
+# Plot: rows = variables, columns = time steps (0, mid, final)
+step_indices = [0, N_STEPS // 2, N_STEPS]
+n_cols = len(step_indices)
+fig, axes = plt.subplots(n_vars, n_cols, figsize=(5 * n_cols, 4 * n_vars))
+if n_vars == 1:
+ axes = axes[np.newaxis, :]
+if n_cols == 1:
+ axes = axes[:, np.newaxis]
+
+for row, var in enumerate(plot_vars):
+ var_idx = var_list.index(var)
+ for col, si in enumerate(step_indices):
+ data, sc = steps[si]
+ # Shape: (batch, time, lead_time, variable, lat, lon)
+ data_2d = data[0, 0, 0, var_idx, :, :]
+ lead = sc["lead_time"]
+ im = axes[row, col].contourf(data_2d, cmap="turbo", levels=20)
+ axes[row, col].set_title(f"{var} | lead={lead}")
+ plt.colorbar(im, ax=axes[row, col], shrink=0.8)
+
+plt.suptitle(f" forecast evolution", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check_.png", dpi=150, bbox_inches="tight")
+print("Saved: sanity_check_.png")
+```
+
+#### Time-series at selected grid points
+
+For models where spatial patterns are less meaningful (e.g., random
+input), or to supplement the spatial plot, show how variables evolve
+over lead time at specific grid points:
+
+```python
+"""Sanity-check time series for prognostic model.
+
+Shows how selected variables evolve over forecast lead time at
+specific grid points.
+"""
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from earth2studio.data import Random, fetch_data
+from earth2studio.models.px import ModelName
+
+# Load model and run forecast
+model = ModelName.from_pretrained()
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+time = np.array([np.datetime64("2024-01-01T00:00")])
+input_coords = model.input_coords()
+input_coords["time"] = time
+ds = Random(input_coords)
+x, coords = fetch_data(ds, time, input_coords["variable"], device=device)
+
+N_STEPS = 10
+iterator = model.create_iterator(x, coords)
+
+# Collect time series at a central grid point
+lat_idx = len(input_coords["lat"]) // 2
+lon_idx = len(input_coords["lon"]) // 2
+var_list = list(input_coords["variable"])
+plot_vars = var_list[:4] # Pick representative variables
+
+lead_times = []
+values = {v: [] for v in plot_vars}
+
+for i, (step_x, step_coords) in enumerate(iterator):
+ lead_times.append(step_coords["lead_time"][0])
+ for var in plot_vars:
+ var_idx = var_list.index(var)
+ val = step_x[0, 0, 0, var_idx, lat_idx, lon_idx].cpu().item()
+ values[var].append(val)
+ if i >= N_STEPS:
+ break
+
+# Convert lead times to hours for plotting
+lead_hours = [lt / np.timedelta64(1, "h") for lt in lead_times]
+
+fig, axes = plt.subplots(len(plot_vars), 1,
+ figsize=(10, 3 * len(plot_vars)),
+ sharex=True)
+if len(plot_vars) == 1:
+ axes = [axes]
+
+for ax, var in zip(axes, plot_vars):
+ ax.plot(lead_hours, values[var], marker="o", markersize=3)
+ ax.set_ylabel(var)
+ ax.grid(True, alpha=0.3)
+ ax.set_title(f"{var} at lat_idx={lat_idx}, lon_idx={lon_idx}")
+
+axes[-1].set_xlabel("Lead time (hours)")
+plt.suptitle(f" — time series at grid center", y=1.02)
+plt.tight_layout()
+plt.savefig("sanity_check__timeseries.png",
+ dpi=150, bbox_inches="tight")
+print("Saved: sanity_check__timeseries.png")
+```
+
+#### Iterator behavior validation
+
+For verifying `create_iterator` mechanics (initial condition yield,
+lead_time progression, hook application):
+
+```python
+"""Iterator validation for prognostic model.
+
+Verifies create_iterator yields correct initial condition, increments
+lead_time correctly, and that front/rear hooks are called.
+"""
+import numpy as np
+import torch
+
+from earth2studio.data import Random, fetch_data
+from earth2studio.models.px import ModelName
+
+model = ModelName.from_pretrained()
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = model.to(device)
+
+time = np.array([np.datetime64("2024-01-01T00:00")])
+input_coords = model.input_coords()
+input_coords["time"] = time
+ds = Random(input_coords)
+x, coords = fetch_data(ds, time, input_coords["variable"], device=device)
+
+N_STEPS = 5
+iterator = model.create_iterator(x, coords)
+
+print("=== Iterator validation ===")
+for i, (step_x, step_coords) in enumerate(iterator):
+ lead = step_coords["lead_time"]
+ print(f"Step {i}: shape={step_x.shape}, lead_time={lead}, "
+ f"device={step_x.device}")
+
+ if i == 0:
+ # Step 0 must be the initial condition (unchanged input)
+ assert torch.allclose(step_x, x), \
+ "Step 0 should yield the initial condition unchanged"
+ assert np.array_equal(lead, coords["lead_time"]), \
+ "Step 0 lead_time should match input lead_time"
+ print(" -> Initial condition verified")
+
+ if i >= N_STEPS:
+ break
+
+# Verify lead_time progression
+expected_step = input_coords["lead_time"][0]
+print(f"\nExpected time step: {expected_step}")
+print(f"Final lead_time after {N_STEPS} steps: {step_coords['lead_time']}")
+print("PASS: Iterator validation successful.")
+```
+
+### 2d. Run comparison and sanity-check scripts
+
+Execute all scripts:
+
+```bash
+uv run python reference_comparison_.py
+uv run python sanity_check_.py
+```
+
+Verify that:
+
+- The reference comparison passes for both single-step and multi-step
+- The sanity-check script runs without errors
+- Output PNGs are generated
+- `create_iterator` yields correct initial condition at step 0
+- `lead_time` increments correctly at each step
+
+### 2e. **[CONFIRM — Sanity-Check & Comparison]**
+
+**You MUST ask the user to visually inspect the generated plot(s)
+before proceeding.** Do not skip this step even if the scripts ran
+without errors — a successful run does not guarantee the plots are
+correct (e.g., empty axes, wrong colorbar range, garbled data).
+
+Tell the user the absolute path to the generated image file(s) and
+the reference comparison metrics, then ask them to inspect:
+
+> The reference comparison and sanity-check scripts ran successfully.
+>
+> **Single-step reference comparison:**
+>
+> - Max absolute difference: ``
+> - Max relative difference: ``
+> - Correlation: ``
+>
+> **Multi-step reference comparison (N steps):**
+>
+> - Step 1: max_abs=``, corr=``
+> - Step N: max_abs=``, corr=``
+> - (error may grow with lead time — this is expected)
+>
+> **Sanity-check plot saved to:**
+> `/absolute/path/to/sanity_check_.png`
+>
+> **Please open this image and confirm it looks correct.** Check:
+>
+> 1. Data is visible on the axes at all time steps (not blank/empty)
+> 2. Values are in physically reasonable ranges
+> 3. Spatial patterns evolve smoothly over lead time
+> 4. No sudden jumps, NaN explosions, or frozen fields
+> 5. Lead time labels increment correctly
+>
+> Does the plot look correct and do the reference comparison metrics
+> look acceptable?
+
+**Do not proceed to Step 3 until the user explicitly confirms.** If
+the user reports problems, debug and fix the issue, re-run the
+scripts, and ask the user to inspect again.
+
+---
+
+## Step 3 — Branch, Commit & Open PR
+
+### **[CONFIRM — Ready to Submit]**
+
+Before proceeding, confirm with the user:
+
+> All implementation and validation steps are complete:
+>
+> - Prognostic model class implemented with correct method ordering
+> - Triple inheritance: `torch.nn.Module + AutoModelMixin + PrognosticMixin`
+> - Coordinate system with proper `handshake_dim` indices
+> - `__call__` (single step) and `create_iterator` (multi-step) implemented
+> - `create_iterator` yields initial condition first, uses front/rear hooks
+> - Model loading implemented (`load_default_package`, `load_model`)
+> - Registered in `earth2studio/models/px/__init__.py`
+> - Documentation added to `docs/modules/models.rst`
+> - Reference URLs included in class docstrings
+> - CHANGELOG.md updated
+> - Format, lint, and license checks pass
+> - Unit tests written and passing with >= 90% coverage
+> - Dependencies in pyproject.toml confirmed
+> - Reference comparison passes (single-step and multi-step)
+> - Sanity-check plots generated and confirmed by user
+>
+> Ready to create a branch, commit, and prepare a PR?
+
+### 3a. Create branch and commit
+
+```bash
+git checkout -b feat/prognostic-model-
+git add earth2studio/models/px/.py \
+ earth2studio/models/px/__init__.py \
+ test/models/px/test_.py \
+ pyproject.toml \
+ CHANGELOG.md \
+ docs/modules/models.rst
+git commit -m "feat: add prognostic model
+
+Add prognostic model for .
+Includes unit tests and documentation."
+```
+
+Do **NOT** add the sanity-check scripts, comparison scripts, or
+their output images.
+
+### 3b. Identify the fork remote and push branch
+
+The working repository is typically a **fork** of
+`NVIDIA/earth2studio`. Before pushing, confirm which git remote
+points to the user's fork:
+
+```bash
+git remote -v
+```
+
+Ask the user:
+
+> Which git remote is your fork of `NVIDIA/earth2studio`?
+> (Usually `origin` — e.g., `git@github.com:/earth2studio.git`)
+
+Then push the feature branch to the **fork** remote:
+
+```bash
+git push -u feat/prognostic-model-
+```
+
+### 3c. Open Pull Request (fork -> NVIDIA/earth2studio)
+
+> **Important:** PRs must be opened **from the fork** to the
+> **upstream source repository** `NVIDIA/earth2studio`. The branch
+> lives on the fork; the PR targets `main` on the upstream repo.
+
+Use `gh pr create` with explicit `--repo` and `--head` flags:
+
+```bash
+gh pr create \
+ --repo NVIDIA/earth2studio \
+ --base main \
+ --head :feat/prognostic-model- \
+ --title "feat: add prognostic model" \
+ --body "..."
+```
+
+Where `` is the GitHub username that owns the fork.
+
+The PR body should follow this prognostic-model-specific template:
+
+````markdown
+## Description
+
+Add `` prognostic model for .
+
+Closes # (if applicable)
+
+### Model details
+
+| Property | Value |
+|---|---|
+| **Architecture** | PyTorch / ONNX / JAX |
+| **Time step** | Xh (e.g., 6h, 24h) |
+| **Input variables** | N variables (list or link) |
+| **Output variables** | M variables (list or link) |
+| **Spatial resolution** | X° x Y° (NxM grid) |
+| **History required** | None / Xh (e.g., [-6h, 0h]) |
+| **Checkpoint source** | NGC / HuggingFace / S3 |
+| **Checkpoint size** | XX MB |
+| **Reference** | |
+
+### Dependencies added
+
+| Package | Version | License | License URL | Reason |
+|---|---|---|---|---|
+| `` | `>=X.Y` | | [link]() | |
+
+*(or "No new dependencies needed")*
+
+When filling this table, look up each new dependency's license:
+
+1. Check the package's PyPI page or repository for the license type
+2. Link directly to the license file
+3. Flag any **non-permissive licenses** (GPL, AGPL, SSPL) — these
+ may be incompatible with the project's Apache-2.0 license
+
+### Reference comparison
+
+**Single step:**
+
+- Max absolute difference:
+- Max relative difference:
+- Correlation:
+
+**Multi-step (N steps):**
+
+- Step 1: max_abs=, corr=
+- Step N: max_abs=, corr=
+
+### Validation
+
+See sanity-check plots in PR comments below.
+
+## Checklist
+
+- [x] I am familiar with the [Contributing Guidelines][contrib].
+- [x] New or existing tests cover these changes.
+- [x] The documentation is up to date with these changes.
+- [x] The [CHANGELOG.md][changelog] is up to date with these changes.
+- [ ] An [issue][issues] is linked to this pull request.
+- [ ] Assess and address Greptile feedback (AI code review bot).
+
+[contrib]: https://github.com/NVIDIA/earth2studio/blob/main/CONTRIBUTING.md
+[changelog]: https://github.com/NVIDIA/earth2studio/blob/main/CHANGELOG.md
+[issues]: https://github.com/NVIDIA/earth2studio/issues
+````
+
+### 3d. Post sanity-check as PR comment
+
+After the PR is created, post the sanity-check visualization as a
+separate **PR comment** so it is immediately visible to reviewers.
+
+#### Image upload limitation
+
+**GitHub has no CLI or REST API for uploading images to PR comments.**
+The only way to embed an image is via the browser's drag-and-drop
+editor or by referencing an already-hosted URL.
+
+**Practical workflow:**
+
+1. Write the comment body to a temp file (avoids shell quoting issues
+ with heredocs containing backticks and markdown).
+2. Post the comment **without** the image — include the validation
+ table, reference comparison metrics, the full sanity-check script,
+ and a placeholder line.
+3. Tell the user to drag the image into the browser editor.
+
+```bash
+# 1. Write body to a temp file (use your editor tool, not heredoc)
+
+# 2. Post the comment
+gh api -X POST repos/NVIDIA/earth2studio/issues//comments \
+ -F "body=@/tmp/pr_comment_body.md" \
+ --jq '.html_url'
+```
+
+Do **not** waste time trying `curl` uploads, GraphQL file mutations,
+or the `uploads.github.com` asset endpoint — they do not work for
+issue/PR comment images.
+
+#### Comment content template
+
+```markdown
+## Sanity-Check Validation
+
+**Model:** `` —
+**Architecture:** PyTorch / ONNX / JAX
+**Time step:** Xh
+**Test environment:**
+
+### Reference Comparison
+
+**Single step:**
+
+| Metric | Value |
+|--------|-------|
+| Max absolute difference | |
+| Max relative difference | |
+| Correlation | |
+
+**Multi-step (N steps):**
+
+| Step | Lead time | Max abs diff | Correlation |
+|------|-----------|-------------|-------------|
+| 1 | Xh | | |
+| ... | ... | ... | ... |
+| N | NXh | | |
+
+### Model Summary
+
+| Property | Value |
+|----------|-------|
+| Input variables | |
+| Output variables | |
+| Spatial resolution | X° x Y° (NxM) |
+| Time step | Xh |
+| History required | None / [-Xh, 0h] |
+| Inference time | ~XX ms per step |
+
+**Key findings:**
+
+-
+-
+-
+
+> **TODO:** Attach sanity-check image by editing this comment in
+> the browser.
+
+
+Sanity-check script (click to expand)
+
+```python
+PASTE THE FULL WORKING SCRIPT HERE — not a truncated excerpt.
+The script must be copy-pasteable and produce the plot end-to-end.
+```
+
+
+```
+
+**Important:** Always paste the **complete, runnable** script — not
+a shortened version. Reviewers should be able to reproduce the plot
+by copying the script directly.
+
+#### Finalize
+
+After posting, inform the user of:
+
+1. The comment URL
+2. The local path to the image file for manual attachment
+3. Instructions: *"Edit the comment in your browser and drag the
+ image file into the editor to embed it."*
+
+> **Note:** The sanity-check image and script are for PR review
+> purposes only — they must NOT be committed to the repository.
+
+---
+
+## Step 4 — Automated Code Review (Greptile)
+
+After the PR is created and pushed, an automated code review from
+**greptile-apps** (Greptile) will be posted as PR review comments.
+Wait for this review, then process the feedback.
+
+### 4a. Wait for Greptile review
+
+Poll for review comments from `greptile-apps[bot]` every 30 seconds
+for up to **5 minutes**. Time out gracefully if no review arrives:
+
+```bash
+# Poll loop — check every 30s, timeout after 5 minutes (10 attempts)
+for i in $(seq 1 10); do
+ REVIEW_ID=$(gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .id' 2>/dev/null)
+ if [ -n "$REVIEW_ID" ]; then
+ echo "Greptile review found: $REVIEW_ID"
+ break
+ fi
+ echo "Attempt $i/10 — no review yet, waiting 30s..."
+ sleep 30
+done
+```
+
+If no review after 5 minutes, inform the user:
+
+> Greptile hasn't posted a review after 5 minutes. This can happen if
+> the review bot is busy or the PR hasn't triggered it. You can:
+>
+> 1. Ask me to check again later
+> 2. Skip this step and proceed without automated review
+> 3. Manually request a review from Greptile on the PR page
+
+### 4b. Pull and parse review comments
+
+Once the review is posted, fetch all comments:
+
+```bash
+# Get all review comments on the PR
+gh api repos/NVIDIA/earth2studio/pulls//comments \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") |
+ {path: .path, line: .diff_hunk, body: .body}'
+```
+
+Also fetch the top-level review body:
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//reviews \
+ --jq '.[] | select(.user.login == "greptile-apps[bot]") | .body'
+```
+
+### 4c. Categorize and present to user
+
+Parse each comment and categorize it:
+
+| Category | Description | Default action |
+|---|---|---|
+| **Bug / correctness** | Logic errors, wrong behavior | Fix |
+| **Style / convention** | Naming, formatting, patterns | Fix if valid |
+| **Performance** | Inefficiency, resource waste | Evaluate |
+| **Documentation** | Missing/wrong docs, docstrings | Fix |
+| **Suggestion** | Alternative approach, nice-to-have | User decides |
+| **False positive** | Incorrect or irrelevant feedback | Dismiss |
+
+### **[CONFIRM — Review Triage]**
+
+Present each comment to the user in a summary table:
+
+```markdown
+| # | File | Line | Category | Summary | Proposed Action |
+|---|------|------|----------|---------|-----------------|
+| 1 | .py | 142 | Bug | Missing null check | Fix: add guard |
+| 2 | .py | 305 | Style | Use f-string | Fix: convert |
+| 3 | .py | 45 | Suggestion | Add type alias | Skip: not needed |
+| ... | ... | ... | ... | ... | ... |
+```
+
+For each comment, briefly explain:
+
+- What Greptile flagged
+- Whether you agree or disagree (with reasoning)
+- Your proposed fix (or why to skip)
+
+Ask the user to confirm which comments to address. The user may:
+
+- Accept all proposed fixes
+- Select specific fixes
+- Override your recommendation on any comment
+- Add their own fixes
+
+### 4d. Implement fixes
+
+For each accepted fix:
+
+1. Make the code change
+2. Run `make format && make lint` after all fixes
+3. Run the relevant tests:
+
+ ```bash
+ uv run python -m pytest test/models/px/test_.py -v --timeout=60
+ ```
+
+4. Commit with a message like:
+
+ ```bash
+ git commit -m "fix: address code review feedback (Greptile)"
+ ```
+
+### 4e. Respond to review comments
+
+For each Greptile comment, post a reply on the PR:
+
+**For fixed comments:**
+
+```bash
+gh api repos/NVIDIA/earth2studio/pulls//comments//replies \
+ -f body="Fixed in